Coverage for src/flag_gems/runtime/backend/_ascend/ops/index.py: 0%

116 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

8 

9 

10@triton.jit 

11def index_kernel_func( 

12 input_ptr, 

13 stride: tl.constexpr, 

14 index_len, 

15 index_ptr, 

16 out_ptr, 

17 BLOCK_SIZE: tl.constexpr, 

18 MAX_DATA_SIZE: tl.constexpr, 

19): 

20 pid0 = tl.program_id(axis=0) 

21 

22 for i in range(0, BLOCK_SIZE): 

23 offset = pid0 * BLOCK_SIZE + i 

24 

25 if offset < index_len: 

26 in_start_index = tl.load(index_ptr + offset) * stride 

27 out_start_offset = offset * stride 

28 loop_num = (stride - 1) // MAX_DATA_SIZE + 1 

29 

30 for loop_idx in range(0, loop_num): 

31 inner_offset = loop_idx * MAX_DATA_SIZE + tl.arange(0, MAX_DATA_SIZE) 

32 mask = inner_offset < stride 

33 cur_value = tl.load( 

34 input_ptr + in_start_index + inner_offset, mask=mask 

35 ) 

36 tl.store( 

37 out_ptr + out_start_offset + inner_offset, cur_value, mask=mask 

38 ) 

39 

40 

41def index_wrapper(input, indices, out): 

42 """ 

43 Simple kernel wrapper for contiguous tensor indices starting from dim 0 

44 """ 

45 input_shape = input.shape 

46 input_dim = len(input_shape) 

47 indices_dim = len(indices) 

48 stride = 1 

49 

50 for i in range(indices_dim, input_dim): 

51 stride *= input_shape[i] 

52 

53 index_len = indices[0].numel() 

54 if index_len <= 0: 

55 return 

56 

57 actual_index = indices[0] 

58 for idx in range(0, indices_dim - 1): 

59 actual_index = actual_index * input_shape[idx + 1] + indices[idx + 1] 

60 

61 BLOCK_SIZE = 32 

62 MAX_DATA_SIZE = 8 * 1024 

63 

64 grid = lambda meta: (triton.cdiv(index_len, meta["BLOCK_SIZE"]),) 

65 

66 index_kernel_func[grid]( 

67 input, 

68 stride, 

69 index_len, 

70 actual_index, 

71 out, 

72 BLOCK_SIZE=BLOCK_SIZE, 

73 MAX_DATA_SIZE=MAX_DATA_SIZE, 

74 ) 

75 

76 

77def index(inp, indices): 

78 logger.debug("GEMS_ASCEND INDEX") 

79 indices = list(indices) 

80 if not indices: 

81 raise ValueError("at least one index must be provided") 

82 

83 indices = [ 

84 index.to(inp.device) 

85 if index is not None and index.device != inp.device 

86 else index 

87 for index in indices 

88 ] 

89 

90 # Step 1: Process indices (convert bool/int8 to long, handle None) 

91 # Following PyTorch meta implementation 

92 processed_indices = [] 

93 for i, index in enumerate(indices): 

94 if index is not None: 

95 # Check dtype 

96 if index.dtype in [torch.int8, torch.bool]: 

97 # Convert boolean/int8 mask to long indices 

98 nonzero = index.nonzero() 

99 k = len(processed_indices) 

100 if k + index.ndim > inp.ndim: 

101 raise IndexError( 

102 f"too many indices for tensor of dimension {inp.ndim}" 

103 ) 

104 # Check shape matches 

105 for j in range(index.ndim): 

106 if index.shape[j] != inp.shape[k + j]: 

107 raise IndexError( 

108 f"The shape of the mask {index.shape} at index {i} " 

109 f"does not match the shape of the indexed tensor {inp.shape} at index {k + j}" 

110 ) 

111 # Extract indices from nonzero 

112 for j in range(index.ndim): 

113 processed_indices.append(nonzero.select(1, j)) 

114 elif index.dtype in [torch.long, torch.int, torch.int32, torch.int64]: 

115 processed_indices.append(index) 

116 else: 

117 raise TypeError( 

118 "tensors used as indices must be long, int, byte or bool tensors" 

119 ) 

120 else: 

121 processed_indices.append(None) 

122 

123 indices = processed_indices 

124 

125 # Check indices count 

126 if len(indices) > inp.ndim: 

127 raise IndexError( 

128 f"too many indices for tensor of dimension {inp.ndim} (got {len(indices)})" 

129 ) 

130 

131 # Step 2: Broadcast indices (only tensor indices, not None) 

132 tensor_indices = [idx for idx in indices if idx is not None] 

133 if tensor_indices: 

134 # Broadcast all tensor indices together 

135 if len(tensor_indices) > 1: 

136 tensor_indices = list(torch.broadcast_tensors(*tensor_indices)) 

137 # Update indices list with broadcasted tensors 

138 tensor_idx = 0 

139 for i in range(len(indices)): 

140 if indices[i] is not None: 

141 indices[i] = tensor_indices[tensor_idx] 

142 tensor_idx += 1 

143 

144 # Step 3: Add missing None indices (pad to input.ndim) 

145 while len(indices) < inp.ndim: 

146 indices.append(None) 

147 

148 # Step 4: Check if has contiguous subspace 

149 # (all non-None tensors are adjacent) 

150 state = 0 

151 has_contiguous_subspace = False 

152 starts_from_zero = False 

153 for i, index in enumerate(indices): 

154 if state == 0: 

155 if index is not None: 

156 if i == 0: 

157 starts_from_zero = True 

158 state = 1 

159 elif state == 1: 

160 if index is None: 

161 state = 2 

162 else: 

163 if index is not None: 

164 break 

165 else: 

166 has_contiguous_subspace = True 

167 

168 # Step 5: Transpose to front if needed 

169 # If not contiguous, transpose input so all non-None indices come first 

170 if not has_contiguous_subspace or not starts_from_zero: 

171 # Build full index tuple (None -> slice(None)) 

172 full_indices = [] 

173 for idx in indices: 

174 if idx is None: 

175 full_indices.append(slice(None)) 

176 else: 

177 full_indices.append(idx) 

178 return inp[tuple(full_indices)] 

179 

180 # Step 6: Now indices have contiguous subspace 

181 # Calculate output shape: before_shape + replacement_shape + after_shape 

182 before_shape = [] 

183 after_shape = [] 

184 replacement_shape = [] 

185 

186 for dim, index in enumerate(indices): 

187 if index is None: 

188 if replacement_shape: 

189 # None after tensor indices -> goes to after_shape 

190 after_shape.append(inp.shape[dim]) 

191 else: 

192 # None before tensor indices -> goes to before_shape 

193 before_shape.append(inp.shape[dim]) 

194 else: 

195 # First tensor index determines replacement_shape 

196 if not replacement_shape: 

197 replacement_shape = list(index.shape) 

198 

199 # Step 7: Build output shape and create output tensor 

200 out_shape = before_shape + replacement_shape + after_shape 

201 out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device) 

202 

203 # Step 8: Handle empty tensor case 

204 if inp.numel() == 0 or out.numel() == 0: 

205 return out 

206 

207 # Step 9: Extract only tensor indices for kernel 

208 tensor_indices = [idx for idx in indices if idx is not None] 

209 if not tensor_indices: 

210 # All None, just reshape 

211 return inp.view(*out_shape) 

212 

213 # Step 10: Call kernel with tensor indices 

214 # Note: kernel needs to handle the fact that input was potentially permuted 

215 # and output shape includes None dimensions 

216 if inp.ndim == 1 and len(tensor_indices) == 1: 

217 return torch.gather(inp, 0, tensor_indices[0]) 

218 

219 index_wrapper(inp, tensor_indices, out) 

220 return out