Coverage for src/flag_gems/ops/index_copy_.py: 99%

160 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import importlib 

2import logging 

3import os 

4from typing import Any, Callable, List, Mapping, Tuple 

5 

6import torch 

7 

8from flag_gems.utils.code_cache import code_cache_dir 

9from flag_gems.utils.code_utils import IndentedBuffer 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

15 code.writeline("import triton") 

16 code.writeline("import triton.language as tl") 

17 code.writeline("from flag_gems.utils import libentry") 

18 

19 code.newline() 

20 code.newline() 

21 

22 return code 

23 

24 

25def generate_index_copy_kernel( 

26 rank: int, 

27 kernel_name: str, 

28 code: IndentedBuffer, 

29) -> IndentedBuffer: 

30 # the decorators 

31 code.writeline("@libentry()") 

32 code.writeline("@triton.jit") 

33 

34 # signature 

35 code.writeline(f"def {kernel_name}(") 

36 with code.indent(): 

37 if rank > 0: 

38 code.writeline("index,") 

39 code.writeline("src,") 

40 code.writeline("out,") 

41 code.writeline("N,") 

42 code.writeline("inp_numel,") 

43 code.writeline("inp_stride_dim,") 

44 code.writeline("inp_shape_dim,") 

45 code.writeline("src_shape_dim,") 

46 code.writeline("delta,") 

47 

48 stride_args = ", ".join(f"src_stride_{i}: int" for i in range(rank)) 

49 code.writeline(f"{stride_args}, # stride for src") 

50 

51 shape_args = ", ".join(f"src_shape_{i}: int" for i in range(rank)) 

52 code.writeline(f"{shape_args}, # shape for src") 

53 

54 code.writeline("BLOCK_SIZE: tl.constexpr,") 

55 

56 code.writeline("):") 

57 

58 # Kernel Code 

59 with code.indent(): 

60 code.writeline("pid = tl.program_id(axis=0)") 

61 code.writeline("offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)") 

62 code.writeline("mask = offsets < N") 

63 

64 for i in range(rank - 1, -1, -1): 

65 code.writeline(f"src_offset{i} = offsets % src_shape_{i}") 

66 code.writeline(f"offsets = offsets // src_shape_{i}") 

67 code.newline() 

68 comp = [f"src_offset{i} * src_stride_{i}" for i in range(rank)] 

69 code.writeline(f"src_offset = {' + '.join(comp)}") 

70 

71 code.writeline("pre_cal = (inp_stride_dim * src_shape_dim)") 

72 

73 # index copy 

74 code.writeline("pre_idx = (src_offset // pre_cal).to(tl.int64)") 

75 code.writeline( 

76 "dim_idx = (src_offset % pre_cal // inp_stride_dim).to(tl.int64)" 

77 ) 

78 code.writeline( 

79 "src_dim_idx = (tl.load(index + dim_idx, mask=mask, other=0)).to(tl.int64)" 

80 ) 

81 code.writeline( 

82 'assert src_dim_idx >= 0 and src_dim_idx < inp_shape_dim, "0 <= index < self.size(dim)"' 

83 ) 

84 code.writeline( 

85 "input_idx = (src_offset + (delta * pre_idx + src_dim_idx - dim_idx) * inp_stride_dim).to(tl.int64)" 

86 ) 

87 

88 code.writeline("input_mask = (input_idx >= 0) & (input_idx < inp_numel)") 

89 code.writeline("store_mask = mask & input_mask") 

90 code.writeline("src_val = tl.load(src + src_offset, mask=mask, other=0)") 

91 code.writeline("tl.store(out + input_idx, src_val, mask=store_mask)") 

92 

93 code.newline() 

94 code.newline() 

95 return code 

96 

97 

98def parameter_for_wrapper() -> str: 

99 # out, index, src, dim, inp_stride_dim, src_shape_dim, delta, N, inp.numel() 

100 parameters: List[str] = [] 

101 parameters.append("out") 

102 parameters.append("index") 

103 parameters.append("src") 

104 parameters.append("dim") 

105 parameters.append("inp_stride_dim") 

106 parameters.append("inp_shape_dim") 

107 parameters.append("src_shape_dim") 

108 parameters.append("delta") 

109 parameters.append("N") 

110 parameters.append("inp_numel") 

111 

112 return ", ".join(parameters) 

113 

114 

115def generate_destination_passing_wrapper( 

116 rank: int, 

117 wrapper_name: str, 

118 kernel_name: str, 

119 code: IndentedBuffer, 

120) -> IndentedBuffer: 

121 parameters: str = parameter_for_wrapper() 

122 wrapper_signature: str = f"def {wrapper_name} ({parameters}):" 

123 code.writeline(wrapper_signature) 

124 

125 with code.indent(): 

126 code.writeline("src_strides = list(src.stride())") 

127 code.writeline("src_shapes = list(src.shape)") 

128 

129 # kernel launch 

130 code.writeline("BLOCK_SIZE = 128") # BLOCK_SIZE setting 

131 code.writeline("grid = (triton.cdiv(N, BLOCK_SIZE),)") 

132 kernel_launch: str = f"{kernel_name}[grid](" 

133 code.writeline(kernel_launch) 

134 with code.indent(): 

135 code.writeline( 

136 "index, src, out, N, inp_numel, inp_stride_dim, inp_shape_dim, src_shape_dim, delta, " 

137 ) 

138 if rank > 0: 

139 s = ", ".join(f"src_strides[{i}]" for i in range(rank)) 

140 code.writeline(f"{s},") 

141 

142 s = ", ".join(f"src_shapes[{i}]" for i in range(rank)) 

143 code.writeline(f"{s},") 

144 code.writeline("BLOCK_SIZE=BLOCK_SIZE") 

145 code.writeline(")") 

146 code.writeline("return out") 

147 

148 return code 

149 

150 

151def generate_code( 

152 inputs: Tuple[Any], 

153 wrapper_name: str, 

154 kernel_name: str, 

155 code: IndentedBuffer, 

156) -> IndentedBuffer: 

157 # inputs: [out, index, src, dim, inp_stride_dim, inp_shape_dim, src_shape_dim, delta, N, inp.numel()] 

158 shape = inputs[2].shape 

159 rank = len(shape) 

160 

161 code = generate_imports(code) 

162 code = generate_index_copy_kernel(rank, kernel_name, code) 

163 code = generate_destination_passing_wrapper(rank, wrapper_name, kernel_name, code) 

164 return code 

165 

166 

167class IndexCopyFunction: 

168 def __init__(self): 

169 self.pid = os.getpid() 

170 self.overloads: Mapping[str, Callable] = {} 

171 

172 def __call__(self, *args, **kwargs): 

173 key = f"{self.arg_key(*args)}" 

174 if key in self.overloads: 

175 return self.overloads[key](*args, **kwargs) 

176 

177 code = IndentedBuffer() 

178 code = generate_code( 

179 args, 

180 "_index_copy_wrapper", 

181 "_index_copy_jit_function", 

182 code, 

183 ) 

184 

185 file_name = f"index_copy_rank_{key}_pid_{self.pid}.py" 

186 

187 try: 

188 with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f: 

189 f.write(code.getvalue()) 

190 

191 # load 

192 spec = importlib.util.spec_from_file_location( 

193 f"_gen_module_rank_{key}_pid_{self.pid}", 

194 f.name, 

195 ) 

196 

197 m = importlib.util.module_from_spec(spec) 

198 spec.loader.exec_module(m) 

199 overload = getattr(m, "_index_copy_wrapper") 

200 self.overloads[key] = overload 

201 except Exception as e: 

202 raise RuntimeError( 

203 f"Failed to generate or load index_copy kernel: {e}" 

204 ) from e 

205 

206 return overload(*args, **kwargs) 

207 

208 def arg_key(self, *args): 

209 tensors = [item for item in args if torch.is_tensor(item)] 

210 max_rank = max(item.ndim for item in tensors) 

211 return max_rank 

212 

213 

214_index_copy_func = IndexCopyFunction() 

215 

216 

217_FALLBACK_KEYSET = torch._C.DispatchKeySet( 

218 torch._C.DispatchKey.CompositeExplicitAutograd 

219) 

220 

221 

222def index_copy(inp, dim, index, src): 

223 logger.debug("GEMS INDEX_COPY") 

224 assert ((0 <= index) * (index < inp.size(dim))).equal( 

225 torch.ones(tuple(index.shape), dtype=torch.bool, device=inp.device) 

226 ), "0 <= index < self.size(dim)" 

227 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

228 assert index.numel() == src.size( 

229 dim 

230 ), "The dimth dimension of source must have the same size as the length of index" 

231 assert ( 

232 inp.ndim == src.ndim 

233 ), "Self and source should have the same number of dimensions" 

234 assert all( 

235 (inp.size(i) == src.size(i)) or i == dim for i in range(0, inp.ndim) 

236 ), "src.size(d) == self.size(d) for all dimensions d != dim" 

237 

238 # Use native clone to avoid potential issues with FlagGems copy_ dispatch 

239 out = torch.ops.aten.clone.default.redispatch(_FALLBACK_KEYSET, inp) 

240 

241 dim %= inp.ndim 

242 inp_stride_dim = inp.stride(dim) 

243 src_shape_dim = src.size(dim) 

244 inp_shape_dim = inp.size(dim) 

245 delta = inp.size(dim) - src_shape_dim 

246 N = src.numel() 

247 

248 _index_copy_func( 

249 out, 

250 index, 

251 src, 

252 dim, 

253 inp_stride_dim, 

254 inp_shape_dim, 

255 src_shape_dim, 

256 delta, 

257 N, 

258 inp.numel(), 

259 ) 

260 return out 

261 

262 

263def index_copy_(inp, dim, index, src): 

264 logger.debug("GEMS INDEX_COPY_") 

265 assert ((0 <= index) * (index < inp.size(dim))).equal( 

266 torch.ones(tuple(index.shape), dtype=torch.bool, device=inp.device) 

267 ), "0 <= index < self.size(dim)" 

268 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

269 assert index.numel() == src.size( 

270 dim 

271 ), "The dimth dimension of source must have the same size as the length of index" 

272 assert ( 

273 inp.ndim == src.ndim 

274 ), "Self and source should have the same number of dimensions" 

275 assert all( 

276 (inp.size(i) == src.size(i)) or i == dim for i in range(0, inp.ndim) 

277 ), "src.size(d) == self.size(d) for all dimensions d != dim" 

278 

279 dim %= inp.ndim 

280 inp_stride_dim = inp.stride(dim) 

281 src_shape_dim = src.size(dim) 

282 inp_shape_dim = inp.size(dim) 

283 delta = inp.size(dim) - src_shape_dim 

284 N = src.numel() 

285 

286 _index_copy_func( 

287 inp, 

288 index, 

289 src, 

290 dim, 

291 inp_stride_dim, 

292 inp_shape_dim, 

293 src_shape_dim, 

294 delta, 

295 N, 

296 inp.numel(), 

297 ) 

298 return inp