Coverage for src/flag_gems/utils/pointwise_dynamic.py: 94%

1016 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import importlib 

2import os 

3from dataclasses import dataclass 

4from enum import Enum, auto 

5from typing import Callable, Iterable, List, Mapping, Optional, Sequence, Tuple 

6 

7import torch 

8import triton 

9from triton.runtime.jit import JITFunction 

10 

11from flag_gems.utils.code_cache import code_cache_dir 

12from flag_gems.utils.code_utils import IndentedBuffer, write_atomic 

13from flag_gems.utils.codegen_config_utils import CodeGenConfig, get_codegen_config 

14from flag_gems.utils.device_info import get_device_capability 

15from flag_gems.utils.shape_utils import ( 

16 MemOverlap, 

17 all_c_contiguous, 

18 all_the_same_shape, 

19 all_the_same_stride, 

20 broadcast_shapes, 

21 broadcasted_stride, 

22 check_tensor_attributes, 

23 has_internal_overlapping, 

24) 

25from flag_gems.utils.tensor_wrapper import StridedBuffer 

26from flag_gems.utils.type_utils import ELEMENTWISE_TYPE_PROMOTION_KIND, type_promotion 

27 

28 

29# ------------------ Operation Description --------------------------- 

30def _type_name(type) -> str: 

31 "Render typename as string, work for both (bool, int, float, str) and torch.dtype object" 

32 if type in (bool, int, float, str): 

33 return type.__name__ 

34 if isinstance(type, torch.dtype): 

35 return str(type) 

36 return str(type) 

37 

38 

39def _check_typed_list(container, type): 

40 for item in container: 

41 assert isinstance(item, type) 

42 

43 

44def _check_sized_list(container, size): 

45 assert len(container) == size 

46 

47 

48def _tuple_content(strings: Sequence[str]) -> str: 

49 # comma separated list 

50 if len(strings) == 0: 

51 return "" 

52 if len(strings) == 1: 

53 return f"{strings[0]}," 

54 else: 

55 return ", ".join(strings) 

56 

57 

58def _cs(strings: Iterable[str]) -> str: 

59 return ", ".join(strings) 

60 

61 

62def _broadcast_vec(i, ndim): 

63 axes = [":" if j == i else "None" for j in range(ndim)] 

64 return f"[{_cs(axes)}]" 

65 

66 

67class FunctionSchema: 

68 _num_inputs: int 

69 _is_tensor: List[bool] 

70 _dtypes: List[Optional[type]] 

71 

72 _num_input_tensors: int 

73 _num_non_tensor_inputs: int 

74 

75 _num_outputs: int 

76 _promotion_methods: List[Tuple[int, ...]] 

77 

78 def __init__( 

79 self, 

80 *, 

81 num_inputs: Optional[int] = None, 

82 is_tensor: Optional[List[bool]] = None, 

83 dtypes: Optional[List[Optional[type]]] = None, 

84 num_outputs: Optional[int] = None, 

85 promotion_methods=None, 

86 ): 

87 if is_tensor is not None: 

88 _check_typed_list(is_tensor, bool) 

89 if dtypes is not None: 

90 _check_typed_list(dtypes, (type, type(None))) 

91 

92 if promotion_methods is None: 

93 raise ValueError( 

94 "No type promotion method provided! You must provide type promotion method for each output!" 

95 ) 

96 else: 

97 self._promotion_methods = self.canonicalize_promotion_methods( 

98 promotion_methods 

99 ) 

100 if num_inputs is not None: 

101 self._num_inputs = num_inputs 

102 if is_tensor is not None: 

103 _check_sized_list(is_tensor, num_inputs) 

104 self._is_tensor = is_tensor 

105 else: 

106 self._is_tensor = [True] * num_inputs 

107 

108 if dtypes is not None: 

109 _check_sized_list(dtypes, num_inputs) 

110 self._dtypes = dtypes 

111 else: 

112 self._dtypes = [None] * num_inputs 

113 elif is_tensor is not None: 

114 self._num_inputs = len(is_tensor) 

115 self._is_tensor = is_tensor 

116 if dtypes is not None: 

117 _check_sized_list(dtypes, self._num_inputs) 

118 self._dtypes = dtypes 

119 else: 

120 self._dtypes = [None] * self._num_inputs 

121 elif dtypes is not None: 

122 self._num_inputs = len(dtypes) 

123 self._dtypes = dtypes 

124 if is_tensor is not None: 

125 _check_sized_list(is_tensor, self._num_inputs) 

126 self._is_tensor = is_tensor 

127 else: 

128 self._is_tensor = [item is None for item in dtypes] 

129 else: 

130 raise ValueError( 

131 "Cannot create FunctionSchema when none of (num_inputs, is_tensor, dtypes) is specified." 

132 ) 

133 

134 if num_outputs is not None: 

135 self._num_outputs = num_outputs 

136 _check_sized_list(promotion_methods, num_outputs) 

137 else: 

138 self._num_outputs = len(promotion_methods) 

139 

140 assert self._num_inputs >= 1 

141 assert self._num_outputs >= 1 

142 

143 self._num_input_tensors = sum(self._is_tensor) 

144 self._num_non_tensor_inputs = self._num_inputs - self._num_input_tensors 

145 self._input_id = self._compute_input_id() 

146 

147 @staticmethod 

148 def canonicalize_promotion_methods(promotion_methods): 

149 canonicalized = [] 

150 for item in promotion_methods: 

151 *arg_indices, method = item 

152 canonicalized.append( 

153 (*arg_indices, ELEMENTWISE_TYPE_PROMOTION_KIND[method]) 

154 ) 

155 return canonicalized 

156 

157 def num_inputs(self): 

158 # num of arguments, outputs not included 

159 return self._num_inputs 

160 

161 def num_outputs(self): 

162 return self._num_outputs 

163 

164 def is_tensor(self, arg_id: int) -> bool: 

165 return self._is_tensor[arg_id] 

166 

167 def input_type(self, arg_id) -> Optional[type]: 

168 return self._dtypes[arg_id] 

169 

170 def output_type(self, i): 

171 return self._promotion_methods[i] 

172 

173 def num_input_tensors(self) -> int: 

174 return self._num_input_tensors 

175 

176 def num_output_tensors(self) -> int: 

177 return self._num_outputs 

178 

179 def num_non_tensor_args(self) -> int: 

180 return self._num_non_tensor_inputs 

181 

182 def signature(self, outputs_in_arg: bool = False) -> str: 

183 input_types = [] 

184 for is_tensor, dtype in zip(self._is_tensor, self._dtypes): 

185 if is_tensor: 

186 input_types.append("StridedBuffer") 

187 else: 

188 if dtype is None: 

189 input_types.append("scalar") 

190 else: 

191 input_types.append(_type_name(dtype)) 

192 

193 output_types = [] 

194 

195 if outputs_in_arg: 

196 for i in range(self.num_outputs()): 

197 output_types.append(f"StridedBuffer(a{1}!)") 

198 input_types.extend(output_types) 

199 else: 

200 for _ in range(self.num_outputs()): 

201 output_types.append("StridedBuffer") 

202 sig = f'Pointwise: {", ".join(input_types)} -> {", ".join(output_types)}' 

203 return sig 

204 

205 def _compute_input_id(self): 

206 input_tensor_index = 0 

207 non_tensor_index = 0 

208 mapping: List[int] = [] 

209 for i in range(self.num_inputs()): 

210 if self.is_tensor(i): 

211 mapping.append(input_tensor_index) 

212 input_tensor_index += 1 

213 else: 

214 mapping.append(non_tensor_index) 

215 non_tensor_index += 1 

216 return mapping 

217 

218 def input_index(self, idx): 

219 return self._input_id[idx] 

220 

221 def __str__(self) -> str: 

222 return self.signature(outputs_in_arg=False) 

223 

224 

225class KernelGenerator: 

226 def __init__( 

227 self, 

228 function_schema: FunctionSchema, 

229 scalar_fn: triton.JITFunction, 

230 rank: int, 

231 name: str, 

232 config: CodeGenConfig, 

233 ): 

234 self.fx = function_schema 

235 self.fn = scalar_fn 

236 self.ndim = rank 

237 self.name = name 

238 self.config = config 

239 

240 self.fn_name = scalar_fn.__name__ 

241 self.fn_module = scalar_fn.__module__ 

242 

243 def gen_import_function(self, code: IndentedBuffer): 

244 code.writeline("@triton.jit") 

245 code.writemultiline(self.fn.src) 

246 code.newline() 

247 

248 def gen_decorators(self, code): 

249 code.writeline("@libentry()") 

250 num_non_tensor_args = self.fx.num_non_tensor_args() 

251 if num_non_tensor_args > 0: 

252 # we do not specialize non tensor args since they are passed into the inlined function 

253 # which means that their values may not deserve specialization 

254 non_specialize_arg_names = [f"val{i}" for i in range(num_non_tensor_args)] 

255 code.writeline(f"@triton.jit(do_not_specialize={non_specialize_arg_names})") 

256 else: 

257 code.writeline("@triton.jit") 

258 

259 def input_name(self, i): 

260 is_tensor = self.fx.is_tensor(i) 

261 name = "in" if is_tensor else "val" 

262 index = self.fx.input_index(i) 

263 return f"{name}{index}" 

264 

265 def output_name(self, i): 

266 return f"out{i}" 

267 

268 def gen_signature(self, code, with_block_pointer=False): 

269 code.writeline(f"def {self.name}(") 

270 with code.indent(): 

271 input_tensor_index = 0 

272 non_tensor_index = 0 

273 output_tensor_index = 0 

274 

275 schema = self.fx 

276 # signature: inputs ptrs & non tensor inputs 

277 for i in range(schema.num_inputs()): 

278 if schema.is_tensor(i): 

279 code.writeline( 

280 f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" 

281 ) 

282 input_tensor_index += 1 

283 else: 

284 if schema.input_type(i) is not None: 

285 code.writeline( 

286 f"val{non_tensor_index}: {_type_name(schema.input_type(i))}," 

287 ) 

288 else: 

289 code.writeline(f"val{non_tensor_index},") 

290 non_tensor_index += 1 

291 

292 # signature: output ptrs 

293 for i in range(schema.num_outputs()): 

294 code.writeline( 

295 f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" 

296 ) 

297 output_tensor_index += 1 

298 

299 # signature: strides, for each tensor arguments 

300 ndim = self.ndim 

301 if ndim > 0: 

302 # strides for inputs 

303 for i in range(schema.num_input_tensors()): 

304 stride_args = _cs(f"in{i}_stride{j}: int" for j in range(ndim)) 

305 code.writeline(f"{stride_args}, # strides for in{i}") 

306 if with_block_pointer: 

307 stride_order_args = _cs( 

308 f"in{i}_stride_order{j}: tl.constexpr" for j in range(ndim) 

309 ) 

310 code.writeline(f"{stride_order_args}, # stride order for in{i}") 

311 

312 # strides for outputs 

313 for i in range(schema.num_output_tensors()): 

314 stride_args = _cs(f"out{i}_stride{j}: int" for j in range(ndim)) 

315 code.writeline(f"{stride_args}, # strides for out{i}") 

316 if with_block_pointer: 

317 stride_order_args = _cs( 

318 f"out{i}_stride_order{j}: tl.constexpr" for j in range(ndim) 

319 ) 

320 code.writeline( 

321 f"{stride_order_args}, # stride order for out{i}" 

322 ) 

323 

324 # task space, used to reconstruct multi index 

325 task_space_args = _cs(f"s{i}" for i in range(ndim)) 

326 code.writeline(f"{task_space_args}, # task_space") 

327 

328 # number of tasks, used to compute mask 

329 code.writeline("num_tasks,") 

330 

331 # tile size & tiles_per_cta, gsl style 

332 if ndim > 0: 

333 code.writeline("tiles_per_cta: int,") 

334 tile_sizes = _cs(f"tile_size{i}: tl.constexpr" for i in range(ndim)) 

335 code.writeline(f"{tile_sizes},") 

336 code.writeline("one_tile_per_cta: tl.constexpr,") 

337 code.writeline("):") 

338 

339 def gen_signature_1d_tile(self, code): 

340 code.writeline(f"def {self.name}(") 

341 with code.indent(): 

342 input_tensor_index = 0 

343 non_tensor_index = 0 

344 output_tensor_index = 0 

345 

346 schema = self.fx 

347 # signature: inputs ptrs & non tensor inputs 

348 for i in range(schema.num_inputs()): 

349 if schema.is_tensor(i): 

350 code.writeline( 

351 f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" 

352 ) 

353 input_tensor_index += 1 

354 else: 

355 if schema.input_type(i) is not None: 

356 code.writeline( 

357 f"val{non_tensor_index}: {_type_name(schema.input_type(i))}," 

358 ) 

359 else: 

360 code.writeline(f"val{non_tensor_index},") 

361 non_tensor_index += 1 

362 

363 # signature: output ptrs 

364 for i in range(schema.num_outputs()): 

365 code.writeline( 

366 f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" 

367 ) 

368 output_tensor_index += 1 

369 

370 # signature: strides, for each tensor arguments 

371 ndim = self.ndim 

372 if ndim > 0: 

373 # strides for inputs 

374 for i in range(schema.num_input_tensors()): 

375 stride_args = _cs(f"in{i}_stride{j}: int" for j in range(ndim)) 

376 code.writeline(f"{stride_args}, # strides for in{i}") 

377 

378 # strides for outputs 

379 for i in range(schema.num_output_tensors()): 

380 stride_args = _cs(f"out{i}_stride{j}: int" for j in range(ndim)) 

381 code.writeline(f"{stride_args}, # strides for out{i}") 

382 

383 # task space, used to reconstruct multi index 

384 task_space_args = _cs(f"s{i}" for i in range(ndim)) 

385 code.writeline(f"{task_space_args}, # task_space") 

386 

387 # number of tasks, used to compute mask 

388 code.writeline("num_tasks,") 

389 

390 # tile size & tiles_per_cta, gsl style 

391 if ndim > 0: 

392 code.writeline("tiles_per_cta: int,") 

393 code.writeline("tile_size: tl.constexpr,") 

394 code.writeline("one_tile_per_cta: tl.constexpr,") 

395 code.writeline("):") 

396 

397 def gen_num_tiles(self, code): 

398 # tile-grid size 

399 ndim = self.ndim 

400 for i in range(ndim): 

401 if i < ndim: 

402 code.writeline(f"num_tiles{i} = tl.cdiv(s{i}, tile_size{i})") 

403 

404 def gen_body_for_0d(self, code): 

405 schema = self.fx 

406 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())] 

407 outputs_to_scalar_fn = [ 

408 self.output_name(i) for i in range(schema.num_output_tensors()) 

409 ] 

410 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

411 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

412 

413 code.writeline("# loads") 

414 for i in range(schema.num_input_tensors()): 

415 code.writeline( 

416 f"in{i} = tl.load(in{i}_ptr).to(in{i}_ptr.type.element_ty) " 

417 "# workaround the bug on bool, we should use the pointer's dtype)" 

418 ) 

419 code.newline() 

420 

421 code.writeline("# compute") 

422 code.writeline( 

423 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

424 ) 

425 code.newline() 

426 

427 code.writeline("# stores") 

428 for i in range(schema.num_output_tensors()): 

429 code.writeline( 

430 f"tl.store(out{i}_ptr, out{i}.to(out{i}_ptr.type.element_ty))" 

431 ) 

432 code.newline() 

433 return code 

434 

435 # nd tile 1d grid kernel with block pointer 

436 def gen_body_one_tile_per_cta_with_bptr(self, code): 

437 ndim = self.ndim 

438 schema = self.fx 

439 

440 # block pointer for each operand 

441 shape = _tuple_content(tuple(f"s{i}" for i in range(ndim))) 

442 offsets = _tuple_content(tuple(f"offset{i}" for i in range(ndim))) 

443 tile_sizes = _tuple_content(tuple(f"tile_size{i}" for i in range(ndim))) 

444 

445 # reconstruct pid multi index 

446 code.writeline( 

447 "# pid multi index recontruction: we use c ordering, right axes changes fastest" 

448 ) 

449 for i in reversed(range(ndim)): 

450 if i > 0: 

451 code.writeline(f"tile_id{i} = tile_id % num_tiles{i}") 

452 code.writeline(f"tile_id //= num_tiles{i}") 

453 else: 

454 code.writeline(f"tile_id{i} = tile_id") 

455 code.newline() 

456 

457 # cta_offsets 

458 code.writeline("# tile offsets") 

459 for i in range(ndim): 

460 # Or else: AssertionError: Block pointers only support 32 bit 

461 # `offsets/block_shape`, add a `.to(tl.int32)` or use regular indexing 

462 # for 64 bit support 

463 code.writeline(f"offset{i} = (tile_id{i} * tile_size{i}).to(tl.int32)") 

464 

465 # loads 

466 code.writeline("# loads") 

467 for i in range(schema.num_input_tensors()): 

468 strides = _tuple_content(tuple(f"in{i}_stride{j}" for j in range(ndim))) 

469 order = _tuple_content(tuple(f"in{i}_stride_order{j}" for j in range(ndim))) 

470 code.writeline( 

471 f"in{i}_bptr = tl.make_block_ptr(" 

472 f"in{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))" 

473 ) 

474 code.writeline( 

475 f"in{i} = tl.load(in{i}_bptr, boundary_check=({order})).to(in{i}_ptr.type.element_ty) " 

476 "# workaround the bug on bool, we should use the original pointer's dtype(instead of block pointer's)" 

477 ) 

478 code.newline() 

479 

480 # compute 

481 # TODO: sepearate this part 

482 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())] 

483 outputs_to_scalar_fn = [ 

484 self.output_name(i) for i in range(schema.num_output_tensors()) 

485 ] 

486 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

487 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

488 

489 code.writeline("# compute") 

490 code.writeline( 

491 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

492 ) 

493 code.newline() 

494 

495 # stores 

496 code.writeline( 

497 "# stores, note that store to block pointer does not automatically cast the value to the pointer's dtype" 

498 ) 

499 for i in range(schema.num_output_tensors()): 

500 strides = _tuple_content(tuple(f"out{i}_stride{j}" for j in range(ndim))) 

501 order = _tuple_content( 

502 tuple(f"out{i}_stride_order{j}" for j in range(ndim)) 

503 ) 

504 code.writeline( 

505 f"out{i}_bptr = tl.make_block_ptr(" 

506 f"out{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))" 

507 ) 

508 code.writeline( 

509 f"tl.store(out{i}_bptr, out{i}.to(out{i}_bptr.type.element_ty), boundary_check=({order}))" 

510 ) 

511 

512 def gen_body_gsl_with_bptr(self, code): 

513 code.writeline("num_ctas = tle.num_programs(0)") 

514 code.writeline("for j in range(0, tiles_per_cta):") 

515 with code.indent(): 

516 code.writeline("tile_id = pid + j * num_ctas") 

517 self.gen_body_one_tile_per_cta_with_bptr(code) 

518 

519 def gen_body_one_tile_per_cta_without_bptr(self, code): 

520 ndim = self.ndim 

521 schema = self.fx 

522 

523 # reconstruct pid multi index 

524 code.writeline( 

525 "# pid multi index recontruction: we use c ordering, right axes changes fastest" 

526 ) 

527 for i in reversed(range(ndim)): 

528 if i > 0: 

529 code.writeline(f"tile_id{i} = tile_id % num_tiles{i}") 

530 code.writeline(f"tile_id //= num_tiles{i}") 

531 else: 

532 code.writeline(f"tile_id{i} = tile_id") 

533 code.newline() 

534 

535 # offsets 

536 for i in range(ndim): 

537 code.writeline( 

538 f"offsets{i} = tile_id{i} * tile_size{i} + tl.arange(0, tile_size{i})" 

539 ) 

540 

541 # masks 

542 for i in range(ndim): 

543 code.writeline(f"mask{i} = offsets{i} < s{i}") 

544 masks = tuple(f"mask{i}{_broadcast_vec(i, ndim)}" for i in range(ndim)) 

545 mask_combine = " & ".join(masks) 

546 code.writeline(f"mask = {mask_combine}") 

547 

548 # loads 

549 code.writeline("# loads") 

550 for i in range(schema.num_input_tensors()): 

551 offsets = tuple( 

552 f"offsets{j}{_broadcast_vec(j, ndim)} * in{i}_stride{j}" 

553 for j in range(ndim) 

554 ) 

555 offset_combine = " + ".join(offsets) 

556 code.writeline( 

557 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)" 

558 ) 

559 

560 code.newline() 

561 

562 # compute 

563 # TODO: sepearate this part 

564 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())] 

565 outputs_to_scalar_fn = [ 

566 self.output_name(i) for i in range(schema.num_output_tensors()) 

567 ] 

568 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

569 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

570 

571 code.writeline("# compute") 

572 code.writeline( 

573 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

574 ) 

575 code.newline() 

576 

577 # stores 

578 for i in range(schema.num_output_tensors()): 

579 offsets = tuple( 

580 f"offsets{j}{_broadcast_vec(j, ndim)} * out{i}_stride{j}" 

581 for j in range(ndim) 

582 ) 

583 offset_combine = " + ".join(offsets) 

584 code.writeline( 

585 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)" 

586 ) 

587 

588 def gen_body_gsl_without_bptr(self, code): 

589 code.writeline("num_ctas = tle.num_programs(0)") 

590 code.writeline("for j in range(0, tiles_per_cta):") 

591 with code.indent(): 

592 code.writeline("tile_id = pid + j * num_ctas") 

593 self.gen_body_one_tile_per_cta_without_bptr(code) 

594 

595 def codegen_nd_tile_with_bptr(self, code): 

596 """Generate kernel nd tile & 1d grid with gsl support with block pointer.""" 

597 self.gen_import_function(code) 

598 self.gen_decorators(code) 

599 self.gen_signature(code, with_block_pointer=True) 

600 

601 # function body for rank-0 

602 if self.ndim == 0: 

603 with code.indent(): 

604 self.gen_body_for_0d(code) 

605 return code 

606 

607 with code.indent(): 

608 code.writeline("pid = tle.program_id(0)") 

609 self.gen_num_tiles(code) 

610 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute 

611 code.writeline("if one_tile_per_cta: # monolitic kernel style") 

612 with code.indent(): 

613 code.writeline("tile_id = pid") 

614 self.gen_body_one_tile_per_cta_with_bptr(code) 

615 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/ 

616 code.writeline("else: # grid-stride-loop style kernel") 

617 with code.indent(): 

618 self.gen_body_gsl_with_bptr(code) 

619 code.newline() 

620 return code 

621 

622 def codegen_nd_tile_without_bptr(self, code): 

623 self.gen_import_function(code) 

624 self.gen_decorators(code) 

625 self.gen_signature(code, with_block_pointer=False) 

626 

627 # function body for rank-0 

628 if self.ndim == 0: 

629 with code.indent(): 

630 self.gen_body_for_0d(code) 

631 return code 

632 

633 with code.indent(): 

634 code.writeline("pid = tle.program_id(0)") 

635 self.gen_num_tiles(code) 

636 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute 

637 code.writeline("if one_tile_per_cta: # monolitic kernel style") 

638 with code.indent(): 

639 code.writeline("tile_id = pid") 

640 self.gen_body_one_tile_per_cta_without_bptr(code) 

641 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/ 

642 code.writeline("else: # grid-stride-loop style kernel") 

643 with code.indent(): 

644 self.gen_body_gsl_without_bptr(code) 

645 code.newline() 

646 return code 

647 

648 def codegen_nd_tile(self, code): 

649 use_block_pointer = self.config.prefer_block_pointer 

650 if use_block_pointer: 

651 self.codegen_nd_tile_with_bptr(code) 

652 else: 

653 self.codegen_nd_tile_without_bptr(code) 

654 return code 

655 

656 def gen_body_one_tile_per_cta_1d_tile(self, code): 

657 ndim = self.ndim 

658 schema = self.fx 

659 

660 # tile id 

661 code.writeline("tid = tile_id * tile_size + tl.arange(0, tile_size)") 

662 code.writeline("mask = tid < num_tasks") 

663 

664 # multi index reconstruction 

665 for i in reversed(range(ndim)): 

666 if i > 0: 

667 code.writeline(f"i{i} = tid % s{i}") 

668 code.writeline(f"tid //= s{i}") 

669 else: 

670 code.writeline(f"i{i} = tid") 

671 code.newline() 

672 

673 # loads 

674 code.writeline("# loads") 

675 for i in range(schema.num_input_tensors()): 

676 offsets = tuple(f"i{j} * in{i}_stride{j}" for j in range(ndim)) 

677 offset_combine = " + ".join(offsets) 

678 code.writeline( 

679 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)" 

680 ) 

681 

682 code.newline() 

683 

684 # compute 

685 # TODO: sepearate this part 

686 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())] 

687 outputs_to_scalar_fn = [ 

688 self.output_name(i) for i in range(schema.num_output_tensors()) 

689 ] 

690 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

691 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

692 

693 code.writeline("# compute") 

694 code.writeline( 

695 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

696 ) 

697 code.newline() 

698 

699 # stores 

700 for i in range(schema.num_output_tensors()): 

701 offsets = tuple(f"i{j} * out{i}_stride{j}" for j in range(ndim)) 

702 offset_combine = " + ".join(offsets) 

703 code.writeline( 

704 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)" 

705 ) 

706 

707 def gen_body_gsl_1d_tile(self, code): 

708 code.writeline("num_ctas = tle.num_programs(0)") 

709 code.writeline("for j in range(0, tiles_per_cta):") 

710 with code.indent(): 

711 code.writeline("tile_id = pid + j * num_ctas") 

712 self.gen_body_one_tile_per_cta_1d_tile(code) 

713 

714 def codegen_1d_tile(self, code): 

715 """Generate kernel 1d tile & 1d grid with gsl support.""" 

716 self.gen_import_function(code) 

717 self.gen_decorators(code) 

718 self.gen_signature_1d_tile(code) 

719 

720 # function body for rank-0 

721 if self.ndim == 0: 

722 with code.indent(): 

723 self.gen_body_for_0d(code) 

724 return code 

725 

726 with code.indent(): 

727 code.writeline("pid = tle.program_id(0)") 

728 # code.writeline("num_ctas = te.num_programs(0)") 

729 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute 

730 code.writeline("if one_tile_per_cta: # monolitic kernel style") 

731 with code.indent(): 

732 code.writeline("tile_id = pid") 

733 self.gen_body_one_tile_per_cta_1d_tile(code) 

734 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/ 

735 code.writeline("else: # grid-stride-loop style kernel") 

736 with code.indent(): 

737 self.gen_body_gsl_1d_tile(code) 

738 code.newline() 

739 return code 

740 

741 

742class WrapperGenerator: 

743 def __init__( 

744 self, 

745 function_schema: FunctionSchema, 

746 jit_fn_name: str, 

747 ndim: int, 

748 name: str, 

749 config: CodeGenConfig, 

750 ): 

751 self.fx = function_schema 

752 self.jit_fn_name = jit_fn_name 

753 self.ndim = ndim 

754 self.name = name 

755 self.config = config 

756 

757 def input_name(self, i): 

758 is_tensor = self.fx.is_tensor(i) 

759 name = "in" if is_tensor else "val" 

760 index = self.fx.input_index(i) 

761 return f"{name}{index}" 

762 

763 def output_name(self, i): 

764 return f"out{i}" 

765 

766 def gen_signature(self, code: IndentedBuffer): 

767 # TODO: check if triton handles constexprs transitively 

768 schema = self.fx 

769 params: List[str] = [] 

770 for i in range(schema.num_inputs()): 

771 if schema.is_tensor(i): 

772 params.append( 

773 f"{self.input_name(i)}: Union[torch.Tensor, StridedBuffer]" 

774 ) 

775 else: 

776 arg_type = schema.input_type(i) 

777 if arg_type is not None: 

778 params.append(f"{self.input_name(i)}: {_type_name(arg_type)}") 

779 else: 

780 params.append(f"{self.input_name(i)}") 

781 # NOTE: [the wrapper's signature and rules for passing parameters ] 

782 # input params: must be passed by position, since the names are renamed to 

783 # in0, in1, val0, val1, ..., So passing these parameters by keyword is wierd 

784 # So we enforce that these parameters must be passed by position. 

785 # maybe we can fix it later 

786 # output parameters: must be passed by keyword, since the scalar function 

787 # do not have output parameters(think of it as some scalar function, output 

788 # parameter does not make sense in this case.) They are added to allow destination 

789 # passing style API. Output parameter is convenient in cases where we want 

790 # to use some pre-defiend outputs(especially when they are some views of other 

791 # tensors). We emphasize that these parameters are added in-addition, we enforce 

792 # that they be passed by keyword. After all, out0, out1, ... does not mismatch 

793 # names form the scalar function, since it does not have output parameters. 

794 params.append("/") 

795 params.append("*") # output params must be passed by keyword 

796 

797 for i in range(schema.num_output_tensors()): 

798 params.append(f"{self.output_name(i)}: Union[torch.Tensor, StridedBuffer]") 

799 code.writeline(f"def {self.name}({_cs(params)}): ") 

800 

801 def gen_docstring(self, code: IndentedBuffer): 

802 schema = self.fx 

803 doc = f'"""Generated wrapper function with {schema.signature(outputs_in_arg=True)}"""' 

804 code.writeline(doc) 

805 

806 def gen_same_shape_check(self, code: IndentedBuffer): 

807 schema: FunctionSchema = self.fx 

808 params = [f"in{i}.shape" for i in range(schema.num_input_tensors())] + [ 

809 f"out{i}.shape" for i in range(schema.num_output_tensors()) 

810 ] 

811 check: str = " == ".join(params) 

812 code.writeline(f"assert {check}, 'operand shapes mismatch'") 

813 

814 def gen_task_partition(self, code: IndentedBuffer): 

815 code.writeline("# task partitioning") 

816 ndim = self.ndim 

817 if ndim == 0: 

818 code.writeline("num_warps = 1") 

819 code.writeline("num_ctas = 1") 

820 else: 

821 code.writeline("shape = out0.shape") 

822 code.writeline("num_tasks = out0.numel()") 

823 code.writeline("if num_tasks == 0:") 

824 with code.indent(): 

825 self.gen_return(code) 

826 max_tile_size = self.config.max_tile_size 

827 # Check if all input and output dtypes are complex 

828 all_complex = True 

829 for i in range(self.fx.num_inputs()): 

830 if self.fx.is_tensor(i): 

831 input_dtype = self.fx.input_type(i) 

832 if input_dtype is not None and not ( 

833 input_dtype == torch.complex64 

834 or input_dtype == torch.complex128 

835 ): 

836 all_complex = False 

837 break 

838 if all_complex: 

839 # If all inputs are complex, set max_tile_size to half 

840 max_tile_size = max_tile_size // 2 

841 major, _ = get_device_capability() 

842 if self.name.find("fill_scalar") != -1 and major >= 9: 

843 code.writeline("tile_sizes = tuple([64])") 

844 else: 

845 code.writeline( 

846 f"tile_sizes = heuristics_for_tile_size({max_tile_size}, *shape)" 

847 ) 

848 code.writeline("tile_size = math.prod(tile_sizes)") 

849 code.writeline( 

850 "num_tiles = math.prod(triton.cdiv(size, tile_size) for size, tile_size in zip(shape, tile_sizes))" 

851 ) 

852 

853 if self.name.find("fill_scalar") != -1 and major >= 9: 

854 code.writeline("num_ctas = num_tiles") 

855 else: 

856 max_grid_size0 = self.config.max_grid_size[0] 

857 code.writeline(f"num_ctas = min({max_grid_size0}, num_tiles)") 

858 

859 code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)") 

860 code.writeline("num_warps = heuristics_for_num_warps(tile_size)") 

861 code.writeline("one_tile_per_cta = tiles_per_cta==1") 

862 code.writeline("grid = (num_ctas, 1, 1)") 

863 

864 def gen_task_partition_1d(self, code: IndentedBuffer): 

865 code.writeline("# task partitioning") 

866 ndim = self.ndim 

867 if ndim == 0: 

868 code.writeline("num_warps = 1") 

869 code.writeline("num_ctas = 1") 

870 else: 

871 code.writeline("shape = out0.shape") 

872 code.writeline("num_tasks = out0.numel()") 

873 code.writeline("if num_tasks == 0:") 

874 with code.indent(): 

875 self.gen_return(code) 

876 max_tile_size = self.config.max_tile_size 

877 # Check if all input and output dtypes are complex 

878 all_complex = True 

879 for i in range(self.fx.num_inputs()): 

880 if self.fx.is_tensor(i): 

881 input_dtype = self.fx.input_type(i) 

882 if input_dtype is not None and not ( 

883 input_dtype == torch.complex64 

884 or input_dtype == torch.complex128 

885 ): 

886 all_complex = False 

887 break 

888 if all_complex: 

889 max_tile_size = max_tile_size // 2 

890 major, _ = get_device_capability() 

891 if self.name.find("fill_scalar") != -1 and major >= 9: 

892 code.writeline("tile_sizes = tuple([1024])") 

893 else: 

894 code.writeline( 

895 f"tile_sizes = heuristics_for_tile_size({max_tile_size}, num_tasks)" 

896 ) 

897 

898 code.writeline("tile_size = tile_sizes[0]") 

899 code.writeline("num_tiles = triton.cdiv(num_tasks, tile_size)") 

900 

901 if self.name.find("fill_scalar") != -1 and major >= 9: 

902 code.writeline("num_ctas = num_tiles") 

903 else: 

904 max_grid_size0 = self.config.max_grid_size[0] 

905 code.writeline(f"num_ctas = min({max_grid_size0}, num_tiles)") 

906 

907 code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)") 

908 code.writeline("num_warps = heuristics_for_num_warps(tile_size)") 

909 code.writeline("one_tile_per_cta = tiles_per_cta==1") 

910 code.writeline("grid = (num_ctas, 1, 1)") 

911 

912 def gen_kernel_launch( 

913 self, 

914 code: IndentedBuffer, 

915 ): 

916 schema = self.fx 

917 ndim = self.ndim 

918 

919 with_block_pointer = self.config.prefer_block_pointer 

920 

921 code.writeline("# kernel launch") 

922 for i in range(schema.num_input_tensors()): 

923 code.writeline(f"in{i}_strides = in{i}.stride()") 

924 if not with_block_pointer: 

925 continue 

926 if ndim >= 2: # where ndim is 1, we don't need to compute stride order 

927 code.writeline(f"in{i}_stride_order = stride_order(in{i}_strides)") 

928 else: 

929 code.writeline(f"in{i}_stride_order = (0,)") 

930 for i in range(schema.num_output_tensors()): 

931 code.writeline(f"out{i}_strides = out{i}.stride()") 

932 if not with_block_pointer: 

933 continue 

934 if ndim >= 2: 

935 code.writeline(f"out{i}_stride_order = stride_order(out{i}_strides)") 

936 else: 

937 code.writeline(f"out{i}_stride_order = (0,)") 

938 

939 code.writeline("with torch_device_fn.device(in0.device.index):") 

940 with code.indent(): 

941 code.writeline(f"{self.jit_fn_name}[grid](") 

942 with code.indent(): 

943 params = [] 

944 # NOTE: WRAP 

945 for i in range(schema.num_inputs()): 

946 if schema.is_tensor(i): 

947 params.append(f"{self.input_name(i)}") 

948 else: 

949 params.append(self.input_name(i)) 

950 for i in range(schema.num_output_tensors()): 

951 params.append(f"{self.output_name(i)}") 

952 

953 code.writeline(f"{_cs(params)},") 

954 

955 if ndim > 0: 

956 for i in range(schema.num_input_tensors()): 

957 s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim)) 

958 code.writeline(f"{s}, # stride for in{i}") 

959 if not with_block_pointer: 

960 continue 

961 order = ", ".join( 

962 f"in{i}_stride_order[{j}]" for j in range(ndim) 

963 ) 

964 code.writeline(f"{order}, # stride order for in{i}") 

965 

966 for i in range(schema.num_output_tensors()): 

967 s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim)) 

968 code.writeline(f"{s}, # stride for out{i}") 

969 if not with_block_pointer: 

970 continue 

971 order = ", ".join( 

972 f"out{i}_stride_order[{j}]" for j in range(ndim) 

973 ) 

974 code.writeline(f"{order}, # stride orderfor out{i}") 

975 

976 shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim)) 

977 code.writeline(f"{shape_args}, # task indexing space") 

978 code.writeline("num_tasks, # num tasks") 

979 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta") 

980 for i in range(ndim): 

981 code.writeline(f"tile_size{i}=tile_sizes[{i}],") 

982 code.writeline("one_tile_per_cta=one_tile_per_cta,") 

983 code.writeline("num_warps=num_warps,") 

984 code.writeline(")") 

985 

986 def gen_kernel_launch_1d( 

987 self, 

988 code: IndentedBuffer, 

989 ): 

990 schema = self.fx 

991 ndim = self.ndim 

992 

993 code.writeline("# kernel launch") 

994 for i in range(schema.num_input_tensors()): 

995 code.writeline(f"in{i}_strides = in{i}.stride()") 

996 for i in range(schema.num_output_tensors()): 

997 code.writeline(f"out{i}_strides = out{i}.stride()") 

998 

999 code.writeline("with torch_device_fn.device(in0.device.index):") 

1000 with code.indent(): 

1001 code.writeline(f"{self.jit_fn_name}[grid](") 

1002 with code.indent(): 

1003 params = [] 

1004 # NOTE: WRAP 

1005 for i in range(schema.num_inputs()): 

1006 if schema.is_tensor(i): 

1007 params.append(f"{self.input_name(i)}") 

1008 else: 

1009 params.append(self.input_name(i)) 

1010 for i in range(schema.num_output_tensors()): 

1011 params.append(f"{self.output_name(i)}") 

1012 

1013 code.writeline(f"{_cs(params)},") 

1014 

1015 if ndim > 0: 

1016 for i in range(schema.num_input_tensors()): 

1017 s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim)) 

1018 code.writeline(f"{s}, # stride for in{i}") 

1019 for i in range(schema.num_output_tensors()): 

1020 s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim)) 

1021 code.writeline(f"{s}, # stride for out{i}") 

1022 

1023 shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim)) 

1024 code.writeline(f"{shape_args}, # task indexing space") 

1025 code.writeline("num_tasks, # num tasks") 

1026 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta") 

1027 code.writeline("tile_size=tile_size,") 

1028 code.writeline("one_tile_per_cta=one_tile_per_cta,") 

1029 code.writeline("num_warps=num_warps,") 

1030 code.writeline(")") 

1031 

1032 def gen_return(self, code: IndentedBuffer): 

1033 return_exprs = _cs(f"out{i}" for i in range(self.fx.num_output_tensors())) 

1034 code.writeline(f"return {return_exprs}") 

1035 

1036 def codegen_nd_tile(self, code): 

1037 self.gen_signature(code) 

1038 

1039 with code.indent(): 

1040 self.gen_docstring(code) 

1041 self.gen_same_shape_check(code) 

1042 self.gen_task_partition(code) 

1043 self.gen_kernel_launch(code) 

1044 self.gen_return(code) 

1045 code.newline() 

1046 return code 

1047 

1048 def codegen_1d_tile(self, code): 

1049 self.gen_signature(code) 

1050 

1051 with code.indent(): 

1052 self.gen_docstring(code) 

1053 self.gen_same_shape_check(code) 

1054 self.gen_task_partition_1d(code) 

1055 self.gen_kernel_launch_1d(code) 

1056 self.gen_return(code) 

1057 code.newline() 

1058 return code 

1059 

1060 

1061class ModuleGenerator: 

1062 def __init__( 

1063 self, 

1064 function_schema: FunctionSchema, 

1065 scalar_fn: triton.JITFunction, 

1066 ndim: int, 

1067 jit_fn_name: str, 

1068 wrapper_name: str, 

1069 config: CodeGenConfig, 

1070 ): 

1071 self.config = config 

1072 self.scalar_fn = scalar_fn 

1073 self.wrapper_gen = WrapperGenerator( 

1074 function_schema, jit_fn_name, ndim, wrapper_name, config 

1075 ) 

1076 self.kernel_gen = KernelGenerator( 

1077 function_schema, scalar_fn, ndim, jit_fn_name, config 

1078 ) 

1079 

1080 @staticmethod 

1081 def _collect_jit_deps(scalar_fn): 

1082 """Collect extra imports and local @triton.jit helper sources. 

1083 

1084 Parses the source module where scalar_fn is defined using AST. 

1085 Returns a tuple of: 

1086 - extra_imports: dict of module_path -> set of names 

1087 - local_sources: list of source strings for local @triton.jit 

1088 functions (those NOT decorated with @pointwise_dynamic) 

1089 """ 

1090 import ast 

1091 import inspect 

1092 

1093 py_fn = getattr(scalar_fn, "fn", scalar_fn) 

1094 module_name = getattr(py_fn, "__module__", None) 

1095 if not module_name: 

1096 return {}, [] 

1097 try: 

1098 mod = importlib.import_module(module_name) 

1099 source_file = inspect.getfile(mod) 

1100 except (ImportError, TypeError, OSError): 

1101 return {}, [] 

1102 try: 

1103 with open(source_file) as f: 

1104 module_source = f.read() 

1105 source_lines = module_source.splitlines(keepends=True) 

1106 tree = ast.parse(module_source) 

1107 except (OSError, SyntaxError): 

1108 return {}, [] 

1109 

1110 # Collect non-standard import-from lines 

1111 ALREADY_IMPORTED = { 

1112 "math", 

1113 "typing", 

1114 "torch", 

1115 "triton", 

1116 "triton.language", 

1117 "flag_gems.utils.shape_utils", 

1118 "flag_gems.utils.tensor_wrapper", 

1119 "flag_gems.utils.libentry", 

1120 "flag_gems.utils", 

1121 "flag_gems.runtime", 

1122 "flag_gems.utils.pointwise_dynamic", 

1123 } 

1124 extra_imports = {} 

1125 for node in ast.iter_child_nodes(tree): 

1126 if isinstance(node, ast.ImportFrom) and node.module: 

1127 if node.module in ALREADY_IMPORTED: 

1128 continue 

1129 names = {alias.name for alias in node.names} 

1130 extra_imports.setdefault(node.module, set()).update(names) 

1131 

1132 # Collect local @triton.jit functions (without @pointwise_dynamic) 

1133 def _has_decorator(func_node, name): 

1134 for dec in func_node.decorator_list: 

1135 src = "".join(source_lines[dec.lineno - 1 : dec.end_lineno]) 

1136 if name in src: 

1137 return True 

1138 return False 

1139 

1140 def _extract_source(func_node): 

1141 start = func_node.lineno - 1 

1142 if func_node.decorator_list: 

1143 start = func_node.decorator_list[0].lineno - 1 

1144 end = func_node.end_lineno 

1145 return "".join(source_lines[start:end]) 

1146 

1147 local_sources = [] 

1148 for node in ast.iter_child_nodes(tree): 

1149 if not isinstance(node, ast.FunctionDef): 

1150 continue 

1151 if not _has_decorator(node, "triton.jit") and not _has_decorator( 

1152 node, "jit" 

1153 ): 

1154 continue 

1155 if _has_decorator(node, "pointwise_dynamic"): 

1156 continue 

1157 local_sources.append(_extract_source(node)) 

1158 

1159 return extra_imports, local_sources 

1160 

1161 def generate_imports(self, code: IndentedBuffer) -> IndentedBuffer: 

1162 code.writeline("import math") 

1163 code.writeline("from typing import Union") 

1164 code.writeline("import torch") 

1165 code.writeline("import triton") 

1166 code.writeline("from triton import language as tl") 

1167 code.newline() 

1168 code.writeline("from flag_gems.utils.shape_utils import (") 

1169 code.writeline(" heuristics_for_tile_size,") 

1170 code.writeline(" heuristics_for_num_warps,") 

1171 code.writeline(" stride_order,") 

1172 code.writeline(")") 

1173 code.writeline("from flag_gems.utils.tensor_wrapper import StridedBuffer") 

1174 code.writeline("from flag_gems.utils.libentry import libentry") 

1175 code.writeline("from flag_gems.utils import triton_lang_extension as tle") 

1176 code.writeline("from flag_gems.runtime import torch_device_fn") 

1177 

1178 # Generate extra imports and local JIT deps of the scalar function 

1179 jit_dep_imports, local_jit_sources = self._collect_jit_deps(self.scalar_fn) 

1180 for module_path, names in sorted(jit_dep_imports.items()): 

1181 sorted_names = ", ".join(sorted(names)) 

1182 code.writeline(f"from {module_path} import {sorted_names}") 

1183 

1184 code.newline() 

1185 code.newline() 

1186 

1187 # Emit local @triton.jit helper functions 

1188 for source in local_jit_sources: 

1189 for line in source.splitlines(): 

1190 code.writeline(line) 

1191 code.newline() 

1192 

1193 return code 

1194 

1195 def codegen(self, code: IndentedBuffer): 

1196 code = self.generate_imports(code) 

1197 if self.config.prefer_1d_tile: 

1198 code = self.wrapper_gen.codegen_1d_tile(code) 

1199 code = self.kernel_gen.codegen_1d_tile(code) 

1200 else: 

1201 code = self.wrapper_gen.codegen_nd_tile(code) 

1202 code = self.kernel_gen.codegen_nd_tile(code) 

1203 return code 

1204 

1205 

1206@dataclass 

1207class KernelInfo: 

1208 """Information about a generated kernel for C++ integration.""" 

1209 

1210 file_path: str 

1211 kernel_name: str 

1212 wrapper_name: str 

1213 ndim: int 

1214 

1215 

1216class ComplexMode(Enum): 

1217 NONE = auto() 

1218 ELEMENTWISE = auto() # add/sub: view_as_real → same kernel → view_as_complex 

1219 CROSS = auto() # mul/div: split ar/ai/br/bi → cross_kernel 

1220 

1221 

1222@dataclass 

1223class ComplexStrategy: 

1224 mode: ComplexMode = ComplexMode.NONE 

1225 cross_kernel: object = None 

1226 tensorize_scalars: bool = False 

1227 fallback_target: object = None 

1228 

1229 

1230_REAL_TO_COMPLEX = { 

1231 torch.float16: torch.complex32, 

1232 torch.bfloat16: torch.complex32, 

1233 torch.float32: torch.complex64, 

1234 torch.float64: torch.complex128, 

1235} 

1236 

1237 

1238class PointwiseDynamicFunction: 

1239 """Utility to generate function for general pointwise operation. It generate wrapper & JITFunction 

1240 which are specialized according to the rank of the task space(the broadcasted shape of all input tensors). 

1241 The generated code are written out to the cache directory (defaults to ~/.flaggems). 

1242 """ 

1243 

1244 def __init__(self, op_desc: FunctionSchema, scalar_fn: JITFunction, config=None): 

1245 self.fx = op_desc 

1246 

1247 assert isinstance(scalar_fn, JITFunction) 

1248 self._scalar_fn = scalar_fn 

1249 self._scalar_fn_cache_key = scalar_fn.cache_key 

1250 self.pid = os.getpid() 

1251 

1252 self.config: CodeGenConfig = config or get_codegen_config() 

1253 

1254 # instantiated & cached overloads 

1255 self.overloads: Mapping[str, Callable] = {} 

1256 # cached kernel info for C++ integration 

1257 self._kernel_info_cache: Mapping[str, KernelInfo] = {} 

1258 

1259 # complex dispatch support 

1260 self.complex_strategy = ComplexStrategy() 

1261 self._operand_indices = self._infer_operand_indices() 

1262 

1263 # -------------------- operand index inference -------------------- 

1264 

1265 def _infer_operand_indices(self): 

1266 """Infer operand indices from schema._promotion_methods, done once at init.""" 

1267 indices = set() 

1268 for pm in self.fx._promotion_methods: 

1269 for idx in pm[:-1]: 

1270 indices.add(idx) 

1271 return frozenset(indices) 

1272 

1273 # -------------------- register_complex -------------------- 

1274 

1275 def register_complex( 

1276 self, mode, cross_kernel=None, tensorize_scalars=False, fallback_target=None 

1277 ): 

1278 """Register complex number support for this kernel. 

1279 

1280 Args: 

1281 mode: ComplexMode.ELEMENTWISE (add/sub) or ComplexMode.CROSS (mul/div). 

1282 cross_kernel: A PointwiseDynamicFunction for cross-term ops (mul/div). 

1283 tensorize_scalars: If True, scalar operands are converted to tensors 

1284 before delegating to fallback_target. 

1285 fallback_target: A PointwiseDynamicFunction (tensor-tensor version) 

1286 to delegate to after tensorizing scalar operands. 

1287 """ 

1288 self.complex_strategy = ComplexStrategy( 

1289 mode=mode, 

1290 cross_kernel=cross_kernel, 

1291 tensorize_scalars=tensorize_scalars, 

1292 fallback_target=fallback_target, 

1293 ) 

1294 return self 

1295 

1296 # -------------------- call entry -------------------- 

1297 

1298 def __call__(self, *args, **kwargs): 

1299 if self._should_use_complex_path(args): 

1300 return self._call_complex_dispatch(*args, **kwargs) 

1301 return self._call_real_impl(*args, **kwargs) 

1302 

1303 def _call_real_impl(self, *args, **kwargs): 

1304 """Single entry point for real kernel invocation.""" 

1305 ndim, args, kwargs = self.prepare_args(*args, **kwargs) 

1306 overload = self.instantiate(ndim) 

1307 out = overload(*args, **kwargs) 

1308 return self._unwrap(out) 

1309 

1310 # -------------------- complex helpers -------------------- 

1311 

1312 @staticmethod 

1313 def _is_complex_arg(a): 

1314 return (isinstance(a, torch.Tensor) and a.is_complex()) or isinstance( 

1315 a, complex 

1316 ) 

1317 

1318 def _should_use_complex_path(self, args): 

1319 if self.complex_strategy.mode == ComplexMode.NONE: 

1320 return False 

1321 return any( 

1322 self._is_complex_arg(args[i]) 

1323 for i in self._operand_indices 

1324 if i < len(args) 

1325 ) 

1326 

1327 def _split_args(self, args): 

1328 """Split args into operands and others by original position index.""" 

1329 operands = {} 

1330 others = {} 

1331 for i, a in enumerate(args): 

1332 if i in self._operand_indices: 

1333 operands[i] = a 

1334 else: 

1335 others[i] = a 

1336 return operands, others 

1337 

1338 def _merge_args(self, operands, others): 

1339 """Rebuild args tuple from operands and others by original position index.""" 

1340 total = len(operands) + len(others) 

1341 merged = [None] * total 

1342 for i, v in operands.items(): 

1343 merged[i] = v 

1344 for i, v in others.items(): 

1345 merged[i] = v 

1346 return tuple(merged) 

1347 

1348 def _classify_complex_inputs(self, operands): 

1349 """Classify operands as 'all_complex', 'mixed', or 'real'.""" 

1350 complex_count = sum(1 for v in operands.values() if self._is_complex_arg(v)) 

1351 if complex_count == len(operands): 

1352 return "all_complex" 

1353 elif complex_count > 0: 

1354 return "mixed" 

1355 return "real" 

1356 

1357 def _infer_device(self, operands): 

1358 for v in operands.values(): 

1359 if isinstance(v, torch.Tensor): 

1360 return v.device 

1361 return None 

1362 

1363 def _infer_complex_dtype(self, operands): 

1364 return torch.result_type(*operands.values()) 

1365 

1366 def _tensorize_scalar_operands(self, operands, dtype, device): 

1367 """Convert scalar operands to tensors.""" 

1368 result = {} 

1369 for i, v in operands.items(): 

1370 if not isinstance(v, torch.Tensor): 

1371 if isinstance(v, complex): 

1372 result[i] = torch.tensor(v, dtype=dtype, device=device) 

1373 elif isinstance(v, float): 

1374 result[i] = torch.tensor(v, dtype=torch.float32, device=device) 

1375 elif isinstance(v, (int, bool)): 

1376 result[i] = torch.tensor(v, dtype=torch.int64, device=device) 

1377 else: 

1378 result[i] = v 

1379 else: 

1380 result[i] = v 

1381 return result 

1382 

1383 def _to_complex_tensor(self, a, target_dtype, device): 

1384 """Convert a scalar or real tensor to a complex tensor.""" 

1385 if isinstance(a, torch.Tensor): 

1386 if a.is_complex(): 

1387 return a 

1388 if a.is_floating_point(): 

1389 cdtype = _REAL_TO_COMPLEX.get(a.dtype, torch.complex64) 

1390 else: 

1391 a = a.to(torch.float32) 

1392 cdtype = torch.complex64 

1393 return torch.complex(a, torch.zeros_like(a)).to(cdtype) 

1394 elif isinstance(a, complex): 

1395 return torch.tensor(a, dtype=target_dtype, device=device) 

1396 elif isinstance(a, (int, float)): 

1397 return torch.tensor(complex(a, 0), dtype=target_dtype, device=device) 

1398 return a 

1399 

1400 # -------------------- complex dispatch -------------------- 

1401 

1402 def _call_complex_dispatch(self, *args, **kwargs): 

1403 """Unified complex dispatch entry point.""" 

1404 strategy = self.complex_strategy 

1405 operands, others = self._split_args(args) 

1406 

1407 device = self._infer_device(operands) 

1408 result_dtype = self._infer_complex_dtype(operands) 

1409 

1410 # tensorize scalar operands and delegate to fallback_target 

1411 if strategy.tensorize_scalars and strategy.fallback_target is not None: 

1412 operands = self._tensorize_scalar_operands(operands, result_dtype, device) 

1413 new_args = self._merge_args(operands, others) 

1414 return strategy.fallback_target(*new_args, **kwargs) 

1415 

1416 # convert all operands to complex tensors 

1417 for i in list(operands.keys()): 

1418 operands[i] = self._to_complex_tensor(operands[i], result_dtype, device) 

1419 

1420 # broadcast complex tensor operands 

1421 complex_tensors = [operands[i] for i in sorted(operands.keys())] 

1422 complex_tensors = torch.broadcast_tensors(*complex_tensors) 

1423 for idx, key in enumerate(sorted(operands.keys())): 

1424 operands[key] = complex_tensors[idx] 

1425 

1426 classification = self._classify_complex_inputs(operands) 

1427 

1428 if strategy.mode == ComplexMode.CROSS and classification == "all_complex": 

1429 return self._call_complex_cross(operands, result_dtype) 

1430 elif classification in ("all_complex", "mixed"): 

1431 return self._call_complex_elementwise( 

1432 operands, others, result_dtype, kwargs 

1433 ) 

1434 else: 

1435 new_args = self._merge_args(operands, others) 

1436 return self._call_real_impl(*new_args, **kwargs) 

1437 

1438 def _call_complex_elementwise(self, operands, others, result_dtype, kwargs): 

1439 """Elementwise: view_as_real -> call real kernel -> view_as_complex.""" 

1440 real_tensors = {i: torch.view_as_real(t) for i, t in operands.items()} 

1441 

1442 # promote to common real dtype 

1443 dtypes = [t.dtype for t in real_tensors.values()] 

1444 common_dtype = dtypes[0] 

1445 for d in dtypes[1:]: 

1446 common_dtype = torch.promote_types(common_dtype, d) 

1447 real_tensors = {i: t.to(common_dtype) for i, t in real_tensors.items()} 

1448 

1449 new_args = self._merge_args(real_tensors, others) 

1450 out_real = self._call_real_impl(*new_args, **kwargs) 

1451 return torch.view_as_complex(out_real.contiguous()).to(result_dtype) 

1452 

1453 def _call_complex_cross(self, operands, result_dtype): 

1454 """Cross-term: split ar/ai/br/bi -> call cross_kernel -> stack -> view_as_complex.""" 

1455 sorted_keys = sorted(operands.keys()) 

1456 A, B = operands[sorted_keys[0]], operands[sorted_keys[1]] 

1457 Ar = torch.view_as_real(A) 

1458 Br = torch.view_as_real(B) 

1459 ar, ai = Ar[..., 0], Ar[..., 1] 

1460 br, bi = Br[..., 0], Br[..., 1] 

1461 

1462 common_dtype = torch.promote_types(ar.dtype, br.dtype) 

1463 ar, ai = ar.to(common_dtype), ai.to(common_dtype) 

1464 br, bi = br.to(common_dtype), bi.to(common_dtype) 

1465 

1466 real, imag = self.complex_strategy.cross_kernel(ar, ai, br, bi) 

1467 

1468 out = torch.stack((real, imag), dim=-1) 

1469 return torch.view_as_complex(out.contiguous()).to(result_dtype) 

1470 

1471 @staticmethod 

1472 def use_fast_path(tensors): 

1473 return all_the_same_shape(tensors) and ( 

1474 all_c_contiguous(tensors) 

1475 or ( 

1476 all_the_same_stride(tensors) 

1477 and torch.ops.aten.is_non_overlapping_and_dense(tensors[0]) 

1478 ) 

1479 ) 

1480 

1481 def prepare_args(self, *args, _skip_tensor_check=False, **kwargs): 

1482 # output allocation(when needed) 

1483 # task simplification & task-rank infernece & input-output reinterpretation 

1484 schema = self.fx 

1485 outputs_that_need_allocation: List[int] = [] 

1486 out_tensors = [] 

1487 for i in range(schema.num_output_tensors()): 

1488 k = f"out{i}" 

1489 if k in kwargs: 

1490 out_tensors.append(kwargs[k]) 

1491 else: 

1492 outputs_that_need_allocation.append(i) 

1493 # input arguments must be passed by position 

1494 if not _skip_tensor_check and schema._is_tensor is not None: 

1495 if not check_tensor_attributes(args, (schema._is_tensor)): 

1496 raise ValueError( 

1497 "Input arguments must be passed by position, and the corresponding dtype must be specified." 

1498 ) 

1499 in_tensors = [item for i, item in enumerate(args) if schema.is_tensor(i)] 

1500 

1501 # output dtype promotions 

1502 outputs_dtypes_for_allocation = [] 

1503 for i in outputs_that_need_allocation: 

1504 *arg_indices, method = schema._promotion_methods[i] 

1505 promote_args = (args[j] for j in arg_indices) 

1506 _, dtype = type_promotion(*promote_args, type_promotion=method) 

1507 outputs_dtypes_for_allocation.append(dtype) 

1508 

1509 tensors = out_tensors + in_tensors 

1510 INT32_MAX = torch.iinfo(torch.int32).max 

1511 if tensors[0].numel() > INT32_MAX: 

1512 self.config.prefer_block_pointer = False 

1513 if self.use_fast_path(tensors): # dimension collapse & use physical ordering 

1514 allocated_outputs = [ 

1515 torch.empty_like(tensors[0], dtype=dtype) 

1516 for dtype in outputs_dtypes_for_allocation 

1517 ] 

1518 task_shape = (tensors[0].numel(),) 

1519 strides = (1,) 

1520 ndim = 1 

1521 args = tuple( 

1522 ( 

1523 StridedBuffer(item, task_shape, strides) 

1524 if schema.is_tensor(i) 

1525 else item 

1526 ) 

1527 for i, item in enumerate(args) 

1528 ) 

1529 kwargs = { 

1530 k: StridedBuffer(item, task_shape, strides) 

1531 for k, item in kwargs.items() 

1532 } 

1533 for seq_id, output_id in enumerate(outputs_that_need_allocation): 

1534 kwargs[f"out{output_id}"] = StridedBuffer( 

1535 allocated_outputs[seq_id], task_shape, strides 

1536 ) 

1537 else: 

1538 # a simple strategy: all the undefined tensors will follow the first 

1539 # tensor that is not broadcated, no attempts to simplify task, no reordering, 

1540 # no dimenion collapsing 

1541 shapes = tuple(item.shape for item in in_tensors) 

1542 

1543 task_shape = broadcast_shapes(shapes) 

1544 

1545 if out_tensors: 

1546 for index, item in enumerate(out_tensors): 

1547 if list(item.shape) != list(task_shape): 

1548 raise RuntimeError( 

1549 f"out tensor at index {index} shape is invalid, should be {task_shape} but is {item.shape}!" 

1550 ) 

1551 # output arguments must not have internal overlapping for pointwise operation 

1552 if has_internal_overlapping(item) == MemOverlap.Yes: 

1553 raise RuntimeError( 

1554 "Pointwise Input arguments should not have internal overlapping." 

1555 ) 

1556 

1557 ndim = len(task_shape) 

1558 for item in tensors: 

1559 if item.shape == task_shape: 

1560 allocated_outputs = [ 

1561 torch.empty_like(item, dtype=dtype) 

1562 for dtype in outputs_dtypes_for_allocation 

1563 ] 

1564 break 

1565 else: # nobreak 

1566 device = tensors[0].device 

1567 allocated_outputs = [ 

1568 torch.empty(task_shape, dtype=dtype, device=device) 

1569 for dtype in outputs_dtypes_for_allocation 

1570 ] 

1571 args = tuple( 

1572 ( 

1573 StridedBuffer( 

1574 item, 

1575 task_shape, 

1576 broadcasted_stride(item.shape, item.stride(), task_shape), 

1577 ) 

1578 if schema.is_tensor(i) 

1579 else item 

1580 ) 

1581 for i, item in enumerate(args) 

1582 ) 

1583 kwargs = { 

1584 k: StridedBuffer( 

1585 item, 

1586 task_shape, 

1587 broadcasted_stride(item.shape, item.stride(), task_shape), 

1588 ) 

1589 for k, item in kwargs.items() 

1590 } 

1591 for seq_id, output_id in enumerate(outputs_that_need_allocation): 

1592 item = allocated_outputs[seq_id] 

1593 kwargs[f"out{output_id}"] = StridedBuffer( 

1594 item, 

1595 task_shape, 

1596 broadcasted_stride(item.shape, item.stride(), task_shape), 

1597 ) 

1598 return (ndim, args, kwargs) 

1599 

1600 def _unwrap(self, tensors): 

1601 # unwrap StridedBuffer to get Tensor 

1602 if self.fx.num_output_tensors() == 1: 

1603 item = tensors 

1604 return item.unwrap() 

1605 return tuple(item.unwrap() for item in tensors) 

1606 

1607 def _compute_kernel_names(self, ndim: int) -> Tuple[str, str, str]: 

1608 """Compute kernel name, wrapper name, and file path for a given ndim. 

1609 

1610 This is the single source of truth for naming, used by both instantiate() 

1611 and get_kernel_info() to ensure consistency. 

1612 

1613 Returns: 

1614 Tuple of (kernel_name, wrapper_name, file_path) 

1615 """ 

1616 scalar_fn_name = self._scalar_fn.__name__ 

1617 kernel_name = f"{scalar_fn_name}_kernel_rank_{ndim}" 

1618 wrapper_name = f"{scalar_fn_name}_wrapper_rank_{ndim}" 

1619 

1620 file_name = ( 

1621 f"pointwise_dynamic_{self._scalar_fn_cache_key}_{kernel_name}_" 

1622 f"{'1d_tile_' if self.config.prefer_1d_tile else ''}" 

1623 f"{'bptr' if (not self.config.prefer_1d_tile and self.config.prefer_block_pointer) else ''}" 

1624 ".py" 

1625 ) 

1626 file_path = str(code_cache_dir() / file_name) 

1627 

1628 return kernel_name, wrapper_name, file_path 

1629 

1630 def instantiate(self, ndim): 

1631 # NOTE: manually instantiated overload does not have `prepare_args` as 

1632 # preprocessing, so you have to manually allocate output and make sure that 

1633 # the inputs & ouputs actually fits the manually instantiated overload 

1634 key = f"{ndim}_{self.config.prefer_block_pointer}" 

1635 if key in self.overloads: 

1636 return self.overloads[key] 

1637 

1638 code = IndentedBuffer() 

1639 

1640 # Use helper to compute names (single source of truth) 

1641 kernel_name, wrapper_name, file_path = self._compute_kernel_names(ndim) 

1642 

1643 module_gen = ModuleGenerator( 

1644 self.fx, 

1645 self._scalar_fn, 

1646 ndim, 

1647 kernel_name, 

1648 wrapper_name, 

1649 self.config, 

1650 ) 

1651 module_gen.codegen(code) 

1652 

1653 # NOTE: [why write the generated code to a file] 

1654 # triton uses inpsect to get the source of the jitted function, which requires 

1655 # that the source code can be found by inspect 

1656 # We write it into a file, since inspect cannot find the source of functions dynamically 

1657 # created via exec string. We can help inspect to find the source by hacking linecache 

1658 # library, but we find generating a module simpler, since we can generating 2 functions 

1659 # the kernel and the wrapper, and the wrapper calls the kernel. 

1660 write_atomic(file_path, code.getvalue()) 

1661 

1662 # load 

1663 spec = importlib.util.spec_from_file_location( 

1664 f"_gen_module_{self._scalar_fn_cache_key}_rank_{ndim}", 

1665 file_path, 

1666 ) 

1667 m = importlib.util.module_from_spec(spec) 

1668 # do not expose it to sys.modules 

1669 # sys.modules["_add_module"] = m 

1670 

1671 # NOTE: [why not import the scalar function] 

1672 # we do not re-import the scalar function, although the generated kernel **calls** it 

1673 # Since a function's __name__ may be changed, from the module where it is defined import its 

1674 # __name__ is not same; Also the same may be rebind to something else, importing via name 

1675 # cannot guarantee that scalar function is imported. 

1676 # So we copy the scalar function and its __globals__ to the generated module to do this 

1677 # https://stackoverflow.com/questions/11170949/how-to-make-a-copy-of-a-python-module-at-runtime 

1678 spec.loader.exec_module(m) 

1679 m.__dict__.update(self._scalar_fn.__globals__) 

1680 m.__dict__[self._scalar_fn.__name__] = self._scalar_fn 

1681 

1682 overload = getattr(m, wrapper_name) 

1683 self.overloads[key] = overload 

1684 

1685 # Cache kernel info for C++ integration 

1686 self._kernel_info_cache[key] = KernelInfo( 

1687 file_path=file_path, 

1688 kernel_name=kernel_name, 

1689 wrapper_name=wrapper_name, 

1690 ndim=ndim, 

1691 ) 

1692 

1693 return overload 

1694 

1695 def get_kernel_info(self, ndim: int) -> KernelInfo: 

1696 """Get kernel information for a given ndim. 

1697 

1698 This method is useful for C++ integration to get the file path and 

1699 kernel name without duplicating the naming logic. 

1700 

1701 If the kernel hasn't been instantiated yet, this will instantiate it first. 

1702 

1703 Args: 

1704 ndim: The rank of the task space 

1705 

1706 Returns: 

1707 KernelInfo with file_path, kernel_name, wrapper_name, and ndim 

1708 """ 

1709 key = f"{ndim}_{self.config.prefer_block_pointer}" 

1710 

1711 # Ensure the kernel is instantiated 

1712 if key not in self._kernel_info_cache: 

1713 self.instantiate(ndim) 

1714 

1715 return self._kernel_info_cache[key] 

1716 

1717 

1718def pointwise_dynamic( 

1719 f: Optional[JITFunction] = None, 

1720 *, 

1721 num_inputs: Optional[int] = None, 

1722 is_tensor: Optional[List[bool]] = None, 

1723 dtypes: Optional[List[Optional[type]]] = None, 

1724 num_outputs: Optional[int] = None, 

1725 promotion_methods: Optional[Tuple[int, ...]] = None, 

1726 config: Optional[CodeGenConfig] = None, 

1727): 

1728 def decorator(fn): 

1729 nonlocal num_inputs 

1730 if (num_inputs is None) and (is_tensor is None) and (dtypes is None): 

1731 num_inputs = len(fn.arg_names) 

1732 op_desc = FunctionSchema( 

1733 num_inputs=num_inputs, 

1734 is_tensor=is_tensor, 

1735 dtypes=dtypes, 

1736 num_outputs=num_outputs, 

1737 promotion_methods=promotion_methods, 

1738 ) 

1739 return PointwiseDynamicFunction(op_desc, fn, config) 

1740 

1741 if f is not None: 

1742 return decorator(f) 

1743 return decorator