Coverage for src/flag_gems/utils/pointwise_dynamic.py: 94%

1019 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import importlib 

2import os 

3from dataclasses import dataclass 

4from enum import Enum, auto 

5from typing import Callable, Iterable, List, Mapping, Optional, Sequence, Tuple 

6 

7import torch 

8import triton 

9from triton.runtime.jit import JITFunction 

10 

11from flag_gems.utils.code_cache import code_cache_dir 

12from flag_gems.utils.code_utils import IndentedBuffer, write_atomic 

13from flag_gems.utils.codegen_config_utils import CodeGenConfig, get_codegen_config 

14from flag_gems.utils.device_info import get_device_capability 

15from flag_gems.utils.shape_utils import ( 

16 MemOverlap, 

17 all_c_contiguous, 

18 all_the_same_shape, 

19 all_the_same_stride, 

20 broadcast_shapes, 

21 broadcasted_stride, 

22 check_tensor_attributes, 

23 has_internal_overlapping, 

24) 

25from flag_gems.utils.tensor_wrapper import StridedBuffer 

26from flag_gems.utils.type_utils import ELEMENTWISE_TYPE_PROMOTION_KIND, type_promotion 

27 

28 

29# ------------------ Operation Description --------------------------- 

30def _type_name(type) -> str: 

31 "Render typename as string, work for both (bool, int, float, str) and torch.dtype object" 

32 if type in (bool, int, float, str): 

33 return type.__name__ 

34 if isinstance(type, torch.dtype): 

35 return str(type) 

36 return str(type) 

37 

38 

39def _check_typed_list(container, type): 

40 for item in container: 

41 assert isinstance(item, type) 

42 

43 

44def _check_sized_list(container, size): 

45 assert len(container) == size 

46 

47 

48def _tuple_content(strings: Sequence[str]) -> str: 

49 # comma separated list 

50 if len(strings) == 0: 

51 return "" 

52 if len(strings) == 1: 

53 return f"{strings[0]}," 

54 else: 

55 return ", ".join(strings) 

56 

57 

58def _cs(strings: Iterable[str]) -> str: 

59 return ", ".join(strings) 

60 

61 

62def _broadcast_vec(i, ndim): 

63 axes = [":" if j == i else "None" for j in range(ndim)] 

64 return f"[{_cs(axes)}]" 

65 

66 

67class FunctionSchema: 

68 _num_inputs: int 

69 _is_tensor: List[bool] 

70 _dtypes: List[Optional[type]] 

71 

72 _num_input_tensors: int 

73 _num_non_tensor_inputs: int 

74 

75 _num_outputs: int 

76 _promotion_methods: List[Tuple[int, ...]] 

77 

78 def __init__( 

79 self, 

80 *, 

81 num_inputs: Optional[int] = None, 

82 is_tensor: Optional[List[bool]] = None, 

83 dtypes: Optional[List[Optional[type]]] = None, 

84 num_outputs: Optional[int] = None, 

85 promotion_methods=None, 

86 ): 

87 if is_tensor is not None: 

88 _check_typed_list(is_tensor, bool) 

89 if dtypes is not None: 

90 _check_typed_list(dtypes, (type, type(None))) 

91 

92 if promotion_methods is None: 

93 raise ValueError( 

94 "No type promotion method provided! You must provide type promotion method for each output!" 

95 ) 

96 else: 

97 self._promotion_methods = self.canonicalize_promotion_methods( 

98 promotion_methods 

99 ) 

100 if num_inputs is not None: 

101 self._num_inputs = num_inputs 

102 if is_tensor is not None: 

103 _check_sized_list(is_tensor, num_inputs) 

104 self._is_tensor = is_tensor 

105 else: 

106 self._is_tensor = [True] * num_inputs 

107 

108 if dtypes is not None: 

109 _check_sized_list(dtypes, num_inputs) 

110 self._dtypes = dtypes 

111 else: 

112 self._dtypes = [None] * num_inputs 

113 elif is_tensor is not None: 

114 self._num_inputs = len(is_tensor) 

115 self._is_tensor = is_tensor 

116 if dtypes is not None: 

117 _check_sized_list(dtypes, self._num_inputs) 

118 self._dtypes = dtypes 

119 else: 

120 self._dtypes = [None] * self._num_inputs 

121 elif dtypes is not None: 

122 self._num_inputs = len(dtypes) 

123 self._dtypes = dtypes 

124 if is_tensor is not None: 

125 _check_sized_list(is_tensor, self._num_inputs) 

126 self._is_tensor = is_tensor 

127 else: 

128 self._is_tensor = [item is None for item in dtypes] 

129 else: 

130 raise ValueError( 

131 "Cannot create FunctionSchema when none of (num_inputs, is_tensor, dtypes) is specified." 

132 ) 

133 

134 if num_outputs is not None: 

135 self._num_outputs = num_outputs 

136 _check_sized_list(promotion_methods, num_outputs) 

137 else: 

138 self._num_outputs = len(promotion_methods) 

139 

140 assert self._num_inputs >= 1 

141 assert self._num_outputs >= 1 

142 

143 self._num_input_tensors = sum(self._is_tensor) 

144 self._num_non_tensor_inputs = self._num_inputs - self._num_input_tensors 

145 self._input_id = self._compute_input_id() 

146 

147 @staticmethod 

148 def canonicalize_promotion_methods(promotion_methods): 

149 canonicalized = [] 

150 for item in promotion_methods: 

151 *arg_indices, method = item 

152 canonicalized.append( 

153 (*arg_indices, ELEMENTWISE_TYPE_PROMOTION_KIND[method]) 

154 ) 

155 return canonicalized 

156 

157 def num_inputs(self): 

158 # num of arguments, outputs not included 

159 return self._num_inputs 

160 

161 def num_outputs(self): 

162 return self._num_outputs 

163 

164 def is_tensor(self, arg_id: int) -> bool: 

165 return self._is_tensor[arg_id] 

166 

167 def input_type(self, arg_id) -> Optional[type]: 

168 return self._dtypes[arg_id] 

169 

170 def output_type(self, i): 

171 return self._promotion_methods[i] 

172 

173 def num_input_tensors(self) -> int: 

174 return self._num_input_tensors 

175 

176 def num_output_tensors(self) -> int: 

177 return self._num_outputs 

178 

179 def num_non_tensor_args(self) -> int: 

180 return self._num_non_tensor_inputs 

181 

182 def signature(self, outputs_in_arg: bool = False) -> str: 

183 input_types = [] 

184 for is_tensor, dtype in zip(self._is_tensor, self._dtypes): 

185 if is_tensor: 

186 input_types.append("StridedBuffer") 

187 else: 

188 if dtype is None: 

189 input_types.append("scalar") 

190 else: 

191 input_types.append(_type_name(dtype)) 

192 

193 output_types = [] 

194 

195 if outputs_in_arg: 

196 for i in range(self.num_outputs()): 

197 output_types.append(f"StridedBuffer(a{1}!)") 

198 input_types.extend(output_types) 

199 else: 

200 for _ in range(self.num_outputs()): 

201 output_types.append("StridedBuffer") 

202 sig = f'Pointwise: {", ".join(input_types)} -> {", ".join(output_types)}' 

203 return sig 

204 

205 def _compute_input_id(self): 

206 input_tensor_index = 0 

207 non_tensor_index = 0 

208 mapping: List[int] = [] 

209 for i in range(self.num_inputs()): 

210 if self.is_tensor(i): 

211 mapping.append(input_tensor_index) 

212 input_tensor_index += 1 

213 else: 

214 mapping.append(non_tensor_index) 

215 non_tensor_index += 1 

216 return mapping 

217 

218 def input_index(self, idx): 

219 return self._input_id[idx] 

220 

221 def __str__(self) -> str: 

222 return self.signature(outputs_in_arg=False) 

223 

224 

225class KernelGenerator: 

226 def __init__( 

227 self, 

228 function_schema: FunctionSchema, 

229 scalar_fn: triton.JITFunction, 

230 rank: int, 

231 name: str, 

232 config: CodeGenConfig, 

233 ): 

234 self.fx = function_schema 

235 self.fn = scalar_fn 

236 self.ndim = rank 

237 self.name = name 

238 self.config = config 

239 

240 self.fn_name = scalar_fn.__name__ 

241 self.fn_module = scalar_fn.__module__ 

242 

243 def gen_import_function(self, code: IndentedBuffer): 

244 code.writeline("@triton.jit") 

245 code.writemultiline(self.fn.src) 

246 code.newline() 

247 

248 def gen_decorators(self, code): 

249 code.writeline("@libentry()") 

250 num_non_tensor_args = self.fx.num_non_tensor_args() 

251 if num_non_tensor_args > 0: 

252 # we do not specialize non tensor args since they are passed into the inlined function 

253 # which means that their values may not deserve specialization 

254 non_specialize_arg_names = [f"val{i}" for i in range(num_non_tensor_args)] 

255 code.writeline(f"@triton.jit(do_not_specialize={non_specialize_arg_names})") 

256 else: 

257 code.writeline("@triton.jit") 

258 

259 def input_name(self, i): 

260 is_tensor = self.fx.is_tensor(i) 

261 name = "in" if is_tensor else "val" 

262 index = self.fx.input_index(i) 

263 return f"{name}{index}" 

264 

265 def output_name(self, i): 

266 return f"out{i}" 

267 

268 def gen_signature(self, code, with_block_pointer=False): 

269 code.writeline(f"def {self.name}(") 

270 with code.indent(): 

271 input_tensor_index = 0 

272 non_tensor_index = 0 

273 output_tensor_index = 0 

274 

275 schema = self.fx 

276 # signature: inputs ptrs & non tensor inputs 

277 for i in range(schema.num_inputs()): 

278 if schema.is_tensor(i): 

279 code.writeline( 

280 f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" 

281 ) 

282 input_tensor_index += 1 

283 else: 

284 if schema.input_type(i) is not None: 

285 code.writeline( 

286 f"val{non_tensor_index}: {_type_name(schema.input_type(i))}," 

287 ) 

288 else: 

289 code.writeline(f"val{non_tensor_index},") 

290 non_tensor_index += 1 

291 

292 # signature: output ptrs 

293 for i in range(schema.num_outputs()): 

294 code.writeline( 

295 f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" 

296 ) 

297 output_tensor_index += 1 

298 

299 # signature: strides, for each tensor arguments 

300 ndim = self.ndim 

301 if ndim > 0: 

302 # strides for inputs 

303 for i in range(schema.num_input_tensors()): 

304 stride_args = _cs( 

305 f"in{i}_stride{j}: tl.constexpr" for j in range(ndim) 

306 ) 

307 code.writeline(f"{stride_args}, # strides for in{i}") 

308 if with_block_pointer: 

309 stride_order_args = _cs( 

310 f"in{i}_stride_order{j}: tl.constexpr" for j in range(ndim) 

311 ) 

312 code.writeline(f"{stride_order_args}, # stride order for in{i}") 

313 

314 # strides for outputs 

315 for i in range(schema.num_output_tensors()): 

316 stride_args = _cs( 

317 f"out{i}_stride{j}: tl.constexpr" for j in range(ndim) 

318 ) 

319 code.writeline(f"{stride_args}, # strides for out{i}") 

320 if with_block_pointer: 

321 stride_order_args = _cs( 

322 f"out{i}_stride_order{j}: tl.constexpr" for j in range(ndim) 

323 ) 

324 code.writeline( 

325 f"{stride_order_args}, # stride order for out{i}" 

326 ) 

327 

328 # task space, used to reconstruct multi index 

329 task_space_args = _cs(f"s{i}" for i in range(ndim)) 

330 code.writeline(f"{task_space_args}, # task_space") 

331 

332 # number of tasks, used to compute mask 

333 code.writeline("num_tasks,") 

334 

335 # tile size & tiles_per_cta, gsl style 

336 if ndim > 0: 

337 code.writeline("tiles_per_cta: int,") 

338 tile_sizes = _cs(f"tile_size{i}: tl.constexpr" for i in range(ndim)) 

339 code.writeline(f"{tile_sizes},") 

340 code.writeline("one_tile_per_cta: tl.constexpr,") 

341 code.writeline("):") 

342 

343 def gen_signature_1d_tile(self, code): 

344 code.writeline(f"def {self.name}(") 

345 with code.indent(): 

346 input_tensor_index = 0 

347 non_tensor_index = 0 

348 output_tensor_index = 0 

349 

350 schema = self.fx 

351 # signature: inputs ptrs & non tensor inputs 

352 for i in range(schema.num_inputs()): 

353 if schema.is_tensor(i): 

354 code.writeline( 

355 f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" 

356 ) 

357 input_tensor_index += 1 

358 else: 

359 if schema.input_type(i) is not None: 

360 code.writeline( 

361 f"val{non_tensor_index}: {_type_name(schema.input_type(i))}," 

362 ) 

363 else: 

364 code.writeline(f"val{non_tensor_index},") 

365 non_tensor_index += 1 

366 

367 # signature: output ptrs 

368 for i in range(schema.num_outputs()): 

369 code.writeline( 

370 f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" 

371 ) 

372 output_tensor_index += 1 

373 

374 # signature: strides, for each tensor arguments 

375 ndim = self.ndim 

376 if ndim > 0: 

377 # strides for inputs 

378 for i in range(schema.num_input_tensors()): 

379 stride_args = _cs(f"in{i}_stride{j}: int" for j in range(ndim)) 

380 code.writeline(f"{stride_args}, # strides for in{i}") 

381 

382 # strides for outputs 

383 for i in range(schema.num_output_tensors()): 

384 stride_args = _cs(f"out{i}_stride{j}: int" for j in range(ndim)) 

385 code.writeline(f"{stride_args}, # strides for out{i}") 

386 

387 # task space, used to reconstruct multi index 

388 task_space_args = _cs(f"s{i}" for i in range(ndim)) 

389 code.writeline(f"{task_space_args}, # task_space") 

390 

391 # number of tasks, used to compute mask 

392 code.writeline("num_tasks,") 

393 

394 # tile size & tiles_per_cta, gsl style 

395 if ndim > 0: 

396 code.writeline("tiles_per_cta: int,") 

397 code.writeline("tile_size: tl.constexpr,") 

398 code.writeline("one_tile_per_cta: tl.constexpr,") 

399 code.writeline("):") 

400 

401 def gen_num_tiles(self, code): 

402 # tile-grid size 

403 ndim = self.ndim 

404 for i in range(ndim): 

405 if i < ndim: 

406 code.writeline(f"num_tiles{i} = tl.cdiv(s{i}, tile_size{i})") 

407 

408 def gen_body_for_0d(self, code): 

409 schema = self.fx 

410 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())] 

411 outputs_to_scalar_fn = [ 

412 self.output_name(i) for i in range(schema.num_output_tensors()) 

413 ] 

414 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

415 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

416 

417 code.writeline("# loads") 

418 for i in range(schema.num_input_tensors()): 

419 code.writeline( 

420 f"in{i} = tl.load(in{i}_ptr).to(in{i}_ptr.type.element_ty) " 

421 "# workaround the bug on bool, we should use the pointer's dtype)" 

422 ) 

423 code.newline() 

424 

425 code.writeline("# compute") 

426 code.writeline( 

427 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

428 ) 

429 code.newline() 

430 

431 code.writeline("# stores") 

432 for i in range(schema.num_output_tensors()): 

433 code.writeline( 

434 f"tl.store(out{i}_ptr, out{i}.to(out{i}_ptr.type.element_ty))" 

435 ) 

436 code.newline() 

437 return code 

438 

439 # nd tile 1d grid kernel with block pointer 

440 def gen_body_one_tile_per_cta_with_bptr(self, code): 

441 ndim = self.ndim 

442 schema = self.fx 

443 

444 # block pointer for each operand 

445 shape = _tuple_content(tuple(f"s{i}" for i in range(ndim))) 

446 offsets = _tuple_content(tuple(f"offset{i}" for i in range(ndim))) 

447 tile_sizes = _tuple_content(tuple(f"tile_size{i}" for i in range(ndim))) 

448 

449 # reconstruct pid multi index 

450 code.writeline( 

451 "# pid multi index recontruction: we use c ordering, right axes changes fastest" 

452 ) 

453 for i in reversed(range(ndim)): 

454 if i > 0: 

455 code.writeline(f"tile_id{i} = tile_id % num_tiles{i}") 

456 code.writeline(f"tile_id //= num_tiles{i}") 

457 else: 

458 code.writeline(f"tile_id{i} = tile_id") 

459 code.newline() 

460 

461 # cta_offsets 

462 code.writeline("# tile offsets") 

463 for i in range(ndim): 

464 # Or else: AssertionError: Block pointers only support 32 bit 

465 # `offsets/block_shape`, add a `.to(tl.int32)` or use regular indexing 

466 # for 64 bit support 

467 code.writeline(f"offset{i} = (tile_id{i} * tile_size{i}).to(tl.int32)") 

468 

469 # loads 

470 code.writeline("# loads") 

471 for i in range(schema.num_input_tensors()): 

472 strides = _tuple_content(tuple(f"in{i}_stride{j}" for j in range(ndim))) 

473 import flag_gems 

474 

475 if flag_gems.vendor_name == "spacemit": 

476 order = _tuple_content(tuple(f"{ndim - j - 1}" for j in range(ndim))) 

477 else: 

478 order = _tuple_content( 

479 tuple(f"in{i}_stride_order{j}" for j in range(ndim)) 

480 ) 

481 code.writeline( 

482 f"in{i}_bptr = tl.make_block_ptr(" 

483 f"in{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))" 

484 ) 

485 code.writeline( 

486 f"in{i} = tl.load(in{i}_bptr, boundary_check=({order})).to(in{i}_ptr.type.element_ty) " 

487 "# workaround the bug on bool, we should use the original pointer's dtype(instead of block pointer's)" 

488 ) 

489 code.newline() 

490 

491 # compute 

492 # TODO: sepearate this part 

493 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())] 

494 outputs_to_scalar_fn = [ 

495 self.output_name(i) for i in range(schema.num_output_tensors()) 

496 ] 

497 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

498 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

499 

500 code.writeline("# compute") 

501 code.writeline( 

502 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

503 ) 

504 code.newline() 

505 

506 # stores 

507 code.writeline( 

508 "# stores, note that store to block pointer does not automatically cast the value to the pointer's dtype" 

509 ) 

510 for i in range(schema.num_output_tensors()): 

511 strides = _tuple_content(tuple(f"out{i}_stride{j}" for j in range(ndim))) 

512 order = _tuple_content( 

513 tuple(f"out{i}_stride_order{j}" for j in range(ndim)) 

514 ) 

515 code.writeline( 

516 f"out{i}_bptr = tl.make_block_ptr(" 

517 f"out{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))" 

518 ) 

519 code.writeline( 

520 f"tl.store(out{i}_bptr, out{i}.to(out{i}_bptr.type.element_ty), boundary_check=({order}))" 

521 ) 

522 

523 def gen_body_gsl_with_bptr(self, code): 

524 code.writeline("num_ctas = ext.num_programs(0)") 

525 code.writeline("for j in range(0, tiles_per_cta):") 

526 with code.indent(): 

527 code.writeline("tile_id = pid + j * num_ctas") 

528 self.gen_body_one_tile_per_cta_with_bptr(code) 

529 

530 def gen_body_one_tile_per_cta_without_bptr(self, code): 

531 ndim = self.ndim 

532 schema = self.fx 

533 

534 # reconstruct pid multi index 

535 code.writeline( 

536 "# pid multi index recontruction: we use c ordering, right axes changes fastest" 

537 ) 

538 for i in reversed(range(ndim)): 

539 if i > 0: 

540 code.writeline(f"tile_id{i} = tile_id % num_tiles{i}") 

541 code.writeline(f"tile_id //= num_tiles{i}") 

542 else: 

543 code.writeline(f"tile_id{i} = tile_id") 

544 code.newline() 

545 

546 # offsets 

547 for i in range(ndim): 

548 code.writeline( 

549 f"offsets{i} = tile_id{i} * tile_size{i} + tl.arange(0, tile_size{i})" 

550 ) 

551 

552 # masks 

553 for i in range(ndim): 

554 code.writeline(f"mask{i} = offsets{i} < s{i}") 

555 masks = tuple(f"mask{i}{_broadcast_vec(i, ndim)}" for i in range(ndim)) 

556 mask_combine = " & ".join(masks) 

557 code.writeline(f"mask = {mask_combine}") 

558 

559 # loads 

560 code.writeline("# loads") 

561 for i in range(schema.num_input_tensors()): 

562 offsets = tuple( 

563 f"offsets{j}{_broadcast_vec(j, ndim)} * in{i}_stride{j}" 

564 for j in range(ndim) 

565 ) 

566 offset_combine = " + ".join(offsets) 

567 code.writeline( 

568 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)" 

569 ) 

570 

571 code.newline() 

572 

573 # compute 

574 # TODO: sepearate this part 

575 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())] 

576 outputs_to_scalar_fn = [ 

577 self.output_name(i) for i in range(schema.num_output_tensors()) 

578 ] 

579 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

580 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

581 

582 code.writeline("# compute") 

583 code.writeline( 

584 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

585 ) 

586 code.newline() 

587 

588 # stores 

589 for i in range(schema.num_output_tensors()): 

590 offsets = tuple( 

591 f"offsets{j}{_broadcast_vec(j, ndim)} * out{i}_stride{j}" 

592 for j in range(ndim) 

593 ) 

594 offset_combine = " + ".join(offsets) 

595 code.writeline( 

596 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)" 

597 ) 

598 

599 def gen_body_gsl_without_bptr(self, code): 

600 code.writeline("num_ctas = ext.num_programs(0)") 

601 code.writeline("for j in range(0, tiles_per_cta):") 

602 with code.indent(): 

603 code.writeline("tile_id = pid + j * num_ctas") 

604 self.gen_body_one_tile_per_cta_without_bptr(code) 

605 

606 def codegen_nd_tile_with_bptr(self, code): 

607 """Generate kernel nd tile & 1d grid with gsl support with block pointer.""" 

608 self.gen_import_function(code) 

609 self.gen_decorators(code) 

610 self.gen_signature(code, with_block_pointer=True) 

611 

612 # function body for rank-0 

613 if self.ndim == 0: 

614 with code.indent(): 

615 self.gen_body_for_0d(code) 

616 return code 

617 

618 with code.indent(): 

619 code.writeline("pid = ext.program_id(0)") 

620 self.gen_num_tiles(code) 

621 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute 

622 code.writeline("if one_tile_per_cta: # monolitic kernel style") 

623 with code.indent(): 

624 code.writeline("tile_id = pid") 

625 self.gen_body_one_tile_per_cta_with_bptr(code) 

626 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/ 

627 code.writeline("else: # grid-stride-loop style kernel") 

628 with code.indent(): 

629 self.gen_body_gsl_with_bptr(code) 

630 code.newline() 

631 return code 

632 

633 def codegen_nd_tile_without_bptr(self, code): 

634 self.gen_import_function(code) 

635 self.gen_decorators(code) 

636 self.gen_signature(code, with_block_pointer=False) 

637 

638 # function body for rank-0 

639 if self.ndim == 0: 

640 with code.indent(): 

641 self.gen_body_for_0d(code) 

642 return code 

643 

644 with code.indent(): 

645 code.writeline("pid = ext.program_id(0)") 

646 self.gen_num_tiles(code) 

647 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute 

648 code.writeline("if one_tile_per_cta: # monolitic kernel style") 

649 with code.indent(): 

650 code.writeline("tile_id = pid") 

651 self.gen_body_one_tile_per_cta_without_bptr(code) 

652 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/ 

653 code.writeline("else: # grid-stride-loop style kernel") 

654 with code.indent(): 

655 self.gen_body_gsl_without_bptr(code) 

656 code.newline() 

657 return code 

658 

659 def codegen_nd_tile(self, code): 

660 use_block_pointer = self.config.prefer_block_pointer 

661 if use_block_pointer: 

662 self.codegen_nd_tile_with_bptr(code) 

663 else: 

664 self.codegen_nd_tile_without_bptr(code) 

665 return code 

666 

667 def gen_body_one_tile_per_cta_1d_tile(self, code): 

668 ndim = self.ndim 

669 schema = self.fx 

670 

671 # tile id 

672 code.writeline("tid = tile_id * tile_size + tl.arange(0, tile_size)") 

673 code.writeline("mask = tid < num_tasks") 

674 

675 # multi index reconstruction 

676 for i in reversed(range(ndim)): 

677 if i > 0: 

678 code.writeline(f"i{i} = tid % s{i}") 

679 code.writeline(f"tid //= s{i}") 

680 else: 

681 code.writeline(f"i{i} = tid") 

682 code.newline() 

683 

684 # loads 

685 code.writeline("# loads") 

686 for i in range(schema.num_input_tensors()): 

687 offsets = tuple(f"i{j} * in{i}_stride{j}" for j in range(ndim)) 

688 offset_combine = " + ".join(offsets) 

689 code.writeline( 

690 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)" 

691 ) 

692 

693 code.newline() 

694 

695 # compute 

696 # TODO: sepearate this part 

697 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())] 

698 outputs_to_scalar_fn = [ 

699 self.output_name(i) for i in range(schema.num_output_tensors()) 

700 ] 

701 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

702 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

703 

704 code.writeline("# compute") 

705 code.writeline( 

706 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

707 ) 

708 code.newline() 

709 

710 # stores 

711 for i in range(schema.num_output_tensors()): 

712 offsets = tuple(f"i{j} * out{i}_stride{j}" for j in range(ndim)) 

713 offset_combine = " + ".join(offsets) 

714 code.writeline( 

715 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)" 

716 ) 

717 

718 def gen_body_gsl_1d_tile(self, code): 

719 code.writeline("num_ctas = ext.num_programs(0)") 

720 code.writeline("for j in range(0, tiles_per_cta):") 

721 with code.indent(): 

722 code.writeline("tile_id = pid + j * num_ctas") 

723 self.gen_body_one_tile_per_cta_1d_tile(code) 

724 

725 def codegen_1d_tile(self, code): 

726 """Generate kernel 1d tile & 1d grid with gsl support.""" 

727 self.gen_import_function(code) 

728 self.gen_decorators(code) 

729 self.gen_signature_1d_tile(code) 

730 

731 # function body for rank-0 

732 if self.ndim == 0: 

733 with code.indent(): 

734 self.gen_body_for_0d(code) 

735 return code 

736 

737 with code.indent(): 

738 code.writeline("pid = ext.program_id(0)") 

739 # code.writeline("num_ctas = te.num_programs(0)") 

740 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute 

741 code.writeline("if one_tile_per_cta: # monolitic kernel style") 

742 with code.indent(): 

743 code.writeline("tile_id = pid") 

744 self.gen_body_one_tile_per_cta_1d_tile(code) 

745 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/ 

746 code.writeline("else: # grid-stride-loop style kernel") 

747 with code.indent(): 

748 self.gen_body_gsl_1d_tile(code) 

749 code.newline() 

750 return code 

751 

752 

753class WrapperGenerator: 

754 def __init__( 

755 self, 

756 function_schema: FunctionSchema, 

757 jit_fn_name: str, 

758 ndim: int, 

759 name: str, 

760 config: CodeGenConfig, 

761 ): 

762 self.fx = function_schema 

763 self.jit_fn_name = jit_fn_name 

764 self.ndim = ndim 

765 self.name = name 

766 self.config = config 

767 

768 def input_name(self, i): 

769 is_tensor = self.fx.is_tensor(i) 

770 name = "in" if is_tensor else "val" 

771 index = self.fx.input_index(i) 

772 return f"{name}{index}" 

773 

774 def output_name(self, i): 

775 return f"out{i}" 

776 

777 def gen_signature(self, code: IndentedBuffer): 

778 # TODO: check if triton handles constexprs transitively 

779 schema = self.fx 

780 params: List[str] = [] 

781 for i in range(schema.num_inputs()): 

782 if schema.is_tensor(i): 

783 params.append( 

784 f"{self.input_name(i)}: Union[torch.Tensor, StridedBuffer]" 

785 ) 

786 else: 

787 arg_type = schema.input_type(i) 

788 if arg_type is not None: 

789 params.append(f"{self.input_name(i)}: {_type_name(arg_type)}") 

790 else: 

791 params.append(f"{self.input_name(i)}") 

792 # NOTE: [the wrapper's signature and rules for passing parameters ] 

793 # input params: must be passed by position, since the names are renamed to 

794 # in0, in1, val0, val1, ..., So passing these parameters by keyword is wierd 

795 # So we enforce that these parameters must be passed by position. 

796 # maybe we can fix it later 

797 # output parameters: must be passed by keyword, since the scalar function 

798 # do not have output parameters(think of it as some scalar function, output 

799 # parameter does not make sense in this case.) They are added to allow destination 

800 # passing style API. Output parameter is convenient in cases where we want 

801 # to use some pre-defiend outputs(especially when they are some views of other 

802 # tensors). We emphasize that these parameters are added in-addition, we enforce 

803 # that they be passed by keyword. After all, out0, out1, ... does not mismatch 

804 # names form the scalar function, since it does not have output parameters. 

805 params.append("/") 

806 params.append("*") # output params must be passed by keyword 

807 

808 for i in range(schema.num_output_tensors()): 

809 params.append(f"{self.output_name(i)}: Union[torch.Tensor, StridedBuffer]") 

810 code.writeline(f"def {self.name}({_cs(params)}): ") 

811 

812 def gen_docstring(self, code: IndentedBuffer): 

813 schema = self.fx 

814 doc = f'"""Generated wrapper function with {schema.signature(outputs_in_arg=True)}"""' 

815 code.writeline(doc) 

816 

817 def gen_same_shape_check(self, code: IndentedBuffer): 

818 schema: FunctionSchema = self.fx 

819 params = [f"in{i}.shape" for i in range(schema.num_input_tensors())] + [ 

820 f"out{i}.shape" for i in range(schema.num_output_tensors()) 

821 ] 

822 check: str = " == ".join(params) 

823 code.writeline(f"assert {check}, 'operand shapes mismatch'") 

824 

825 def gen_task_partition(self, code: IndentedBuffer): 

826 code.writeline("# task partitioning") 

827 ndim = self.ndim 

828 if ndim == 0: 

829 code.writeline("num_warps = 1") 

830 code.writeline("num_ctas = 1") 

831 else: 

832 code.writeline("shape = out0.shape") 

833 code.writeline("num_tasks = out0.numel()") 

834 code.writeline("if num_tasks == 0:") 

835 with code.indent(): 

836 self.gen_return(code) 

837 max_tile_size = self.config.max_tile_size 

838 # Check if all input and output dtypes are complex 

839 all_complex = True 

840 for i in range(self.fx.num_inputs()): 

841 if self.fx.is_tensor(i): 

842 input_dtype = self.fx.input_type(i) 

843 if input_dtype is not None and not ( 

844 input_dtype == torch.complex64 

845 or input_dtype == torch.complex128 

846 ): 

847 all_complex = False 

848 break 

849 if all_complex: 

850 # If all inputs are complex, set max_tile_size to half 

851 max_tile_size = max_tile_size // 2 

852 major, _ = get_device_capability() 

853 if self.name.find("fill_scalar") != -1 and major >= 9: 

854 code.writeline("tile_sizes = tuple([64])") 

855 else: 

856 code.writeline( 

857 f"tile_sizes = heuristics_for_tile_size({max_tile_size}, *shape)" 

858 ) 

859 code.writeline("tile_size = math.prod(tile_sizes)") 

860 code.writeline( 

861 "num_tiles = math.prod(triton.cdiv(size, tile_size) for size, tile_size in zip(shape, tile_sizes))" 

862 ) 

863 

864 if self.name.find("fill_scalar") != -1 and major >= 9: 

865 code.writeline("num_ctas = num_tiles") 

866 else: 

867 max_grid_size0 = self.config.max_grid_size[0] 

868 code.writeline(f"num_ctas = min({max_grid_size0}, num_tiles)") 

869 

870 code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)") 

871 code.writeline("num_warps = heuristics_for_num_warps(tile_size)") 

872 code.writeline("one_tile_per_cta = tiles_per_cta==1") 

873 code.writeline("grid = (num_ctas, 1, 1)") 

874 

875 def gen_task_partition_1d(self, code: IndentedBuffer): 

876 code.writeline("# task partitioning") 

877 ndim = self.ndim 

878 if ndim == 0: 

879 code.writeline("num_warps = 1") 

880 code.writeline("num_ctas = 1") 

881 else: 

882 code.writeline("shape = out0.shape") 

883 code.writeline("num_tasks = out0.numel()") 

884 code.writeline("if num_tasks == 0:") 

885 with code.indent(): 

886 self.gen_return(code) 

887 max_tile_size = self.config.max_tile_size 

888 # Check if all input and output dtypes are complex 

889 all_complex = True 

890 for i in range(self.fx.num_inputs()): 

891 if self.fx.is_tensor(i): 

892 input_dtype = self.fx.input_type(i) 

893 if input_dtype is not None and not ( 

894 input_dtype == torch.complex64 

895 or input_dtype == torch.complex128 

896 ): 

897 all_complex = False 

898 break 

899 if all_complex: 

900 max_tile_size = max_tile_size // 2 

901 major, _ = get_device_capability() 

902 if self.name.find("fill_scalar") != -1 and major >= 9: 

903 code.writeline("tile_sizes = tuple([1024])") 

904 else: 

905 code.writeline( 

906 f"tile_sizes = heuristics_for_tile_size({max_tile_size}, num_tasks)" 

907 ) 

908 

909 code.writeline("tile_size = tile_sizes[0]") 

910 code.writeline("num_tiles = triton.cdiv(num_tasks, tile_size)") 

911 

912 if self.name.find("fill_scalar") != -1 and major >= 9: 

913 code.writeline("num_ctas = num_tiles") 

914 else: 

915 max_grid_size0 = self.config.max_grid_size[0] 

916 code.writeline(f"num_ctas = min({max_grid_size0}, num_tiles)") 

917 

918 code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)") 

919 code.writeline("num_warps = heuristics_for_num_warps(tile_size)") 

920 code.writeline("one_tile_per_cta = tiles_per_cta==1") 

921 code.writeline("grid = (num_ctas, 1, 1)") 

922 

923 def gen_kernel_launch( 

924 self, 

925 code: IndentedBuffer, 

926 ): 

927 schema = self.fx 

928 ndim = self.ndim 

929 

930 with_block_pointer = self.config.prefer_block_pointer 

931 

932 code.writeline("# kernel launch") 

933 for i in range(schema.num_input_tensors()): 

934 code.writeline(f"in{i}_strides = in{i}.stride()") 

935 if not with_block_pointer: 

936 continue 

937 if ndim >= 2: # where ndim is 1, we don't need to compute stride order 

938 code.writeline(f"in{i}_stride_order = stride_order(in{i}_strides)") 

939 else: 

940 code.writeline(f"in{i}_stride_order = (0,)") 

941 for i in range(schema.num_output_tensors()): 

942 code.writeline(f"out{i}_strides = out{i}.stride()") 

943 if not with_block_pointer: 

944 continue 

945 if ndim >= 2: 

946 code.writeline(f"out{i}_stride_order = stride_order(out{i}_strides)") 

947 else: 

948 code.writeline(f"out{i}_stride_order = (0,)") 

949 

950 code.writeline("with torch_device_fn.device(in0.device.index):") 

951 with code.indent(): 

952 code.writeline(f"{self.jit_fn_name}[grid](") 

953 with code.indent(): 

954 params = [] 

955 # NOTE: WRAP 

956 for i in range(schema.num_inputs()): 

957 if schema.is_tensor(i): 

958 params.append(f"{self.input_name(i)}") 

959 else: 

960 params.append(self.input_name(i)) 

961 for i in range(schema.num_output_tensors()): 

962 params.append(f"{self.output_name(i)}") 

963 

964 code.writeline(f"{_cs(params)},") 

965 

966 if ndim > 0: 

967 for i in range(schema.num_input_tensors()): 

968 s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim)) 

969 code.writeline(f"{s}, # stride for in{i}") 

970 if not with_block_pointer: 

971 continue 

972 order = ", ".join( 

973 f"in{i}_stride_order[{j}]" for j in range(ndim) 

974 ) 

975 code.writeline(f"{order}, # stride order for in{i}") 

976 

977 for i in range(schema.num_output_tensors()): 

978 s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim)) 

979 code.writeline(f"{s}, # stride for out{i}") 

980 if not with_block_pointer: 

981 continue 

982 order = ", ".join( 

983 f"out{i}_stride_order[{j}]" for j in range(ndim) 

984 ) 

985 code.writeline(f"{order}, # stride orderfor out{i}") 

986 

987 shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim)) 

988 code.writeline(f"{shape_args}, # task indexing space") 

989 code.writeline("num_tasks, # num tasks") 

990 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta") 

991 for i in range(ndim): 

992 code.writeline(f"tile_size{i}=tile_sizes[{i}],") 

993 code.writeline("one_tile_per_cta=one_tile_per_cta,") 

994 code.writeline("num_warps=num_warps,") 

995 code.writeline(")") 

996 

997 def gen_kernel_launch_1d( 

998 self, 

999 code: IndentedBuffer, 

1000 ): 

1001 schema = self.fx 

1002 ndim = self.ndim 

1003 

1004 code.writeline("# kernel launch") 

1005 for i in range(schema.num_input_tensors()): 

1006 code.writeline(f"in{i}_strides = in{i}.stride()") 

1007 for i in range(schema.num_output_tensors()): 

1008 code.writeline(f"out{i}_strides = out{i}.stride()") 

1009 

1010 code.writeline("with torch_device_fn.device(in0.device.index):") 

1011 with code.indent(): 

1012 code.writeline(f"{self.jit_fn_name}[grid](") 

1013 with code.indent(): 

1014 params = [] 

1015 # NOTE: WRAP 

1016 for i in range(schema.num_inputs()): 

1017 if schema.is_tensor(i): 

1018 params.append(f"{self.input_name(i)}") 

1019 else: 

1020 params.append(self.input_name(i)) 

1021 for i in range(schema.num_output_tensors()): 

1022 params.append(f"{self.output_name(i)}") 

1023 

1024 code.writeline(f"{_cs(params)},") 

1025 

1026 if ndim > 0: 

1027 for i in range(schema.num_input_tensors()): 

1028 s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim)) 

1029 code.writeline(f"{s}, # stride for in{i}") 

1030 for i in range(schema.num_output_tensors()): 

1031 s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim)) 

1032 code.writeline(f"{s}, # stride for out{i}") 

1033 

1034 shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim)) 

1035 code.writeline(f"{shape_args}, # task indexing space") 

1036 code.writeline("num_tasks, # num tasks") 

1037 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta") 

1038 code.writeline("tile_size=tile_size,") 

1039 code.writeline("one_tile_per_cta=one_tile_per_cta,") 

1040 code.writeline("num_warps=num_warps,") 

1041 code.writeline(")") 

1042 

1043 def gen_return(self, code: IndentedBuffer): 

1044 return_exprs = _cs(f"out{i}" for i in range(self.fx.num_output_tensors())) 

1045 code.writeline(f"return {return_exprs}") 

1046 

1047 def codegen_nd_tile(self, code): 

1048 self.gen_signature(code) 

1049 

1050 with code.indent(): 

1051 self.gen_docstring(code) 

1052 self.gen_same_shape_check(code) 

1053 self.gen_task_partition(code) 

1054 self.gen_kernel_launch(code) 

1055 self.gen_return(code) 

1056 code.newline() 

1057 return code 

1058 

1059 def codegen_1d_tile(self, code): 

1060 self.gen_signature(code) 

1061 

1062 with code.indent(): 

1063 self.gen_docstring(code) 

1064 self.gen_same_shape_check(code) 

1065 self.gen_task_partition_1d(code) 

1066 self.gen_kernel_launch_1d(code) 

1067 self.gen_return(code) 

1068 code.newline() 

1069 return code 

1070 

1071 

1072class ModuleGenerator: 

1073 def __init__( 

1074 self, 

1075 function_schema: FunctionSchema, 

1076 scalar_fn: triton.JITFunction, 

1077 ndim: int, 

1078 jit_fn_name: str, 

1079 wrapper_name: str, 

1080 config: CodeGenConfig, 

1081 ): 

1082 self.config = config 

1083 self.scalar_fn = scalar_fn 

1084 self.wrapper_gen = WrapperGenerator( 

1085 function_schema, jit_fn_name, ndim, wrapper_name, config 

1086 ) 

1087 self.kernel_gen = KernelGenerator( 

1088 function_schema, scalar_fn, ndim, jit_fn_name, config 

1089 ) 

1090 

1091 @staticmethod 

1092 def _collect_jit_deps(scalar_fn): 

1093 """Collect extra imports and local @triton.jit helper sources. 

1094 

1095 Parses the source module where scalar_fn is defined using AST. 

1096 Returns a tuple of: 

1097 - extra_imports: dict of module_path -> set of names 

1098 - local_sources: list of source strings for local @triton.jit 

1099 functions (those NOT decorated with @pointwise_dynamic) 

1100 """ 

1101 import ast 

1102 import inspect 

1103 

1104 py_fn = getattr(scalar_fn, "fn", scalar_fn) 

1105 module_name = getattr(py_fn, "__module__", None) 

1106 if not module_name: 

1107 return {}, [] 

1108 try: 

1109 mod = importlib.import_module(module_name) 

1110 source_file = inspect.getfile(mod) 

1111 except (ImportError, TypeError, OSError): 

1112 return {}, [] 

1113 try: 

1114 with open(source_file) as f: 

1115 module_source = f.read() 

1116 source_lines = module_source.splitlines(keepends=True) 

1117 tree = ast.parse(module_source) 

1118 except (OSError, SyntaxError): 

1119 return {}, [] 

1120 

1121 # Collect non-standard import-from lines 

1122 ALREADY_IMPORTED = { 

1123 "math", 

1124 "typing", 

1125 "torch", 

1126 "triton", 

1127 "triton.language", 

1128 "flag_gems.utils.shape_utils", 

1129 "flag_gems.utils.tensor_wrapper", 

1130 "flag_gems.utils.libentry", 

1131 "flag_gems.utils", 

1132 "flag_gems.runtime", 

1133 "flag_gems.utils.pointwise_dynamic", 

1134 } 

1135 extra_imports = {} 

1136 for node in ast.iter_child_nodes(tree): 

1137 if isinstance(node, ast.ImportFrom) and node.module: 

1138 if node.module in ALREADY_IMPORTED: 

1139 continue 

1140 names = {alias.name for alias in node.names} 

1141 extra_imports.setdefault(node.module, set()).update(names) 

1142 

1143 # Collect local @triton.jit functions (without @pointwise_dynamic) 

1144 def _has_decorator(func_node, name): 

1145 for dec in func_node.decorator_list: 

1146 src = "".join(source_lines[dec.lineno - 1 : dec.end_lineno]) 

1147 if name in src: 

1148 return True 

1149 return False 

1150 

1151 def _extract_source(func_node): 

1152 start = func_node.lineno - 1 

1153 if func_node.decorator_list: 

1154 start = func_node.decorator_list[0].lineno - 1 

1155 end = func_node.end_lineno 

1156 return "".join(source_lines[start:end]) 

1157 

1158 local_sources = [] 

1159 for node in ast.iter_child_nodes(tree): 

1160 if not isinstance(node, ast.FunctionDef): 

1161 continue 

1162 if not _has_decorator(node, "triton.jit") and not _has_decorator( 

1163 node, "jit" 

1164 ): 

1165 continue 

1166 if _has_decorator(node, "pointwise_dynamic"): 

1167 continue 

1168 local_sources.append(_extract_source(node)) 

1169 

1170 return extra_imports, local_sources 

1171 

1172 def generate_imports(self, code: IndentedBuffer) -> IndentedBuffer: 

1173 code.writeline("import math") 

1174 code.writeline("from typing import Union") 

1175 code.writeline("import torch") 

1176 code.writeline("import triton") 

1177 code.writeline("from triton import language as tl") 

1178 code.newline() 

1179 code.writeline("from flag_gems.utils.shape_utils import (") 

1180 code.writeline(" heuristics_for_tile_size,") 

1181 code.writeline(" heuristics_for_num_warps,") 

1182 code.writeline(" stride_order,") 

1183 code.writeline(")") 

1184 code.writeline("from flag_gems.utils.tensor_wrapper import StridedBuffer") 

1185 code.writeline("from flag_gems.utils.libentry import libentry") 

1186 code.writeline("from flag_gems.utils import triton_lang_extension as ext") 

1187 code.writeline("from flag_gems.runtime import torch_device_fn") 

1188 

1189 # Generate extra imports and local JIT deps of the scalar function 

1190 jit_dep_imports, local_jit_sources = self._collect_jit_deps(self.scalar_fn) 

1191 for module_path, names in sorted(jit_dep_imports.items()): 

1192 sorted_names = ", ".join(sorted(names)) 

1193 code.writeline(f"from {module_path} import {sorted_names}") 

1194 

1195 code.newline() 

1196 code.newline() 

1197 

1198 # Emit local @triton.jit helper functions 

1199 for source in local_jit_sources: 

1200 for line in source.splitlines(): 

1201 code.writeline(line) 

1202 code.newline() 

1203 

1204 return code 

1205 

1206 def codegen(self, code: IndentedBuffer): 

1207 code = self.generate_imports(code) 

1208 if self.config.prefer_1d_tile: 

1209 code = self.wrapper_gen.codegen_1d_tile(code) 

1210 code = self.kernel_gen.codegen_1d_tile(code) 

1211 else: 

1212 code = self.wrapper_gen.codegen_nd_tile(code) 

1213 code = self.kernel_gen.codegen_nd_tile(code) 

1214 return code 

1215 

1216 

1217@dataclass 

1218class KernelInfo: 

1219 """Information about a generated kernel for C++ integration.""" 

1220 

1221 file_path: str 

1222 kernel_name: str 

1223 wrapper_name: str 

1224 ndim: int 

1225 

1226 

1227class ComplexMode(Enum): 

1228 NONE = auto() 

1229 ELEMENTWISE = auto() # add/sub: view_as_real → same kernel → view_as_complex 

1230 CROSS = auto() # mul/div: split ar/ai/br/bi → cross_kernel 

1231 

1232 

1233@dataclass 

1234class ComplexStrategy: 

1235 mode: ComplexMode = ComplexMode.NONE 

1236 cross_kernel: object = None 

1237 tensorize_scalars: bool = False 

1238 fallback_target: object = None 

1239 

1240 

1241_REAL_TO_COMPLEX = { 

1242 torch.float16: torch.complex32, 

1243 torch.bfloat16: torch.complex32, 

1244 torch.float32: torch.complex64, 

1245 torch.float64: torch.complex128, 

1246} 

1247 

1248 

1249class PointwiseDynamicFunction: 

1250 """Utility to generate function for general pointwise operation. It generate wrapper & JITFunction 

1251 which are specialized according to the rank of the task space(the broadcasted shape of all input tensors). 

1252 The generated code are written out to the cache directory (defaults to ~/.flaggems). 

1253 """ 

1254 

1255 def __init__(self, op_desc: FunctionSchema, scalar_fn: JITFunction, config=None): 

1256 self.fx = op_desc 

1257 

1258 assert isinstance(scalar_fn, JITFunction) 

1259 self._scalar_fn = scalar_fn 

1260 self._scalar_fn_cache_key = scalar_fn.cache_key 

1261 self.pid = os.getpid() 

1262 

1263 self.config: CodeGenConfig = config or get_codegen_config() 

1264 

1265 # instantiated & cached overloads 

1266 self.overloads: Mapping[str, Callable] = {} 

1267 # cached kernel info for C++ integration 

1268 self._kernel_info_cache: Mapping[str, KernelInfo] = {} 

1269 

1270 # complex dispatch support 

1271 self.complex_strategy = ComplexStrategy() 

1272 self._operand_indices = self._infer_operand_indices() 

1273 

1274 # -------------------- operand index inference -------------------- 

1275 

1276 def _infer_operand_indices(self): 

1277 """Infer operand indices from schema._promotion_methods, done once at init.""" 

1278 indices = set() 

1279 for pm in self.fx._promotion_methods: 

1280 for idx in pm[:-1]: 

1281 indices.add(idx) 

1282 return frozenset(indices) 

1283 

1284 # -------------------- register_complex -------------------- 

1285 

1286 def register_complex( 

1287 self, mode, cross_kernel=None, tensorize_scalars=False, fallback_target=None 

1288 ): 

1289 """Register complex number support for this kernel. 

1290 

1291 Args: 

1292 mode: ComplexMode.ELEMENTWISE (add/sub) or ComplexMode.CROSS (mul/div). 

1293 cross_kernel: A PointwiseDynamicFunction for cross-term ops (mul/div). 

1294 tensorize_scalars: If True, scalar operands are converted to tensors 

1295 before delegating to fallback_target. 

1296 fallback_target: A PointwiseDynamicFunction (tensor-tensor version) 

1297 to delegate to after tensorizing scalar operands. 

1298 """ 

1299 self.complex_strategy = ComplexStrategy( 

1300 mode=mode, 

1301 cross_kernel=cross_kernel, 

1302 tensorize_scalars=tensorize_scalars, 

1303 fallback_target=fallback_target, 

1304 ) 

1305 return self 

1306 

1307 # -------------------- call entry -------------------- 

1308 

1309 def __call__(self, *args, **kwargs): 

1310 if self._should_use_complex_path(args): 

1311 return self._call_complex_dispatch(*args, **kwargs) 

1312 return self._call_real_impl(*args, **kwargs) 

1313 

1314 def _call_real_impl(self, *args, **kwargs): 

1315 """Single entry point for real kernel invocation.""" 

1316 ndim, args, kwargs = self.prepare_args(*args, **kwargs) 

1317 overload = self.instantiate(ndim) 

1318 out = overload(*args, **kwargs) 

1319 return self._unwrap(out) 

1320 

1321 # -------------------- complex helpers -------------------- 

1322 

1323 @staticmethod 

1324 def _is_complex_arg(a): 

1325 return (isinstance(a, torch.Tensor) and a.is_complex()) or isinstance( 

1326 a, complex 

1327 ) 

1328 

1329 def _should_use_complex_path(self, args): 

1330 if self.complex_strategy.mode == ComplexMode.NONE: 

1331 return False 

1332 return any( 

1333 self._is_complex_arg(args[i]) 

1334 for i in self._operand_indices 

1335 if i < len(args) 

1336 ) 

1337 

1338 def _split_args(self, args): 

1339 """Split args into operands and others by original position index.""" 

1340 operands = {} 

1341 others = {} 

1342 for i, a in enumerate(args): 

1343 if i in self._operand_indices: 

1344 operands[i] = a 

1345 else: 

1346 others[i] = a 

1347 return operands, others 

1348 

1349 def _merge_args(self, operands, others): 

1350 """Rebuild args tuple from operands and others by original position index.""" 

1351 total = len(operands) + len(others) 

1352 merged = [None] * total 

1353 for i, v in operands.items(): 

1354 merged[i] = v 

1355 for i, v in others.items(): 

1356 merged[i] = v 

1357 return tuple(merged) 

1358 

1359 def _classify_complex_inputs(self, operands): 

1360 """Classify operands as 'all_complex', 'mixed', or 'real'.""" 

1361 complex_count = sum(1 for v in operands.values() if self._is_complex_arg(v)) 

1362 if complex_count == len(operands): 

1363 return "all_complex" 

1364 elif complex_count > 0: 

1365 return "mixed" 

1366 return "real" 

1367 

1368 def _infer_device(self, operands): 

1369 for v in operands.values(): 

1370 if isinstance(v, torch.Tensor): 

1371 return v.device 

1372 return None 

1373 

1374 def _infer_complex_dtype(self, operands): 

1375 return torch.result_type(*operands.values()) 

1376 

1377 def _tensorize_scalar_operands(self, operands, dtype, device): 

1378 """Convert scalar operands to tensors.""" 

1379 result = {} 

1380 for i, v in operands.items(): 

1381 if not isinstance(v, torch.Tensor): 

1382 if isinstance(v, complex): 

1383 result[i] = torch.tensor(v, dtype=dtype, device=device) 

1384 elif isinstance(v, float): 

1385 result[i] = torch.tensor(v, dtype=torch.float32, device=device) 

1386 elif isinstance(v, (int, bool)): 

1387 result[i] = torch.tensor(v, dtype=torch.int64, device=device) 

1388 else: 

1389 result[i] = v 

1390 else: 

1391 result[i] = v 

1392 return result 

1393 

1394 def _to_complex_tensor(self, a, target_dtype, device): 

1395 """Convert a scalar or real tensor to a complex tensor.""" 

1396 if isinstance(a, torch.Tensor): 

1397 if a.is_complex(): 

1398 return a 

1399 if a.is_floating_point(): 

1400 cdtype = _REAL_TO_COMPLEX.get(a.dtype, torch.complex64) 

1401 else: 

1402 a = a.to(torch.float32) 

1403 cdtype = torch.complex64 

1404 return torch.complex(a, torch.zeros_like(a)).to(cdtype) 

1405 elif isinstance(a, complex): 

1406 return torch.tensor(a, dtype=target_dtype, device=device) 

1407 elif isinstance(a, (int, float)): 

1408 return torch.tensor(complex(a, 0), dtype=target_dtype, device=device) 

1409 return a 

1410 

1411 # -------------------- complex dispatch -------------------- 

1412 

1413 def _call_complex_dispatch(self, *args, **kwargs): 

1414 """Unified complex dispatch entry point.""" 

1415 strategy = self.complex_strategy 

1416 operands, others = self._split_args(args) 

1417 

1418 device = self._infer_device(operands) 

1419 result_dtype = self._infer_complex_dtype(operands) 

1420 

1421 # tensorize scalar operands and delegate to fallback_target 

1422 if strategy.tensorize_scalars and strategy.fallback_target is not None: 

1423 operands = self._tensorize_scalar_operands(operands, result_dtype, device) 

1424 new_args = self._merge_args(operands, others) 

1425 return strategy.fallback_target(*new_args, **kwargs) 

1426 

1427 # convert all operands to complex tensors 

1428 for i in list(operands.keys()): 

1429 operands[i] = self._to_complex_tensor(operands[i], result_dtype, device) 

1430 

1431 # broadcast complex tensor operands 

1432 complex_tensors = [operands[i] for i in sorted(operands.keys())] 

1433 complex_tensors = torch.broadcast_tensors(*complex_tensors) 

1434 for idx, key in enumerate(sorted(operands.keys())): 

1435 operands[key] = complex_tensors[idx] 

1436 

1437 classification = self._classify_complex_inputs(operands) 

1438 

1439 if strategy.mode == ComplexMode.CROSS and classification == "all_complex": 

1440 return self._call_complex_cross(operands, result_dtype) 

1441 elif classification in ("all_complex", "mixed"): 

1442 return self._call_complex_elementwise( 

1443 operands, others, result_dtype, kwargs 

1444 ) 

1445 else: 

1446 new_args = self._merge_args(operands, others) 

1447 return self._call_real_impl(*new_args, **kwargs) 

1448 

1449 def _call_complex_elementwise(self, operands, others, result_dtype, kwargs): 

1450 """Elementwise: view_as_real -> call real kernel -> view_as_complex.""" 

1451 real_tensors = {i: torch.view_as_real(t) for i, t in operands.items()} 

1452 

1453 # promote to common real dtype 

1454 dtypes = [t.dtype for t in real_tensors.values()] 

1455 common_dtype = dtypes[0] 

1456 for d in dtypes[1:]: 

1457 common_dtype = torch.promote_types(common_dtype, d) 

1458 real_tensors = {i: t.to(common_dtype) for i, t in real_tensors.items()} 

1459 

1460 new_args = self._merge_args(real_tensors, others) 

1461 out_real = self._call_real_impl(*new_args, **kwargs) 

1462 return torch.view_as_complex(out_real.contiguous()).to(result_dtype) 

1463 

1464 def _call_complex_cross(self, operands, result_dtype): 

1465 """Cross-term: split ar/ai/br/bi -> call cross_kernel -> stack -> view_as_complex.""" 

1466 sorted_keys = sorted(operands.keys()) 

1467 A, B = operands[sorted_keys[0]], operands[sorted_keys[1]] 

1468 Ar = torch.view_as_real(A) 

1469 Br = torch.view_as_real(B) 

1470 ar, ai = Ar[..., 0], Ar[..., 1] 

1471 br, bi = Br[..., 0], Br[..., 1] 

1472 

1473 common_dtype = torch.promote_types(ar.dtype, br.dtype) 

1474 ar, ai = ar.to(common_dtype), ai.to(common_dtype) 

1475 br, bi = br.to(common_dtype), bi.to(common_dtype) 

1476 

1477 real, imag = self.complex_strategy.cross_kernel(ar, ai, br, bi) 

1478 

1479 out = torch.stack((real, imag), dim=-1) 

1480 return torch.view_as_complex(out.contiguous()).to(result_dtype) 

1481 

1482 @staticmethod 

1483 def use_fast_path(tensors): 

1484 return all_the_same_shape(tensors) and ( 

1485 all_c_contiguous(tensors) 

1486 or ( 

1487 all_the_same_stride(tensors) 

1488 and torch.ops.aten.is_non_overlapping_and_dense(tensors[0]) 

1489 ) 

1490 ) 

1491 

1492 def prepare_args(self, *args, _skip_tensor_check=False, **kwargs): 

1493 # output allocation(when needed) 

1494 # task simplification & task-rank infernece & input-output reinterpretation 

1495 schema = self.fx 

1496 outputs_that_need_allocation: List[int] = [] 

1497 out_tensors = [] 

1498 for i in range(schema.num_output_tensors()): 

1499 k = f"out{i}" 

1500 if k in kwargs: 

1501 out_tensors.append(kwargs[k]) 

1502 else: 

1503 outputs_that_need_allocation.append(i) 

1504 # input arguments must be passed by position 

1505 if not _skip_tensor_check and schema._is_tensor is not None: 

1506 if not check_tensor_attributes(args, (schema._is_tensor)): 

1507 raise ValueError( 

1508 "Input arguments must be passed by position, and the corresponding dtype must be specified." 

1509 ) 

1510 in_tensors = [item for i, item in enumerate(args) if schema.is_tensor(i)] 

1511 

1512 # output dtype promotions 

1513 outputs_dtypes_for_allocation = [] 

1514 for i in outputs_that_need_allocation: 

1515 *arg_indices, method = schema._promotion_methods[i] 

1516 promote_args = (args[j] for j in arg_indices) 

1517 _, dtype = type_promotion(*promote_args, type_promotion=method) 

1518 outputs_dtypes_for_allocation.append(dtype) 

1519 

1520 tensors = out_tensors + in_tensors 

1521 INT32_MAX = torch.iinfo(torch.int32).max 

1522 if tensors[0].numel() > INT32_MAX: 

1523 self.config.prefer_block_pointer = False 

1524 if self.use_fast_path(tensors): # dimension collapse & use physical ordering 

1525 allocated_outputs = [ 

1526 torch.empty_like(tensors[0], dtype=dtype) 

1527 for dtype in outputs_dtypes_for_allocation 

1528 ] 

1529 task_shape = (tensors[0].numel(),) 

1530 strides = (1,) 

1531 ndim = 1 

1532 args = tuple( 

1533 ( 

1534 StridedBuffer(item, task_shape, strides) 

1535 if schema.is_tensor(i) 

1536 else item 

1537 ) 

1538 for i, item in enumerate(args) 

1539 ) 

1540 kwargs = { 

1541 k: StridedBuffer(item, task_shape, strides) 

1542 for k, item in kwargs.items() 

1543 } 

1544 for seq_id, output_id in enumerate(outputs_that_need_allocation): 

1545 kwargs[f"out{output_id}"] = StridedBuffer( 

1546 allocated_outputs[seq_id], task_shape, strides 

1547 ) 

1548 else: 

1549 # a simple strategy: all the undefined tensors will follow the first 

1550 # tensor that is not broadcated, no attempts to simplify task, no reordering, 

1551 # no dimenion collapsing 

1552 shapes = tuple(item.shape for item in in_tensors) 

1553 

1554 task_shape = broadcast_shapes(shapes) 

1555 

1556 if out_tensors: 

1557 for index, item in enumerate(out_tensors): 

1558 if list(item.shape) != list(task_shape): 

1559 raise RuntimeError( 

1560 f"out tensor at index {index} shape is invalid, should be {task_shape} but is {item.shape}!" 

1561 ) 

1562 # output arguments must not have internal overlapping for pointwise operation 

1563 if has_internal_overlapping(item) == MemOverlap.Yes: 

1564 raise RuntimeError( 

1565 "Pointwise Input arguments should not have internal overlapping." 

1566 ) 

1567 

1568 ndim = len(task_shape) 

1569 for item in tensors: 

1570 if item.shape == task_shape: 

1571 allocated_outputs = [ 

1572 torch.empty_like(item, dtype=dtype) 

1573 for dtype in outputs_dtypes_for_allocation 

1574 ] 

1575 break 

1576 else: # nobreak 

1577 device = tensors[0].device 

1578 allocated_outputs = [ 

1579 torch.empty(task_shape, dtype=dtype, device=device) 

1580 for dtype in outputs_dtypes_for_allocation 

1581 ] 

1582 args = tuple( 

1583 ( 

1584 StridedBuffer( 

1585 item, 

1586 task_shape, 

1587 broadcasted_stride(item.shape, item.stride(), task_shape), 

1588 ) 

1589 if schema.is_tensor(i) 

1590 else item 

1591 ) 

1592 for i, item in enumerate(args) 

1593 ) 

1594 kwargs = { 

1595 k: StridedBuffer( 

1596 item, 

1597 task_shape, 

1598 broadcasted_stride(item.shape, item.stride(), task_shape), 

1599 ) 

1600 for k, item in kwargs.items() 

1601 } 

1602 for seq_id, output_id in enumerate(outputs_that_need_allocation): 

1603 item = allocated_outputs[seq_id] 

1604 kwargs[f"out{output_id}"] = StridedBuffer( 

1605 item, 

1606 task_shape, 

1607 broadcasted_stride(item.shape, item.stride(), task_shape), 

1608 ) 

1609 return (ndim, args, kwargs) 

1610 

1611 def _unwrap(self, tensors): 

1612 # unwrap StridedBuffer to get Tensor 

1613 if self.fx.num_output_tensors() == 1: 

1614 item = tensors 

1615 return item.unwrap() 

1616 return tuple(item.unwrap() for item in tensors) 

1617 

1618 def _compute_kernel_names(self, ndim: int) -> Tuple[str, str, str]: 

1619 """Compute kernel name, wrapper name, and file path for a given ndim. 

1620 

1621 This is the single source of truth for naming, used by both instantiate() 

1622 and get_kernel_info() to ensure consistency. 

1623 

1624 Returns: 

1625 Tuple of (kernel_name, wrapper_name, file_path) 

1626 """ 

1627 scalar_fn_name = self._scalar_fn.__name__ 

1628 kernel_name = f"{scalar_fn_name}_kernel_rank_{ndim}" 

1629 wrapper_name = f"{scalar_fn_name}_wrapper_rank_{ndim}" 

1630 

1631 file_name = ( 

1632 f"pointwise_dynamic_{self._scalar_fn_cache_key}_{kernel_name}_" 

1633 f"{'1d_tile_' if self.config.prefer_1d_tile else ''}" 

1634 f"{'bptr' if (not self.config.prefer_1d_tile and self.config.prefer_block_pointer) else ''}" 

1635 ".py" 

1636 ) 

1637 file_path = str(code_cache_dir() / file_name) 

1638 

1639 return kernel_name, wrapper_name, file_path 

1640 

1641 def instantiate(self, ndim): 

1642 # NOTE: manually instantiated overload does not have `prepare_args` as 

1643 # preprocessing, so you have to manually allocate output and make sure that 

1644 # the inputs & ouputs actually fits the manually instantiated overload 

1645 key = f"{ndim}_{self.config.prefer_block_pointer}" 

1646 if key in self.overloads: 

1647 return self.overloads[key] 

1648 

1649 code = IndentedBuffer() 

1650 

1651 # Use helper to compute names (single source of truth) 

1652 kernel_name, wrapper_name, file_path = self._compute_kernel_names(ndim) 

1653 

1654 module_gen = ModuleGenerator( 

1655 self.fx, 

1656 self._scalar_fn, 

1657 ndim, 

1658 kernel_name, 

1659 wrapper_name, 

1660 self.config, 

1661 ) 

1662 module_gen.codegen(code) 

1663 

1664 # NOTE: [why write the generated code to a file] 

1665 # triton uses inpsect to get the source of the jitted function, which requires 

1666 # that the source code can be found by inspect 

1667 # We write it into a file, since inspect cannot find the source of functions dynamically 

1668 # created via exec string. We can help inspect to find the source by hacking linecache 

1669 # library, but we find generating a module simpler, since we can generating 2 functions 

1670 # the kernel and the wrapper, and the wrapper calls the kernel. 

1671 write_atomic(file_path, code.getvalue()) 

1672 

1673 # load 

1674 spec = importlib.util.spec_from_file_location( 

1675 f"_gen_module_{self._scalar_fn_cache_key}_rank_{ndim}", 

1676 file_path, 

1677 ) 

1678 m = importlib.util.module_from_spec(spec) 

1679 # do not expose it to sys.modules 

1680 # sys.modules["_add_module"] = m 

1681 

1682 # NOTE: [why not import the scalar function] 

1683 # we do not re-import the scalar function, although the generated kernel **calls** it 

1684 # Since a function's __name__ may be changed, from the module where it is defined import its 

1685 # __name__ is not same; Also the same may be rebind to something else, importing via name 

1686 # cannot guarantee that scalar function is imported. 

1687 # So we copy the scalar function and its __globals__ to the generated module to do this 

1688 # https://stackoverflow.com/questions/11170949/how-to-make-a-copy-of-a-python-module-at-runtime 

1689 spec.loader.exec_module(m) 

1690 m.__dict__.update(self._scalar_fn.__globals__) 

1691 m.__dict__[self._scalar_fn.__name__] = self._scalar_fn 

1692 

1693 overload = getattr(m, wrapper_name) 

1694 self.overloads[key] = overload 

1695 

1696 # Cache kernel info for C++ integration 

1697 self._kernel_info_cache[key] = KernelInfo( 

1698 file_path=file_path, 

1699 kernel_name=kernel_name, 

1700 wrapper_name=wrapper_name, 

1701 ndim=ndim, 

1702 ) 

1703 

1704 return overload 

1705 

1706 def get_kernel_info(self, ndim: int) -> KernelInfo: 

1707 """Get kernel information for a given ndim. 

1708 

1709 This method is useful for C++ integration to get the file path and 

1710 kernel name without duplicating the naming logic. 

1711 

1712 If the kernel hasn't been instantiated yet, this will instantiate it first. 

1713 

1714 Args: 

1715 ndim: The rank of the task space 

1716 

1717 Returns: 

1718 KernelInfo with file_path, kernel_name, wrapper_name, and ndim 

1719 """ 

1720 key = f"{ndim}_{self.config.prefer_block_pointer}" 

1721 

1722 # Ensure the kernel is instantiated 

1723 if key not in self._kernel_info_cache: 

1724 self.instantiate(ndim) 

1725 

1726 return self._kernel_info_cache[key] 

1727 

1728 

1729def pointwise_dynamic( 

1730 f: Optional[JITFunction] = None, 

1731 *, 

1732 num_inputs: Optional[int] = None, 

1733 is_tensor: Optional[List[bool]] = None, 

1734 dtypes: Optional[List[Optional[type]]] = None, 

1735 num_outputs: Optional[int] = None, 

1736 promotion_methods: Optional[Tuple[int, ...]] = None, 

1737 config: Optional[CodeGenConfig] = None, 

1738): 

1739 def decorator(fn): 

1740 nonlocal num_inputs 

1741 if (num_inputs is None) and (is_tensor is None) and (dtypes is None): 

1742 num_inputs = len(fn.arg_names) 

1743 op_desc = FunctionSchema( 

1744 num_inputs=num_inputs, 

1745 is_tensor=is_tensor, 

1746 dtypes=dtypes, 

1747 num_outputs=num_outputs, 

1748 promotion_methods=promotion_methods, 

1749 ) 

1750 return PointwiseDynamicFunction(op_desc, fn, config) 

1751 

1752 if f is not None: 

1753 return decorator(f) 

1754 return decorator