Coverage for src/flag_gems/runtime/backend/_sunrise/ops/index_select.py: 0%

93 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.utils import dim_compress, libentry 

9from flag_gems.utils import triton_lang_extension as ext 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12 

13 

14@libentry() 

15@triton.heuristics(runtime.get_heuristic_config("index_select")) 

16@triton.jit 

17def index_select_kernel( 

18 inp, out, M, N, index, index_len, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr 

19): 

20 pid_x = ext.program_id(axis=0) 

21 pid_y = ext.program_id(axis=1) 

22 rows_offsets = pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

23 rows_mask = rows_offsets < M 

24 cols_offsets = pid_y * BLOCK_N + tl.arange(0, BLOCK_N) 

25 

26 out_mask = rows_mask and (cols_offsets < index_len) 

27 

28 indices = tl.load(index + cols_offsets, mask=(cols_offsets < index_len), other=0) 

29 valid_lower_bound = indices >= 0 

30 valid_upper_bound = indices < N 

31 index_valid_mask = valid_lower_bound & valid_upper_bound 

32 

33 inp_off = rows_offsets * N + indices[None, :] 

34 out_off = rows_offsets * index_len + cols_offsets[None, :] 

35 

36 final_mask = out_mask & index_valid_mask 

37 selected = tl.load(inp + inp_off, mask=final_mask, other=0.0) 

38 tl.store(out + out_off, selected, mask=final_mask) 

39 

40 

41def index_select_heur_block_m(args): 

42 N = args["N"] 

43 if N >= 8192: 

44 return 2048 

45 elif N >= 4096: 

46 return 1024 

47 elif N >= 1024: 

48 return 512 

49 return 256 

50 

51 

52@libentry() 

53@triton.heuristics({"BLOCK_N": index_select_heur_block_m}) 

54@triton.jit 

55def index_select_dim0_kernel(inp, out, N, index, BLOCK_N: tl.constexpr): 

56 pid_x = ext.program_id(axis=0) 

57 pid_y = ext.program_id(axis=1) 

58 cols_offsets = pid_y * BLOCK_N + tl.arange(0, BLOCK_N) 

59 mask = cols_offsets < N 

60 indices = tl.load(index + pid_x) 

61 in_offset = indices * N + cols_offsets 

62 selected = tl.load(inp + in_offset, mask=mask, other=0.0) 

63 out_offset = pid_x * N + cols_offsets 

64 tl.store(out + out_offset, selected, mask=mask) 

65 

66 

67def _dim_compress(inp, dim): 

68 batch_dim = [dim] 

69 reduction_dim = [i for i in range(inp.ndim) if i != dim] 

70 order = batch_dim + reduction_dim 

71 return inp.permute(order).contiguous() 

72 

73 

74def index_select_dim0(inp, dim, index): 

75 # inp_shape = list(inp.shape) 

76 inp = _dim_compress(inp, dim) 

77 out_shape = list(inp.shape) 

78 index_len = index.numel() 

79 out_shape[0] = index_len 

80 out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device) 

81 N = out.numel() // index_len 

82 grid = lambda meta: ( 

83 index_len, 

84 triton.cdiv(N, meta["BLOCK_N"]), 

85 ) 

86 index_select_dim0_kernel[grid](inp, out, N, index) 

87 if dim != 0: 

88 order = [i for i in range(1, out.ndim)] 

89 order.insert(dim, 0) 

90 return out.permute(order).contiguous() 

91 else: 

92 return out 

93 

94 

95def index_select(inp, dim, index): 

96 logger.debug("GEMS INDEX SELECT") 

97 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

98 assert index.ndim <= 1, "Index should have dimension 1 or 0" 

99 

100 if index.ndim == 0: 

101 index = index.unsqueeze(0) 

102 dim = dim % inp.ndim 

103 inp_shape = list(inp.shape) 

104 index_len = index.numel() 

105 

106 if index_len > 0 and (inp.ndim == 2 or inp.ndim == 3): 

107 return index_select_dim0(inp, dim, index) 

108 

109 # with dim_compress 

110 inp = dim_compress(inp, dim) 

111 N = inp_shape[dim] 

112 M = inp.numel() // N 

113 out_shape = list(inp.shape) 

114 out_shape[inp.ndim - 1] = index_len 

115 out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device) 

116 

117 grid = lambda meta: ( 

118 triton.cdiv(M, meta["BLOCK_M"]), 

119 triton.cdiv(index_len, meta["BLOCK_N"]), 

120 ) 

121 index_select_kernel[grid](inp, out, M, N, index, index_len) 

122 if dim != out.ndim - 1: 

123 order = [i for i in range(out.ndim - 1)] 

124 order.insert(dim, out.ndim - 1) 

125 out = out.permute(order).contiguous() 

126 return out.reshape(out.shape) 

127 else: 

128 return out