Coverage for src/flag_gems/ops/slice_backward.py: 67%
36 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def slice_backward_kernel(
8 grad_output_ptr,
9 grad_input_ptr,
10 numel,
11 inner,
12 slice_len,
13 dim_size,
14 start,
15 step,
16 BLOCK_SIZE: tl.constexpr,
17):
18 pid = tl.program_id(0)
20 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
22 mask = offsets < numel
24 grad = tl.load(grad_output_ptr + offsets, mask=mask)
26 outer_idx = offsets // (slice_len * inner)
28 slice_idx = (offsets // inner) % slice_len
30 inner_idx = offsets % inner
32 dim_index = start + slice_idx * step
34 input_offset = outer_idx * dim_size * inner + dim_index * inner + inner_idx
36 tl.store(grad_input_ptr + input_offset, grad, mask=mask)
39def slice_backward(
40 grad_output,
41 input_sizes,
42 dim,
43 start,
44 end,
45 step,
46):
47 grad_input = torch.zeros(
48 input_sizes,
49 device=grad_output.device,
50 dtype=grad_output.dtype,
51 )
53 shape = list(input_sizes)
55 if dim < 0:
56 dim += len(shape)
58 outer = 1
59 for i in range(dim):
60 outer *= shape[i]
62 inner = 1
63 for i in range(dim + 1, len(shape)):
64 inner *= shape[i]
66 dim_size = shape[dim]
68 slice_len = grad_output.shape[dim]
69 if start < 0:
70 start += dim_size
71 start = max(0, min(start, dim_size))
73 numel = grad_output.numel()
75 BLOCK = 1024
77 grid = lambda meta: (triton.cdiv(numel, meta["BLOCK_SIZE"]),)
79 slice_backward_kernel[grid](
80 grad_output,
81 grad_input,
82 numel,
83 inner,
84 slice_len,
85 dim_size,
86 start,
87 step,
88 BLOCK_SIZE=BLOCK,
89 )
91 return grad_input