Coverage for src/flag_gems/utils/pointwise_dynamic.py: 94%

1019 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import importlib 

2import os 

3from dataclasses import dataclass 

4from enum import Enum, auto 

5from typing import Callable, Iterable, List, Mapping, Optional, Sequence, Tuple 

6 

7import torch 

8import triton 

9from triton.runtime.jit import JITFunction 

10 

11from flag_gems.utils.code_cache import code_cache_dir 

12from flag_gems.utils.code_utils import IndentedBuffer, write_atomic 

13from flag_gems.utils.codegen_config_utils import CodeGenConfig, get_codegen_config 

14from flag_gems.utils.device_info import get_device_capability 

15from flag_gems.utils.shape_utils import ( 

16 MemOverlap, 

17 all_c_contiguous, 

18 all_the_same_shape, 

19 all_the_same_stride, 

20 broadcast_shapes, 

21 broadcasted_stride, 

22 check_tensor_attributes, 

23 has_internal_overlapping, 

24) 

25from flag_gems.utils.tensor_wrapper import StridedBuffer 

26from flag_gems.utils.type_utils import ELEMENTWISE_TYPE_PROMOTION_KIND, type_promotion 

27 

28 

29# ------------------ Operation Description --------------------------- 

30def _type_name(type) -> str: 

31 "Render typename as string, work for both (bool, int, float, str) and torch.dtype object" 

32 if type in (bool, int, float, str): 

33 return type.__name__ 

34 if isinstance(type, torch.dtype): 

35 return str(type) 

36 return str(type) 

37 

38 

39def _check_typed_list(container, type): 

40 for item in container: 

41 assert isinstance(item, type) 

42 

43 

44def _check_sized_list(container, size): 

45 assert len(container) == size 

46 

47 

48def _tuple_content(strings: Sequence[str]) -> str: 

49 # comma separated list 

50 if len(strings) == 0: 

51 return "" 

52 if len(strings) == 1: 

53 return f"{strings[0]}," 

54 else: 

55 return ", ".join(strings) 

56 

57 

58def _cs(strings: Iterable[str]) -> str: 

59 return ", ".join(strings) 

60 

61 

62def _broadcast_vec(i, ndim): 

63 axes = [":" if j == i else "None" for j in range(ndim)] 

64 return f"[{_cs(axes)}]" 

65 

66 

67class FunctionSchema: 

68 _num_inputs: int 

69 _is_tensor: List[bool] 

70 _dtypes: List[Optional[type]] 

71 

72 _num_input_tensors: int 

73 _num_non_tensor_inputs: int 

74 

75 _num_outputs: int 

76 _promotion_methods: List[Tuple[int, ...]] 

77 

78 def __init__( 

79 self, 

80 *, 

81 num_inputs: Optional[int] = None, 

82 is_tensor: Optional[List[bool]] = None, 

83 dtypes: Optional[List[Optional[type]]] = None, 

84 num_outputs: Optional[int] = None, 

85 promotion_methods=None, 

86 ): 

87 if is_tensor is not None: 

88 _check_typed_list(is_tensor, bool) 

89 if dtypes is not None: 

90 _check_typed_list(dtypes, (type, type(None))) 

91 

92 if promotion_methods is None: 

93 raise ValueError( 

94 "No type promotion method provided! You must provide type promotion method for each output!" 

95 ) 

96 else: 

97 self._promotion_methods = self.canonicalize_promotion_methods( 

98 promotion_methods 

99 ) 

100 if num_inputs is not None: 

101 self._num_inputs = num_inputs 

102 if is_tensor is not None: 

103 _check_sized_list(is_tensor, num_inputs) 

104 self._is_tensor = is_tensor 

105 else: 

106 self._is_tensor = [True] * num_inputs 

107 

108 if dtypes is not None: 

109 _check_sized_list(dtypes, num_inputs) 

110 self._dtypes = dtypes 

111 else: 

112 self._dtypes = [None] * num_inputs 

113 elif is_tensor is not None: 

114 self._num_inputs = len(is_tensor) 

115 self._is_tensor = is_tensor 

116 if dtypes is not None: 

117 _check_sized_list(dtypes, self._num_inputs) 

118 self._dtypes = dtypes 

119 else: 

120 self._dtypes = [None] * self._num_inputs 

121 elif dtypes is not None: 

122 self._num_inputs = len(dtypes) 

123 self._dtypes = dtypes 

124 if is_tensor is not None: 

125 _check_sized_list(is_tensor, self._num_inputs) 

126 self._is_tensor = is_tensor 

127 else: 

128 self._is_tensor = [item is None for item in dtypes] 

129 else: 

130 raise ValueError( 

131 "Cannot create FunctionSchema when none of (num_inputs, is_tensor, dtypes) is specified." 

132 ) 

133 

134 if num_outputs is not None: 

135 self._num_outputs = num_outputs 

136 _check_sized_list(promotion_methods, num_outputs) 

137 else: 

138 self._num_outputs = len(promotion_methods) 

139 

140 assert self._num_inputs >= 1 

141 assert self._num_outputs >= 1 

142 

143 self._num_input_tensors = sum(self._is_tensor) 

144 self._num_non_tensor_inputs = self._num_inputs - self._num_input_tensors 

145 self._input_id = self._compute_input_id() 

146 

147 @staticmethod 

148 def canonicalize_promotion_methods(promotion_methods): 

149 canonicalized = [] 

150 for item in promotion_methods: 

151 *arg_indices, method = item 

152 canonicalized.append( 

153 (*arg_indices, ELEMENTWISE_TYPE_PROMOTION_KIND[method]) 

154 ) 

155 return canonicalized 

156 

157 def num_inputs(self): 

158 # num of arguments, outputs not included 

159 return self._num_inputs 

160 

161 def num_outputs(self): 

162 return self._num_outputs 

163 

164 def is_tensor(self, arg_id: int) -> bool: 

165 return self._is_tensor[arg_id] 

166 

167 def input_type(self, arg_id) -> Optional[type]: 

168 return self._dtypes[arg_id] 

169 

170 def output_type(self, i): 

171 return self._promotion_methods[i] 

172 

173 def num_input_tensors(self) -> int: 

174 return self._num_input_tensors 

175 

176 def num_output_tensors(self) -> int: 

177 return self._num_outputs 

178 

179 def num_non_tensor_args(self) -> int: 

180 return self._num_non_tensor_inputs 

181 

182 def signature(self, outputs_in_arg: bool = False) -> str: 

183 input_types = [] 

184 for is_tensor, dtype in zip(self._is_tensor, self._dtypes): 

185 if is_tensor: 

186 input_types.append("StridedBuffer") 

187 else: 

188 if dtype is None: 

189 input_types.append("scalar") 

190 else: 

191 input_types.append(_type_name(dtype)) 

192 

193 output_types = [] 

194 

195 if outputs_in_arg: 

196 for i in range(self.num_outputs()): 

197 output_types.append(f"StridedBuffer(a{1}!)") 

198 input_types.extend(output_types) 

199 else: 

200 for _ in range(self.num_outputs()): 

201 output_types.append("StridedBuffer") 

202 sig = f'Pointwise: {", ".join(input_types)} -> {", ".join(output_types)}' 

203 return sig 

204 

205 def _compute_input_id(self): 

206 input_tensor_index = 0 

207 non_tensor_index = 0 

208 mapping: List[int] = [] 

209 for i in range(self.num_inputs()): 

210 if self.is_tensor(i): 

211 mapping.append(input_tensor_index) 

212 input_tensor_index += 1 

213 else: 

214 mapping.append(non_tensor_index) 

215 non_tensor_index += 1 

216 return mapping 

217 

218 def input_index(self, idx): 

219 return self._input_id[idx] 

220 

221 def __str__(self) -> str: 

222 return self.signature(outputs_in_arg=False) 

223 

224 

225class KernelGenerator: 

226 def __init__( 

227 self, 

228 function_schema: FunctionSchema, 

229 scalar_fn: triton.JITFunction, 

230 rank: int, 

231 name: str, 

232 config: CodeGenConfig, 

233 ): 

234 self.fx = function_schema 

235 self.fn = scalar_fn 

236 self.ndim = rank 

237 self.name = name 

238 self.config = config 

239 

240 self.fn_name = scalar_fn.__name__ 

241 self.fn_module = scalar_fn.__module__ 

242 

243 def gen_import_function(self, code: IndentedBuffer): 

244 code.writeline("@triton.jit") 

245 code.writemultiline(self.fn.src) 

246 code.newline() 

247 

248 def gen_decorators(self, code): 

249 code.writeline("@libentry()") 

250 num_non_tensor_args = self.fx.num_non_tensor_args() 

251 if num_non_tensor_args > 0: 

252 # we do not specialize non tensor args since they are passed into the inlined function 

253 # which means that their values may not deserve specialization 

254 non_specialize_arg_names = [f"val{i}" for i in range(num_non_tensor_args)] 

255 code.writeline(f"@triton.jit(do_not_specialize={non_specialize_arg_names})") 

256 else: 

257 code.writeline("@triton.jit") 

258 

259 def input_name(self, i): 

260 is_tensor = self.fx.is_tensor(i) 

261 name = "in" if is_tensor else "val" 

262 index = self.fx.input_index(i) 

263 return f"{name}{index}" 

264 

265 def output_name(self, i): 

266 return f"out{i}" 

267 

268 def gen_signature(self, code, with_block_pointer=False): 

269 code.writeline(f"def {self.name}(") 

270 with code.indent(): 

271 input_tensor_index = 0 

272 non_tensor_index = 0 

273 output_tensor_index = 0 

274 

275 schema = self.fx 

276 # signature: inputs ptrs & non tensor inputs 

277 for i in range(schema.num_inputs()): 

278 if schema.is_tensor(i): 

279 code.writeline( 

280 f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" 

281 ) 

282 input_tensor_index += 1 

283 else: 

284 if schema.input_type(i) is not None: 

285 code.writeline( 

286 f"val{non_tensor_index}: {_type_name(schema.input_type(i))}," 

287 ) 

288 else: 

289 code.writeline(f"val{non_tensor_index},") 

290 non_tensor_index += 1 

291 

292 # signature: output ptrs 

293 for i in range(schema.num_outputs()): 

294 code.writeline( 

295 f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" 

296 ) 

297 output_tensor_index += 1 

298 

299 # signature: strides, for each tensor arguments 

300 ndim = self.ndim 

301 if ndim > 0: 

302 # strides for inputs 

303 for i in range(schema.num_input_tensors()): 

304 stride_args = _cs(f"in{i}_stride{j}: int" for j in range(ndim)) 

305 code.writeline(f"{stride_args}, # strides for in{i}") 

306 if with_block_pointer: 

307 stride_order_args = _cs( 

308 f"in{i}_stride_order{j}: tl.constexpr" for j in range(ndim) 

309 ) 

310 code.writeline(f"{stride_order_args}, # stride order for in{i}") 

311 

312 # strides for outputs 

313 for i in range(schema.num_output_tensors()): 

314 stride_args = _cs(f"out{i}_stride{j}: int" for j in range(ndim)) 

315 code.writeline(f"{stride_args}, # strides for out{i}") 

316 if with_block_pointer: 

317 stride_order_args = _cs( 

318 f"out{i}_stride_order{j}: tl.constexpr" for j in range(ndim) 

319 ) 

320 code.writeline( 

321 f"{stride_order_args}, # stride order for out{i}" 

322 ) 

323 

324 # task space, used to reconstruct multi index 

325 task_space_args = _cs(f"s{i}" for i in range(ndim)) 

326 code.writeline(f"{task_space_args}, # task_space") 

327 

328 # number of tasks, used to compute mask 

329 code.writeline("num_tasks,") 

330 

331 # tile size & tiles_per_cta, gsl style 

332 if ndim > 0: 

333 code.writeline("tiles_per_cta: int,") 

334 tile_sizes = _cs(f"tile_size{i}: tl.constexpr" for i in range(ndim)) 

335 code.writeline(f"{tile_sizes},") 

336 code.writeline("one_tile_per_cta: tl.constexpr,") 

337 code.writeline("):") 

338 

339 def gen_signature_1d_tile(self, code): 

340 code.writeline(f"def {self.name}(") 

341 with code.indent(): 

342 input_tensor_index = 0 

343 non_tensor_index = 0 

344 output_tensor_index = 0 

345 

346 schema = self.fx 

347 # signature: inputs ptrs & non tensor inputs 

348 for i in range(schema.num_inputs()): 

349 if schema.is_tensor(i): 

350 code.writeline( 

351 f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" 

352 ) 

353 input_tensor_index += 1 

354 else: 

355 if schema.input_type(i) is not None: 

356 code.writeline( 

357 f"val{non_tensor_index}: {_type_name(schema.input_type(i))}," 

358 ) 

359 else: 

360 code.writeline(f"val{non_tensor_index},") 

361 non_tensor_index += 1 

362 

363 # signature: output ptrs 

364 for i in range(schema.num_outputs()): 

365 code.writeline( 

366 f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" 

367 ) 

368 output_tensor_index += 1 

369 

370 # signature: strides, for each tensor arguments 

371 ndim = self.ndim 

372 if ndim > 0: 

373 # strides for inputs 

374 for i in range(schema.num_input_tensors()): 

375 stride_args = _cs(f"in{i}_stride{j}: int" for j in range(ndim)) 

376 code.writeline(f"{stride_args}, # strides for in{i}") 

377 

378 # strides for outputs 

379 for i in range(schema.num_output_tensors()): 

380 stride_args = _cs(f"out{i}_stride{j}: int" for j in range(ndim)) 

381 code.writeline(f"{stride_args}, # strides for out{i}") 

382 

383 # task space, used to reconstruct multi index 

384 task_space_args = _cs(f"s{i}" for i in range(ndim)) 

385 code.writeline(f"{task_space_args}, # task_space") 

386 

387 # number of tasks, used to compute mask 

388 code.writeline("num_tasks,") 

389 

390 # tile size & tiles_per_cta, gsl style 

391 if ndim > 0: 

392 code.writeline("tiles_per_cta: int,") 

393 code.writeline("tile_size: tl.constexpr,") 

394 code.writeline("one_tile_per_cta: tl.constexpr,") 

395 code.writeline("):") 

396 

397 def gen_num_tiles(self, code): 

398 # tile-grid size 

399 ndim = self.ndim 

400 for i in range(ndim): 

401 if i < ndim: 

402 code.writeline(f"num_tiles{i} = tl.cdiv(s{i}, tile_size{i})") 

403 

404 def gen_body_for_0d(self, code): 

405 schema = self.fx 

406 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())] 

407 outputs_to_scalar_fn = [ 

408 self.output_name(i) for i in range(schema.num_output_tensors()) 

409 ] 

410 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

411 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

412 

413 code.writeline("# loads") 

414 for i in range(schema.num_input_tensors()): 

415 code.writeline( 

416 f"in{i} = tl.load(in{i}_ptr).to(in{i}_ptr.type.element_ty) " 

417 "# workaround the bug on bool, we should use the pointer's dtype)" 

418 ) 

419 code.newline() 

420 

421 code.writeline("# compute") 

422 code.writeline( 

423 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

424 ) 

425 code.newline() 

426 

427 code.writeline("# stores") 

428 for i in range(schema.num_output_tensors()): 

429 code.writeline( 

430 f"tl.store(out{i}_ptr, out{i}.to(out{i}_ptr.type.element_ty))" 

431 ) 

432 code.newline() 

433 return code 

434 

435 # nd tile 1d grid kernel with block pointer 

436 def gen_body_one_tile_per_cta_with_bptr(self, code): 

437 ndim = self.ndim 

438 schema = self.fx 

439 

440 # block pointer for each operand 

441 shape = _tuple_content(tuple(f"s{i}" for i in range(ndim))) 

442 offsets = _tuple_content(tuple(f"offset{i}" for i in range(ndim))) 

443 tile_sizes = _tuple_content(tuple(f"tile_size{i}" for i in range(ndim))) 

444 

445 # reconstruct pid multi index 

446 code.writeline( 

447 "# pid multi index recontruction: we use c ordering, right axes changes fastest" 

448 ) 

449 for i in reversed(range(ndim)): 

450 if i > 0: 

451 code.writeline(f"tile_id{i} = tile_id % num_tiles{i}") 

452 code.writeline(f"tile_id //= num_tiles{i}") 

453 else: 

454 code.writeline(f"tile_id{i} = tile_id") 

455 code.newline() 

456 

457 # cta_offsets 

458 code.writeline("# tile offsets") 

459 for i in range(ndim): 

460 # Or else: AssertionError: Block pointers only support 32 bit 

461 # `offsets/block_shape`, add a `.to(tl.int32)` or use regular indexing 

462 # for 64 bit support 

463 code.writeline(f"offset{i} = (tile_id{i} * tile_size{i}).to(tl.int32)") 

464 

465 # loads 

466 code.writeline("# loads") 

467 for i in range(schema.num_input_tensors()): 

468 strides = _tuple_content(tuple(f"in{i}_stride{j}" for j in range(ndim))) 

469 import flag_gems 

470 

471 if flag_gems.vendor_name == "spacemit": 

472 order = _tuple_content(tuple(f"{ndim - j - 1}" for j in range(ndim))) 

473 else: 

474 order = _tuple_content( 

475 tuple(f"in{i}_stride_order{j}" for j in range(ndim)) 

476 ) 

477 code.writeline( 

478 f"in{i}_bptr = tl.make_block_ptr(" 

479 f"in{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))" 

480 ) 

481 code.writeline( 

482 f"in{i} = tl.load(in{i}_bptr, boundary_check=({order})).to(in{i}_ptr.type.element_ty) " 

483 "# workaround the bug on bool, we should use the original pointer's dtype(instead of block pointer's)" 

484 ) 

485 code.newline() 

486 

487 # compute 

488 # TODO: sepearate this part 

489 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())] 

490 outputs_to_scalar_fn = [ 

491 self.output_name(i) for i in range(schema.num_output_tensors()) 

492 ] 

493 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

494 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

495 

496 code.writeline("# compute") 

497 code.writeline( 

498 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

499 ) 

500 code.newline() 

501 

502 # stores 

503 code.writeline( 

504 "# stores, note that store to block pointer does not automatically cast the value to the pointer's dtype" 

505 ) 

506 for i in range(schema.num_output_tensors()): 

507 strides = _tuple_content(tuple(f"out{i}_stride{j}" for j in range(ndim))) 

508 order = _tuple_content( 

509 tuple(f"out{i}_stride_order{j}" for j in range(ndim)) 

510 ) 

511 code.writeline( 

512 f"out{i}_bptr = tl.make_block_ptr(" 

513 f"out{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))" 

514 ) 

515 code.writeline( 

516 f"tl.store(out{i}_bptr, out{i}.to(out{i}_bptr.type.element_ty), boundary_check=({order}))" 

517 ) 

518 

519 def gen_body_gsl_with_bptr(self, code): 

520 code.writeline("num_ctas = ext.num_programs(0)") 

521 code.writeline("for j in range(0, tiles_per_cta):") 

522 with code.indent(): 

523 code.writeline("tile_id = pid + j * num_ctas") 

524 self.gen_body_one_tile_per_cta_with_bptr(code) 

525 

526 def gen_body_one_tile_per_cta_without_bptr(self, code): 

527 ndim = self.ndim 

528 schema = self.fx 

529 

530 # reconstruct pid multi index 

531 code.writeline( 

532 "# pid multi index recontruction: we use c ordering, right axes changes fastest" 

533 ) 

534 for i in reversed(range(ndim)): 

535 if i > 0: 

536 code.writeline(f"tile_id{i} = tile_id % num_tiles{i}") 

537 code.writeline(f"tile_id //= num_tiles{i}") 

538 else: 

539 code.writeline(f"tile_id{i} = tile_id") 

540 code.newline() 

541 

542 # offsets 

543 for i in range(ndim): 

544 code.writeline( 

545 f"offsets{i} = tile_id{i} * tile_size{i} + tl.arange(0, tile_size{i})" 

546 ) 

547 

548 # masks 

549 for i in range(ndim): 

550 code.writeline(f"mask{i} = offsets{i} < s{i}") 

551 masks = tuple(f"mask{i}{_broadcast_vec(i, ndim)}" for i in range(ndim)) 

552 mask_combine = " & ".join(masks) 

553 code.writeline(f"mask = {mask_combine}") 

554 

555 # loads 

556 code.writeline("# loads") 

557 for i in range(schema.num_input_tensors()): 

558 offsets = tuple( 

559 f"offsets{j}{_broadcast_vec(j, ndim)} * in{i}_stride{j}" 

560 for j in range(ndim) 

561 ) 

562 offset_combine = " + ".join(offsets) 

563 code.writeline( 

564 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)" 

565 ) 

566 

567 code.newline() 

568 

569 # compute 

570 # TODO: sepearate this part 

571 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())] 

572 outputs_to_scalar_fn = [ 

573 self.output_name(i) for i in range(schema.num_output_tensors()) 

574 ] 

575 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

576 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

577 

578 code.writeline("# compute") 

579 code.writeline( 

580 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

581 ) 

582 code.newline() 

583 

584 # stores 

585 for i in range(schema.num_output_tensors()): 

586 offsets = tuple( 

587 f"offsets{j}{_broadcast_vec(j, ndim)} * out{i}_stride{j}" 

588 for j in range(ndim) 

589 ) 

590 offset_combine = " + ".join(offsets) 

591 code.writeline( 

592 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)" 

593 ) 

594 

595 def gen_body_gsl_without_bptr(self, code): 

596 code.writeline("num_ctas = ext.num_programs(0)") 

597 code.writeline("for j in range(0, tiles_per_cta):") 

598 with code.indent(): 

599 code.writeline("tile_id = pid + j * num_ctas") 

600 self.gen_body_one_tile_per_cta_without_bptr(code) 

601 

602 def codegen_nd_tile_with_bptr(self, code): 

603 """Generate kernel nd tile & 1d grid with gsl support with block pointer.""" 

604 self.gen_import_function(code) 

605 self.gen_decorators(code) 

606 self.gen_signature(code, with_block_pointer=True) 

607 

608 # function body for rank-0 

609 if self.ndim == 0: 

610 with code.indent(): 

611 self.gen_body_for_0d(code) 

612 return code 

613 

614 with code.indent(): 

615 code.writeline("pid = ext.program_id(0)") 

616 self.gen_num_tiles(code) 

617 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute 

618 code.writeline("if one_tile_per_cta: # monolitic kernel style") 

619 with code.indent(): 

620 code.writeline("tile_id = pid") 

621 self.gen_body_one_tile_per_cta_with_bptr(code) 

622 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/ 

623 code.writeline("else: # grid-stride-loop style kernel") 

624 with code.indent(): 

625 self.gen_body_gsl_with_bptr(code) 

626 code.newline() 

627 return code 

628 

629 def codegen_nd_tile_without_bptr(self, code): 

630 self.gen_import_function(code) 

631 self.gen_decorators(code) 

632 self.gen_signature(code, with_block_pointer=False) 

633 

634 # function body for rank-0 

635 if self.ndim == 0: 

636 with code.indent(): 

637 self.gen_body_for_0d(code) 

638 return code 

639 

640 with code.indent(): 

641 code.writeline("pid = ext.program_id(0)") 

642 self.gen_num_tiles(code) 

643 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute 

644 code.writeline("if one_tile_per_cta: # monolitic kernel style") 

645 with code.indent(): 

646 code.writeline("tile_id = pid") 

647 self.gen_body_one_tile_per_cta_without_bptr(code) 

648 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/ 

649 code.writeline("else: # grid-stride-loop style kernel") 

650 with code.indent(): 

651 self.gen_body_gsl_without_bptr(code) 

652 code.newline() 

653 return code 

654 

655 def codegen_nd_tile(self, code): 

656 use_block_pointer = self.config.prefer_block_pointer 

657 if use_block_pointer: 

658 self.codegen_nd_tile_with_bptr(code) 

659 else: 

660 self.codegen_nd_tile_without_bptr(code) 

661 return code 

662 

663 def gen_body_one_tile_per_cta_1d_tile(self, code): 

664 ndim = self.ndim 

665 schema = self.fx 

666 

667 # tile id 

668 code.writeline("tid = tile_id * tile_size + tl.arange(0, tile_size)") 

669 code.writeline("mask = tid < num_tasks") 

670 

671 # multi index reconstruction 

672 for i in reversed(range(ndim)): 

673 if i > 0: 

674 code.writeline(f"i{i} = tid % s{i}") 

675 code.writeline(f"tid //= s{i}") 

676 else: 

677 code.writeline(f"i{i} = tid") 

678 code.newline() 

679 

680 # loads 

681 code.writeline("# loads") 

682 for i in range(schema.num_input_tensors()): 

683 offsets = tuple(f"i{j} * in{i}_stride{j}" for j in range(ndim)) 

684 offset_combine = " + ".join(offsets) 

685 code.writeline( 

686 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)" 

687 ) 

688 

689 code.newline() 

690 

691 # compute 

692 # TODO: sepearate this part 

693 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())] 

694 outputs_to_scalar_fn = [ 

695 self.output_name(i) for i in range(schema.num_output_tensors()) 

696 ] 

697 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

698 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

699 

700 code.writeline("# compute") 

701 code.writeline( 

702 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

703 ) 

704 code.newline() 

705 

706 # stores 

707 for i in range(schema.num_output_tensors()): 

708 offsets = tuple(f"i{j} * out{i}_stride{j}" for j in range(ndim)) 

709 offset_combine = " + ".join(offsets) 

710 code.writeline( 

711 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)" 

712 ) 

713 

714 def gen_body_gsl_1d_tile(self, code): 

715 code.writeline("num_ctas = ext.num_programs(0)") 

716 code.writeline("for j in range(0, tiles_per_cta):") 

717 with code.indent(): 

718 code.writeline("tile_id = pid + j * num_ctas") 

719 self.gen_body_one_tile_per_cta_1d_tile(code) 

720 

721 def codegen_1d_tile(self, code): 

722 """Generate kernel 1d tile & 1d grid with gsl support.""" 

723 self.gen_import_function(code) 

724 self.gen_decorators(code) 

725 self.gen_signature_1d_tile(code) 

726 

727 # function body for rank-0 

728 if self.ndim == 0: 

729 with code.indent(): 

730 self.gen_body_for_0d(code) 

731 return code 

732 

733 with code.indent(): 

734 code.writeline("pid = ext.program_id(0)") 

735 # code.writeline("num_ctas = te.num_programs(0)") 

736 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute 

737 code.writeline("if one_tile_per_cta: # monolitic kernel style") 

738 with code.indent(): 

739 code.writeline("tile_id = pid") 

740 self.gen_body_one_tile_per_cta_1d_tile(code) 

741 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/ 

742 code.writeline("else: # grid-stride-loop style kernel") 

743 with code.indent(): 

744 self.gen_body_gsl_1d_tile(code) 

745 code.newline() 

746 return code 

747 

748 

749class WrapperGenerator: 

750 def __init__( 

751 self, 

752 function_schema: FunctionSchema, 

753 jit_fn_name: str, 

754 ndim: int, 

755 name: str, 

756 config: CodeGenConfig, 

757 ): 

758 self.fx = function_schema 

759 self.jit_fn_name = jit_fn_name 

760 self.ndim = ndim 

761 self.name = name 

762 self.config = config 

763 

764 def input_name(self, i): 

765 is_tensor = self.fx.is_tensor(i) 

766 name = "in" if is_tensor else "val" 

767 index = self.fx.input_index(i) 

768 return f"{name}{index}" 

769 

770 def output_name(self, i): 

771 return f"out{i}" 

772 

773 def gen_signature(self, code: IndentedBuffer): 

774 # TODO: check if triton handles constexprs transitively 

775 schema = self.fx 

776 params: List[str] = [] 

777 for i in range(schema.num_inputs()): 

778 if schema.is_tensor(i): 

779 params.append( 

780 f"{self.input_name(i)}: Union[torch.Tensor, StridedBuffer]" 

781 ) 

782 else: 

783 arg_type = schema.input_type(i) 

784 if arg_type is not None: 

785 params.append(f"{self.input_name(i)}: {_type_name(arg_type)}") 

786 else: 

787 params.append(f"{self.input_name(i)}") 

788 # NOTE: [the wrapper's signature and rules for passing parameters ] 

789 # input params: must be passed by position, since the names are renamed to 

790 # in0, in1, val0, val1, ..., So passing these parameters by keyword is wierd 

791 # So we enforce that these parameters must be passed by position. 

792 # maybe we can fix it later 

793 # output parameters: must be passed by keyword, since the scalar function 

794 # do not have output parameters(think of it as some scalar function, output 

795 # parameter does not make sense in this case.) They are added to allow destination 

796 # passing style API. Output parameter is convenient in cases where we want 

797 # to use some pre-defiend outputs(especially when they are some views of other 

798 # tensors). We emphasize that these parameters are added in-addition, we enforce 

799 # that they be passed by keyword. After all, out0, out1, ... does not mismatch 

800 # names form the scalar function, since it does not have output parameters. 

801 params.append("/") 

802 params.append("*") # output params must be passed by keyword 

803 

804 for i in range(schema.num_output_tensors()): 

805 params.append(f"{self.output_name(i)}: Union[torch.Tensor, StridedBuffer]") 

806 code.writeline(f"def {self.name}({_cs(params)}): ") 

807 

808 def gen_docstring(self, code: IndentedBuffer): 

809 schema = self.fx 

810 doc = f'"""Generated wrapper function with {schema.signature(outputs_in_arg=True)}"""' 

811 code.writeline(doc) 

812 

813 def gen_same_shape_check(self, code: IndentedBuffer): 

814 schema: FunctionSchema = self.fx 

815 params = [f"in{i}.shape" for i in range(schema.num_input_tensors())] + [ 

816 f"out{i}.shape" for i in range(schema.num_output_tensors()) 

817 ] 

818 check: str = " == ".join(params) 

819 code.writeline(f"assert {check}, 'operand shapes mismatch'") 

820 

821 def gen_task_partition(self, code: IndentedBuffer): 

822 code.writeline("# task partitioning") 

823 ndim = self.ndim 

824 if ndim == 0: 

825 code.writeline("num_warps = 1") 

826 code.writeline("num_ctas = 1") 

827 else: 

828 code.writeline("shape = out0.shape") 

829 code.writeline("num_tasks = out0.numel()") 

830 code.writeline("if num_tasks == 0:") 

831 with code.indent(): 

832 self.gen_return(code) 

833 max_tile_size = self.config.max_tile_size 

834 # Check if all input and output dtypes are complex 

835 all_complex = True 

836 for i in range(self.fx.num_inputs()): 

837 if self.fx.is_tensor(i): 

838 input_dtype = self.fx.input_type(i) 

839 if input_dtype is not None and not ( 

840 input_dtype == torch.complex64 

841 or input_dtype == torch.complex128 

842 ): 

843 all_complex = False 

844 break 

845 if all_complex: 

846 # If all inputs are complex, set max_tile_size to half 

847 max_tile_size = max_tile_size // 2 

848 major, _ = get_device_capability() 

849 if self.name.find("fill_scalar") != -1 and major >= 9: 

850 code.writeline("tile_sizes = tuple([64])") 

851 else: 

852 code.writeline( 

853 f"tile_sizes = heuristics_for_tile_size({max_tile_size}, *shape)" 

854 ) 

855 code.writeline("tile_size = math.prod(tile_sizes)") 

856 code.writeline( 

857 "num_tiles = math.prod(triton.cdiv(size, tile_size) for size, tile_size in zip(shape, tile_sizes))" 

858 ) 

859 

860 if self.name.find("fill_scalar") != -1 and major >= 9: 

861 code.writeline("num_ctas = num_tiles") 

862 else: 

863 max_grid_size0 = self.config.max_grid_size[0] 

864 code.writeline(f"num_ctas = min({max_grid_size0}, num_tiles)") 

865 

866 code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)") 

867 code.writeline("num_warps = heuristics_for_num_warps(tile_size)") 

868 code.writeline("one_tile_per_cta = tiles_per_cta==1") 

869 code.writeline("grid = (num_ctas, 1, 1)") 

870 

871 def gen_task_partition_1d(self, code: IndentedBuffer): 

872 code.writeline("# task partitioning") 

873 ndim = self.ndim 

874 if ndim == 0: 

875 code.writeline("num_warps = 1") 

876 code.writeline("num_ctas = 1") 

877 else: 

878 code.writeline("shape = out0.shape") 

879 code.writeline("num_tasks = out0.numel()") 

880 code.writeline("if num_tasks == 0:") 

881 with code.indent(): 

882 self.gen_return(code) 

883 max_tile_size = self.config.max_tile_size 

884 # Check if all input and output dtypes are complex 

885 all_complex = True 

886 for i in range(self.fx.num_inputs()): 

887 if self.fx.is_tensor(i): 

888 input_dtype = self.fx.input_type(i) 

889 if input_dtype is not None and not ( 

890 input_dtype == torch.complex64 

891 or input_dtype == torch.complex128 

892 ): 

893 all_complex = False 

894 break 

895 if all_complex: 

896 max_tile_size = max_tile_size // 2 

897 major, _ = get_device_capability() 

898 if self.name.find("fill_scalar") != -1 and major >= 9: 

899 code.writeline("tile_sizes = tuple([1024])") 

900 else: 

901 code.writeline( 

902 f"tile_sizes = heuristics_for_tile_size({max_tile_size}, num_tasks)" 

903 ) 

904 

905 code.writeline("tile_size = tile_sizes[0]") 

906 code.writeline("num_tiles = triton.cdiv(num_tasks, tile_size)") 

907 

908 if self.name.find("fill_scalar") != -1 and major >= 9: 

909 code.writeline("num_ctas = num_tiles") 

910 else: 

911 max_grid_size0 = self.config.max_grid_size[0] 

912 code.writeline(f"num_ctas = min({max_grid_size0}, num_tiles)") 

913 

914 code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)") 

915 code.writeline("num_warps = heuristics_for_num_warps(tile_size)") 

916 code.writeline("one_tile_per_cta = tiles_per_cta==1") 

917 code.writeline("grid = (num_ctas, 1, 1)") 

918 

919 def gen_kernel_launch( 

920 self, 

921 code: IndentedBuffer, 

922 ): 

923 schema = self.fx 

924 ndim = self.ndim 

925 

926 with_block_pointer = self.config.prefer_block_pointer 

927 

928 code.writeline("# kernel launch") 

929 for i in range(schema.num_input_tensors()): 

930 code.writeline(f"in{i}_strides = in{i}.stride()") 

931 if not with_block_pointer: 

932 continue 

933 if ndim >= 2: # where ndim is 1, we don't need to compute stride order 

934 code.writeline(f"in{i}_stride_order = stride_order(in{i}_strides)") 

935 else: 

936 code.writeline(f"in{i}_stride_order = (0,)") 

937 for i in range(schema.num_output_tensors()): 

938 code.writeline(f"out{i}_strides = out{i}.stride()") 

939 if not with_block_pointer: 

940 continue 

941 if ndim >= 2: 

942 code.writeline(f"out{i}_stride_order = stride_order(out{i}_strides)") 

943 else: 

944 code.writeline(f"out{i}_stride_order = (0,)") 

945 

946 code.writeline("with torch_device_fn.device(in0.device.index):") 

947 with code.indent(): 

948 code.writeline(f"{self.jit_fn_name}[grid](") 

949 with code.indent(): 

950 params = [] 

951 # NOTE: WRAP 

952 for i in range(schema.num_inputs()): 

953 if schema.is_tensor(i): 

954 params.append(f"{self.input_name(i)}") 

955 else: 

956 params.append(self.input_name(i)) 

957 for i in range(schema.num_output_tensors()): 

958 params.append(f"{self.output_name(i)}") 

959 

960 code.writeline(f"{_cs(params)},") 

961 

962 if ndim > 0: 

963 for i in range(schema.num_input_tensors()): 

964 s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim)) 

965 code.writeline(f"{s}, # stride for in{i}") 

966 if not with_block_pointer: 

967 continue 

968 order = ", ".join( 

969 f"in{i}_stride_order[{j}]" for j in range(ndim) 

970 ) 

971 code.writeline(f"{order}, # stride order for in{i}") 

972 

973 for i in range(schema.num_output_tensors()): 

974 s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim)) 

975 code.writeline(f"{s}, # stride for out{i}") 

976 if not with_block_pointer: 

977 continue 

978 order = ", ".join( 

979 f"out{i}_stride_order[{j}]" for j in range(ndim) 

980 ) 

981 code.writeline(f"{order}, # stride orderfor out{i}") 

982 

983 shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim)) 

984 code.writeline(f"{shape_args}, # task indexing space") 

985 code.writeline("num_tasks, # num tasks") 

986 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta") 

987 for i in range(ndim): 

988 code.writeline(f"tile_size{i}=tile_sizes[{i}],") 

989 code.writeline("one_tile_per_cta=one_tile_per_cta,") 

990 code.writeline("num_warps=num_warps,") 

991 code.writeline(")") 

992 

993 def gen_kernel_launch_1d( 

994 self, 

995 code: IndentedBuffer, 

996 ): 

997 schema = self.fx 

998 ndim = self.ndim 

999 

1000 code.writeline("# kernel launch") 

1001 for i in range(schema.num_input_tensors()): 

1002 code.writeline(f"in{i}_strides = in{i}.stride()") 

1003 for i in range(schema.num_output_tensors()): 

1004 code.writeline(f"out{i}_strides = out{i}.stride()") 

1005 

1006 code.writeline("with torch_device_fn.device(in0.device.index):") 

1007 with code.indent(): 

1008 code.writeline(f"{self.jit_fn_name}[grid](") 

1009 with code.indent(): 

1010 params = [] 

1011 # NOTE: WRAP 

1012 for i in range(schema.num_inputs()): 

1013 if schema.is_tensor(i): 

1014 params.append(f"{self.input_name(i)}") 

1015 else: 

1016 params.append(self.input_name(i)) 

1017 for i in range(schema.num_output_tensors()): 

1018 params.append(f"{self.output_name(i)}") 

1019 

1020 code.writeline(f"{_cs(params)},") 

1021 

1022 if ndim > 0: 

1023 for i in range(schema.num_input_tensors()): 

1024 s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim)) 

1025 code.writeline(f"{s}, # stride for in{i}") 

1026 for i in range(schema.num_output_tensors()): 

1027 s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim)) 

1028 code.writeline(f"{s}, # stride for out{i}") 

1029 

1030 shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim)) 

1031 code.writeline(f"{shape_args}, # task indexing space") 

1032 code.writeline("num_tasks, # num tasks") 

1033 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta") 

1034 code.writeline("tile_size=tile_size,") 

1035 code.writeline("one_tile_per_cta=one_tile_per_cta,") 

1036 code.writeline("num_warps=num_warps,") 

1037 code.writeline(")") 

1038 

1039 def gen_return(self, code: IndentedBuffer): 

1040 return_exprs = _cs(f"out{i}" for i in range(self.fx.num_output_tensors())) 

1041 code.writeline(f"return {return_exprs}") 

1042 

1043 def codegen_nd_tile(self, code): 

1044 self.gen_signature(code) 

1045 

1046 with code.indent(): 

1047 self.gen_docstring(code) 

1048 self.gen_same_shape_check(code) 

1049 self.gen_task_partition(code) 

1050 self.gen_kernel_launch(code) 

1051 self.gen_return(code) 

1052 code.newline() 

1053 return code 

1054 

1055 def codegen_1d_tile(self, code): 

1056 self.gen_signature(code) 

1057 

1058 with code.indent(): 

1059 self.gen_docstring(code) 

1060 self.gen_same_shape_check(code) 

1061 self.gen_task_partition_1d(code) 

1062 self.gen_kernel_launch_1d(code) 

1063 self.gen_return(code) 

1064 code.newline() 

1065 return code 

1066 

1067 

1068class ModuleGenerator: 

1069 def __init__( 

1070 self, 

1071 function_schema: FunctionSchema, 

1072 scalar_fn: triton.JITFunction, 

1073 ndim: int, 

1074 jit_fn_name: str, 

1075 wrapper_name: str, 

1076 config: CodeGenConfig, 

1077 ): 

1078 self.config = config 

1079 self.scalar_fn = scalar_fn 

1080 self.wrapper_gen = WrapperGenerator( 

1081 function_schema, jit_fn_name, ndim, wrapper_name, config 

1082 ) 

1083 self.kernel_gen = KernelGenerator( 

1084 function_schema, scalar_fn, ndim, jit_fn_name, config 

1085 ) 

1086 

1087 @staticmethod 

1088 def _collect_jit_deps(scalar_fn): 

1089 """Collect extra imports and local @triton.jit helper sources. 

1090 

1091 Parses the source module where scalar_fn is defined using AST. 

1092 Returns a tuple of: 

1093 - extra_imports: dict of module_path -> set of names 

1094 - local_sources: list of source strings for local @triton.jit 

1095 functions (those NOT decorated with @pointwise_dynamic) 

1096 """ 

1097 import ast 

1098 import inspect 

1099 

1100 py_fn = getattr(scalar_fn, "fn", scalar_fn) 

1101 module_name = getattr(py_fn, "__module__", None) 

1102 if not module_name: 

1103 return {}, [] 

1104 try: 

1105 mod = importlib.import_module(module_name) 

1106 source_file = inspect.getfile(mod) 

1107 except (ImportError, TypeError, OSError): 

1108 return {}, [] 

1109 try: 

1110 with open(source_file) as f: 

1111 module_source = f.read() 

1112 source_lines = module_source.splitlines(keepends=True) 

1113 tree = ast.parse(module_source) 

1114 except (OSError, SyntaxError): 

1115 return {}, [] 

1116 

1117 # Collect non-standard import-from lines 

1118 ALREADY_IMPORTED = { 

1119 "math", 

1120 "typing", 

1121 "torch", 

1122 "triton", 

1123 "triton.language", 

1124 "flag_gems.utils.shape_utils", 

1125 "flag_gems.utils.tensor_wrapper", 

1126 "flag_gems.utils.libentry", 

1127 "flag_gems.utils", 

1128 "flag_gems.runtime", 

1129 "flag_gems.utils.pointwise_dynamic", 

1130 } 

1131 extra_imports = {} 

1132 for node in ast.iter_child_nodes(tree): 

1133 if isinstance(node, ast.ImportFrom) and node.module: 

1134 if node.module in ALREADY_IMPORTED: 

1135 continue 

1136 names = {alias.name for alias in node.names} 

1137 extra_imports.setdefault(node.module, set()).update(names) 

1138 

1139 # Collect local @triton.jit functions (without @pointwise_dynamic) 

1140 def _has_decorator(func_node, name): 

1141 for dec in func_node.decorator_list: 

1142 src = "".join(source_lines[dec.lineno - 1 : dec.end_lineno]) 

1143 if name in src: 

1144 return True 

1145 return False 

1146 

1147 def _extract_source(func_node): 

1148 start = func_node.lineno - 1 

1149 if func_node.decorator_list: 

1150 start = func_node.decorator_list[0].lineno - 1 

1151 end = func_node.end_lineno 

1152 return "".join(source_lines[start:end]) 

1153 

1154 local_sources = [] 

1155 for node in ast.iter_child_nodes(tree): 

1156 if not isinstance(node, ast.FunctionDef): 

1157 continue 

1158 if not _has_decorator(node, "triton.jit") and not _has_decorator( 

1159 node, "jit" 

1160 ): 

1161 continue 

1162 if _has_decorator(node, "pointwise_dynamic"): 

1163 continue 

1164 local_sources.append(_extract_source(node)) 

1165 

1166 return extra_imports, local_sources 

1167 

1168 def generate_imports(self, code: IndentedBuffer) -> IndentedBuffer: 

1169 code.writeline("import math") 

1170 code.writeline("from typing import Union") 

1171 code.writeline("import torch") 

1172 code.writeline("import triton") 

1173 code.writeline("from triton import language as tl") 

1174 code.newline() 

1175 code.writeline("from flag_gems.utils.shape_utils import (") 

1176 code.writeline(" heuristics_for_tile_size,") 

1177 code.writeline(" heuristics_for_num_warps,") 

1178 code.writeline(" stride_order,") 

1179 code.writeline(")") 

1180 code.writeline("from flag_gems.utils.tensor_wrapper import StridedBuffer") 

1181 code.writeline("from flag_gems.utils.libentry import libentry") 

1182 code.writeline("from flag_gems.utils import triton_lang_extension as ext") 

1183 code.writeline("from flag_gems.runtime import torch_device_fn") 

1184 

1185 # Generate extra imports and local JIT deps of the scalar function 

1186 jit_dep_imports, local_jit_sources = self._collect_jit_deps(self.scalar_fn) 

1187 for module_path, names in sorted(jit_dep_imports.items()): 

1188 sorted_names = ", ".join(sorted(names)) 

1189 code.writeline(f"from {module_path} import {sorted_names}") 

1190 

1191 code.newline() 

1192 code.newline() 

1193 

1194 # Emit local @triton.jit helper functions 

1195 for source in local_jit_sources: 

1196 for line in source.splitlines(): 

1197 code.writeline(line) 

1198 code.newline() 

1199 

1200 return code 

1201 

1202 def codegen(self, code: IndentedBuffer): 

1203 code = self.generate_imports(code) 

1204 if self.config.prefer_1d_tile: 

1205 code = self.wrapper_gen.codegen_1d_tile(code) 

1206 code = self.kernel_gen.codegen_1d_tile(code) 

1207 else: 

1208 code = self.wrapper_gen.codegen_nd_tile(code) 

1209 code = self.kernel_gen.codegen_nd_tile(code) 

1210 return code 

1211 

1212 

1213@dataclass 

1214class KernelInfo: 

1215 """Information about a generated kernel for C++ integration.""" 

1216 

1217 file_path: str 

1218 kernel_name: str 

1219 wrapper_name: str 

1220 ndim: int 

1221 

1222 

1223class ComplexMode(Enum): 

1224 NONE = auto() 

1225 ELEMENTWISE = auto() # add/sub: view_as_real → same kernel → view_as_complex 

1226 CROSS = auto() # mul/div: split ar/ai/br/bi → cross_kernel 

1227 

1228 

1229@dataclass 

1230class ComplexStrategy: 

1231 mode: ComplexMode = ComplexMode.NONE 

1232 cross_kernel: object = None 

1233 tensorize_scalars: bool = False 

1234 fallback_target: object = None 

1235 

1236 

1237_REAL_TO_COMPLEX = { 

1238 torch.float16: torch.complex32, 

1239 torch.bfloat16: torch.complex32, 

1240 torch.float32: torch.complex64, 

1241 torch.float64: torch.complex128, 

1242} 

1243 

1244 

1245class PointwiseDynamicFunction: 

1246 """Utility to generate function for general pointwise operation. It generate wrapper & JITFunction 

1247 which are specialized according to the rank of the task space(the broadcasted shape of all input tensors). 

1248 The generated code are written out to the cache directory (defaults to ~/.flaggems). 

1249 """ 

1250 

1251 def __init__(self, op_desc: FunctionSchema, scalar_fn: JITFunction, config=None): 

1252 self.fx = op_desc 

1253 

1254 assert isinstance(scalar_fn, JITFunction) 

1255 self._scalar_fn = scalar_fn 

1256 self._scalar_fn_cache_key = scalar_fn.cache_key 

1257 self.pid = os.getpid() 

1258 

1259 self.config: CodeGenConfig = config or get_codegen_config() 

1260 

1261 # instantiated & cached overloads 

1262 self.overloads: Mapping[str, Callable] = {} 

1263 # cached kernel info for C++ integration 

1264 self._kernel_info_cache: Mapping[str, KernelInfo] = {} 

1265 

1266 # complex dispatch support 

1267 self.complex_strategy = ComplexStrategy() 

1268 self._operand_indices = self._infer_operand_indices() 

1269 

1270 # -------------------- operand index inference -------------------- 

1271 

1272 def _infer_operand_indices(self): 

1273 """Infer operand indices from schema._promotion_methods, done once at init.""" 

1274 indices = set() 

1275 for pm in self.fx._promotion_methods: 

1276 for idx in pm[:-1]: 

1277 indices.add(idx) 

1278 return frozenset(indices) 

1279 

1280 # -------------------- register_complex -------------------- 

1281 

1282 def register_complex( 

1283 self, mode, cross_kernel=None, tensorize_scalars=False, fallback_target=None 

1284 ): 

1285 """Register complex number support for this kernel. 

1286 

1287 Args: 

1288 mode: ComplexMode.ELEMENTWISE (add/sub) or ComplexMode.CROSS (mul/div). 

1289 cross_kernel: A PointwiseDynamicFunction for cross-term ops (mul/div). 

1290 tensorize_scalars: If True, scalar operands are converted to tensors 

1291 before delegating to fallback_target. 

1292 fallback_target: A PointwiseDynamicFunction (tensor-tensor version) 

1293 to delegate to after tensorizing scalar operands. 

1294 """ 

1295 self.complex_strategy = ComplexStrategy( 

1296 mode=mode, 

1297 cross_kernel=cross_kernel, 

1298 tensorize_scalars=tensorize_scalars, 

1299 fallback_target=fallback_target, 

1300 ) 

1301 return self 

1302 

1303 # -------------------- call entry -------------------- 

1304 

1305 def __call__(self, *args, **kwargs): 

1306 if self._should_use_complex_path(args): 

1307 return self._call_complex_dispatch(*args, **kwargs) 

1308 return self._call_real_impl(*args, **kwargs) 

1309 

1310 def _call_real_impl(self, *args, **kwargs): 

1311 """Single entry point for real kernel invocation.""" 

1312 ndim, args, kwargs = self.prepare_args(*args, **kwargs) 

1313 overload = self.instantiate(ndim) 

1314 out = overload(*args, **kwargs) 

1315 return self._unwrap(out) 

1316 

1317 # -------------------- complex helpers -------------------- 

1318 

1319 @staticmethod 

1320 def _is_complex_arg(a): 

1321 return (isinstance(a, torch.Tensor) and a.is_complex()) or isinstance( 

1322 a, complex 

1323 ) 

1324 

1325 def _should_use_complex_path(self, args): 

1326 if self.complex_strategy.mode == ComplexMode.NONE: 

1327 return False 

1328 return any( 

1329 self._is_complex_arg(args[i]) 

1330 for i in self._operand_indices 

1331 if i < len(args) 

1332 ) 

1333 

1334 def _split_args(self, args): 

1335 """Split args into operands and others by original position index.""" 

1336 operands = {} 

1337 others = {} 

1338 for i, a in enumerate(args): 

1339 if i in self._operand_indices: 

1340 operands[i] = a 

1341 else: 

1342 others[i] = a 

1343 return operands, others 

1344 

1345 def _merge_args(self, operands, others): 

1346 """Rebuild args tuple from operands and others by original position index.""" 

1347 total = len(operands) + len(others) 

1348 merged = [None] * total 

1349 for i, v in operands.items(): 

1350 merged[i] = v 

1351 for i, v in others.items(): 

1352 merged[i] = v 

1353 return tuple(merged) 

1354 

1355 def _classify_complex_inputs(self, operands): 

1356 """Classify operands as 'all_complex', 'mixed', or 'real'.""" 

1357 complex_count = sum(1 for v in operands.values() if self._is_complex_arg(v)) 

1358 if complex_count == len(operands): 

1359 return "all_complex" 

1360 elif complex_count > 0: 

1361 return "mixed" 

1362 return "real" 

1363 

1364 def _infer_device(self, operands): 

1365 for v in operands.values(): 

1366 if isinstance(v, torch.Tensor): 

1367 return v.device 

1368 return None 

1369 

1370 def _infer_complex_dtype(self, operands): 

1371 return torch.result_type(*operands.values()) 

1372 

1373 def _tensorize_scalar_operands(self, operands, dtype, device): 

1374 """Convert scalar operands to tensors.""" 

1375 result = {} 

1376 for i, v in operands.items(): 

1377 if not isinstance(v, torch.Tensor): 

1378 if isinstance(v, complex): 

1379 result[i] = torch.tensor(v, dtype=dtype, device=device) 

1380 elif isinstance(v, float): 

1381 result[i] = torch.tensor(v, dtype=torch.float32, device=device) 

1382 elif isinstance(v, (int, bool)): 

1383 result[i] = torch.tensor(v, dtype=torch.int64, device=device) 

1384 else: 

1385 result[i] = v 

1386 else: 

1387 result[i] = v 

1388 return result 

1389 

1390 def _to_complex_tensor(self, a, target_dtype, device): 

1391 """Convert a scalar or real tensor to a complex tensor.""" 

1392 if isinstance(a, torch.Tensor): 

1393 if a.is_complex(): 

1394 return a 

1395 if a.is_floating_point(): 

1396 cdtype = _REAL_TO_COMPLEX.get(a.dtype, torch.complex64) 

1397 else: 

1398 a = a.to(torch.float32) 

1399 cdtype = torch.complex64 

1400 return torch.complex(a, torch.zeros_like(a)).to(cdtype) 

1401 elif isinstance(a, complex): 

1402 return torch.tensor(a, dtype=target_dtype, device=device) 

1403 elif isinstance(a, (int, float)): 

1404 return torch.tensor(complex(a, 0), dtype=target_dtype, device=device) 

1405 return a 

1406 

1407 # -------------------- complex dispatch -------------------- 

1408 

1409 def _call_complex_dispatch(self, *args, **kwargs): 

1410 """Unified complex dispatch entry point.""" 

1411 strategy = self.complex_strategy 

1412 operands, others = self._split_args(args) 

1413 

1414 device = self._infer_device(operands) 

1415 result_dtype = self._infer_complex_dtype(operands) 

1416 

1417 # tensorize scalar operands and delegate to fallback_target 

1418 if strategy.tensorize_scalars and strategy.fallback_target is not None: 

1419 operands = self._tensorize_scalar_operands(operands, result_dtype, device) 

1420 new_args = self._merge_args(operands, others) 

1421 return strategy.fallback_target(*new_args, **kwargs) 

1422 

1423 # convert all operands to complex tensors 

1424 for i in list(operands.keys()): 

1425 operands[i] = self._to_complex_tensor(operands[i], result_dtype, device) 

1426 

1427 # broadcast complex tensor operands 

1428 complex_tensors = [operands[i] for i in sorted(operands.keys())] 

1429 complex_tensors = torch.broadcast_tensors(*complex_tensors) 

1430 for idx, key in enumerate(sorted(operands.keys())): 

1431 operands[key] = complex_tensors[idx] 

1432 

1433 classification = self._classify_complex_inputs(operands) 

1434 

1435 if strategy.mode == ComplexMode.CROSS and classification == "all_complex": 

1436 return self._call_complex_cross(operands, result_dtype) 

1437 elif classification in ("all_complex", "mixed"): 

1438 return self._call_complex_elementwise( 

1439 operands, others, result_dtype, kwargs 

1440 ) 

1441 else: 

1442 new_args = self._merge_args(operands, others) 

1443 return self._call_real_impl(*new_args, **kwargs) 

1444 

1445 def _call_complex_elementwise(self, operands, others, result_dtype, kwargs): 

1446 """Elementwise: view_as_real -> call real kernel -> view_as_complex.""" 

1447 real_tensors = {i: torch.view_as_real(t) for i, t in operands.items()} 

1448 

1449 # promote to common real dtype 

1450 dtypes = [t.dtype for t in real_tensors.values()] 

1451 common_dtype = dtypes[0] 

1452 for d in dtypes[1:]: 

1453 common_dtype = torch.promote_types(common_dtype, d) 

1454 real_tensors = {i: t.to(common_dtype) for i, t in real_tensors.items()} 

1455 

1456 new_args = self._merge_args(real_tensors, others) 

1457 out_real = self._call_real_impl(*new_args, **kwargs) 

1458 return torch.view_as_complex(out_real.contiguous()).to(result_dtype) 

1459 

1460 def _call_complex_cross(self, operands, result_dtype): 

1461 """Cross-term: split ar/ai/br/bi -> call cross_kernel -> stack -> view_as_complex.""" 

1462 sorted_keys = sorted(operands.keys()) 

1463 A, B = operands[sorted_keys[0]], operands[sorted_keys[1]] 

1464 Ar = torch.view_as_real(A) 

1465 Br = torch.view_as_real(B) 

1466 ar, ai = Ar[..., 0], Ar[..., 1] 

1467 br, bi = Br[..., 0], Br[..., 1] 

1468 

1469 common_dtype = torch.promote_types(ar.dtype, br.dtype) 

1470 ar, ai = ar.to(common_dtype), ai.to(common_dtype) 

1471 br, bi = br.to(common_dtype), bi.to(common_dtype) 

1472 

1473 real, imag = self.complex_strategy.cross_kernel(ar, ai, br, bi) 

1474 

1475 out = torch.stack((real, imag), dim=-1) 

1476 return torch.view_as_complex(out.contiguous()).to(result_dtype) 

1477 

1478 @staticmethod 

1479 def use_fast_path(tensors): 

1480 return all_the_same_shape(tensors) and ( 

1481 all_c_contiguous(tensors) 

1482 or ( 

1483 all_the_same_stride(tensors) 

1484 and torch.ops.aten.is_non_overlapping_and_dense(tensors[0]) 

1485 ) 

1486 ) 

1487 

1488 def prepare_args(self, *args, _skip_tensor_check=False, **kwargs): 

1489 # output allocation(when needed) 

1490 # task simplification & task-rank infernece & input-output reinterpretation 

1491 schema = self.fx 

1492 outputs_that_need_allocation: List[int] = [] 

1493 out_tensors = [] 

1494 for i in range(schema.num_output_tensors()): 

1495 k = f"out{i}" 

1496 if k in kwargs: 

1497 out_tensors.append(kwargs[k]) 

1498 else: 

1499 outputs_that_need_allocation.append(i) 

1500 # input arguments must be passed by position 

1501 if not _skip_tensor_check and schema._is_tensor is not None: 

1502 if not check_tensor_attributes(args, (schema._is_tensor)): 

1503 raise ValueError( 

1504 "Input arguments must be passed by position, and the corresponding dtype must be specified." 

1505 ) 

1506 in_tensors = [item for i, item in enumerate(args) if schema.is_tensor(i)] 

1507 

1508 # output dtype promotions 

1509 outputs_dtypes_for_allocation = [] 

1510 for i in outputs_that_need_allocation: 

1511 *arg_indices, method = schema._promotion_methods[i] 

1512 promote_args = (args[j] for j in arg_indices) 

1513 _, dtype = type_promotion(*promote_args, type_promotion=method) 

1514 outputs_dtypes_for_allocation.append(dtype) 

1515 

1516 tensors = out_tensors + in_tensors 

1517 INT32_MAX = torch.iinfo(torch.int32).max 

1518 if tensors[0].numel() > INT32_MAX: 

1519 self.config.prefer_block_pointer = False 

1520 if self.use_fast_path(tensors): # dimension collapse & use physical ordering 

1521 allocated_outputs = [ 

1522 torch.empty_like(tensors[0], dtype=dtype) 

1523 for dtype in outputs_dtypes_for_allocation 

1524 ] 

1525 task_shape = (tensors[0].numel(),) 

1526 strides = (1,) 

1527 ndim = 1 

1528 args = tuple( 

1529 ( 

1530 StridedBuffer(item, task_shape, strides) 

1531 if schema.is_tensor(i) 

1532 else item 

1533 ) 

1534 for i, item in enumerate(args) 

1535 ) 

1536 kwargs = { 

1537 k: StridedBuffer(item, task_shape, strides) 

1538 for k, item in kwargs.items() 

1539 } 

1540 for seq_id, output_id in enumerate(outputs_that_need_allocation): 

1541 kwargs[f"out{output_id}"] = StridedBuffer( 

1542 allocated_outputs[seq_id], task_shape, strides 

1543 ) 

1544 else: 

1545 # a simple strategy: all the undefined tensors will follow the first 

1546 # tensor that is not broadcated, no attempts to simplify task, no reordering, 

1547 # no dimenion collapsing 

1548 shapes = tuple(item.shape for item in in_tensors) 

1549 

1550 task_shape = broadcast_shapes(shapes) 

1551 

1552 if out_tensors: 

1553 for index, item in enumerate(out_tensors): 

1554 if list(item.shape) != list(task_shape): 

1555 raise RuntimeError( 

1556 f"out tensor at index {index} shape is invalid, should be {task_shape} but is {item.shape}!" 

1557 ) 

1558 # output arguments must not have internal overlapping for pointwise operation 

1559 if has_internal_overlapping(item) == MemOverlap.Yes: 

1560 raise RuntimeError( 

1561 "Pointwise Input arguments should not have internal overlapping." 

1562 ) 

1563 

1564 ndim = len(task_shape) 

1565 for item in tensors: 

1566 if item.shape == task_shape: 

1567 allocated_outputs = [ 

1568 torch.empty_like(item, dtype=dtype) 

1569 for dtype in outputs_dtypes_for_allocation 

1570 ] 

1571 break 

1572 else: # nobreak 

1573 device = tensors[0].device 

1574 allocated_outputs = [ 

1575 torch.empty(task_shape, dtype=dtype, device=device) 

1576 for dtype in outputs_dtypes_for_allocation 

1577 ] 

1578 args = tuple( 

1579 ( 

1580 StridedBuffer( 

1581 item, 

1582 task_shape, 

1583 broadcasted_stride(item.shape, item.stride(), task_shape), 

1584 ) 

1585 if schema.is_tensor(i) 

1586 else item 

1587 ) 

1588 for i, item in enumerate(args) 

1589 ) 

1590 kwargs = { 

1591 k: StridedBuffer( 

1592 item, 

1593 task_shape, 

1594 broadcasted_stride(item.shape, item.stride(), task_shape), 

1595 ) 

1596 for k, item in kwargs.items() 

1597 } 

1598 for seq_id, output_id in enumerate(outputs_that_need_allocation): 

1599 item = allocated_outputs[seq_id] 

1600 kwargs[f"out{output_id}"] = StridedBuffer( 

1601 item, 

1602 task_shape, 

1603 broadcasted_stride(item.shape, item.stride(), task_shape), 

1604 ) 

1605 return (ndim, args, kwargs) 

1606 

1607 def _unwrap(self, tensors): 

1608 # unwrap StridedBuffer to get Tensor 

1609 if self.fx.num_output_tensors() == 1: 

1610 item = tensors 

1611 return item.unwrap() 

1612 return tuple(item.unwrap() for item in tensors) 

1613 

1614 def _compute_kernel_names(self, ndim: int) -> Tuple[str, str, str]: 

1615 """Compute kernel name, wrapper name, and file path for a given ndim. 

1616 

1617 This is the single source of truth for naming, used by both instantiate() 

1618 and get_kernel_info() to ensure consistency. 

1619 

1620 Returns: 

1621 Tuple of (kernel_name, wrapper_name, file_path) 

1622 """ 

1623 scalar_fn_name = self._scalar_fn.__name__ 

1624 kernel_name = f"{scalar_fn_name}_kernel_rank_{ndim}" 

1625 wrapper_name = f"{scalar_fn_name}_wrapper_rank_{ndim}" 

1626 

1627 file_name = ( 

1628 f"pointwise_dynamic_{self._scalar_fn_cache_key}_{kernel_name}_" 

1629 f"{'1d_tile_' if self.config.prefer_1d_tile else ''}" 

1630 f"{'bptr' if (not self.config.prefer_1d_tile and self.config.prefer_block_pointer) else ''}" 

1631 ".py" 

1632 ) 

1633 file_path = str(code_cache_dir() / file_name) 

1634 

1635 return kernel_name, wrapper_name, file_path 

1636 

1637 def instantiate(self, ndim): 

1638 # NOTE: manually instantiated overload does not have `prepare_args` as 

1639 # preprocessing, so you have to manually allocate output and make sure that 

1640 # the inputs & ouputs actually fits the manually instantiated overload 

1641 key = f"{ndim}_{self.config.prefer_block_pointer}" 

1642 if key in self.overloads: 

1643 return self.overloads[key] 

1644 

1645 code = IndentedBuffer() 

1646 

1647 # Use helper to compute names (single source of truth) 

1648 kernel_name, wrapper_name, file_path = self._compute_kernel_names(ndim) 

1649 

1650 module_gen = ModuleGenerator( 

1651 self.fx, 

1652 self._scalar_fn, 

1653 ndim, 

1654 kernel_name, 

1655 wrapper_name, 

1656 self.config, 

1657 ) 

1658 module_gen.codegen(code) 

1659 

1660 # NOTE: [why write the generated code to a file] 

1661 # triton uses inpsect to get the source of the jitted function, which requires 

1662 # that the source code can be found by inspect 

1663 # We write it into a file, since inspect cannot find the source of functions dynamically 

1664 # created via exec string. We can help inspect to find the source by hacking linecache 

1665 # library, but we find generating a module simpler, since we can generating 2 functions 

1666 # the kernel and the wrapper, and the wrapper calls the kernel. 

1667 write_atomic(file_path, code.getvalue()) 

1668 

1669 # load 

1670 spec = importlib.util.spec_from_file_location( 

1671 f"_gen_module_{self._scalar_fn_cache_key}_rank_{ndim}", 

1672 file_path, 

1673 ) 

1674 m = importlib.util.module_from_spec(spec) 

1675 # do not expose it to sys.modules 

1676 # sys.modules["_add_module"] = m 

1677 

1678 # NOTE: [why not import the scalar function] 

1679 # we do not re-import the scalar function, although the generated kernel **calls** it 

1680 # Since a function's __name__ may be changed, from the module where it is defined import its 

1681 # __name__ is not same; Also the same may be rebind to something else, importing via name 

1682 # cannot guarantee that scalar function is imported. 

1683 # So we copy the scalar function and its __globals__ to the generated module to do this 

1684 # https://stackoverflow.com/questions/11170949/how-to-make-a-copy-of-a-python-module-at-runtime 

1685 spec.loader.exec_module(m) 

1686 m.__dict__.update(self._scalar_fn.__globals__) 

1687 m.__dict__[self._scalar_fn.__name__] = self._scalar_fn 

1688 

1689 overload = getattr(m, wrapper_name) 

1690 self.overloads[key] = overload 

1691 

1692 # Cache kernel info for C++ integration 

1693 self._kernel_info_cache[key] = KernelInfo( 

1694 file_path=file_path, 

1695 kernel_name=kernel_name, 

1696 wrapper_name=wrapper_name, 

1697 ndim=ndim, 

1698 ) 

1699 

1700 return overload 

1701 

1702 def get_kernel_info(self, ndim: int) -> KernelInfo: 

1703 """Get kernel information for a given ndim. 

1704 

1705 This method is useful for C++ integration to get the file path and 

1706 kernel name without duplicating the naming logic. 

1707 

1708 If the kernel hasn't been instantiated yet, this will instantiate it first. 

1709 

1710 Args: 

1711 ndim: The rank of the task space 

1712 

1713 Returns: 

1714 KernelInfo with file_path, kernel_name, wrapper_name, and ndim 

1715 """ 

1716 key = f"{ndim}_{self.config.prefer_block_pointer}" 

1717 

1718 # Ensure the kernel is instantiated 

1719 if key not in self._kernel_info_cache: 

1720 self.instantiate(ndim) 

1721 

1722 return self._kernel_info_cache[key] 

1723 

1724 

1725def pointwise_dynamic( 

1726 f: Optional[JITFunction] = None, 

1727 *, 

1728 num_inputs: Optional[int] = None, 

1729 is_tensor: Optional[List[bool]] = None, 

1730 dtypes: Optional[List[Optional[type]]] = None, 

1731 num_outputs: Optional[int] = None, 

1732 promotion_methods: Optional[Tuple[int, ...]] = None, 

1733 config: Optional[CodeGenConfig] = None, 

1734): 

1735 def decorator(fn): 

1736 nonlocal num_inputs 

1737 if (num_inputs is None) and (is_tensor is None) and (dtypes is None): 

1738 num_inputs = len(fn.arg_names) 

1739 op_desc = FunctionSchema( 

1740 num_inputs=num_inputs, 

1741 is_tensor=is_tensor, 

1742 dtypes=dtypes, 

1743 num_outputs=num_outputs, 

1744 promotion_methods=promotion_methods, 

1745 ) 

1746 return PointwiseDynamicFunction(op_desc, fn, config) 

1747 

1748 if f is not None: 

1749 return decorator(f) 

1750 return decorator