Coverage for src/flag_gems/ops/select_backward.py: 62%
50 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8_BLOCK = 1024
9logger = logging.getLogger(__name__)
12@triton.jit
13def _select_backward_kernel(
14 grad_ptr,
15 out_ptr,
16 total: tl.constexpr,
17 inner_size: tl.constexpr,
18 dim_stride: tl.constexpr,
19 index: tl.constexpr,
20 BLOCK: tl.constexpr,
21):
22 pid = tl.program_id(0)
24 offs = pid * BLOCK + tl.arange(0, BLOCK)
25 mask = offs < total
27 outer = offs // inner_size
28 inner = offs % inner_size
30 vals = tl.load(grad_ptr + offs, mask=mask)
31 out_offset = outer * dim_stride + index * inner_size + inner
33 tl.store(out_ptr + out_offset, vals, mask=mask)
36def select_backward(grad, input_sizes, dim, index, out=None):
37 logger.debug("GEMS SELECT_BACKWARD")
38 dim = int(dim)
39 index = int(index)
40 sizes = list(input_sizes)
41 ndim = len(sizes)
43 if dim < 0:
44 dim += ndim
45 if dim < 0 or dim >= ndim:
46 raise ValueError("invalid dim")
48 dim_size = sizes[dim]
50 if index < 0:
51 index += dim_size
52 if index < 0 or index >= dim_size:
53 raise ValueError("index out of range")
55 if out is None:
56 out = torch.zeros(
57 sizes,
58 dtype=grad.dtype,
59 device=grad.device,
60 )
61 else:
62 if tuple(out.shape) != tuple(sizes):
63 raise ValueError("out shape mismatch")
64 if out.dtype != grad.dtype:
65 raise ValueError("dtype mismatch")
66 if out.device != grad.device:
67 raise ValueError("device mismatch")
69 out.zero_()
71 outer_size = math.prod(sizes[:dim]) if dim > 0 else 1
72 inner_size = math.prod(sizes[dim + 1 :]) if dim < ndim - 1 else 1
73 total = outer_size * inner_size
75 grad_view = grad.contiguous().view(outer_size, inner_size)
76 dim_stride = dim_size * inner_size
78 n_elements = outer_size * inner_size
79 grid = (triton.cdiv(n_elements, _BLOCK),)
81 _select_backward_kernel[grid](
82 grad_view,
83 out,
84 total,
85 inner_size,
86 dim_stride,
87 index,
88 BLOCK=_BLOCK,
89 )
90 return out