Coverage for src/flag_gems/ops/slice_backward.py: 67%

36 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def slice_backward_kernel( 

8 grad_output_ptr, 

9 grad_input_ptr, 

10 numel, 

11 inner, 

12 slice_len, 

13 dim_size, 

14 start, 

15 step, 

16 BLOCK_SIZE: tl.constexpr, 

17): 

18 pid = tl.program_id(0) 

19 

20 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

21 

22 mask = offsets < numel 

23 

24 grad = tl.load(grad_output_ptr + offsets, mask=mask) 

25 

26 outer_idx = offsets // (slice_len * inner) 

27 

28 slice_idx = (offsets // inner) % slice_len 

29 

30 inner_idx = offsets % inner 

31 

32 dim_index = start + slice_idx * step 

33 

34 input_offset = outer_idx * dim_size * inner + dim_index * inner + inner_idx 

35 

36 tl.store(grad_input_ptr + input_offset, grad, mask=mask) 

37 

38 

39def slice_backward( 

40 grad_output, 

41 input_sizes, 

42 dim, 

43 start, 

44 end, 

45 step, 

46): 

47 grad_input = torch.zeros( 

48 input_sizes, 

49 device=grad_output.device, 

50 dtype=grad_output.dtype, 

51 ) 

52 

53 shape = list(input_sizes) 

54 

55 if dim < 0: 

56 dim += len(shape) 

57 

58 outer = 1 

59 for i in range(dim): 

60 outer *= shape[i] 

61 

62 inner = 1 

63 for i in range(dim + 1, len(shape)): 

64 inner *= shape[i] 

65 

66 dim_size = shape[dim] 

67 

68 slice_len = grad_output.shape[dim] 

69 if start < 0: 

70 start += dim_size 

71 start = max(0, min(start, dim_size)) 

72 

73 numel = grad_output.numel() 

74 

75 BLOCK = 1024 

76 

77 grid = lambda meta: (triton.cdiv(numel, meta["BLOCK_SIZE"]),) 

78 

79 slice_backward_kernel[grid]( 

80 grad_output, 

81 grad_input, 

82 numel, 

83 inner, 

84 slice_len, 

85 dim_size, 

86 start, 

87 step, 

88 BLOCK_SIZE=BLOCK, 

89 ) 

90 

91 return grad_input