Coverage for src/flag_gems/ops/select_backward.py: 62%

50 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8_BLOCK = 1024 

9logger = logging.getLogger(__name__) 

10 

11 

12@triton.jit 

13def _select_backward_kernel( 

14 grad_ptr, 

15 out_ptr, 

16 total: tl.constexpr, 

17 inner_size: tl.constexpr, 

18 dim_stride: tl.constexpr, 

19 index: tl.constexpr, 

20 BLOCK: tl.constexpr, 

21): 

22 pid = tl.program_id(0) 

23 

24 offs = pid * BLOCK + tl.arange(0, BLOCK) 

25 mask = offs < total 

26 

27 outer = offs // inner_size 

28 inner = offs % inner_size 

29 

30 vals = tl.load(grad_ptr + offs, mask=mask) 

31 out_offset = outer * dim_stride + index * inner_size + inner 

32 

33 tl.store(out_ptr + out_offset, vals, mask=mask) 

34 

35 

36def select_backward(grad, input_sizes, dim, index, out=None): 

37 logger.debug("GEMS SELECT_BACKWARD") 

38 dim = int(dim) 

39 index = int(index) 

40 sizes = list(input_sizes) 

41 ndim = len(sizes) 

42 

43 if dim < 0: 

44 dim += ndim 

45 if dim < 0 or dim >= ndim: 

46 raise ValueError("invalid dim") 

47 

48 dim_size = sizes[dim] 

49 

50 if index < 0: 

51 index += dim_size 

52 if index < 0 or index >= dim_size: 

53 raise ValueError("index out of range") 

54 

55 if out is None: 

56 out = torch.zeros( 

57 sizes, 

58 dtype=grad.dtype, 

59 device=grad.device, 

60 ) 

61 else: 

62 if tuple(out.shape) != tuple(sizes): 

63 raise ValueError("out shape mismatch") 

64 if out.dtype != grad.dtype: 

65 raise ValueError("dtype mismatch") 

66 if out.device != grad.device: 

67 raise ValueError("device mismatch") 

68 

69 out.zero_() 

70 

71 outer_size = math.prod(sizes[:dim]) if dim > 0 else 1 

72 inner_size = math.prod(sizes[dim + 1 :]) if dim < ndim - 1 else 1 

73 total = outer_size * inner_size 

74 

75 grad_view = grad.contiguous().view(outer_size, inner_size) 

76 dim_stride = dim_size * inner_size 

77 

78 n_elements = outer_size * inner_size 

79 grid = (triton.cdiv(n_elements, _BLOCK),) 

80 

81 _select_backward_kernel[grid]( 

82 grad_view, 

83 out, 

84 total, 

85 inner_size, 

86 dim_stride, 

87 index, 

88 BLOCK=_BLOCK, 

89 ) 

90 return out