Coverage for src/flag_gems/runtime/backend/_cambricon/utils/pointwise_dynamic.py: 0%

1111 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import importlib 

2import os 

3from dataclasses import dataclass 

4from typing import Callable, Iterable, List, Mapping, Optional, Sequence, Tuple 

5 

6import torch 

7import triton 

8from triton.runtime.jit import JITFunction 

9 

10from flag_gems.utils.code_cache import code_cache_dir 

11from flag_gems.utils.code_utils import IndentedBuffer, write_atomic 

12from flag_gems.utils.codegen_config_utils import CodeGenConfig, get_codegen_config 

13from flag_gems.utils.shape_utils import ( 

14 MemOverlap, 

15 all_c_contiguous, 

16 all_the_same_shape, 

17 all_the_same_stride, 

18 broadcast_shapes, 

19 broadcasted_stride, 

20 check_tensor_attributes, 

21 has_internal_overlapping, 

22) 

23from flag_gems.utils.tensor_wrapper import StridedBuffer 

24from flag_gems.utils.type_utils import ELEMENTWISE_TYPE_PROMOTION_KIND, type_promotion 

25 

26 

27# ------------------ Operation Description --------------------------- 

28def _type_name(type) -> str: 

29 "Render typename as string, work for both (bool, int, float, str) and torch.dtype object" 

30 if type in (bool, int, float, str): 

31 return type.__name__ 

32 if isinstance(type, torch.dtype): 

33 return str(type) 

34 return str(type) 

35 

36 

37def _check_typed_list(container, type): 

38 for item in container: 

39 assert isinstance(item, type) 

40 

41 

42def _check_sized_list(container, size): 

43 assert len(container) == size 

44 

45 

46def _tuple_content(strings: Sequence[str]) -> str: 

47 # comma separated list 

48 if len(strings) == 0: 

49 return "" 

50 if len(strings) == 1: 

51 return f"{strings[0]}," 

52 else: 

53 return ", ".join(strings) 

54 

55 

56def _cs(strings: Iterable[str]) -> str: 

57 return ", ".join(strings) 

58 

59 

60def _broadcast_vec(i, ndim): 

61 axes = [":" if j == i else "None" for j in range(ndim)] 

62 return f"[{_cs(axes)}]" 

63 

64 

65class FunctionSchema: 

66 _num_inputs: int 

67 _is_tensor: List[bool] 

68 _dtypes: List[Optional[type]] 

69 

70 _num_input_tensors: int 

71 _num_non_tensor_inputs: int 

72 

73 _num_outputs: int 

74 _promotion_methods: List[Tuple[int, ...]] 

75 

76 def __init__( 

77 self, 

78 *, 

79 num_inputs: Optional[int] = None, 

80 is_tensor: Optional[List[bool]] = None, 

81 dtypes: Optional[List[Optional[type]]] = None, 

82 num_outputs: Optional[int] = None, 

83 promotion_methods=None, 

84 ): 

85 if is_tensor is not None: 

86 _check_typed_list(is_tensor, bool) 

87 if dtypes is not None: 

88 _check_typed_list(dtypes, (type, type(None))) 

89 

90 if promotion_methods is None: 

91 raise ValueError( 

92 "No type promotion method provided! You must provide type promotion method for each output!" 

93 ) 

94 else: 

95 self._promotion_methods = self.canonicalize_promotion_methods( 

96 promotion_methods 

97 ) 

98 if num_inputs is not None: 

99 self._num_inputs = num_inputs 

100 if is_tensor is not None: 

101 _check_sized_list(is_tensor, num_inputs) 

102 self._is_tensor = is_tensor 

103 else: 

104 self._is_tensor = [True] * num_inputs 

105 

106 if dtypes is not None: 

107 _check_sized_list(dtypes, num_inputs) 

108 self._dtypes = dtypes 

109 else: 

110 self._dtypes = [None] * num_inputs 

111 elif is_tensor is not None: 

112 self._num_inputs = len(is_tensor) 

113 self._is_tensor = is_tensor 

114 if dtypes is not None: 

115 _check_sized_list(dtypes, self._num_inputs) 

116 self._dtypes = dtypes 

117 else: 

118 self._dtypes = [None] * self._num_inputs 

119 elif dtypes is not None: 

120 self._num_inputs = len(dtypes) 

121 self._dtypes = dtypes 

122 if is_tensor is not None: 

123 _check_sized_list(is_tensor, self._num_inputs) 

124 self._is_tensor = is_tensor 

125 else: 

126 self._is_tensor = [item is None for item in dtypes] 

127 else: 

128 raise ValueError( 

129 "Cannot create FunctionSchema when none of (num_inputs, is_tensor, dtypes) is specified." 

130 ) 

131 

132 if num_outputs is not None: 

133 self._num_outputs = num_outputs 

134 _check_sized_list(promotion_methods, num_outputs) 

135 else: 

136 self._num_outputs = len(promotion_methods) 

137 

138 assert self._num_inputs >= 1 

139 assert self._num_outputs >= 1 

140 

141 self._num_input_tensors = sum(self._is_tensor) 

142 self._num_non_tensor_inputs = self._num_inputs - self._num_input_tensors 

143 self._input_id = self._compute_input_id() 

144 

145 @staticmethod 

146 def canonicalize_promotion_methods(promotion_methods): 

147 canonicalized = [] 

148 for item in promotion_methods: 

149 *arg_indices, method = item 

150 canonicalized.append( 

151 (*arg_indices, ELEMENTWISE_TYPE_PROMOTION_KIND[method]) 

152 ) 

153 return canonicalized 

154 

155 def num_inputs(self): 

156 # num of arguments, outputs not included 

157 return self._num_inputs 

158 

159 def num_outputs(self): 

160 return self._num_outputs 

161 

162 def is_tensor(self, arg_id: int) -> bool: 

163 return self._is_tensor[arg_id] 

164 

165 def input_type(self, arg_id) -> Optional[type]: 

166 return self._dtypes[arg_id] 

167 

168 def output_type(self, i): 

169 return self._promotion_methods[i] 

170 

171 def num_input_tensors(self) -> int: 

172 return self._num_input_tensors 

173 

174 def num_output_tensors(self) -> int: 

175 return self._num_outputs 

176 

177 def num_non_tensor_args(self) -> int: 

178 return self._num_non_tensor_inputs 

179 

180 def signature(self, outputs_in_arg: bool = False) -> str: 

181 input_types = [] 

182 for is_tensor, dtype in zip(self._is_tensor, self._dtypes): 

183 if is_tensor: 

184 input_types.append("StridedBuffer") 

185 else: 

186 if dtype is None: 

187 input_types.append("scalar") 

188 else: 

189 input_types.append(_type_name(dtype)) 

190 

191 output_types = [] 

192 

193 if outputs_in_arg: 

194 for i in range(self.num_outputs()): 

195 output_types.append(f"StridedBuffer(a{1}!)") 

196 input_types.extend(output_types) 

197 else: 

198 for _ in range(self.num_outputs()): 

199 output_types.append("StridedBuffer") 

200 sig = f"Pointwise: {', '.join(input_types)} -> {', '.join(output_types)}" 

201 return sig 

202 

203 def _compute_input_id(self): 

204 input_tensor_index = 0 

205 non_tensor_index = 0 

206 mapping: List[int] = [] 

207 for i in range(self.num_inputs()): 

208 if self.is_tensor(i): 

209 mapping.append(input_tensor_index) 

210 input_tensor_index += 1 

211 else: 

212 mapping.append(non_tensor_index) 

213 non_tensor_index += 1 

214 return mapping 

215 

216 def input_index(self, idx): 

217 return self._input_id[idx] 

218 

219 def __str__(self) -> str: 

220 return self.signature(outputs_in_arg=False) 

221 

222 

223class KernelGenerator: 

224 def __init__( 

225 self, 

226 function_schema: FunctionSchema, 

227 scalar_fn: triton.JITFunction, 

228 rank: int, 

229 name: str, 

230 config: CodeGenConfig, 

231 ): 

232 self.fx = function_schema 

233 self.fn = scalar_fn 

234 self.ndim = rank 

235 self.name = name 

236 self.config = config 

237 

238 self.fn_name = scalar_fn.__name__ 

239 self.fn_module = scalar_fn.__module__ 

240 

241 def gen_import_function(self, code: IndentedBuffer): 

242 code.writeline("@triton.jit") 

243 code.writemultiline(self.fn.src) 

244 

245 def gen_config_prune(self, code): 

246 code.newline() 

247 code.newline() 

248 code.writeline("def config_prune(configs, named_args, **kwargs):") 

249 with code.indent(): 

250 code.writeline("new_configs = []") 

251 code.writeline("elem_sizes = []") 

252 for i in range(self.fx.num_input_tensors()): 

253 code.writeline( 

254 f"elem_sizes.append(named_args['in{i}_ptr'].dtype.itemsize)" 

255 ) 

256 for i in range(self.fx.num_output_tensors()): 

257 code.writeline( 

258 f"elem_sizes.append(named_args['out{i}_ptr'].dtype.itemsize)" 

259 ) 

260 

261 code.writeline("max_elem_size = max(elem_sizes)") 

262 shape = ", ".join(f"s{i}" for i in range(self.ndim)) 

263 named_shape = ", ".join(f"named_args['s{i}']" for i in range(self.ndim)) 

264 code.writeline(f"{shape} = {named_shape}") 

265 tile_sizes = ", ".join(f"tile_size{i}" for i in range(self.ndim)) 

266 tile_size_dict = ", ".join( 

267 f"'tile_size{i}': tile_size{i}" for i in range(self.ndim) 

268 ) 

269 

270 code.writeline("if max_elem_size < 8:") 

271 with code.indent(): 

272 code.writeline("max_tile_sizes = [1024, 2048, 4096, 8192, 16000]") 

273 code.writeline("for max_tile_size in max_tile_sizes:") 

274 with code.indent(): 

275 code.writeline( 

276 f"({tile_sizes}, ) = heuristics_for_tile_size(max_tile_size, {shape})" 

277 ) 

278 code.writeline( 

279 f"new_configs.append(triton.Config({{{tile_size_dict}}}, num_stages=3, num_warps=1))" 

280 ) 

281 code.writeline("else:") 

282 with code.indent(): 

283 code.writeline("max_tile_sizes = [1024, 2048, 4096, 8000]") 

284 code.writeline("for max_tile_size in max_tile_sizes:") 

285 with code.indent(): 

286 code.writeline( 

287 f"({tile_sizes}, ) = heuristics_for_tile_size(max_tile_size, {shape})" 

288 ) 

289 code.writeline( 

290 f"new_configs.append(triton.Config({{{tile_size_dict}}}, num_stages=3, num_warps=1))" 

291 ) 

292 

293 code.writeline("return new_configs") 

294 

295 def gen_hooks(self, code): 

296 code.newline() 

297 code.newline() 

298 code.writeline("restore_copies = {}") 

299 code.writeline( 

300 "KEYSET = torch._C.DispatchKeySet(torch._C.DispatchKey.PrivateUse1)" 

301 ) 

302 code.writeline("def pre_hook(kwargs, reset_only=False):") 

303 with code.indent(): 

304 code.writeline("if not reset_only:") 

305 with code.indent(): 

306 code.writeline( 

307 "torch_copy_ = flag_gems.current_work_registrar.torch_ops_map['aten::copy_']" 

308 ) 

309 code.writeline(f"for name in {self.name}.fn.restore_value:") 

310 with code.indent(): 

311 code.writeline("restore_copy = torch.empty_like(kwargs[name])") 

312 code.writeline( 

313 "restore_copies[name] = torch_copy_.call_boxed(KEYSET, restore_copy, kwargs[name])" 

314 ) 

315 

316 code.writeline("def post_hook(kwargs, exception):") 

317 with code.indent(): 

318 code.writeline(f"for name in {self.name}.fn.restore_value:") 

319 with code.indent(): 

320 code.writeline( 

321 "torch_copy_ = flag_gems.current_work_registrar.torch_ops_map['aten::copy_']" 

322 ) 

323 code.writeline( 

324 "kwargs[name] = torch_copy_.call_boxed(KEYSET, kwargs[name], restore_copies[name])" 

325 ) 

326 

327 def gen_decorators(self, code): 

328 if self.ndim in [1, 2, 3, 4] and (not self.config.prefer_1d_tile): 

329 self.gen_config_prune(code) 

330 

331 if self.fn_name == "_copy_kernel": 

332 self.gen_hooks(code) 

333 

334 num_non_tensor_args = self.fx.num_non_tensor_args() 

335 if num_non_tensor_args > 0: 

336 non_tensor_arg_names = ", ".join( 

337 f"'val{i}'" for i in range(num_non_tensor_args) 

338 ) 

339 

340 shapes = ", ".join(f"'s{i}'" for i in range(self.ndim)) 

341 stride_args = [] 

342 for i in range(self.fx.num_input_tensors()): 

343 stride_args.append(_cs(f"'in{i}_stride{j}'" for j in range(self.ndim))) 

344 for i in range(self.fx.num_output_tensors()): 

345 stride_args.append(_cs(f"'out{i}_stride{j}'" for j in range(self.ndim))) 

346 

347 code.writeline("@libentry()") 

348 if self.ndim == 1 and (not self.config.prefer_1d_tile): 

349 code.writeline("@libtuner(") 

350 with code.indent(): 

351 code.writeline("configs=[") 

352 with code.indent(): 

353 code.writeline( 

354 "triton.Config({'tile_size0': 1024}, num_stages=3, num_warps=1)," 

355 ) 

356 code.writeline( 

357 "triton.Config({'tile_size0': 2048}, num_stages=3, num_warps=1)," 

358 ) 

359 code.writeline("],") 

360 if num_non_tensor_args > 0: 

361 code.writeline( 

362 f"key=['num_tasks', {_cs(stride_args)}, {non_tensor_arg_names}]," 

363 ) 

364 else: 

365 code.writeline(f"key=['num_tasks', {_cs(stride_args)}],") 

366 code.writeline("prune_configs_by={'early_config_prune': config_prune},") 

367 output_params = [ 

368 f"out{i}_ptr" for i in range(self.fx.num_output_tensors()) 

369 ] 

370 output_elements = ", ".join(f"'{name}'" for name in output_params) 

371 code.writeline(f"restore_value=[{output_elements}],") 

372 if self.fn_name == "_copy_kernel": 

373 code.writeline("pre_hook=pre_hook,") 

374 code.writeline("post_hook=post_hook,") 

375 code.writeline(")") 

376 

377 if self.ndim == 2 and (not self.config.prefer_1d_tile): 

378 code.writeline("@libtuner(") 

379 with code.indent(): 

380 code.writeline("configs=[") 

381 with code.indent(): 

382 code.writeline( 

383 "triton.Config({'tile_size0': 1, 'tile_size1': 1024}, num_stages=3, num_warps=1)," 

384 ) 

385 code.writeline( 

386 "triton.Config({'tile_size0': 1, 'tile_size1': 2048}, num_stages=3, num_warps=1)," 

387 ) 

388 code.writeline("],") 

389 if num_non_tensor_args > 0: 

390 code.writeline( 

391 f"key=[{shapes}, {_cs(stride_args)}, {non_tensor_arg_names}]," 

392 ) 

393 else: 

394 code.writeline(f"key=[{shapes}, {_cs(stride_args)}],") 

395 code.writeline("prune_configs_by={'early_config_prune': config_prune},") 

396 output_params = [ 

397 f"out{i}_ptr" for i in range(self.fx.num_output_tensors()) 

398 ] 

399 output_elements = ", ".join(f"'{name}'" for name in output_params) 

400 code.writeline(f"restore_value=[{output_elements}],") 

401 if self.fn_name == "_copy_kernel": 

402 code.writeline("pre_hook=pre_hook,") 

403 code.writeline("post_hook=post_hook,") 

404 code.writeline(")") 

405 

406 if self.ndim == 3 and (not self.config.prefer_1d_tile): 

407 code.writeline("@libtuner(") 

408 with code.indent(): 

409 code.writeline("configs=[") 

410 with code.indent(): 

411 code.writeline( 

412 """ 

413 triton.Config({'tile_size0': 1, 'tile_size1': 1, 'tile_size2': 1024}, num_stages=3, num_warps=1), 

414 """ 

415 ) 

416 code.writeline( 

417 """ 

418 triton.Config({'tile_size0': 1, 'tile_size1': 1, 'tile_size2': 2048}, num_stages=3, num_warps=1), 

419 """ 

420 ) 

421 code.writeline("],") 

422 if num_non_tensor_args > 0: 

423 code.writeline( 

424 f"key=[{shapes}, {_cs(stride_args)}, {non_tensor_arg_names}]," 

425 ) 

426 else: 

427 code.writeline(f"key=[{shapes}, {_cs(stride_args)}],") 

428 code.writeline("prune_configs_by={'early_config_prune': config_prune},") 

429 output_params = [ 

430 f"out{i}_ptr" for i in range(self.fx.num_output_tensors()) 

431 ] 

432 output_elements = ", ".join(f"'{name}'" for name in output_params) 

433 code.writeline(f"restore_value=[{output_elements}],") 

434 if self.fn_name == "_copy_kernel": 

435 code.writeline("pre_hook=pre_hook,") 

436 code.writeline("post_hook=post_hook,") 

437 code.writeline(")") 

438 

439 if self.ndim == 4 and (not self.config.prefer_1d_tile): 

440 code.writeline("@libtuner(") 

441 with code.indent(): 

442 code.writeline("configs=[") 

443 with code.indent(): 

444 code.writeline( 

445 """ 

446 triton.Config({'tile_size0': 1,'tile_size1': 1,'tile_size2': 1,'tile_size3': 1024},num_stages=3,num_warps=1), 

447 """ 

448 ) 

449 code.writeline( 

450 """ 

451 triton.Config({'tile_size0': 1,'tile_size1': 1,'tile_size2': 1,'tile_size3': 2048},num_stages=3,num_warps=1), 

452 """ 

453 ) 

454 code.writeline("],") 

455 if num_non_tensor_args > 0: 

456 code.writeline( 

457 f"key=[{shapes}, {_cs(stride_args)}, {non_tensor_arg_names}]," 

458 ) 

459 else: 

460 code.writeline(f"key=[{shapes}, {_cs(stride_args)}],") 

461 code.writeline("prune_configs_by={'early_config_prune': config_prune},") 

462 output_params = [ 

463 f"out{i}_ptr" for i in range(self.fx.num_output_tensors()) 

464 ] 

465 output_elements = ", ".join(f"'{name}'" for name in output_params) 

466 code.writeline(f"restore_value=[{output_elements}],") 

467 if self.fn_name == "_copy_kernel": 

468 code.writeline("pre_hook=pre_hook,") 

469 code.writeline("post_hook=post_hook,") 

470 code.writeline(")") 

471 

472 if num_non_tensor_args > 0: 

473 # we do not specialize non tensor args since they are passed into the inlined function 

474 # which means that their values may not deserve specialization 

475 non_specialize_arg_names = [f"val{i}" for i in range(num_non_tensor_args)] 

476 code.writeline(f"@triton.jit(do_not_specialize={non_specialize_arg_names})") 

477 else: 

478 code.writeline("@triton.jit") 

479 

480 def input_name(self, i): 

481 is_tensor = self.fx.is_tensor(i) 

482 name = "in" if is_tensor else "val" 

483 index = self.fx.input_index(i) 

484 return f"{name}{index}" 

485 

486 def output_name(self, i): 

487 return f"out{i}" 

488 

489 def gen_signature(self, code, with_block_pointer=False): 

490 code.writeline(f"def {self.name}(") 

491 with code.indent(): 

492 input_tensor_index = 0 

493 non_tensor_index = 0 

494 output_tensor_index = 0 

495 

496 schema = self.fx 

497 # signature: inputs ptrs & non tensor inputs 

498 for i in range(schema.num_inputs()): 

499 if schema.is_tensor(i): 

500 code.writeline( 

501 f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" 

502 ) 

503 input_tensor_index += 1 

504 else: 

505 if schema.input_type(i) is not None: 

506 code.writeline( 

507 f"val{non_tensor_index}: {_type_name(schema.input_type(i))}," 

508 ) 

509 else: 

510 code.writeline(f"val{non_tensor_index},") 

511 non_tensor_index += 1 

512 

513 # signature: output ptrs 

514 for i in range(schema.num_outputs()): 

515 code.writeline( 

516 f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" 

517 ) 

518 output_tensor_index += 1 

519 

520 # signature: strides, for each tensor arguments 

521 ndim = self.ndim 

522 if ndim > 0: 

523 # strides for inputs 

524 for i in range(schema.num_input_tensors()): 

525 stride_args = _cs(f"in{i}_stride{j}: int" for j in range(ndim)) 

526 code.writeline(f"{stride_args}, # strides for in{i}") 

527 if with_block_pointer: 

528 stride_order_args = _cs( 

529 f"in{i}_stride_order{j}: tl.constexpr" for j in range(ndim) 

530 ) 

531 code.writeline(f"{stride_order_args}, # stride order for in{i}") 

532 zero_stride_args = _cs( 

533 f"in{i}_zero_stride{j}: tl.constexpr" for j in range(ndim) 

534 ) 

535 code.writeline( 

536 f"{zero_stride_args}, # zero stride flag for in{i}" 

537 ) 

538 

539 # strides for outputs 

540 for i in range(schema.num_output_tensors()): 

541 stride_args = _cs(f"out{i}_stride{j}: int" for j in range(ndim)) 

542 code.writeline(f"{stride_args}, # strides for out{i}") 

543 if with_block_pointer: 

544 stride_order_args = _cs( 

545 f"out{i}_stride_order{j}: tl.constexpr" for j in range(ndim) 

546 ) 

547 code.writeline( 

548 f"{stride_order_args}, # stride order for out{i}" 

549 ) 

550 zero_stride_args = _cs( 

551 f"out{i}_zero_stride{j}: tl.constexpr" for j in range(ndim) 

552 ) 

553 code.writeline( 

554 f"{zero_stride_args}, # zero stride flag for out{i}" 

555 ) 

556 

557 # task space, used to reconstruct multi index 

558 task_space_args = _cs(f"s{i}" for i in range(ndim)) 

559 code.writeline(f"{task_space_args}, # task_space") 

560 

561 # number of tasks, used to compute mask 

562 code.writeline("num_tasks,") 

563 if self.config.prefer_block_pointer: 

564 code.writeline("FALLBACK_BPTR: tl.constexpr,") 

565 

566 # tile size & tiles_per_cta, gsl style 

567 if ndim > 0: 

568 tile_sizes = _cs(f"tile_size{i}: tl.constexpr" for i in range(ndim)) 

569 code.writeline(f"{tile_sizes},") 

570 if ndim > 4: 

571 code.writeline("tiles_per_cta: int,") 

572 code.writeline("one_tile_per_cta: tl.constexpr,") 

573 code.writeline("):") 

574 

575 def gen_signature_1d_tile(self, code): 

576 code.writeline(f"def {self.name}(") 

577 with code.indent(): 

578 input_tensor_index = 0 

579 non_tensor_index = 0 

580 output_tensor_index = 0 

581 

582 schema = self.fx 

583 # signature: inputs ptrs & non tensor inputs 

584 for i in range(schema.num_inputs()): 

585 if schema.is_tensor(i): 

586 code.writeline( 

587 f"in{input_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" 

588 ) 

589 input_tensor_index += 1 

590 else: 

591 if schema.input_type(i) is not None: 

592 code.writeline( 

593 f"val{non_tensor_index}: {_type_name(schema.input_type(i))}," 

594 ) 

595 else: 

596 code.writeline(f"val{non_tensor_index},") 

597 non_tensor_index += 1 

598 

599 # signature: output ptrs 

600 for i in range(schema.num_outputs()): 

601 code.writeline( 

602 f"out{output_tensor_index}_ptr: tl.tensor, # of tl.pointer_type" 

603 ) 

604 output_tensor_index += 1 

605 

606 # signature: strides, for each tensor arguments 

607 ndim = self.ndim 

608 if ndim > 0: 

609 # strides for inputs 

610 for i in range(schema.num_input_tensors()): 

611 stride_args = _cs(f"in{i}_stride{j}: int" for j in range(ndim)) 

612 code.writeline(f"{stride_args}, # strides for in{i}") 

613 

614 # strides for outputs 

615 for i in range(schema.num_output_tensors()): 

616 stride_args = _cs(f"out{i}_stride{j}: int" for j in range(ndim)) 

617 code.writeline(f"{stride_args}, # strides for out{i}") 

618 

619 # task space, used to reconstruct multi index 

620 task_space_args = _cs(f"s{i}" for i in range(ndim)) 

621 code.writeline(f"{task_space_args}, # task_space") 

622 

623 # number of tasks, used to compute mask 

624 code.writeline("num_tasks,") 

625 

626 if self.config.prefer_block_pointer: 

627 code.writeline("FALLBACK_BPTR: tl.constexpr,") 

628 

629 # tile size & tiles_per_cta, gsl style 

630 if ndim > 0: 

631 code.writeline("tiles_per_cta: int,") 

632 code.writeline("tile_size: tl.constexpr,") 

633 code.writeline("one_tile_per_cta: tl.constexpr,") 

634 code.writeline("):") 

635 

636 def gen_num_tiles(self, code): 

637 # tile-grid size 

638 ndim = self.ndim 

639 for i in range(ndim): 

640 if i < ndim: 

641 code.writeline(f"num_tiles{i} = tl.cdiv(s{i}, tile_size{i})") 

642 

643 def gen_body_for_0d(self, code): 

644 schema = self.fx 

645 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())] 

646 outputs_to_scalar_fn = [ 

647 self.output_name(i) for i in range(schema.num_output_tensors()) 

648 ] 

649 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

650 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

651 

652 code.writeline("# loads") 

653 for i in range(schema.num_input_tensors()): 

654 code.writeline( 

655 f"in{i} = tl.load(in{i}_ptr).to(in{i}_ptr.type.element_ty) " 

656 "# workaround the bug on bool, we should use the pointer's dtype)" 

657 ) 

658 code.newline() 

659 

660 code.writeline("# compute") 

661 code.writeline( 

662 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

663 ) 

664 code.newline() 

665 

666 code.writeline("# stores") 

667 for i in range(schema.num_output_tensors()): 

668 code.writeline( 

669 f"tl.store(out{i}_ptr, out{i}.to(out{i}_ptr.type.element_ty))" 

670 ) 

671 code.newline() 

672 return code 

673 

674 # nd tile 1d grid kernel with block pointer 

675 def gen_body_one_tile_per_cta_with_bptr(self, code): 

676 ndim = self.ndim 

677 schema = self.fx 

678 

679 # block pointer for each operand 

680 shape = _tuple_content(tuple(f"s{i}" for i in range(ndim))) 

681 offsets = _tuple_content(tuple(f"offset{i}" for i in range(ndim))) 

682 tile_sizes = _tuple_content(tuple(f"tile_size{i}" for i in range(ndim))) 

683 

684 # reconstruct pid multi index 

685 code.writeline( 

686 "# pid multi index recontruction: we use c ordering, right axes changes fastest" 

687 ) 

688 for i in reversed(range(ndim)): 

689 if i > 0: 

690 code.writeline(f"tile_id{i} = tile_id % num_tiles{i}") 

691 code.writeline(f"tile_id //= num_tiles{i}") 

692 else: 

693 code.writeline(f"tile_id{i} = tile_id") 

694 code.newline() 

695 

696 # cta_offsets 

697 code.writeline("# tile offsets") 

698 

699 # Because block pointer only support `tl.int32` indexing, when max offsets 

700 # of ptrs exceeding 2^31, we should fallback it to noraml indexing method. 

701 code.writeline("if not FALLBACK_BPTR:") 

702 with code.indent(): 

703 for i in range(ndim): 

704 # Or else: AssertionError: Block pointers only support 32 bit 

705 # `offsets/block_shape`, add a `.to(tl.int32)` or use regular indexing 

706 # for 64 bit support 

707 code.writeline(f"offset{i} = (tile_id{i} * tile_size{i}).to(tl.int32)") 

708 

709 # loads 

710 code.writeline("# loads") 

711 for i in range(schema.num_input_tensors()): 

712 strides = _tuple_content(tuple(f"in{i}_stride{j}" for j in range(ndim))) 

713 order = _tuple_content( 

714 tuple(f"in{i}_stride_order{j}" for j in range(ndim)) 

715 ) 

716 

717 for j in range(ndim): 

718 code.writeline(f"if in{i}_zero_stride{j}:") 

719 with code.indent(): 

720 code.writeline(f"in{i}_stride{j} = 0") 

721 

722 code.writeline( 

723 f"in{i}_bptr = tl.make_block_ptr(" 

724 f"in{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))" 

725 ) 

726 

727 code.writeline( 

728 f"in{i} = tl.load(in{i}_bptr, boundary_check=({order})).to(in{i}_ptr.type.element_ty) " 

729 ) 

730 code.newline() 

731 

732 # compute 

733 # TODO: sepearate this part 

734 inputs_to_scalar_fn = [ 

735 self.input_name(i) for i in range(schema.num_inputs()) 

736 ] 

737 outputs_to_scalar_fn = [ 

738 self.output_name(i) for i in range(schema.num_output_tensors()) 

739 ] 

740 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

741 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

742 

743 code.writeline("# compute") 

744 code.writeline( 

745 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

746 ) 

747 code.newline() 

748 

749 # stores 

750 for i in range(schema.num_output_tensors()): 

751 strides = _tuple_content( 

752 tuple(f"out{i}_stride{j}" for j in range(ndim)) 

753 ) 

754 order = _tuple_content( 

755 tuple(f"out{i}_stride_order{j}" for j in range(ndim)) 

756 ) 

757 

758 for j in range(ndim): 

759 code.writeline(f"if out{i}_zero_stride{j}:") 

760 with code.indent(): 

761 code.writeline(f"out{i}_stride{j} = 0") 

762 

763 code.writeline( 

764 f"out{i}_bptr = tl.make_block_ptr(" 

765 f"out{i}_ptr, ({shape}), ({strides}), ({offsets}), ({tile_sizes}), order=({order}))" 

766 ) 

767 

768 code.writeline( 

769 f"tl.store(out{i}_bptr, out{i}.to(out{i}_bptr.type.element_ty), boundary_check=({order}))" 

770 ) 

771 code.writeline("else:") 

772 with code.indent(): 

773 # offsets 

774 for i in range(ndim): 

775 code.writeline( 

776 f"offsets{i} = tile_id{i} * tile_size{i} + tl.arange(0, tile_size{i})" 

777 ) 

778 

779 # masks 

780 for i in range(ndim): 

781 code.writeline(f"mask{i} = offsets{i} < s{i}") 

782 masks = tuple(f"mask{i}{_broadcast_vec(i, ndim)}" for i in range(ndim)) 

783 mask_combine = " & ".join(masks) 

784 code.writeline(f"mask = {mask_combine}") 

785 

786 # loads 

787 code.writeline("# loads") 

788 for i in range(schema.num_input_tensors()): 

789 strides = _tuple_content(tuple(f"in{i}_stride{j}" for j in range(ndim))) 

790 order = _tuple_content( 

791 tuple(f"in{i}_stride_order{j}" for j in range(ndim)) 

792 ) 

793 

794 for j in range(ndim): 

795 code.writeline(f"if in{i}_zero_stride{j}:") 

796 with code.indent(): 

797 code.writeline(f"in{i}_stride{j} = 0") 

798 offsets = tuple( 

799 f"offsets{j}{_broadcast_vec(j, ndim)} * in{i}_stride{j}" 

800 for j in range(ndim) 

801 ) 

802 offset_combine = " + ".join(offsets) 

803 code.writeline( 

804 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)" 

805 ) 

806 

807 code.newline() 

808 

809 # compute 

810 inputs_to_scalar_fn = [ 

811 self.input_name(i) for i in range(schema.num_inputs()) 

812 ] 

813 outputs_to_scalar_fn = [ 

814 self.output_name(i) for i in range(schema.num_output_tensors()) 

815 ] 

816 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

817 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

818 

819 code.writeline("# compute") 

820 code.writeline( 

821 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

822 ) 

823 code.newline() 

824 

825 # store 

826 for i in range(schema.num_output_tensors()): 

827 strides = _tuple_content( 

828 tuple(f"out{i}_stride{j}" for j in range(ndim)) 

829 ) 

830 order = _tuple_content( 

831 tuple(f"out{i}_stride_order{j}" for j in range(ndim)) 

832 ) 

833 

834 for j in range(ndim): 

835 code.writeline(f"if out{i}_zero_stride{j}:") 

836 with code.indent(): 

837 code.writeline(f"out{i}_stride{j} = 0") 

838 

839 offsets = tuple( 

840 f"offsets{j}{_broadcast_vec(j, ndim)} * out{i}_stride{j}" 

841 for j in range(ndim) 

842 ) 

843 offset_combine = " + ".join(offsets) 

844 code.writeline( 

845 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)" 

846 ) 

847 

848 def gen_body_gsl_with_bptr(self, code): 

849 code.writeline("num_ctas = tle.num_programs(0)") 

850 if self.ndim <= 4: 

851 num_tiles = " * ".join([f"num_tiles{i}" for i in range(self.ndim)]) 

852 code.writeline( 

853 f"tiles_per_cta = tl.cdiv({num_tiles}, num_ctas).to(tl.int32)" 

854 ) 

855 code.writeline("for j in range(0, tiles_per_cta):") 

856 with code.indent(): 

857 code.writeline("tile_id = pid + j * num_ctas") 

858 self.gen_body_one_tile_per_cta_with_bptr(code) 

859 

860 def gen_body_one_tile_per_cta_without_bptr(self, code): 

861 ndim = self.ndim 

862 schema = self.fx 

863 

864 # reconstruct pid multi index 

865 code.writeline( 

866 "# pid multi index recontruction: we use c ordering, right axes changes fastest" 

867 ) 

868 for i in reversed(range(ndim)): 

869 if i > 0: 

870 code.writeline(f"tile_id{i} = tile_id % num_tiles{i}") 

871 code.writeline(f"tile_id //= num_tiles{i}") 

872 else: 

873 code.writeline(f"tile_id{i} = tile_id") 

874 code.newline() 

875 

876 # offsets 

877 for i in range(ndim): 

878 code.writeline( 

879 f"offsets{i} = tile_id{i} * tile_size{i} + tl.arange(0, tile_size{i})" 

880 ) 

881 

882 # masks 

883 for i in range(ndim): 

884 code.writeline(f"mask{i} = offsets{i} < s{i}") 

885 masks = tuple(f"mask{i}{_broadcast_vec(i, ndim)}" for i in range(ndim)) 

886 mask_combine = " & ".join(masks) 

887 code.writeline(f"mask = {mask_combine}") 

888 

889 # loads 

890 code.writeline("# loads") 

891 for i in range(schema.num_input_tensors()): 

892 offsets = tuple( 

893 f"offsets{j}{_broadcast_vec(j, ndim)} * in{i}_stride{j}" 

894 for j in range(ndim) 

895 ) 

896 offset_combine = " + ".join(offsets) 

897 code.writeline( 

898 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)" 

899 ) 

900 

901 code.newline() 

902 

903 # compute 

904 # TODO: sepearate this part 

905 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())] 

906 outputs_to_scalar_fn = [ 

907 self.output_name(i) for i in range(schema.num_output_tensors()) 

908 ] 

909 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

910 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

911 

912 code.writeline("# compute") 

913 code.writeline( 

914 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

915 ) 

916 code.newline() 

917 

918 # stores 

919 for i in range(schema.num_output_tensors()): 

920 offsets = tuple( 

921 f"offsets{j}{_broadcast_vec(j, ndim)} * out{i}_stride{j}" 

922 for j in range(ndim) 

923 ) 

924 offset_combine = " + ".join(offsets) 

925 code.writeline( 

926 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)" 

927 ) 

928 

929 def gen_body_gsl_without_bptr(self, code): 

930 code.writeline("num_ctas = tle.num_programs(0)") 

931 if self.ndim <= 4: 

932 num_tiles = " * ".join([f"num_tiles{i}" for i in range(self.ndim)]) 

933 code.writeline(f"tiles_per_cta = tl.cdiv({num_tiles}, num_ctas)") 

934 code.writeline("for j in range(0, tiles_per_cta):") 

935 with code.indent(): 

936 code.writeline("tile_id = pid + j * num_ctas") 

937 self.gen_body_one_tile_per_cta_without_bptr(code) 

938 

939 def codegen_nd_tile_with_bptr(self, code): 

940 """Generate kernel nd tile & 1d grid with gsl support with block pointer.""" 

941 self.gen_import_function(code) 

942 self.gen_decorators(code) 

943 self.gen_signature(code, with_block_pointer=True) 

944 

945 # function body for rank-0 

946 if self.ndim == 0: 

947 with code.indent(): 

948 self.gen_body_for_0d(code) 

949 return code 

950 

951 with code.indent(): 

952 code.writeline("pid = tle.program_id(0)") 

953 self.gen_num_tiles(code) 

954 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute 

955 if self.ndim > 4: 

956 code.writeline("if one_tile_per_cta: # monolitic kernel style") 

957 with code.indent(): 

958 code.writeline("tile_id = pid") 

959 self.gen_body_one_tile_per_cta_with_bptr(code) 

960 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/ 

961 code.writeline("else: # grid-stride-loop style kernel") 

962 with code.indent(): 

963 self.gen_body_gsl_with_bptr(code) 

964 else: 

965 self.gen_body_gsl_with_bptr(code) 

966 code.newline() 

967 return code 

968 

969 def codegen_nd_tile_without_bptr(self, code): 

970 self.gen_import_function(code) 

971 self.gen_decorators(code) 

972 self.gen_signature(code, with_block_pointer=False) 

973 

974 # function body for rank-0 

975 if self.ndim == 0: 

976 with code.indent(): 

977 self.gen_body_for_0d(code) 

978 return code 

979 

980 with code.indent(): 

981 code.writeline("pid = tle.program_id(0)") 

982 self.gen_num_tiles(code) 

983 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute 

984 if self.ndim > 4: 

985 code.writeline("if one_tile_per_cta: # monolitic kernel style") 

986 with code.indent(): 

987 code.writeline("tile_id = pid") 

988 self.gen_body_one_tile_per_cta_without_bptr(code) 

989 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/ 

990 code.writeline("else: # grid-stride-loop style kernel") 

991 with code.indent(): 

992 self.gen_body_gsl_without_bptr(code) 

993 else: 

994 self.gen_body_gsl_without_bptr(code) 

995 code.newline() 

996 return code 

997 

998 def codegen_nd_tile(self, code): 

999 use_block_pointer = self.config.prefer_block_pointer 

1000 if use_block_pointer: 

1001 self.codegen_nd_tile_with_bptr(code) 

1002 else: 

1003 self.codegen_nd_tile_without_bptr(code) 

1004 return code 

1005 

1006 def gen_body_one_tile_per_cta_1d_tile(self, code): 

1007 ndim = self.ndim 

1008 schema = self.fx 

1009 

1010 # tile id 

1011 code.writeline("tid = tile_id * tile_size + tl.arange(0, tile_size)") 

1012 code.writeline("mask = tid < num_tasks") 

1013 

1014 # multi index reconstruction 

1015 for i in reversed(range(ndim)): 

1016 if i > 0: 

1017 code.writeline(f"i{i} = tid % s{i}") 

1018 code.writeline(f"tid //= s{i}") 

1019 else: 

1020 code.writeline(f"i{i} = tid") 

1021 code.newline() 

1022 

1023 # loads 

1024 code.writeline("# loads") 

1025 for i in range(schema.num_input_tensors()): 

1026 offsets = tuple(f"i{j} * in{i}_stride{j}" for j in range(ndim)) 

1027 offset_combine = " + ".join(offsets) 

1028 code.writeline( 

1029 f"in{i} = tl.load(in{i}_ptr + {offset_combine}, mask=mask).to(in{i}_ptr.type.element_ty)" 

1030 ) 

1031 

1032 code.newline() 

1033 

1034 # compute 

1035 # TODO: sepearate this part 

1036 inputs_to_scalar_fn = [self.input_name(i) for i in range(schema.num_inputs())] 

1037 outputs_to_scalar_fn = [ 

1038 self.output_name(i) for i in range(schema.num_output_tensors()) 

1039 ] 

1040 inputs_to_scalar_fn = _cs(inputs_to_scalar_fn) 

1041 outputs_to_scalar_fn = _cs(outputs_to_scalar_fn) 

1042 

1043 code.writeline("# compute") 

1044 code.writeline( 

1045 f"{outputs_to_scalar_fn} = {self.fn_name}({inputs_to_scalar_fn})" 

1046 ) 

1047 code.newline() 

1048 

1049 # stores 

1050 for i in range(schema.num_output_tensors()): 

1051 offsets = tuple(f"i{j} * out{i}_stride{j}" for j in range(ndim)) 

1052 offset_combine = " + ".join(offsets) 

1053 code.writeline( 

1054 f"in{i} = tl.store(out{i}_ptr + {offset_combine}, out{i}, mask=mask)" 

1055 ) 

1056 

1057 def gen_body_gsl_1d_tile(self, code): 

1058 code.writeline("num_ctas = tle.num_programs(0)") 

1059 code.writeline("for j in range(0, tiles_per_cta):") 

1060 with code.indent(): 

1061 code.writeline("tile_id = pid + j * num_ctas") 

1062 self.gen_body_one_tile_per_cta_1d_tile(code) 

1063 

1064 def codegen_1d_tile(self, code): 

1065 """Generate kernel 1d tile & 1d grid with gsl support.""" 

1066 self.gen_import_function(code) 

1067 self.gen_decorators(code) 

1068 self.gen_signature_1d_tile(code) 

1069 

1070 # function body for rank-0 

1071 if self.ndim == 0: 

1072 with code.indent(): 

1073 self.gen_body_for_0d(code) 

1074 return code 

1075 

1076 with code.indent(): 

1077 code.writeline("pid = tle.program_id(0)") 

1078 # code.writeline("num_ctas = te.num_programs(0)") 

1079 # monolitic kernel: one_tile_per_cta, it may requires a very large grid to compute 

1080 code.writeline("if one_tile_per_cta: # monolitic kernel style") 

1081 with code.indent(): 

1082 code.writeline("tile_id = pid") 

1083 self.gen_body_one_tile_per_cta_1d_tile(code) 

1084 # https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/ 

1085 code.writeline("else: # grid-stride-loop style kernel") 

1086 with code.indent(): 

1087 self.gen_body_gsl_1d_tile(code) 

1088 code.newline() 

1089 return code 

1090 

1091 

1092class WrapperGenerator: 

1093 def __init__( 

1094 self, 

1095 function_schema: FunctionSchema, 

1096 jit_fn_name: str, 

1097 ndim: int, 

1098 name: str, 

1099 config: CodeGenConfig, 

1100 ): 

1101 self.fx = function_schema 

1102 self.jit_fn_name = jit_fn_name 

1103 self.ndim = ndim 

1104 self.name = name 

1105 self.config = config 

1106 

1107 def input_name(self, i): 

1108 is_tensor = self.fx.is_tensor(i) 

1109 name = "in" if is_tensor else "val" 

1110 index = self.fx.input_index(i) 

1111 return f"{name}{index}" 

1112 

1113 def output_name(self, i): 

1114 return f"out{i}" 

1115 

1116 def gen_signature(self, code: IndentedBuffer): 

1117 # TODO: check if triton handles constexprs transitively 

1118 schema = self.fx 

1119 params: List[str] = [] 

1120 for i in range(schema.num_inputs()): 

1121 if schema.is_tensor(i): 

1122 params.append( 

1123 f"{self.input_name(i)}: Union[torch.Tensor, StridedBuffer]" 

1124 ) 

1125 else: 

1126 arg_type = schema.input_type(i) 

1127 if arg_type is not None: 

1128 params.append(f"{self.input_name(i)}: {_type_name(arg_type)}") 

1129 else: 

1130 params.append(f"{self.input_name(i)}") 

1131 # NOTE: [the wrapper's signature and rules for passing parameters ] 

1132 # input params: must be passed by position, since the names are renamed to 

1133 # in0, in1, val0, val1, ..., So passing these parameters by keyword is wierd 

1134 # So we enforce that these parameters must be passed by position. 

1135 # maybe we can fix it later 

1136 # output parameters: must be passed by keyword, since the scalar function 

1137 # do not have output parameters(think of it as some scalar function, output 

1138 # parameter does not make sense in this case.) They are added to allow destination 

1139 # passing style API. Output parameter is convenient in cases where we want 

1140 # to use some pre-defiend outputs(especially when they are some views of other 

1141 # tensors). We emphasize that these parameters are added in-addition, we enforce 

1142 # that they be passed by keyword. After all, out0, out1, ... does not mismatch 

1143 # names form the scalar function, since it does not have output parameters. 

1144 params.append("/") 

1145 params.append("*") # output params must be passed by keyword 

1146 

1147 for i in range(schema.num_output_tensors()): 

1148 params.append(f"{self.output_name(i)}: Union[torch.Tensor, StridedBuffer]") 

1149 code.writeline(f"def {self.name}({_cs(params)}): ") 

1150 

1151 def gen_docstring(self, code: IndentedBuffer): 

1152 schema = self.fx 

1153 doc = f'"""Generated wrapper function with {schema.signature(outputs_in_arg=True)}"""' 

1154 code.writeline(doc) 

1155 

1156 def gen_same_shape_check(self, code: IndentedBuffer): 

1157 schema: FunctionSchema = self.fx 

1158 params = [f"in{i}.shape" for i in range(schema.num_input_tensors())] + [ 

1159 f"out{i}.shape" for i in range(schema.num_output_tensors()) 

1160 ] 

1161 check: str = " == ".join(params) 

1162 code.writeline(f"assert {check}, 'operand shapes mismatch'") 

1163 

1164 def gen_fallback_bptr(self, code: IndentedBuffer): 

1165 code.writeline("def fallback_bptr(t):") 

1166 with code.indent(): 

1167 code.writeline("ndim = t.dim()") 

1168 code.writeline("sizes = t.size()") 

1169 code.writeline("if t.numel() == 0:") 

1170 with code.indent(): 

1171 code.writeline("return False") 

1172 code.writeline("for i in range(ndim):") 

1173 with code.indent(): 

1174 code.writeline("if sizes[i] >= 2147483648:") 

1175 with code.indent(): 

1176 code.writeline("return True") 

1177 code.writeline("return False") 

1178 code.newline() 

1179 code.newline() 

1180 

1181 def gen_task_partition(self, code: IndentedBuffer): 

1182 code.writeline("# task partitioning") 

1183 ndim = self.ndim 

1184 if ndim == 0: 

1185 code.writeline("num_warps = 1") 

1186 code.writeline("num_ctas = 1") 

1187 else: 

1188 code.writeline("shape = out0.shape") 

1189 code.writeline("num_tasks = out0.numel()") 

1190 code.writeline("if num_tasks == 0:") 

1191 with code.indent(): 

1192 self.gen_return(code) 

1193 max_tile_size = self.config.max_tile_size 

1194 code.writeline( 

1195 f"tile_sizes = heuristics_for_tile_size({max_tile_size}, *shape)" 

1196 ) 

1197 code.writeline("tile_size = math.prod(tile_sizes)") 

1198 code.writeline( 

1199 "num_tiles = math.prod(triton.cdiv(size, tile_size) for size, tile_size in zip(shape, tile_sizes))" 

1200 ) 

1201 code.writeline("num_warps = heuristics_for_num_warps(tile_size)") 

1202 max_grid_size0 = self.config.max_grid_size[0] 

1203 code.writeline(f"num_ctas = min({max_grid_size0} // num_warps, num_tiles)") 

1204 

1205 code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)") 

1206 code.writeline("one_tile_per_cta = tiles_per_cta==1") 

1207 if self.config.prefer_block_pointer: 

1208 code.writeline("FALLBACK_BPTR = False") 

1209 inputs = ",".join( 

1210 [f"in{i}" for i in range(self.fx.num_input_tensors())] 

1211 ) 

1212 outputs = ",".join( 

1213 [f"out{i}" for i in range(self.fx.num_output_tensors())] 

1214 ) 

1215 code.writeline(f"all_tensors = [{inputs}, {outputs}]") 

1216 code.writeline("for t in all_tensors:") 

1217 with code.indent(): 

1218 code.writeline("if fallback_bptr(t):") 

1219 with code.indent(): 

1220 code.writeline("FALLBACK_BPTR = True") 

1221 code.writeline("break") 

1222 if ndim > 0 and ndim <= 4: 

1223 max_grid_size0 = self.config.max_grid_size[0] 

1224 dynamic_num_tiles = " * ".join( 

1225 f"triton.cdiv(meta['s{i}'], meta['tile_size{i}'])" for i in range(ndim) 

1226 ) 

1227 code.writeline( 

1228 f"grid = lambda meta: (min({max_grid_size0} // num_warps, {dynamic_num_tiles}), )" 

1229 ) 

1230 else: 

1231 code.writeline("grid = (num_ctas, 1, 1)") 

1232 

1233 def gen_task_partition_1d(self, code: IndentedBuffer): 

1234 code.writeline("# task partitioning") 

1235 ndim = self.ndim 

1236 if ndim == 0: 

1237 code.writeline("num_warps = 1") 

1238 code.writeline("num_ctas = 1") 

1239 else: 

1240 code.writeline("shape = out0.shape") 

1241 code.writeline("num_tasks = out0.numel()") 

1242 code.writeline("if num_tasks == 0:") 

1243 with code.indent(): 

1244 self.gen_return(code) 

1245 max_tile_size = self.config.max_tile_size 

1246 code.writeline( 

1247 f"tile_sizes = heuristics_for_tile_size({max_tile_size}, num_tasks)" 

1248 ) 

1249 code.writeline("tile_size = tile_sizes[0]") 

1250 code.writeline("num_tiles = triton.cdiv(num_tasks, tile_size)") 

1251 max_grid_size0 = self.config.max_grid_size[0] 

1252 code.writeline(f"num_ctas = min({max_grid_size0}, num_tiles)") 

1253 

1254 code.writeline("tiles_per_cta = triton.cdiv(num_tiles, num_ctas)") 

1255 code.writeline("num_warps = heuristics_for_num_warps(tile_size)") 

1256 code.writeline("one_tile_per_cta = tiles_per_cta==1") 

1257 if self.config.prefer_block_pointer: 

1258 code.writeline("FALLBACK_BPTR = False") 

1259 inputs = ",".join( 

1260 [f"in{i}" for i in range(self.fx.num_input_tensors())] 

1261 ) 

1262 outputs = ",".join( 

1263 [f"out{i}" for i in range(self.fx.num_output_tensors())] 

1264 ) 

1265 code.writeline(f"all_tensors = [{inputs}, {outputs}]") 

1266 code.writeline("for t in all_tensors:") 

1267 with code.indent(): 

1268 code.writeline("if fallback_bptr(t):") 

1269 with code.indent(): 

1270 code.writeline("FALLBACK_BPTR = True") 

1271 code.writeline("break") 

1272 code.writeline("grid = (num_ctas, 1, 1)") 

1273 

1274 def gen_kernel_launch( 

1275 self, 

1276 code: IndentedBuffer, 

1277 ): 

1278 schema = self.fx 

1279 ndim = self.ndim 

1280 

1281 with_block_pointer = self.config.prefer_block_pointer 

1282 

1283 code.writeline("# kernel launch") 

1284 for i in range(schema.num_input_tensors()): 

1285 code.writeline(f"in{i}_strides = in{i}.stride()") 

1286 if not with_block_pointer: 

1287 continue 

1288 if ndim >= 2: # where ndim is 1, we don't need to compute stride order 

1289 code.writeline(f"in{i}_stride_order = stride_order(in{i}_strides)") 

1290 else: 

1291 code.writeline(f"in{i}_stride_order = (0,)") 

1292 code.writeline( 

1293 f"in{i}_zero_strides = [True if s == 0 else False for s in in{i}_strides]" 

1294 ) 

1295 for i in range(schema.num_output_tensors()): 

1296 code.writeline(f"out{i}_strides = out{i}.stride()") 

1297 if not with_block_pointer: 

1298 continue 

1299 if ndim >= 2: 

1300 code.writeline(f"out{i}_stride_order = stride_order(out{i}_strides)") 

1301 else: 

1302 code.writeline(f"out{i}_stride_order = (0,)") 

1303 code.writeline( 

1304 f"out{i}_zero_strides = [True if s == 0 else False for s in out{i}_strides]" 

1305 ) 

1306 

1307 code.writeline("with torch_device_fn.device(in0.device.index):") 

1308 with code.indent(): 

1309 code.writeline(f"{self.jit_fn_name}[grid](") 

1310 with code.indent(): 

1311 params = [] 

1312 # NOTE: WRAP 

1313 for i in range(schema.num_inputs()): 

1314 if schema.is_tensor(i): 

1315 params.append(f"{self.input_name(i)}") 

1316 else: 

1317 params.append(self.input_name(i)) 

1318 for i in range(schema.num_output_tensors()): 

1319 params.append(f"{self.output_name(i)}") 

1320 

1321 code.writeline(f"{_cs(params)},") 

1322 

1323 if ndim > 0: 

1324 for i in range(schema.num_input_tensors()): 

1325 s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim)) 

1326 code.writeline(f"{s}, # stride for in{i}") 

1327 if with_block_pointer: 

1328 order = ", ".join( 

1329 f"in{i}_stride_order[{j}]" for j in range(ndim) 

1330 ) 

1331 code.writeline(f"{order}, # stride order for in{i}") 

1332 zero_strides = ", ".join( 

1333 f"in{i}_zero_strides[{j}]" for j in range(ndim) 

1334 ) 

1335 code.writeline( 

1336 f"{zero_strides}, # zero stride flag for in{i}" 

1337 ) 

1338 

1339 for i in range(schema.num_output_tensors()): 

1340 s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim)) 

1341 code.writeline(f"{s}, # stride for out{i}") 

1342 if with_block_pointer: 

1343 order = ", ".join( 

1344 f"out{i}_stride_order[{j}]" for j in range(ndim) 

1345 ) 

1346 code.writeline(f"{order}, # stride orderfor out{i}") 

1347 zero_strides = ", ".join( 

1348 f"out{i}_zero_strides[{j}]" for j in range(ndim) 

1349 ) 

1350 code.writeline( 

1351 f"{zero_strides}, # zero stride flag for out{i}" 

1352 ) 

1353 

1354 shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim)) 

1355 code.writeline(f"{shape_args}, # task indexing space") 

1356 code.writeline("num_tasks, # num tasks") 

1357 if self.config.prefer_block_pointer: 

1358 code.writeline("FALLBACK_BPTR=FALLBACK_BPTR,") 

1359 if ndim > 4: 

1360 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta") 

1361 if ndim == 0 or ndim > 4: 

1362 for i in range(ndim): 

1363 code.writeline(f"tile_size{i}=tile_sizes[{i}],") 

1364 if ndim > 4: 

1365 code.writeline("one_tile_per_cta=one_tile_per_cta,") 

1366 code.writeline("num_warps=num_warps,") 

1367 code.writeline(")") 

1368 

1369 def gen_kernel_launch_1d( 

1370 self, 

1371 code: IndentedBuffer, 

1372 ): 

1373 schema = self.fx 

1374 ndim = self.ndim 

1375 

1376 code.writeline("# kernel launch") 

1377 for i in range(schema.num_input_tensors()): 

1378 code.writeline(f"in{i}_strides = in{i}.stride()") 

1379 for i in range(schema.num_output_tensors()): 

1380 code.writeline(f"out{i}_strides = out{i}.stride()") 

1381 

1382 code.writeline("with torch_device_fn.device(in0.device.index):") 

1383 with code.indent(): 

1384 code.writeline(f"{self.jit_fn_name}[grid](") 

1385 with code.indent(): 

1386 params = [] 

1387 # NOTE: WRAP 

1388 for i in range(schema.num_inputs()): 

1389 if schema.is_tensor(i): 

1390 params.append(f"{self.input_name(i)}") 

1391 else: 

1392 params.append(self.input_name(i)) 

1393 for i in range(schema.num_output_tensors()): 

1394 params.append(f"{self.output_name(i)}") 

1395 

1396 code.writeline(f"{_cs(params)},") 

1397 

1398 if ndim > 0: 

1399 for i in range(schema.num_input_tensors()): 

1400 s = ", ".join(f"in{i}_strides[{j}]" for j in range(ndim)) 

1401 code.writeline(f"{s}, # stride for in{i}") 

1402 for i in range(schema.num_output_tensors()): 

1403 s = ", ".join(f"out{i}_strides[{j}]" for j in range(ndim)) 

1404 code.writeline(f"{s}, # stride for out{i}") 

1405 

1406 shape_args: str = ", ".join(f"shape[{i}]" for i in range(ndim)) 

1407 code.writeline(f"{shape_args}, # task indexing space") 

1408 code.writeline("num_tasks, # num tasks") 

1409 if self.config.prefer_block_pointer: 

1410 code.writeline("FALLBACK_BPTR=FALLBACK_BPTR,") 

1411 code.writeline("tiles_per_cta=tiles_per_cta, # tiles_per_cta") 

1412 code.writeline("tile_size=tile_size,") 

1413 code.writeline("one_tile_per_cta=one_tile_per_cta,") 

1414 code.writeline("num_warps=num_warps,") 

1415 code.writeline(")") 

1416 

1417 def gen_return(self, code: IndentedBuffer): 

1418 return_exprs = _cs(f"out{i}" for i in range(self.fx.num_output_tensors())) 

1419 code.writeline(f"return {return_exprs}") 

1420 

1421 def codegen_nd_tile(self, code): 

1422 if self.config.prefer_block_pointer: 

1423 self.gen_fallback_bptr(code) 

1424 self.gen_signature(code) 

1425 

1426 with code.indent(): 

1427 self.gen_docstring(code) 

1428 self.gen_same_shape_check(code) 

1429 self.gen_task_partition(code) 

1430 self.gen_kernel_launch(code) 

1431 self.gen_return(code) 

1432 code.newline() 

1433 return code 

1434 

1435 def codegen_1d_tile(self, code): 

1436 if self.config.prefer_block_pointer: 

1437 self.gen_fallback_bptr(code) 

1438 self.gen_signature(code) 

1439 

1440 with code.indent(): 

1441 self.gen_docstring(code) 

1442 self.gen_same_shape_check(code) 

1443 self.gen_task_partition_1d(code) 

1444 self.gen_kernel_launch_1d(code) 

1445 self.gen_return(code) 

1446 code.newline() 

1447 return code 

1448 

1449 

1450class ModuleGenerator: 

1451 def __init__( 

1452 self, 

1453 function_schema: FunctionSchema, 

1454 scalar_fn: triton.JITFunction, 

1455 ndim: int, 

1456 jit_fn_name: str, 

1457 wrapper_name: str, 

1458 config: CodeGenConfig, 

1459 ): 

1460 self.config = config 

1461 self.scalar_fn = scalar_fn 

1462 self.wrapper_gen = WrapperGenerator( 

1463 function_schema, jit_fn_name, ndim, wrapper_name, config 

1464 ) 

1465 self.kernel_gen = KernelGenerator( 

1466 function_schema, scalar_fn, ndim, jit_fn_name, config 

1467 ) 

1468 

1469 @staticmethod 

1470 def _collect_jit_deps(scalar_fn): 

1471 """Collect extra imports and local @triton.jit helper sources. 

1472 

1473 Parses the source module where scalar_fn is defined using AST. 

1474 Returns a tuple of: 

1475 - extra_imports: dict of module_path -> set of names 

1476 - local_sources: list of source strings for local @triton.jit 

1477 functions (those NOT decorated with @pointwise_dynamic) 

1478 """ 

1479 import ast 

1480 import inspect 

1481 

1482 py_fn = getattr(scalar_fn, "fn", scalar_fn) 

1483 module_name = getattr(py_fn, "__module__", None) 

1484 if not module_name: 

1485 return {}, [] 

1486 try: 

1487 mod = importlib.import_module(module_name) 

1488 source_file = inspect.getfile(mod) 

1489 except (ImportError, TypeError, OSError): 

1490 return {}, [] 

1491 try: 

1492 with open(source_file) as f: 

1493 module_source = f.read() 

1494 source_lines = module_source.splitlines(keepends=True) 

1495 tree = ast.parse(module_source) 

1496 except (OSError, SyntaxError): 

1497 return {}, [] 

1498 

1499 # Collect non-standard import-from lines 

1500 ALREADY_IMPORTED = { 

1501 "math", 

1502 "typing", 

1503 "torch", 

1504 "triton", 

1505 "triton.language", 

1506 "flag_gems.utils.shape_utils", 

1507 "flag_gems.utils.tensor_wrapper", 

1508 "flag_gems.utils.libentry", 

1509 "flag_gems.utils", 

1510 "flag_gems.runtime", 

1511 "flag_gems.utils.pointwise_dynamic", 

1512 "utils.pointwise_dynamic", 

1513 "randn", 

1514 "utils", 

1515 "all", 

1516 } 

1517 extra_imports = {} 

1518 for node in ast.iter_child_nodes(tree): 

1519 if isinstance(node, ast.ImportFrom) and node.module: 

1520 if node.module in ALREADY_IMPORTED: 

1521 continue 

1522 names = {alias.name for alias in node.names} 

1523 extra_imports.setdefault(node.module, set()).update(names) 

1524 

1525 # Collect local @triton.jit functions (without @pointwise_dynamic) 

1526 def _has_decorator(func_node, name): 

1527 for dec in func_node.decorator_list: 

1528 src = "".join(source_lines[dec.lineno - 1 : dec.end_lineno]) 

1529 if name in src: 

1530 return True 

1531 return False 

1532 

1533 def _extract_source(func_node): 

1534 start = func_node.lineno - 1 

1535 if func_node.decorator_list: 

1536 start = func_node.decorator_list[0].lineno - 1 

1537 end = func_node.end_lineno 

1538 return "".join(source_lines[start:end]) 

1539 

1540 local_sources = [] 

1541 for node in ast.iter_child_nodes(tree): 

1542 if not isinstance(node, ast.FunctionDef): 

1543 continue 

1544 if not _has_decorator(node, "triton.jit") and not _has_decorator( 

1545 node, "jit" 

1546 ): 

1547 continue 

1548 if _has_decorator(node, "pointwise_dynamic"): 

1549 continue 

1550 local_sources.append(_extract_source(node)) 

1551 

1552 return extra_imports, local_sources 

1553 

1554 def generate_imports(self, code: IndentedBuffer) -> IndentedBuffer: 

1555 code.writeline("import math") 

1556 code.writeline("from typing import Union") 

1557 code.writeline("import torch") 

1558 code.writeline("import triton") 

1559 code.writeline("from triton import language as tl") 

1560 code.newline() 

1561 code.writeline("import flag_gems") 

1562 code.writeline("from flag_gems.utils.shape_utils import (") 

1563 code.writeline(" heuristics_for_tile_size,") 

1564 code.writeline(" heuristics_for_num_warps,") 

1565 code.writeline(" stride_order,") 

1566 code.writeline(")") 

1567 code.writeline("from flag_gems.utils.tensor_wrapper import StridedBuffer") 

1568 code.writeline("from flag_gems.utils.libentry import libentry, libtuner") 

1569 code.writeline("from flag_gems.utils import triton_lang_extension as tle") 

1570 code.writeline("from flag_gems.runtime import torch_device_fn") 

1571 

1572 # Generate extra imports and local JIT deps of the scalar function 

1573 jit_dep_imports, local_jit_sources = self._collect_jit_deps(self.scalar_fn) 

1574 for module_path, names in sorted(jit_dep_imports.items()): 

1575 sorted_names = ", ".join(sorted(names)) 

1576 code.writeline(f"from {module_path} import {sorted_names}") 

1577 

1578 code.newline() 

1579 code.newline() 

1580 

1581 # Emit local @triton.jit helper functions 

1582 for source in local_jit_sources: 

1583 for line in source.splitlines(): 

1584 code.writeline(line) 

1585 code.newline() 

1586 

1587 return code 

1588 

1589 def codegen(self, code: IndentedBuffer): 

1590 code = self.generate_imports(code) 

1591 if self.config.prefer_1d_tile: 

1592 code = self.wrapper_gen.codegen_1d_tile(code) 

1593 code = self.kernel_gen.codegen_1d_tile(code) 

1594 else: 

1595 code = self.wrapper_gen.codegen_nd_tile(code) 

1596 code = self.kernel_gen.codegen_nd_tile(code) 

1597 return code 

1598 

1599 

1600@dataclass 

1601class KernelInfo: 

1602 """Information about a generated kernel for C++ integration.""" 

1603 

1604 file_path: str 

1605 kernel_name: str 

1606 wrapper_name: str 

1607 ndim: int 

1608 

1609 

1610class PointwiseDynamicFunction: 

1611 """Utility to generate function for general pointwise operation. It generate wrapper & JITFunction 

1612 which are specialized according to the rank of the task space(the broadcasted shape of all input tensors). 

1613 The generated code are written out to the cache directory (defaults to ~/.flaggems). 

1614 """ 

1615 

1616 def __init__(self, op_desc: FunctionSchema, scalar_fn: JITFunction, config=None): 

1617 self.fx = op_desc 

1618 

1619 assert isinstance(scalar_fn, JITFunction) 

1620 self._scalar_fn = scalar_fn 

1621 self._scalar_fn_cache_key = scalar_fn.cache_key 

1622 self.pid = os.getpid() 

1623 

1624 self.config: CodeGenConfig = config or get_codegen_config() 

1625 

1626 # instantiated & cached overloads 

1627 self.overloads: Mapping[str, Callable] = {} 

1628 # cached kernel info for C++ integration 

1629 self._kernel_info_cache: Mapping[str, KernelInfo] = {} 

1630 

1631 def __call__(self, *args, **kwargs): 

1632 # inputs must be passed by position, outputs must be passed by keyword 

1633 ndim, args, kwargs = self.prepare_args(*args, **kwargs) 

1634 overload = self.instantiate(ndim) 

1635 out = overload(*args, **kwargs) 

1636 # NOTE: overload keeps the type of outputs: 

1637 # if a pre-defiend output is a Tensor or StridedBuffer, the corresponding 

1638 # output is also a Tensor StridedBuffer, respectively 

1639 # since prepare_args Wraps all the arguments, the outputs are all StridedBuffer 

1640 # but if manually instantiated overload is directly called, take care of 

1641 # that manually 

1642 return self._unwrap(out) 

1643 

1644 @staticmethod 

1645 def use_fast_path(tensors): 

1646 return all_the_same_shape(tensors) and ( 

1647 all_c_contiguous(tensors) 

1648 or ( 

1649 all_the_same_stride(tensors) 

1650 and torch.ops.aten.is_non_overlapping_and_dense(tensors[0]) 

1651 ) 

1652 ) 

1653 

1654 def prepare_args(self, *args, **kwargs): 

1655 # output allocation(when needed) 

1656 # task simplification & task-rank infernece & input-output reinterpretation 

1657 schema = self.fx 

1658 outputs_that_need_allocation: List[int] = [] 

1659 out_tensors = [] 

1660 for i in range(schema.num_output_tensors()): 

1661 k = f"out{i}" 

1662 if k in kwargs: 

1663 out_tensors.append(kwargs[k]) 

1664 else: 

1665 outputs_that_need_allocation.append(i) 

1666 # input arguments must be passed by position 

1667 if schema._is_tensor is not None: 

1668 if not check_tensor_attributes(args, (schema._is_tensor)): 

1669 raise ValueError( 

1670 "Input arguments must be passed by position, and the corresponding dtype must be specified." 

1671 ) 

1672 in_tensors = [item for i, item in enumerate(args) if schema.is_tensor(i)] 

1673 

1674 # output dtype promotions 

1675 outputs_dtypes_for_allocation = [] 

1676 for i in outputs_that_need_allocation: 

1677 *arg_indices, method = schema._promotion_methods[i] 

1678 promote_args = (args[j] for j in arg_indices) 

1679 _, dtype = type_promotion(*promote_args, type_promotion=method) 

1680 outputs_dtypes_for_allocation.append(dtype) 

1681 

1682 tensors = out_tensors + in_tensors 

1683 if self.use_fast_path(tensors): # dimension collapse & use physical ordering 

1684 allocated_outputs = [ 

1685 torch.empty_like(tensors[0], dtype=dtype) 

1686 for dtype in outputs_dtypes_for_allocation 

1687 ] 

1688 task_shape = (tensors[0].numel(),) 

1689 strides = (1,) 

1690 ndim = 1 

1691 args = tuple( 

1692 ( 

1693 StridedBuffer(item, task_shape, strides) 

1694 if schema.is_tensor(i) 

1695 else item 

1696 ) 

1697 for i, item in enumerate(args) 

1698 ) 

1699 kwargs = { 

1700 k: StridedBuffer(item, task_shape, strides) 

1701 for k, item in kwargs.items() 

1702 } 

1703 for seq_id, output_id in enumerate(outputs_that_need_allocation): 

1704 kwargs[f"out{output_id}"] = StridedBuffer( 

1705 allocated_outputs[seq_id], task_shape, strides 

1706 ) 

1707 else: 

1708 # a simple strategy: all the undefined tensors will follow the first 

1709 # tensor that is not broadcated, no attempts to simplify task, no reordering, 

1710 # no dimenion collapsing 

1711 shapes = tuple(item.shape for item in in_tensors) 

1712 

1713 task_shape = broadcast_shapes(shapes) 

1714 

1715 if out_tensors: 

1716 for index, item in enumerate(out_tensors): 

1717 if list(item.shape) != list(task_shape): 

1718 raise RuntimeError( 

1719 f"out tensor at index {index} shape is invalid, should be {task_shape} but is {item.shape}!" 

1720 ) 

1721 # output arguments must not have internal overlapping for pointwise operation 

1722 if has_internal_overlapping(item) == MemOverlap.Yes: 

1723 raise RuntimeError( 

1724 "Pointwise Input arguments should not have internal overlapping." 

1725 ) 

1726 

1727 ndim = len(task_shape) 

1728 for item in tensors: 

1729 if item.shape == task_shape: 

1730 allocated_outputs = [ 

1731 torch.empty_like(item, dtype=dtype) 

1732 for dtype in outputs_dtypes_for_allocation 

1733 ] 

1734 break 

1735 else: # nobreak 

1736 device = tensors[0].device 

1737 allocated_outputs = [ 

1738 torch.empty(task_shape, dtype=dtype, device=device) 

1739 for dtype in outputs_dtypes_for_allocation 

1740 ] 

1741 args = tuple( 

1742 ( 

1743 StridedBuffer( 

1744 item, 

1745 task_shape, 

1746 broadcasted_stride(item.shape, item.stride(), task_shape), 

1747 ) 

1748 if schema.is_tensor(i) 

1749 else item 

1750 ) 

1751 for i, item in enumerate(args) 

1752 ) 

1753 kwargs = { 

1754 k: StridedBuffer( 

1755 item, 

1756 task_shape, 

1757 broadcasted_stride(item.shape, item.stride(), task_shape), 

1758 ) 

1759 for k, item in kwargs.items() 

1760 } 

1761 for seq_id, output_id in enumerate(outputs_that_need_allocation): 

1762 item = allocated_outputs[seq_id] 

1763 kwargs[f"out{output_id}"] = StridedBuffer( 

1764 item, 

1765 task_shape, 

1766 broadcasted_stride(item.shape, item.stride(), task_shape), 

1767 ) 

1768 return (ndim, args, kwargs) 

1769 

1770 def _unwrap(self, tensors): 

1771 # unwrap StridedBuffer to get Tensor 

1772 if self.fx.num_output_tensors() == 1: 

1773 item = tensors 

1774 return item.unwrap() 

1775 return tuple(item.unwrap() for item in tensors) 

1776 

1777 def _compute_kernel_names(self, ndim: int) -> Tuple[str, str, str]: 

1778 """Compute kernel name, wrapper name, and file path for a given ndim. 

1779 

1780 This is the single source of truth for naming, used by both instantiate() 

1781 and get_kernel_info() to ensure consistency. 

1782 

1783 Returns: 

1784 Tuple of (kernel_name, wrapper_name, file_path) 

1785 """ 

1786 scalar_fn_name = self._scalar_fn.__name__ 

1787 kernel_name = f"{scalar_fn_name}_kernel_rank_{ndim}" 

1788 wrapper_name = f"{scalar_fn_name}_wrapper_rank_{ndim}" 

1789 

1790 file_name = ( 

1791 f"pointwise_dynamic_{self._scalar_fn_cache_key}_{kernel_name}_" 

1792 f"{'1d_tile_' if self.config.prefer_1d_tile else ''}" 

1793 f"{'bptr' if (not self.config.prefer_1d_tile and self.config.prefer_block_pointer) else ''}" 

1794 ".py" 

1795 ) 

1796 file_path = str(code_cache_dir() / file_name) 

1797 

1798 return kernel_name, wrapper_name, file_path 

1799 

1800 def instantiate(self, ndim): 

1801 # NOTE: manually instantiated overload does not have `prepare_args` as 

1802 # preprocessing, so you have to manually allocate output and make sure that 

1803 # the inputs & ouputs actually fits the manually instantiated overload 

1804 key = f"{ndim}_{self.config.prefer_block_pointer}" 

1805 if key in self.overloads: 

1806 return self.overloads[key] 

1807 

1808 code = IndentedBuffer() 

1809 

1810 # Use helper to compute names (single source of truth) 

1811 kernel_name, wrapper_name, file_path = self._compute_kernel_names(ndim) 

1812 

1813 module_gen = ModuleGenerator( 

1814 self.fx, 

1815 self._scalar_fn, 

1816 ndim, 

1817 kernel_name, 

1818 wrapper_name, 

1819 self.config, 

1820 ) 

1821 module_gen.codegen(code) 

1822 

1823 # NOTE: [why write the generated code to a file] 

1824 # triton uses inpsect to get the source of the jitted function, which requires 

1825 # that the source code can be found by inspect 

1826 # We write it into a file, since inspect cannot find the source of functions dynamically 

1827 # created via exec string. We can help inspect to find the source by hacking linecache 

1828 # library, but we find generating a module simpler, since we can generating 2 functions 

1829 # the kernel and the wrapper, and the wrapper calls the kernel. 

1830 write_atomic(file_path, code.getvalue()) 

1831 

1832 # load 

1833 spec = importlib.util.spec_from_file_location( 

1834 f"_gen_module_{self._scalar_fn_cache_key}_rank_{ndim}", 

1835 file_path, 

1836 ) 

1837 m = importlib.util.module_from_spec(spec) 

1838 # do not expose it to sys.modules 

1839 # sys.modules["_add_module"] = m 

1840 

1841 # NOTE: [why not import the scalar function] 

1842 # we do not re-import the scalar function, although the generated kernel **calls** it 

1843 # Since a function's __name__ may be changed, from the module where it is defined import its 

1844 # __name__ is not same; Also the same may be rebind to something else, importing via name 

1845 # cannot guarantee that scalar function is imported. 

1846 # So we copy the scalar function and its __globals__ to the generated module to do this 

1847 # https://stackoverflow.com/questions/11170949/how-to-make-a-copy-of-a-python-module-at-runtime 

1848 spec.loader.exec_module(m) 

1849 m.__dict__.update(self._scalar_fn.__globals__) 

1850 m.__dict__[self._scalar_fn.__name__] = self._scalar_fn 

1851 

1852 overload = getattr(m, wrapper_name) 

1853 self.overloads[key] = overload 

1854 

1855 # Cache kernel info for C++ integration 

1856 self._kernel_info_cache[key] = KernelInfo( 

1857 file_path=file_path, 

1858 kernel_name=kernel_name, 

1859 wrapper_name=wrapper_name, 

1860 ndim=ndim, 

1861 ) 

1862 

1863 return overload 

1864 

1865 def get_kernel_info(self, ndim: int) -> KernelInfo: 

1866 """Get kernel information for a given ndim. 

1867 

1868 This method is useful for C++ integration to get the file path and 

1869 kernel name without duplicating the naming logic. 

1870 

1871 If the kernel hasn't been instantiated yet, this will instantiate it first. 

1872 

1873 Args: 

1874 ndim: The rank of the task space 

1875 

1876 Returns: 

1877 KernelInfo with file_path, kernel_name, wrapper_name, and ndim 

1878 """ 

1879 key = f"{ndim}_{self.config.prefer_block_pointer}" 

1880 

1881 # Ensure the kernel is instantiated 

1882 if key not in self._kernel_info_cache: 

1883 self.instantiate(ndim) 

1884 

1885 return self._kernel_info_cache[key] 

1886 

1887 

1888def pointwise_dynamic( 

1889 f: Optional[JITFunction] = None, 

1890 *, 

1891 num_inputs: Optional[int] = None, 

1892 is_tensor: Optional[List[bool]] = None, 

1893 dtypes: Optional[List[Optional[type]]] = None, 

1894 num_outputs: Optional[int] = None, 

1895 promotion_methods: Optional[Tuple[int, ...]] = None, 

1896 config: Optional[CodeGenConfig] = None, 

1897): 

1898 def decorator(fn): 

1899 nonlocal num_inputs 

1900 if (num_inputs is None) and (is_tensor is None) and (dtypes is None): 

1901 num_inputs = len(fn.arg_names) 

1902 op_desc = FunctionSchema( 

1903 num_inputs=num_inputs, 

1904 is_tensor=is_tensor, 

1905 dtypes=dtypes, 

1906 num_outputs=num_outputs, 

1907 promotion_methods=promotion_methods, 

1908 ) 

1909 return PointwiseDynamicFunction(op_desc, fn, config) 

1910 

1911 if f is not None: 

1912 return decorator(f) 

1913 return decorator