Coverage for src/flag_gems/utils/libentry.py: 88%

405 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1from __future__ import annotations 

2 

3import hashlib 

4import inspect 

5import logging 

6import math 

7import multiprocessing 

8import os 

9import time 

10from abc import abstractmethod 

11from collections import OrderedDict 

12from functools import cached_property 

13from itertools import starmap 

14from pathlib import Path 

15from typing import ( 

16 Any, 

17 Callable, 

18 Dict, 

19 Final, 

20 Iterator, 

21 List, 

22 Optional, 

23 Tuple, 

24 Type, 

25 Union, 

26 overload, 

27) 

28 

29import triton 

30 

31from flag_gems import runtime 

32from flag_gems.runtime import torch_device_fn 

33from flag_gems.runtime.backend import _state 

34from flag_gems.utils.code_cache import config_cache_dir 

35from flag_gems.utils.models import PersistantModel, SQLPersistantModel 

36 

37logger = logging.getLogger(__name__) 

38 

39DEVICE_COUNT = runtime.device.device_count 

40 

41version = triton.__version__.split(".") 

42major_version, minor_version = eval(version[0]), eval(version[1]) 

43 

44 

45if major_version == 2: 

46 

47 def all_kwargs(self): 

48 return { 

49 **self.kwargs, 

50 **{ 

51 k: getattr(self, k) 

52 for k in ( 

53 "num_warps", 

54 "num_ctas", 

55 "num_stages", 

56 "num_buffers_warp_spec", 

57 "num_consumer_groups", 

58 "reg_dec_producer", 

59 "reg_inc_consumer", 

60 "maxnreg", 

61 ) 

62 if hasattr(self, k) 

63 }, 

64 } 

65 

66 setattr(triton.Config, "all_kwargs", all_kwargs) 

67 

68FLAGGEMS_DB_URL = os.getenv("FLAGGEMS_DB_URL", None) 

69 

70 

71class Cache(object): 

72 def __init__( 

73 self, table_name: str, model: PersistantModel, *args, **kwargs 

74 ) -> Cache: 

75 super().__init__(*args, **kwargs) 

76 self.table_name: Final[str] = table_name 

77 self.model: Final[PersistantModel] = model 

78 

79 

80class ConfigCache(Cache): 

81 """ 

82 `ConfigCache` is used to store the relationship between keys and their known best configurations. 

83 """ 

84 

85 def __init__( 

86 self, table_name: str, model: PersistantModel, *args, **kwargs 

87 ) -> ConfigCache: 

88 super().__init__(table_name, model, *args, **kwargs) 

89 

90 def __contains__(self, key: Tuple[Union[int, float, str], ...]) -> bool: 

91 return self.get(key) is not None 

92 

93 def __getitem__(self, key: Tuple[Union[int, float, str], ...]) -> triton.Config: 

94 ret: Optional[triton.Config] = self.get(key) 

95 if ret is None: 

96 raise KeyError(f"Key {key} not found in ConfigCache.") 

97 return ret 

98 

99 def __setitem__( 

100 self, key: Tuple[Union[int, float, str], ...], config: triton.Config 

101 ) -> None: 

102 self.set(key, config) 

103 

104 def get(self, key: Tuple[Union[int, float, str], ...]) -> Optional[triton.Config]: 

105 return self.model.get_config(self.table_name, key) 

106 

107 def set( 

108 self, key: Tuple[Union[int, float, str], ...], config: triton.Config 

109 ) -> None: 

110 return self.model.put_config(self.table_name, key, config) 

111 

112 

113class BenchmarkCache(Cache): 

114 def __init__( 

115 self, 

116 table_name: str, 

117 model: PersistantModel, 

118 key: Tuple[Union[int, float, str], ...], 

119 *args, 

120 **kwargs, 

121 ) -> BenchmarkCache: 

122 """ 

123 `BenchmarkCache` is used to store the benchmark results for the pair of the specific key and configuration. 

124 """ 

125 super().__init__(table_name, model, *args, **kwargs) 

126 self.key: Final[Tuple[Union[int, float, str], ...]] = key 

127 

128 def __contains__(self, config: triton.Config) -> bool: 

129 return self.model.get_benchmark(self.key, config) is not None 

130 

131 def __getitem__(self, config: triton.Config) -> Tuple[float]: 

132 ret: Optional[Tuple[float, float, float]] = self.get(config) 

133 if ret is None: 

134 raise KeyError( 

135 f"Config {config} not found in BenchmarkCache for key {self.key}." 

136 ) 

137 return ret 

138 

139 def __setitem__(self, config: triton.Config, benchmark: Tuple[float]) -> None: 

140 return self.set(config, benchmark) 

141 

142 def get(self, config: triton.Config) -> Optional[Tuple[float, float, float]]: 

143 return self.model.get_benchmark(self.table_name, self.key, config) 

144 

145 def set(self, config: triton.Config, benchmark: Tuple[float, float, float]) -> None: 

146 return self.model.put_benchmark(self.table_name, self.key, config, benchmark) 

147 

148 

149class LibCache(object): 

150 _instance = None 

151 

152 def __new__(cls, *args, **kwargs): 

153 if cls._instance is None: 

154 cls._instance = super(LibCache, cls).__new__(cls) 

155 return cls._instance 

156 

157 def __init__(self, db_url: Optional[str] = None): 

158 self.global_cache: Dict = {} 

159 self.volumn: Dict = {} 

160 device_name = _state.vendor_module.vendor_info.device_name 

161 if db_url is None: 

162 try: 

163 device_name: str = torch_device_fn.get_device_name().replace(" ", "_") 

164 except AttributeError: 

165 device_name: str = device_name 

166 cache_file_name: str = ( 

167 f"TunedConfig_{device_name}_triton_{major_version}_{minor_version}.db" 

168 if device_name == "nvidia" 

169 else f"TunedConfig_{device_name}_triton_{major_version}_{minor_version}.db" 

170 ) 

171 cache_path: Path = config_cache_dir() / cache_file_name 

172 self.db_url: str = f"sqlite:///{cache_path}" 

173 else: 

174 self.db_url: str = db_url 

175 self.config_cache_pool: Dict[str, ConfigCache] = {} 

176 self.benchmark_cache_pool: Dict[ 

177 Tuple[str, Tuple[Union[int, float, str], ...]], BenchmarkCache 

178 ] = {} 

179 self.model: PersistantModel = SQLPersistantModel(self.db_url) 

180 

181 @overload 

182 def __getitem__(self, key: str) -> ConfigCache: 

183 ... 

184 

185 @overload 

186 def __getitem__(self, key: Tuple[Union[int, float, str]]) -> BenchmarkCache: 

187 ... 

188 

189 def __getitem__( 

190 self, key: Union[str, Tuple[Union[int, float, str], ...]] 

191 ) -> Union[BenchmarkCache, ConfigCache]: 

192 if isinstance(key, str): 

193 return self.get_config(key) 

194 elif isinstance(key, tuple): 

195 return self.get_benchmark(*key) 

196 else: 

197 assert False, f"the type of key '{key.__class__.__name__}' is unacceptable" 

198 

199 def get_benchmark( 

200 self, table: str, key: Tuple[Union[int, float, str], ...] 

201 ) -> BenchmarkCache: 

202 ret = self.benchmark_cache_pool.get((table, key)) 

203 if ret is None: 

204 ret = BenchmarkCache(table, self.model, key) 

205 self.benchmark_cache_pool[(table, key)] = ret 

206 return ret 

207 

208 def get_config(self, table: str) -> ConfigCache: 

209 ret = self.config_cache_pool.get(table) 

210 if ret is None: 

211 ret = ConfigCache(table, self.model) 

212 self.config_cache_pool[table] = ret 

213 return ret 

214 

215 

216libcache = LibCache(FLAGGEMS_DB_URL) 

217 

218 

219class LibTuner(triton.runtime.Autotuner): 

220 """`LibTuner` is the base class for `FlagGems` library autotuner. 

221 

222 It could be extended in two ways, overriding the `policy` or `run` method in a subclass. 

223 For `policy` extension, `LibTuner` provides a decorator `register_policy` to register a policy function quickly. 

224 Please refer to the implementation of `default_policy` for an example. 

225 """ 

226 

227 # The dispatch table for `LibTuner` subclasses. It's shared across all instances. 

228 _dispatch_table: Dict[str, Type[LibTuner]] = {} 

229 _strategy_table: Dict[str, Callable[[Any], Any]] = {} 

230 

231 def __init__( 

232 self, 

233 fn, 

234 arg_names, 

235 configs, 

236 key, 

237 reset_to_zero, 

238 restore_value, 

239 pre_hook=None, 

240 post_hook=None, 

241 prune_configs_by: Optional[Dict] = None, 

242 warmup=None, 

243 rep=None, 

244 use_cuda_graph=False, 

245 do_bench=None, 

246 strategy=None, 

247 ): 

248 # NOTE(zhengyang): See discussion in https://github.com/triton-lang/triton/pull/4496 

249 if major_version == 2 or (major_version == 3 and minor_version <= 1): 

250 if warmup is None: 

251 warmup = 25 

252 if rep is None: 

253 rep = 100 

254 if major_version == 2: 

255 super().__init__( 

256 fn, 

257 arg_names, 

258 configs, 

259 key, 

260 reset_to_zero, 

261 restore_value, 

262 prune_configs_by, 

263 warmup, 

264 rep, 

265 ) 

266 self.base_fn = fn 

267 while not inspect.isfunction(self.base_fn): 

268 self.base_fn = self.base_fn.fn 

269 elif major_version == 3 and minor_version <= 1: 

270 super().__init__( 

271 fn, 

272 arg_names, 

273 configs, 

274 key, 

275 reset_to_zero, 

276 restore_value, 

277 pre_hook, 

278 post_hook, 

279 prune_configs_by, 

280 warmup, 

281 rep, 

282 use_cuda_graph, 

283 ) 

284 else: 

285 # Triton 3.2+ removed warmup/rep/use_cuda_graph positional arguments. 

286 # Preserve FlagGems tuning behavior by translating them into do_bench. 

287 if do_bench is None: 

288 if use_cuda_graph: 

289 from triton.testing import do_bench_cudagraph 

290 

291 def do_bench(kernel_call, quantiles): 

292 return do_bench_cudagraph( 

293 kernel_call, 

294 rep=rep if rep is not None else 100, 

295 quantiles=quantiles, 

296 ) 

297 

298 elif warmup is not None or rep is not None: 

299 

300 def do_bench(kernel_call, quantiles): 

301 return triton.testing.do_bench( 

302 kernel_call, 

303 warmup=warmup if warmup is not None else 25, 

304 rep=rep if rep is not None else 100, 

305 quantiles=quantiles, 

306 ) 

307 

308 super().__init__( 

309 fn, 

310 arg_names, 

311 configs, 

312 key, 

313 reset_to_zero, 

314 restore_value, 

315 pre_hook=pre_hook, 

316 post_hook=post_hook, 

317 prune_configs_by=prune_configs_by, 

318 do_bench=do_bench, 

319 ) 

320 self.__name__ = self.base_fn.__name__ 

321 self.keys = key 

322 if isinstance(strategy, str): 

323 strategy = LibTuner.get_strategy(strategy) 

324 if not isinstance(strategy, (list, tuple)): 

325 strategy = [strategy] * len(self.keys) 

326 assert len(strategy) == len( 

327 self.keys 

328 ), f"the length of strategy {len(strategy)} must match the length of keys {len(self.keys)}" 

329 strategy: List[Callable[[Any], Any]] = [ 

330 LibTuner.get_strategy(s) if isinstance(s, str) else s for s in strategy 

331 ] 

332 self.strategy: List[Callable[[Any], Any]] = strategy 

333 self.config_table_name: str = f"{self.__name__}_{self.kernel_hash}" 

334 self.benchmark_table_name: str = f"{self.__name__}_{self.cache_key}_benchmark" 

335 self.cache: BenchmarkCache = libcache[self.config_table_name] 

336 

337 @cached_property 

338 def cache_key(self) -> str: 

339 jit_fn = self.fn 

340 while not isinstance(jit_fn, triton.runtime.JITFunction): 

341 jit_fn = jit_fn.fn 

342 return jit_fn.cache_key 

343 

344 @cached_property 

345 def kernel_hash(self) -> str: 

346 return hashlib.md5( 

347 f"{self.cache_key}{self.configs_hash}".encode("utf-8") 

348 ).hexdigest()[:32] 

349 

350 @cached_property 

351 def configs_hash(self) -> str: 

352 return hashlib.md5( 

353 ",".join(map(lambda config: str(config), self.configs)).encode("utf-8") 

354 ).hexdigest()[:32] 

355 

356 def get_key(self, args): 

357 if self.strategy is None: 

358 key = tuple(args[k] for k in self.keys if k in args) 

359 else: 

360 key = tuple( 

361 starmap( 

362 lambda idx0, idx1: self.strategy[idx0](args[idx1]), 

363 enumerate(self.keys), 

364 ) 

365 ) 

366 key += tuple(str(arg.dtype) for arg in args.values() if hasattr(arg, "dtype")) 

367 return key 

368 

369 @staticmethod 

370 @abstractmethod 

371 def policy( 

372 self, 

373 fn: Callable[[triton.Config], List[float]], 

374 configs: Iterator[triton.Config], 

375 args: Tuple[Any], 

376 kwargs: Dict[str, Any], 

377 ) -> Tuple[triton.Config, Dict[str, float]]: 

378 raise NotImplementedError( 

379 f"`policy` isn't implemented in {self.__class__.__name__}" 

380 ) 

381 

382 @classmethod 

383 def register(cls, name: str): 

384 """Register a subclass of `LibTuner` with a name. 

385 

386 Args: 

387 name: The name of the subclass. 

388 Returns: 

389 A decorator that registers the subclass with the name. 

390 """ 

391 

392 def decorator(subclass): 

393 cls._dispatch_table[name] = subclass 

394 return subclass 

395 

396 return decorator 

397 

398 @classmethod 

399 def get(cls, name: str): 

400 return cls._dispatch_table[name] 

401 

402 @classmethod 

403 def get_strategy(cls, name: str): 

404 return cls._strategy_table[name] 

405 

406 @staticmethod 

407 def register_policy( 

408 name: str, 

409 ) -> Type[LibTuner]: 

410 """A decorator to register a policy for `LibTuner`. 

411 

412 This decorator allows you to create a new `LibTuner` subclass without defining a new class explicitly. 

413 The new subclass will have the `policy` method set to the provided policy function and will be registered under 

414 the specified name in the `LibTuner` dispatch table. 

415 """ 

416 

417 def decorator( 

418 policy_impl: Callable[ 

419 [ 

420 Callable[[triton.Config], List[float]], 

421 Iterator[triton.Config], 

422 Tuple[Any], 

423 Dict[str, Any], 

424 ], 

425 Tuple[triton.Config, Dict[str, float]], 

426 ], 

427 ): 

428 @LibTuner.register(name) 

429 class AnonymousLibTunerImpl(LibTuner): 

430 def __init__(self, *args, **kwargs): 

431 super().__init__(*args, **kwargs) 

432 

433 def policy( 

434 self, 

435 fn: Callable[[triton.Config], List[float]], 

436 configs: Iterator[triton.Config], 

437 args: Tuple[Any], 

438 kwargs: Dict[str, Any], 

439 ) -> Tuple[triton.Config, Dict[str, float]]: 

440 return policy_impl(fn, configs, args, kwargs) 

441 

442 return AnonymousLibTunerImpl 

443 

444 return decorator 

445 

446 @staticmethod 

447 def register_strategy(name: str): 

448 def decorator( 

449 strategy: Union[Callable[[Any], Any], List[Callable[[Any], Any]]], 

450 ): 

451 LibTuner._strategy_table[name] = strategy 

452 return strategy 

453 

454 return decorator 

455 

456 def run(self, *args, **kwargs): 

457 if hasattr(self, "seen_tuned_metas"): 

458 self.seen_tuned_metas = {} # flagtree aabs: deduplicate tuned meta 

459 # `arg_names` corresponds to the arguments of the `JITFunction`'s signature, 

460 # so please make sure the orders of `arg_names` and `args` match. 

461 self.nargs = dict(zip(self.arg_names, args)) 

462 used_cached_result = True 

463 if len(self.configs) > 1: 

464 all_args = {**self.nargs, **kwargs} 

465 _args = {k: v for k, v in all_args.items() if k in self.arg_names} 

466 key = self.get_key(_args) 

467 if key not in self.cache: 

468 cache: BenchmarkCache = libcache[self.benchmark_table_name, key] 

469 # prune configs 

470 used_cached_result = False 

471 pruned_configs = self.prune_configs(kwargs) 

472 bench_start = time.time() 

473 

474 def bench(config: triton.Config) -> List[float]: 

475 ret = cache.get(config) 

476 if ret is None: 

477 ret = self._bench(*args, config=config, **kwargs) 

478 cache[config] = tuple(ret) 

479 return list(ret) 

480 

481 best_config, timings = self.policy( 

482 bench, 

483 pruned_configs, 

484 args, 

485 kwargs, 

486 ) 

487 bench_end = time.time() 

488 self.bench_time = bench_end - bench_start 

489 self.cache[key] = best_config 

490 full_nargs = { 

491 **self.nargs, 

492 **kwargs, 

493 **self.cache[key].all_kwargs(), 

494 } 

495 self.pre_hook(full_nargs, reset_only=True) 

496 self.configs_timings = timings 

497 config = self.cache[key] 

498 if config.pre_hook is None: 

499 cached_kwargs = config.all_kwargs() 

500 for original_config in self.configs: 

501 if original_config.all_kwargs() == cached_kwargs: 

502 # Use the original config which has the pre_hook 

503 config = original_config 

504 break 

505 else: 

506 config = self.configs[0] 

507 self.best_config = config 

508 if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" and not used_cached_result: 

509 print( 

510 f"Triton autotuning for function {self.base_fn.__name__} finished after " 

511 f"{self.bench_time:.2f}s; key info: {key}, best config selected: {self.best_config};" 

512 ) 

513 full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()} 

514 if ( 

515 hasattr(self, "shared_config_pre_hook") 

516 and self.shared_config_pre_hook is not None 

517 ): 

518 self.shared_config_pre_hook(full_nargs) 

519 elif config.pre_hook is not None: 

520 config.pre_hook(full_nargs) 

521 ret = self.fn.run( 

522 *args, 

523 **kwargs, 

524 **config.all_kwargs(), 

525 ) 

526 self.nargs = None 

527 return ret 

528 

529 

530@LibTuner.register_strategy(None) 

531@LibTuner.register_strategy("default") 

532def default_strategy(key: Any) -> Any: 

533 return key 

534 

535 

536@LibTuner.register_strategy("log") 

537def log2_strategy(key: Union[int, float]) -> float: 

538 return 2 ** math.ceil(math.log2(key)) 

539 

540 

541@LibTuner.register_strategy("align32") 

542def align32_strategy(key: Union[int, float]) -> int: 

543 if key == 0: 

544 return 0 

545 if key < 32: 

546 return 2 ** math.ceil(math.log2(key)) 

547 return math.ceil(key / 32) * 32 

548 

549 

550@LibTuner.register_policy("default") 

551def default_policy( 

552 bench_fn: Callable[[triton.Config], List[float]], 

553 configs: Iterator[triton.Config], 

554 args: Tuple[Any], 

555 kwargs: Dict[str, Any], 

556) -> Tuple[triton.Config, Dict[str, float]]: 

557 """Default policy for offline autotuning. 

558 

559 Args: 

560 bench_fn: The function to benchmark. 

561 configs: The collection of the configuration search space. 

562 args: Kernel launch arguments. 

563 kwargs: Kernel launch arguments. 

564 Returns: 

565 A tuple containing the best configuration and a dictionary of timings for each configuration. 

566 

567 This is one way to implement a default policy for offline autotuning. It's equal to the following 

568 ``` 

569 @LibTuner.register("default") 

570 class DefaultLibTunerImpl(LibTuner): 

571 def __init__( 

572 self, 

573 *args, 

574 **kwargs, 

575 ): 

576 super().__init__( 

577 *args, 

578 **kwargs, 

579 ) 

580 

581 @staticmethod 

582 def policy( 

583 bench_fn: Callable[[triton.Config], List[float]], 

584 configs: Iterator[triton.Config], 

585 args: Tuple[Any], 

586 kwargs: Dict[str, Any], 

587 ) -> Tuple[triton.Config, Dict[str, float]]: 

588 timings: Dict[triton.Config, int] = { 

589 config: bench_fn(config) for config in configs 

590 } 

591 best_config: triton.Config = min(timings, key=timings.get) 

592 return best_config, timings 

593 ``` 

594 In this way policies could be extended by registering a definition function quickly, 

595 or by creating a new subclass of `LibTuner` and overriding the `policy` method to have 

596 more control over the autotuning process. 

597 """ 

598 timings: Dict[triton.Config, float] = { 

599 config: bench_fn(config) for config in configs 

600 } 

601 best_config: triton.Config = min(timings, key=timings.get) 

602 return best_config, timings 

603 

604 

605def libtuner( 

606 configs, 

607 key, 

608 prune_configs_by=None, 

609 reset_to_zero=None, 

610 restore_value=None, 

611 pre_hook=None, 

612 post_hook=None, 

613 warmup=25, 

614 rep=100, 

615 use_cuda_graph=False, 

616 do_bench=None, 

617 strategy: Union[ 

618 str, Callable[[Any], Any], List[Union[str, Callable[[Any], Any]]] 

619 ] = "default", 

620 policy: Union[str, Type[LibTuner]] = "default", 

621): 

622 """Decorator for triton library autotuner. 

623 

624 `strategy` is a function that takes a key and returns a value. 

625 It accepts a string, which is the name of a registered strategy, or a callable function. 

626 In this form it will be applied to each key in the `key` list. 

627 If it's a tuple or list, it should have the same length as `key`, 

628 and each element should be a string or a callable function that takes a key and returns a value. 

629 `policy` accepts a string, which is the name of a registered `LibTuner` subclass, or a `LibTuner` subclass itself. 

630 """ 

631 

632 if isinstance(policy, str): 

633 policy = LibTuner.get(policy) 

634 assert issubclass( 

635 policy, LibTuner 

636 ), f"the class of {policy.__name__} is {policy.__class__.__name__}, not a subclass of {LibTuner.__name__}" 

637 

638 def decorator(fn): 

639 return policy( 

640 fn, 

641 fn.arg_names, 

642 configs, 

643 key, 

644 reset_to_zero, 

645 restore_value, 

646 pre_hook=pre_hook, 

647 post_hook=post_hook, 

648 prune_configs_by=prune_configs_by, 

649 warmup=warmup, 

650 rep=rep, 

651 use_cuda_graph=use_cuda_graph, 

652 do_bench=do_bench, 

653 strategy=strategy, 

654 ) 

655 

656 return decorator 

657 

658 

659class LibEntry(triton.KernelInterface): 

660 def __init__( 

661 self, 

662 fn, 

663 ): 

664 self.fn = fn 

665 self.arg_names = fn.arg_names 

666 self.divisibility = 16 

667 self.kernel_cache = tuple(dict() for _ in range(DEVICE_COUNT)) 

668 

669 while not isinstance(fn, triton.runtime.JITFunction): 

670 fn = fn.fn 

671 self.jit_function: triton.runtime.JITFunction = fn 

672 self.specialize_indices = [ 

673 p.num 

674 for p in self.jit_function.params 

675 if not p.is_constexpr and not p.do_not_specialize 

676 ] 

677 self.do_not_specialize_indices = [ 

678 p.num 

679 for p in self.jit_function.params 

680 if not p.is_constexpr and p.do_not_specialize 

681 ] 

682 self.lock = multiprocessing.Lock() 

683 self.signature = fn.signature 

684 

685 def key(self, spec_args, dns_args, const_args): 

686 def spec_arg(arg): 

687 if hasattr(arg, "data_ptr"): 

688 return (arg.dtype, arg.data_ptr() % self.divisibility == 0) 

689 return (type(arg), arg) 

690 

691 def dns_arg(arg): 

692 if hasattr(arg, "data_ptr"): 

693 return arg.dtype 

694 if not isinstance(arg, int): 

695 return type(arg) 

696 if -(2**31) <= arg and arg <= 2**31 - 1: 

697 return "i32" 

698 if 2**63 <= arg and arg <= 2**64 - 1: 

699 return "u64" 

700 return "i64" 

701 

702 spec_key = [spec_arg(arg) for arg in spec_args] 

703 dns_key = [dns_arg(arg) for arg in dns_args] 

704 # const args passed by position 

705 return tuple(spec_key + dns_key + const_args) 

706 

707 def run(self, *args, **kwargs): 

708 grid = kwargs["grid"] 

709 

710 # collect all the arguments 

711 spec_args = [] # specialize arguments 

712 dns_args = [] # do not specialize arguments 

713 const_args = [] # constexpr arguments 

714 k_args = OrderedDict() 

715 param_names = list(self.signature.parameters.keys()) 

716 for i, arg in enumerate(args): 

717 hashable_arg = arg 

718 if ( 

719 hasattr(arg, "__class__") 

720 and arg.__class__.__name__ == "TensorDescriptor" 

721 ): 

722 # Create a hashable representation of TensorDescriptor 

723 hashable_arg = ( 

724 "TensorDescriptor", 

725 tuple(arg.shape) if hasattr(arg, "shape") else None, 

726 tuple(arg.strides) if hasattr(arg, "strides") else None, 

727 tuple(arg.block_shape) if hasattr(arg, "block_shape") else None, 

728 arg.padding if hasattr(arg, "padding") else None, 

729 # Add other relevant attributes 

730 ) 

731 if i in self.specialize_indices: 

732 k_args[param_names[i]] = arg 

733 spec_args.append(hashable_arg) 

734 elif i in self.do_not_specialize_indices: 

735 k_args[param_names[i]] = arg 

736 dns_args.append(hashable_arg) 

737 else: 

738 if major_version == 3 and 3 <= minor_version <= 6: 

739 k_args[param_names[i]] = arg 

740 const_args.append(hashable_arg) 

741 for p in self.jit_function.params[len(args) :]: 

742 if p.name in kwargs: 

743 val = kwargs[p.name] 

744 elif p.default is inspect._empty: 

745 continue 

746 else: 

747 val = p.default 

748 

749 if p.is_constexpr: 

750 const_args.append(val) 

751 if major_version == 3 and 3 <= minor_version <= 6: 

752 k_args[p.name] = val 

753 elif p.do_not_specialize: 

754 dns_args.append(val) 

755 k_args[p.name] = val 

756 else: 

757 spec_args.append(val) 

758 k_args[p.name] = val 

759 

760 entry_key = self.key(spec_args, dns_args, const_args) 

761 device = torch_device_fn.current_device() 

762 cache = self.kernel_cache[device] 

763 while entry_key not in cache: 

764 # NOTE: we serialize the first run of a jit function regardless of which device to run on 

765 # because Triton runtime is currently not threadsafe. 

766 with self.lock: 

767 if entry_key in cache: 

768 break 

769 kernel = self.fn.run(*args, **kwargs) 

770 fn = self.fn 

771 # collect constexpr arguments for grid computation 

772 constexprs = {} 

773 tune_constexprs = {} 

774 heur_constexprs = {} 

775 launch_pre_hooks = [] 

776 while not isinstance(fn, triton.runtime.JITFunction): 

777 if isinstance(fn, triton.runtime.Autotuner): 

778 config = fn.best_config 

779 constexprs["num_warps"] = config.num_warps 

780 constexprs["num_stages"] = config.num_stages 

781 constexprs["num_ctas"] = config.num_ctas 

782 constexprs = {**constexprs, **config.kwargs} 

783 tune_constexprs = {**tune_constexprs, **config.kwargs} 

784 if config.pre_hook is not None: 

785 launch_pre_hooks.append( 

786 (config.pre_hook, config.all_kwargs()) 

787 ) 

788 elif isinstance(fn, triton.runtime.Heuristics): 

789 for v, heur in fn.values.items(): 

790 heur_constexprs[v] = heur( 

791 { 

792 **dict(zip(fn.arg_names, args)), 

793 **kwargs, 

794 **constexprs, 

795 } 

796 ) 

797 constexprs[v] = heur_constexprs[v] 

798 else: 

799 raise RuntimeError("Invalid Runtime Function") 

800 fn = fn.fn 

801 for p in self.jit_function.params: 

802 if ( 

803 p.is_constexpr 

804 and p.name not in constexprs 

805 and (p.default is not inspect._empty) 

806 ): 

807 constexprs[p.name] = p.default 

808 cache[entry_key] = ( 

809 kernel, 

810 constexprs, 

811 tune_constexprs, 

812 heur_constexprs, 

813 tuple(launch_pre_hooks), 

814 ) 

815 return kernel, constexprs 

816 

817 ( 

818 kernel, 

819 constexprs, 

820 tune_constexprs, 

821 heur_constexprs, 

822 launch_pre_hooks, 

823 ) = cache[entry_key] 

824 

825 if callable(grid): 

826 # collect all arguments to the grid fn,ie: 

827 # 1. args, 

828 # 2. kwargs, 

829 # 3. all all other captured arguments in CompiledKernel from Autotunner & Heuristics 

830 # when kwargs & captured args conflict, captured args have higher priority 

831 meta = {**dict(zip(self.arg_names, args)), **kwargs, **constexprs} 

832 grid = grid(meta) 

833 grid = grid + (1, 1) 

834 

835 if launch_pre_hooks: 

836 hook_nargs = {**dict(zip(self.arg_names, args)), **kwargs} 

837 for pre_hook, hook_kwargs in launch_pre_hooks: 

838 pre_hook({**hook_nargs, **hook_kwargs}) 

839 

840 if major_version == 3 and 3 <= minor_version <= 6: 

841 all_args = [] 

842 missing_keys = [] 

843 for key in list(self.signature.parameters.keys()): 

844 if key in k_args: 

845 all_args.append(k_args[key]) 

846 elif key in tune_constexprs: 

847 all_args.append(tune_constexprs[key]) 

848 elif key in heur_constexprs: 

849 all_args.append(heur_constexprs[key]) 

850 elif key in constexprs: 

851 all_args.append(constexprs[key]) 

852 else: 

853 missing_keys.append(key) 

854 if len(missing_keys): 

855 raise RuntimeError( 

856 f"[libentry]: probably a bug, the following kernel params where not captured: {missing_keys}" 

857 ) 

858 kernel[grid[0:3]](*all_args) 

859 else: 

860 kernel[grid[0:3]](*k_args.values()) 

861 return kernel, constexprs 

862 

863 

864def libentry(): 

865 """Decorator for triton library entries.""" 

866 

867 def decorator(fn): 

868 return LibEntry(fn) 

869 

870 return decorator