Coverage for src/flag_gems/runtime/backend/_sunrise/ops/scatter.py: 0%

269 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import importlib 

2import logging 

3import os 

4from typing import Any, Callable, List, Mapping, Tuple 

5 

6import torch 

7 

8from flag_gems.utils.code_cache import code_cache_dir 

9from flag_gems.utils.code_utils import IndentedBuffer, write_atomic 

10from flag_gems.utils.shape_utils import ( 

11 MemOverlap, 

12 has_internal_overlapping, 

13 restride_dim, 

14) 

15 

16logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

17 

18 

19def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

20 code.writeline("import torch") 

21 code.writeline("try:") 

22 code.writeline(" import torch_ptpu") 

23 code.writeline("except ImportError:") 

24 code.writeline(" import torch.cuda as torch_ptpu") 

25 code.writeline("import triton") 

26 code.writeline("import triton.language as tl") 

27 code.newline() 

28 code.writeline("from flag_gems.utils import libentry") 

29 code.writeline("from flag_gems import runtime") 

30 code.writeline("import flag_gems") 

31 # code.writeline("from flag_gems.utils import triton_lang_extension as ext") 

32 code.newline() 

33 code.newline() 

34 return code 

35 

36 

37def generate_scatter_kernel( 

38 rank: int, 

39 kernel_name: str, 

40 code: IndentedBuffer, 

41) -> IndentedBuffer: 

42 # make the inlined function visible in the context 

43 code.newline() 

44 

45 # the autotune function 

46 

47 code.writeline("def heur_block(args):") 

48 with code.indent(): 

49 code.writeline("if(flag_gems.vendor_name in ['metax', 'iluvatar']):") 

50 with code.indent(): 

51 code.writeline("return 256") 

52 code.writeline("return 128") 

53 code.newline() 

54 code.newline() 

55 

56 code.writeline("def loop_count(args):") 

57 with code.indent(): 

58 code.writeline("return 4") 

59 code.newline() 

60 code.newline() 

61 

62 # the decorators 

63 code.writeline("@libentry()") 

64 code.writeline("@triton.heuristics(") 

65 with code.indent(): 

66 code.writeline("{") 

67 with code.indent(): 

68 code.writeline('"BLOCK": heur_block,') 

69 code.writeline('"LOOP": loop_count,') 

70 code.writeline("}") 

71 code.writeline(")") 

72 inp_stride_vars = ",".join(f"'inp_stride_{i}'" for i in range(rank)) 

73 index_stride_vars = ",".join(f"'index_stride_{i}'" for i in range(rank)) 

74 src_stride_vars = ",".join(f"'src_stride_{i}'" for i in range(rank)) 

75 shape_vars = ",".join(f"'shape_{i}'" for i in range(rank)) 

76 code.writeline( 

77 f"@triton.jit(do_not_specialize=['N','stride_dim','inp_size_dim'," 

78 f"{inp_stride_vars},{index_stride_vars},{src_stride_vars},{shape_vars}])" 

79 ) 

80 

81 # signature 

82 code.writeline(f"def {kernel_name}(") 

83 with code.indent(): 

84 if rank > 0: 

85 code.writeline("src_strided,") 

86 code.writeline("index,") 

87 code.writeline("inp,") 

88 code.writeline("out,") 

89 

90 stride_args = ", ".join(f"inp_stride_{i}: int" for i in range(rank)) 

91 code.writeline(f"{stride_args}, # stride for inp") 

92 

93 stride_args = ", ".join(f"index_stride_{i}: int" for i in range(rank)) 

94 code.writeline(f"{stride_args}, # stride for index") 

95 

96 stride_args = ", ".join(f"src_stride_{i}: int" for i in range(rank)) 

97 code.writeline(f"{stride_args}, # stride for src") 

98 

99 shape_args = ", ".join(f"shape_{i}: int" for i in range(rank)) 

100 code.writeline(f"{shape_args}, # shape") 

101 code.writeline("inp_size_dim,") 

102 code.writeline("stride_dim,") 

103 code.writeline("N,") 

104 # reduce options 

105 code.writeline("IS_ADD: tl.constexpr,") 

106 code.writeline("IS_MUL: tl.constexpr,") 

107 code.writeline("BLOCK: tl.constexpr,") 

108 code.writeline("LOOP: tl.constexpr,") 

109 code.writeline("INT32_OFFSET: tl.constexpr") 

110 

111 code.writeline("):") 

112 

113 # Kernel Code 

114 with code.indent(): 

115 code.writeline("pid = tl.program_id(0)") 

116 code.writeline("if not INT32_OFFSET:") 

117 with code.indent(): 

118 code.writeline("pid = pid.to(tl.int64)") 

119 code.writeline("offsets = pid * LOOP * BLOCK + tl.arange(0, BLOCK)") 

120 

121 # 1. Calculate inp_offsets and idx_offsets 

122 code.writeline("for loop_iter in tl.static_range(LOOP):") 

123 with code.indent(): 

124 code.writeline("mask = offsets < N") 

125 code.writeline("cur_idx = offsets") 

126 code.writeline("if INT32_OFFSET:") 

127 with code.indent(): 

128 code.writeline("inp_offsets = tl.zeros((BLOCK, ), dtype=tl.int32)") 

129 code.writeline("idx_offsets = tl.zeros((BLOCK, ), dtype=tl.int32)") 

130 code.writeline("src_offsets = tl.zeros((BLOCK, ), dtype=tl.int32)") 

131 code.writeline("else:") 

132 with code.indent(): 

133 code.writeline("inp_offsets = tl.zeros((BLOCK, ), dtype=tl.int64)") 

134 code.writeline("idx_offsets = tl.zeros((BLOCK, ), dtype=tl.int64)") 

135 code.writeline("src_offsets = tl.zeros((BLOCK, ), dtype=tl.int64)") 

136 for i in range(rank)[::-1]: 

137 code.writeline("if INT32_OFFSET:") 

138 with code.indent(): 

139 code.writeline(f"shape_{i} = shape_{i}.to(tl.int32)") 

140 code.writeline(f"inp_stride_{i} = inp_stride_{i}.to(tl.int32)") 

141 code.writeline(f"index_stride_{i} = index_stride_{i}.to(tl.int32)") 

142 code.writeline(f"src_stride_{i} = src_stride_{i}.to(tl.int32)") 

143 code.writeline(f"mod = cur_idx % shape_{i}") 

144 code.writeline(f"inp_offsets += mod * inp_stride_{i}") 

145 code.writeline(f"idx_offsets += mod * index_stride_{i}") 

146 code.writeline(f"src_offsets += mod * src_stride_{i}") 

147 if i != 0: 

148 code.writeline(f"cur_idx = cur_idx // shape_{i}") 

149 

150 # 2. Use offsets to scatter 

151 code.writeline( 

152 "cur_src = tl.load(src_strided + src_offsets, mask=mask, other=0)" 

153 ) 

154 code.writeline( 

155 "cur_index = tl.load(index + idx_offsets, mask=mask, other=0)" 

156 ) 

157 code.writeline("if INT32_OFFSET:") 

158 with code.indent(): 

159 code.writeline("cur_index = cur_index.to(tl.int32)") 

160 code.writeline("stride_dim = stride_dim.to(tl.int32)") 

161 

162 code.writeline("dim_offsets = cur_index * stride_dim") 

163 code.writeline("inp_offsets += dim_offsets") 

164 code.newline() 

165 code.writeline("if IS_ADD: ") 

166 with code.indent(): 

167 code.writeline( 

168 "tl.atomic_add(out + inp_offsets, cur_src, mask=mask, sem='relaxed')" 

169 ) 

170 code.writeline("elif IS_MUL: ") 

171 with code.indent(): 

172 code.writeline("stop = tl.where(mask, 0, 1).to(tl.int1)") 

173 code.writeline("block_stop = False") 

174 code.writeline("while not block_stop:") 

175 with code.indent(): 

176 code.writeline 

177 code.writeline( 

178 "cur_inp = tl.load(out + inp_offsets, mask=mask, other=0)" 

179 ) 

180 code.writeline("res = tl.where(stop, cur_inp, cur_inp * cur_src)") 

181 code.writeline( 

182 "cas_res = tl.atomic_cas(out + inp_offsets, cur_inp, res, sem='relaxed')" 

183 ) 

184 code.writeline("stop |= cur_inp == cas_res") 

185 code.writeline("block_stop = tl.sum(stop.to(tl.int32)) == BLOCK") 

186 

187 code.writeline("else: ") 

188 with code.indent(): 

189 code.writeline("tl.store(out + inp_offsets, cur_src, mask=mask)") 

190 

191 code.writeline("offsets += BLOCK") 

192 

193 code.newline() 

194 code.newline() 

195 return code 

196 

197 

198def parameter_for_wrapper() -> str: 

199 # src_strided, index, inp, out, dim, M, N, reduce 

200 parameters: List[str] = [] 

201 

202 parameters.append("src_strided") 

203 parameters.append("index") 

204 parameters.append("inp") 

205 parameters.append("out") 

206 parameters.append("dim_size") 

207 parameters.append("dim_stride") 

208 parameters.append("N") 

209 parameters.append("reduce: tl.constexpr=None") 

210 parameters.append("int32_offset: tl.constexpr=None") 

211 

212 return ", ".join(parameters) 

213 

214 

215def generate_destination_passing_wrapper( 

216 rank: int, 

217 wrapper_name: str, 

218 kernel_name: str, 

219 code: IndentedBuffer, 

220) -> IndentedBuffer: 

221 parameters: str = parameter_for_wrapper() 

222 wrapper_signature: str = f"def {wrapper_name}({parameters}):" 

223 code.writeline(wrapper_signature) 

224 

225 with code.indent(): 

226 code.writeline("inp_strides = list(inp.stride())") 

227 code.writeline("index_strides = index.stride()") 

228 code.writeline("src_strides = src_strided.stride()") 

229 code.writeline("index_shapes = list(index.shape)") 

230 code.writeline("inp_size_dim = dim_size") 

231 code.writeline("stride_dim = dim_stride") 

232 

233 code.writeline('IS_ADD = reduce == "add"') 

234 code.writeline('IS_MUL = reduce == "multiply"') 

235 code.writeline("int32_offset = int32_offset or True") 

236 

237 # kernel launch 

238 code.writeline("grid = lambda meta: (") 

239 with code.indent(): 

240 code.writeline('triton.cdiv(N, meta["BLOCK"] * meta["LOOP"]), ') 

241 code.writeline(")") 

242 

243 kernel_launch: str = f"{kernel_name}[grid](" 

244 code.writeline(kernel_launch) 

245 

246 with code.indent(): 

247 code.writeline("src_strided, index, inp, out, ") 

248 if rank > 0: 

249 s = ", ".join(f"inp_strides[{i}]" for i in range(rank)) 

250 code.writeline(f"{s},") 

251 

252 s = ", ".join(f"index_strides[{i}]" for i in range(rank)) 

253 code.writeline(f"{s},") 

254 

255 s = ", ".join(f"src_strides[{i}]" for i in range(rank)) 

256 code.writeline(f"{s},") 

257 

258 s = ", ".join(f"index_shapes[{i}]" for i in range(rank)) 

259 code.writeline(f"{s},") 

260 

261 code.writeline("inp_size_dim,") 

262 code.writeline("stride_dim,") 

263 code.writeline("N,") 

264 # reduce options 

265 code.writeline("IS_ADD,") 

266 code.writeline("IS_MUL,") 

267 code.writeline("INT32_OFFSET=int32_offset,") 

268 code.writeline(")") 

269 code.writeline("return out") 

270 

271 return code 

272 

273 

274def generate_code( 

275 inputs: Tuple[Any], 

276 wrapper_name: str, 

277 kernel_name: str, 

278 code: IndentedBuffer, 

279) -> IndentedBuffer: 

280 # inputs: [src_strided, index, inp, out, dim, M, N, reduce] 

281 shape = inputs[1].shape 

282 rank = len(shape) 

283 

284 code = generate_imports(code) 

285 code = generate_scatter_kernel(rank, kernel_name, code) 

286 code = generate_destination_passing_wrapper(rank, wrapper_name, kernel_name, code) 

287 return code 

288 

289 

290class ScatterFunction: 

291 def __init__(self): 

292 self.pid = os.getpid() 

293 self.overloads: Mapping[str, Callable] = {} 

294 

295 def __call__(self, *args, **kwargs): 

296 key = f"{self.arg_key(*args)}" 

297 if key in self.overloads: 

298 overload = self.overloads[key] 

299 else: 

300 code = IndentedBuffer() 

301 code = generate_code( 

302 args, 

303 "_scatter_wrapper", 

304 "_scatter_jit_function", 

305 code, 

306 ) 

307 

308 file_name = f"scatter_rank_{key}.py" 

309 file_path = code_cache_dir() / file_name 

310 write_atomic(file_path, code.getvalue()) 

311 

312 # load 

313 spec = importlib.util.spec_from_file_location( 

314 f"_gen_module_rank_{key}", 

315 file_path, 

316 ) 

317 

318 m = importlib.util.module_from_spec(spec) 

319 spec.loader.exec_module(m) 

320 overload = getattr(m, "_scatter_wrapper") 

321 self.overloads[key] = overload 

322 

323 return overload(*args, **kwargs) 

324 

325 def arg_key(self, *args): 

326 tensors = [item for item in args if torch.is_tensor(item)] 

327 max_rank = max(item.ndim for item in tensors) 

328 return max_rank 

329 

330 

331_scatter_func = ScatterFunction() 

332 

333 

334# 由于atomic不支持fp16相关操作,所以需要进行转换之后再运算,恢复成fp16; 

335def scatter(inp, dim, index, src, reduce=None): 

336 logger.debug("GEMS SCATTER") 

337 is_fp16 = inp.dtype == torch.float16 and (reduce is not None) 

338 if is_fp16: 

339 inp = inp.float() 

340 src = src.float() 

341 out = inp.clone() 

342 

343 if reduce is not None: 

344 assert inp.dtype not in ( 

345 torch.bfloat16, 

346 ), "Unsupported operation: reduce scatter bfloat tensors." 

347 

348 if has_internal_overlapping(out) == MemOverlap.Yes: 

349 out = out.contiguous() 

350 

351 src_strided = src.as_strided(index.shape, src.stride()) 

352 inp_restrided = restride_dim(inp, dim, index.shape) 

353 dim_size = inp.size(dim) 

354 dim_stride = inp.stride(dim) 

355 N = index.numel() 

356 

357 int32_size_dim = lambda x: x.stride(dim) * x.size(dim) < 2**32 

358 use_int32_offset = all(map(int32_size_dim, (inp, index, src))) 

359 _scatter_func( 

360 src_strided, 

361 index, 

362 inp_restrided, 

363 out, 

364 dim_size, 

365 dim_stride, 

366 N, 

367 reduce, 

368 int32_offset=use_int32_offset, 

369 ) 

370 if is_fp16: 

371 out = out.half() 

372 return out 

373 

374 

375def scatter_(inp, dim, index, src, reduce=None): 

376 logger.debug("GEMS SCATTER_") 

377 base_inp = inp 

378 is_fp16 = inp.dtype == torch.float16 and (reduce is not None) 

379 if is_fp16: 

380 inp = inp.float() 

381 src = src.float() 

382 out = inp 

383 

384 if reduce is not None: 

385 assert inp.dtype not in ( 

386 torch.bfloat16, 

387 ), "Unsupported operation: reduce scatter bfloat tensors." 

388 

389 assert ( 

390 has_internal_overlapping(out) != MemOverlap.Yes 

391 ), "Unsupported operation: trying to inplace write to an internally overlapping tensor." 

392 

393 src_restrided = src.as_strided(index.shape, src.stride()) 

394 inp_restrided = restride_dim(inp, dim, index.shape) 

395 dim_size = inp.size(dim) 

396 dim_stride = inp.stride(dim) 

397 N = index.numel() 

398 

399 int32_size_dim = lambda x: x.stride(dim) * x.size(dim) < 2**32 

400 use_int32_offset = all(map(int32_size_dim, (inp, index, src))) 

401 _scatter_func( 

402 src_restrided, 

403 index, 

404 inp_restrided, 

405 out, 

406 dim_size, 

407 dim_stride, 

408 N, 

409 reduce, 

410 int32_offset=use_int32_offset, 

411 ) 

412 if is_fp16: 

413 base_inp.copy_(out) 

414 return base_inp 

415 return inp