Coverage for src/flag_gems/utils/libentry.py: 84%

466 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1from __future__ import annotations 

2 

3import hashlib 

4import inspect 

5import logging 

6import math 

7import multiprocessing 

8import os 

9import time 

10from abc import abstractmethod 

11from collections import OrderedDict 

12from functools import cached_property 

13from itertools import starmap 

14from pathlib import Path 

15from typing import ( 

16 Any, 

17 Callable, 

18 Dict, 

19 Final, 

20 Iterator, 

21 List, 

22 Optional, 

23 Tuple, 

24 Type, 

25 Union, 

26 overload, 

27) 

28 

29import triton 

30 

31from flag_gems import runtime 

32from flag_gems.runtime import torch_device_fn 

33from flag_gems.runtime.backend import _state 

34from flag_gems.utils.code_cache import config_cache_dir 

35from flag_gems.utils.models import PersistantModel, SQLPersistantModel 

36 

37logger = logging.getLogger(__name__) 

38 

39DEVICE_COUNT = runtime.device.device_count 

40 

41version = triton.__version__.split(".") 

42major_version, minor_version = eval(version[0]), eval(version[1]) 

43 

44 

45if major_version == 2: 

46 

47 def all_kwargs(self): 

48 return { 

49 **self.kwargs, 

50 **{ 

51 k: getattr(self, k) 

52 for k in ( 

53 "num_warps", 

54 "num_ctas", 

55 "num_stages", 

56 "num_buffers_warp_spec", 

57 "num_consumer_groups", 

58 "reg_dec_producer", 

59 "reg_inc_consumer", 

60 "maxnreg", 

61 ) 

62 if hasattr(self, k) 

63 }, 

64 } 

65 

66 setattr(triton.Config, "all_kwargs", all_kwargs) 

67 

68FLAGGEMS_DB_URL = os.getenv("FLAGGEMS_DB_URL", None) 

69 

70 

71class Cache(object): 

72 def __init__( 

73 self, table_name: str, model: PersistantModel, *args, **kwargs 

74 ) -> Cache: 

75 super().__init__(*args, **kwargs) 

76 self.table_name: Final[str] = table_name 

77 self.model: Final[PersistantModel] = model 

78 

79 

80class ConfigCache(Cache): 

81 """ 

82 `ConfigCache` is used to store the relationship between keys and their known best configurations. 

83 """ 

84 

85 def __init__( 

86 self, table_name: str, model: PersistantModel, *args, **kwargs 

87 ) -> ConfigCache: 

88 super().__init__(table_name, model, *args, **kwargs) 

89 

90 def __contains__(self, key: Tuple[Union[int, float, str], ...]) -> bool: 

91 return self.get(key) is not None 

92 

93 def __getitem__(self, key: Tuple[Union[int, float, str], ...]) -> triton.Config: 

94 ret: Optional[triton.Config] = self.get(key) 

95 if ret is None: 

96 raise KeyError(f"Key {key} not found in ConfigCache.") 

97 return ret 

98 

99 def __setitem__( 

100 self, key: Tuple[Union[int, float, str], ...], config: triton.Config 

101 ) -> None: 

102 self.set(key, config) 

103 

104 def get(self, key: Tuple[Union[int, float, str], ...]) -> Optional[triton.Config]: 

105 return self.model.get_config(self.table_name, key) 

106 

107 def set( 

108 self, key: Tuple[Union[int, float, str], ...], config: triton.Config 

109 ) -> None: 

110 return self.model.put_config(self.table_name, key, config) 

111 

112 

113class BenchmarkCache(Cache): 

114 def __init__( 

115 self, 

116 table_name: str, 

117 model: PersistantModel, 

118 key: Tuple[Union[int, float, str], ...], 

119 *args, 

120 **kwargs, 

121 ) -> BenchmarkCache: 

122 """ 

123 `BenchmarkCache` is used to store the benchmark results for the pair of the specific key and configuration. 

124 """ 

125 super().__init__(table_name, model, *args, **kwargs) 

126 self.key: Final[Tuple[Union[int, float, str], ...]] = key 

127 

128 def __contains__(self, config: triton.Config) -> bool: 

129 return self.model.get_benchmark(self.key, config) is not None 

130 

131 def __getitem__(self, config: triton.Config) -> Tuple[float]: 

132 ret: Optional[Tuple[float, float, float]] = self.get(config) 

133 if ret is None: 

134 raise KeyError( 

135 f"Config {config} not found in BenchmarkCache for key {self.key}." 

136 ) 

137 return ret 

138 

139 def __setitem__(self, config: triton.Config, benchmark: Tuple[float]) -> None: 

140 return self.set(config, benchmark) 

141 

142 def get(self, config: triton.Config) -> Optional[Tuple[float, float, float]]: 

143 return self.model.get_benchmark(self.table_name, self.key, config) 

144 

145 def set(self, config: triton.Config, benchmark: Tuple[float, float, float]) -> None: 

146 return self.model.put_benchmark(self.table_name, self.key, config, benchmark) 

147 

148 

149class LibCache(object): 

150 _instance = None 

151 

152 def __new__(cls, *args, **kwargs): 

153 if cls._instance is None: 

154 cls._instance = super(LibCache, cls).__new__(cls) 

155 return cls._instance 

156 

157 def __init__(self, db_url: Optional[str] = None): 

158 self.global_cache: Dict = {} 

159 self.volumn: Dict = {} 

160 vendor_name = _state.vendor_module.vendor_info.vendor_name 

161 if db_url is None: 

162 cache_file_name: str = ( 

163 f"TunedConfig_{vendor_name}_triton_{major_version}_{minor_version}.db" 

164 ) 

165 cache_path: Path = config_cache_dir() / cache_file_name 

166 self.db_url: str = f"sqlite:///{cache_path}" 

167 else: 

168 self.db_url: str = db_url 

169 self.config_cache_pool: Dict[str, ConfigCache] = {} 

170 self.benchmark_cache_pool: Dict[ 

171 Tuple[str, Tuple[Union[int, float, str], ...]], BenchmarkCache 

172 ] = {} 

173 self.model: PersistantModel = SQLPersistantModel(self.db_url) 

174 

175 @overload 

176 def __getitem__(self, key: str) -> ConfigCache: 

177 ... 

178 

179 @overload 

180 def __getitem__(self, key: Tuple[Union[int, float, str]]) -> BenchmarkCache: 

181 ... 

182 

183 def __getitem__( 

184 self, key: Union[str, Tuple[Union[int, float, str], ...]] 

185 ) -> Union[BenchmarkCache, ConfigCache]: 

186 if isinstance(key, str): 

187 return self.get_config(key) 

188 elif isinstance(key, tuple): 

189 return self.get_benchmark(*key) 

190 else: 

191 assert False, f"the type of key '{key.__class__.__name__}' is unacceptable" 

192 

193 def get_benchmark( 

194 self, table: str, key: Tuple[Union[int, float, str], ...] 

195 ) -> BenchmarkCache: 

196 ret = self.benchmark_cache_pool.get((table, key)) 

197 if ret is None: 

198 ret = BenchmarkCache(table, self.model, key) 

199 self.benchmark_cache_pool[(table, key)] = ret 

200 return ret 

201 

202 def get_config(self, table: str) -> ConfigCache: 

203 ret = self.config_cache_pool.get(table) 

204 if ret is None: 

205 ret = ConfigCache(table, self.model) 

206 self.config_cache_pool[table] = ret 

207 return ret 

208 

209 

210libcache = LibCache(FLAGGEMS_DB_URL) 

211 

212 

213class LibTuner(triton.runtime.Autotuner): 

214 """`LibTuner` is the base class for `FlagGems` library autotuner. 

215 

216 It could be extended in two ways, overriding the `policy` or `run` method in a subclass. 

217 For `policy` extension, `LibTuner` provides a decorator `register_policy` to register a policy function quickly. 

218 Please refer to the implementation of `default_policy` for an example. 

219 """ 

220 

221 # The dispatch table for `LibTuner` subclasses. It's shared across all instances. 

222 _dispatch_table: Dict[str, Type[LibTuner]] = {} 

223 _strategy_table: Dict[str, Callable[[Any], Any]] = {} 

224 

225 def __init__( 

226 self, 

227 fn, 

228 arg_names, 

229 configs, 

230 key, 

231 reset_to_zero, 

232 restore_value, 

233 pre_hook=None, 

234 post_hook=None, 

235 prune_configs_by: Optional[Dict] = None, 

236 warmup=None, 

237 rep=None, 

238 use_cuda_graph=False, 

239 do_bench=None, 

240 strategy=None, 

241 flagtune_op_name=None, 

242 flagtune_expand_op_name=None, 

243 flagtune_yaml_path=None, 

244 flagtune_pre_hook=None, 

245 ): 

246 # NOTE(zhengyang): See discussion in https://github.com/triton-lang/triton/pull/4496 

247 if major_version == 2 or (major_version == 3 and minor_version <= 1): 

248 if warmup is None: 

249 warmup = 25 

250 if rep is None: 

251 rep = 100 

252 if major_version == 2: 

253 super().__init__( 

254 fn, 

255 arg_names, 

256 configs, 

257 key, 

258 reset_to_zero, 

259 restore_value, 

260 prune_configs_by, 

261 warmup, 

262 rep, 

263 ) 

264 self.base_fn = fn 

265 while not inspect.isfunction(self.base_fn): 

266 self.base_fn = self.base_fn.fn 

267 elif major_version == 3 and minor_version <= 1: 

268 super().__init__( 

269 fn, 

270 arg_names, 

271 configs, 

272 key, 

273 reset_to_zero, 

274 restore_value, 

275 pre_hook, 

276 post_hook, 

277 prune_configs_by, 

278 warmup, 

279 rep, 

280 use_cuda_graph, 

281 ) 

282 else: 

283 # Triton 3.2+ removed warmup/rep/use_cuda_graph positional arguments. 

284 # Preserve FlagGems tuning behavior by translating them into do_bench. 

285 if do_bench is None: 

286 if use_cuda_graph: 

287 from triton.testing import do_bench_cudagraph 

288 

289 def do_bench(kernel_call, quantiles): 

290 return do_bench_cudagraph( 

291 kernel_call, 

292 rep=rep if rep is not None else 100, 

293 quantiles=quantiles, 

294 ) 

295 

296 elif warmup is not None or rep is not None: 

297 

298 def do_bench(kernel_call, quantiles): 

299 return triton.testing.do_bench( 

300 kernel_call, 

301 warmup=warmup if warmup is not None else 25, 

302 rep=rep if rep is not None else 100, 

303 quantiles=quantiles, 

304 ) 

305 

306 super().__init__( 

307 fn, 

308 arg_names, 

309 configs, 

310 key, 

311 reset_to_zero, 

312 restore_value, 

313 pre_hook=pre_hook, 

314 post_hook=post_hook, 

315 prune_configs_by=prune_configs_by, 

316 do_bench=do_bench, 

317 ) 

318 self.__name__ = self.base_fn.__name__ 

319 self.keys = key 

320 self.strategy: List[Callable[[Any], Any]] = self._normalize_strategy(strategy) 

321 self.config_table_name: str = f"{self.__name__}_{self.kernel_hash}" 

322 self.benchmark_table_name: str = f"{self.__name__}_{self.cache_key}_benchmark" 

323 self.cache: BenchmarkCache = libcache[self.config_table_name] 

324 self._flagtune_default_configs = self.configs 

325 self._flagtune_default_strategy = strategy 

326 self._flagtune_active = False 

327 self._flagtune_warned = False 

328 self._flagtune_op_name = flagtune_op_name 

329 self._flagtune_expand_op_name = flagtune_expand_op_name or flagtune_op_name 

330 self._flagtune_yaml_path = flagtune_yaml_path 

331 self._flagtune_pre_hook = flagtune_pre_hook 

332 

333 def _normalize_strategy(self, strategy): 

334 if isinstance(strategy, str): 

335 strategy = LibTuner.get_strategy(strategy) 

336 if not isinstance(strategy, (list, tuple)): 

337 strategy = [strategy] * len(self.keys) 

338 assert len(strategy) == len( 

339 self.keys 

340 ), f"the length of strategy {len(strategy)} must match the length of keys {len(self.keys)}" 

341 return [LibTuner.get_strategy(s) if isinstance(s, str) else s for s in strategy] 

342 

343 def _set_configs_and_strategy(self, configs, strategy): 

344 self.configs = configs 

345 self.strategy = self._normalize_strategy(strategy) 

346 self.__dict__.pop("configs_hash", None) 

347 self.__dict__.pop("kernel_hash", None) 

348 self.config_table_name = f"{self.__name__}_{self.kernel_hash}" 

349 self.benchmark_table_name = f"{self.__name__}_{self.cache_key}_benchmark" 

350 self.cache = libcache[self.config_table_name] 

351 

352 def apply_flagtune(self): 

353 if self._flagtune_op_name is None: 

354 return False 

355 

356 enabled = runtime.flagtune_enabled(self._flagtune_op_name) 

357 if enabled == self._flagtune_active: 

358 return False 

359 

360 if not enabled: 

361 self._set_configs_and_strategy( 

362 self._flagtune_default_configs, 

363 self._flagtune_default_strategy, 

364 ) 

365 self._flagtune_active = False 

366 return True 

367 

368 expand_config = runtime.get_expand_config( 

369 self._flagtune_expand_op_name, 

370 yaml_path=self._flagtune_yaml_path, 

371 ) 

372 configs = runtime.ops_get_configs( 

373 self._flagtune_expand_op_name, 

374 yaml_path=self._flagtune_yaml_path, 

375 pre_hook=self._flagtune_pre_hook, 

376 ) 

377 if expand_config == -1 or not configs: 

378 if not self._flagtune_warned: 

379 logger.warning( 

380 "FlagTune expand config is unavailable for %s; using default configs.", 

381 self._flagtune_expand_op_name, 

382 ) 

383 self._flagtune_warned = True 

384 return False 

385 

386 self._set_configs_and_strategy(configs, expand_config["strategy"]) 

387 self._flagtune_active = True 

388 return True 

389 

390 @cached_property 

391 def cache_key(self) -> str: 

392 jit_fn = self.fn 

393 while not isinstance(jit_fn, triton.runtime.JITFunction): 

394 jit_fn = jit_fn.fn 

395 return jit_fn.cache_key 

396 

397 @cached_property 

398 def kernel_hash(self) -> str: 

399 return hashlib.md5( 

400 f"{self.cache_key}{self.configs_hash}".encode("utf-8") 

401 ).hexdigest()[:32] 

402 

403 @cached_property 

404 def configs_hash(self) -> str: 

405 return hashlib.md5( 

406 ",".join(map(lambda config: str(config), self.configs)).encode("utf-8") 

407 ).hexdigest()[:32] 

408 

409 def get_key(self, args): 

410 if self.strategy is None: 

411 key = tuple(args[k] for k in self.keys if k in args) 

412 else: 

413 key = tuple( 

414 starmap( 

415 lambda idx0, idx1: self.strategy[idx0](args[idx1]), 

416 enumerate(self.keys), 

417 ) 

418 ) 

419 key += tuple(str(arg.dtype) for arg in args.values() if hasattr(arg, "dtype")) 

420 return key 

421 

422 @staticmethod 

423 @abstractmethod 

424 def policy( 

425 self, 

426 fn: Callable[[triton.Config], List[float]], 

427 configs: Iterator[triton.Config], 

428 args: Tuple[Any], 

429 kwargs: Dict[str, Any], 

430 ) -> Tuple[triton.Config, Dict[str, float]]: 

431 raise NotImplementedError( 

432 f"`policy` isn't implemented in {self.__class__.__name__}" 

433 ) 

434 

435 @classmethod 

436 def register(cls, name: str): 

437 """Register a subclass of `LibTuner` with a name. 

438 

439 Args: 

440 name: The name of the subclass. 

441 Returns: 

442 A decorator that registers the subclass with the name. 

443 """ 

444 

445 def decorator(subclass): 

446 cls._dispatch_table[name] = subclass 

447 return subclass 

448 

449 return decorator 

450 

451 @classmethod 

452 def get(cls, name: str): 

453 return cls._dispatch_table[name] 

454 

455 @classmethod 

456 def get_strategy(cls, name: str): 

457 return cls._strategy_table[name] 

458 

459 @staticmethod 

460 def register_policy( 

461 name: str, 

462 ) -> Type[LibTuner]: 

463 """A decorator to register a policy for `LibTuner`. 

464 

465 This decorator allows you to create a new `LibTuner` subclass without defining a new class explicitly. 

466 The new subclass will have the `policy` method set to the provided policy function and will be registered under 

467 the specified name in the `LibTuner` dispatch table. 

468 """ 

469 

470 def decorator( 

471 policy_impl: Callable[ 

472 [ 

473 Callable[[triton.Config], List[float]], 

474 Iterator[triton.Config], 

475 Tuple[Any], 

476 Dict[str, Any], 

477 ], 

478 Tuple[triton.Config, Dict[str, float]], 

479 ], 

480 ): 

481 @LibTuner.register(name) 

482 class AnonymousLibTunerImpl(LibTuner): 

483 def __init__(self, *args, **kwargs): 

484 super().__init__(*args, **kwargs) 

485 

486 def policy( 

487 self, 

488 fn: Callable[[triton.Config], List[float]], 

489 configs: Iterator[triton.Config], 

490 args: Tuple[Any], 

491 kwargs: Dict[str, Any], 

492 ) -> Tuple[triton.Config, Dict[str, float]]: 

493 return policy_impl(fn, configs, args, kwargs) 

494 

495 return AnonymousLibTunerImpl 

496 

497 return decorator 

498 

499 @staticmethod 

500 def register_strategy(name: str): 

501 def decorator( 

502 strategy: Union[Callable[[Any], Any], List[Callable[[Any], Any]]], 

503 ): 

504 LibTuner._strategy_table[name] = strategy 

505 return strategy 

506 

507 return decorator 

508 

509 def run(self, *args, **kwargs): 

510 if hasattr(self, "seen_tuned_metas"): 

511 self.seen_tuned_metas = {} # flagtree aabs: deduplicate tuned meta 

512 # `arg_names` corresponds to the arguments of the `JITFunction`'s signature, 

513 # so please make sure the orders of `arg_names` and `args` match. 

514 self.nargs = dict(zip(self.arg_names, args)) 

515 used_cached_result = True 

516 if len(self.configs) > 1: 

517 all_args = {**self.nargs, **kwargs} 

518 _args = {k: v for k, v in all_args.items() if k in self.arg_names} 

519 key = self.get_key(_args) 

520 if key not in self.cache: 

521 cache: BenchmarkCache = libcache[self.benchmark_table_name, key] 

522 # prune configs 

523 used_cached_result = False 

524 pruned_configs = self.prune_configs(kwargs) 

525 bench_start = time.time() 

526 

527 def bench(config: triton.Config) -> List[float]: 

528 ret = cache.get(config) 

529 if ret is None: 

530 ret = self._bench(*args, config=config, **kwargs) 

531 cache[config] = tuple(ret) 

532 return list(ret) 

533 

534 best_config, timings = self.policy( 

535 bench, 

536 pruned_configs, 

537 args, 

538 kwargs, 

539 ) 

540 bench_end = time.time() 

541 self.bench_time = bench_end - bench_start 

542 self.cache[key] = best_config 

543 full_nargs = { 

544 **self.nargs, 

545 **kwargs, 

546 **self.cache[key].all_kwargs(), 

547 } 

548 self.pre_hook(full_nargs, reset_only=True) 

549 self.configs_timings = timings 

550 config = self.cache[key] 

551 if config.pre_hook is None: 

552 cached_kwargs = config.all_kwargs() 

553 for original_config in self.configs: 

554 if original_config.all_kwargs() == cached_kwargs: 

555 # Use the original config which has the pre_hook 

556 config = original_config 

557 break 

558 else: 

559 config = self.configs[0] 

560 self.best_config = config 

561 if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" and not used_cached_result: 

562 print( 

563 f"Triton autotuning for function {self.base_fn.__name__} finished after " 

564 f"{self.bench_time:.2f}s; key info: {key}, best config selected: {self.best_config};" 

565 ) 

566 full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()} 

567 if ( 

568 hasattr(self, "shared_config_pre_hook") 

569 and self.shared_config_pre_hook is not None 

570 ): 

571 self.shared_config_pre_hook(full_nargs) 

572 elif config.pre_hook is not None: 

573 config.pre_hook(full_nargs) 

574 ret = self.fn.run( 

575 *args, 

576 **kwargs, 

577 **config.all_kwargs(), 

578 ) 

579 self.nargs = None 

580 return ret 

581 

582 

583@LibTuner.register_strategy(None) 

584@LibTuner.register_strategy("default") 

585def default_strategy(key: Any) -> Any: 

586 return key 

587 

588 

589@LibTuner.register_strategy("log") 

590def log2_strategy(key: Union[int, float]) -> float: 

591 return 2 ** math.ceil(math.log2(key)) 

592 

593 

594@LibTuner.register_strategy("align32") 

595def align32_strategy(key: Union[int, float]) -> int: 

596 if key == 0: 

597 return 0 

598 if key < 32: 

599 return 2 ** math.ceil(math.log2(key)) 

600 return math.ceil(key / 32) * 32 

601 

602 

603@LibTuner.register_policy("default") 

604def default_policy( 

605 bench_fn: Callable[[triton.Config], List[float]], 

606 configs: Iterator[triton.Config], 

607 args: Tuple[Any], 

608 kwargs: Dict[str, Any], 

609) -> Tuple[triton.Config, Dict[str, float]]: 

610 """Default policy for offline autotuning. 

611 

612 Args: 

613 bench_fn: The function to benchmark. 

614 configs: The collection of the configuration search space. 

615 args: Kernel launch arguments. 

616 kwargs: Kernel launch arguments. 

617 Returns: 

618 A tuple containing the best configuration and a dictionary of timings for each configuration. 

619 

620 This is one way to implement a default policy for offline autotuning. It's equal to the following 

621 ``` 

622 @LibTuner.register("default") 

623 class DefaultLibTunerImpl(LibTuner): 

624 def __init__( 

625 self, 

626 *args, 

627 **kwargs, 

628 ): 

629 super().__init__( 

630 *args, 

631 **kwargs, 

632 ) 

633 

634 @staticmethod 

635 def policy( 

636 bench_fn: Callable[[triton.Config], List[float]], 

637 configs: Iterator[triton.Config], 

638 args: Tuple[Any], 

639 kwargs: Dict[str, Any], 

640 ) -> Tuple[triton.Config, Dict[str, float]]: 

641 timings: Dict[triton.Config, int] = { 

642 config: bench_fn(config) for config in configs 

643 } 

644 best_config: triton.Config = min(timings, key=timings.get) 

645 return best_config, timings 

646 ``` 

647 In this way policies could be extended by registering a definition function quickly, 

648 or by creating a new subclass of `LibTuner` and overriding the `policy` method to have 

649 more control over the autotuning process. 

650 """ 

651 timings: Dict[triton.Config, float] = { 

652 config: bench_fn(config) for config in configs 

653 } 

654 best_config: triton.Config = min(timings, key=timings.get) 

655 return best_config, timings 

656 

657 

658def libtuner( 

659 configs, 

660 key, 

661 prune_configs_by=None, 

662 reset_to_zero=None, 

663 restore_value=None, 

664 pre_hook=None, 

665 post_hook=None, 

666 warmup=25, 

667 rep=100, 

668 use_cuda_graph=False, 

669 do_bench=None, 

670 strategy: Union[ 

671 str, Callable[[Any], Any], List[Union[str, Callable[[Any], Any]]] 

672 ] = "default", 

673 policy: Union[str, Type[LibTuner]] = "default", 

674 flagtune_op_name=None, 

675 flagtune_expand_op_name=None, 

676 flagtune_yaml_path=None, 

677 flagtune_pre_hook=None, 

678): 

679 """Decorator for triton library autotuner. 

680 

681 `strategy` is a function that takes a key and returns a value. 

682 It accepts a string, which is the name of a registered strategy, or a callable function. 

683 In this form it will be applied to each key in the `key` list. 

684 If it's a tuple or list, it should have the same length as `key`, 

685 and each element should be a string or a callable function that takes a key and returns a value. 

686 `policy` accepts a string, which is the name of a registered `LibTuner` subclass, or a `LibTuner` subclass itself. 

687 """ 

688 

689 if isinstance(policy, str): 

690 policy = LibTuner.get(policy) 

691 assert issubclass( 

692 policy, LibTuner 

693 ), f"the class of {policy.__name__} is {policy.__class__.__name__}, not a subclass of {LibTuner.__name__}" 

694 

695 def decorator(fn): 

696 return policy( 

697 fn, 

698 fn.arg_names, 

699 configs, 

700 key, 

701 reset_to_zero, 

702 restore_value, 

703 pre_hook=pre_hook, 

704 post_hook=post_hook, 

705 prune_configs_by=prune_configs_by, 

706 warmup=warmup, 

707 rep=rep, 

708 use_cuda_graph=use_cuda_graph, 

709 do_bench=do_bench, 

710 strategy=strategy, 

711 flagtune_op_name=flagtune_op_name, 

712 flagtune_expand_op_name=flagtune_expand_op_name, 

713 flagtune_yaml_path=flagtune_yaml_path, 

714 flagtune_pre_hook=flagtune_pre_hook, 

715 ) 

716 

717 return decorator 

718 

719 

720class LibEntry(triton.KernelInterface): 

721 def __init__( 

722 self, 

723 fn, 

724 ): 

725 self.fn = fn 

726 self.arg_names = fn.arg_names 

727 self.divisibility = 16 

728 self.kernel_cache = tuple(dict() for _ in range(DEVICE_COUNT)) 

729 self._has_flagtune_tuner = self._contains_flagtune_tuner(fn) 

730 self._cpu_cache = dict() 

731 

732 while not isinstance(fn, triton.runtime.JITFunction): 

733 fn = fn.fn 

734 self.jit_function: triton.runtime.JITFunction = fn 

735 self.specialize_indices = [ 

736 p.num 

737 for p in self.jit_function.params 

738 if not p.is_constexpr and not p.do_not_specialize 

739 ] 

740 self.do_not_specialize_indices = [ 

741 p.num 

742 for p in self.jit_function.params 

743 if not p.is_constexpr and p.do_not_specialize 

744 ] 

745 self.lock = multiprocessing.Lock() 

746 self.signature = fn.signature 

747 

748 @staticmethod 

749 def _contains_flagtune_tuner(fn): 

750 while not isinstance(fn, triton.runtime.JITFunction): 

751 if ( 

752 getattr(fn, "apply_flagtune", None) is not None 

753 and getattr(fn, "_flagtune_op_name", None) is not None 

754 ): 

755 return True 

756 fn = getattr(fn, "fn", None) 

757 if fn is None: 

758 break 

759 return False 

760 

761 def _apply_flagtune(self): 

762 changed = False 

763 fn = self.fn 

764 while not isinstance(fn, triton.runtime.JITFunction): 

765 apply_flagtune = getattr(fn, "apply_flagtune", None) 

766 if apply_flagtune is not None: 

767 changed = apply_flagtune() or changed 

768 fn = getattr(fn, "fn", None) 

769 if fn is None: 

770 break 

771 if changed: 

772 for cache in self.kernel_cache: 

773 cache.clear() 

774 

775 def key(self, spec_args, dns_args, const_args): 

776 def spec_arg(arg): 

777 if hasattr(arg, "data_ptr"): 

778 return (arg.dtype, arg.data_ptr() % self.divisibility == 0) 

779 return (type(arg), arg) 

780 

781 def dns_arg(arg): 

782 if hasattr(arg, "data_ptr"): 

783 return arg.dtype 

784 if not isinstance(arg, int): 

785 return type(arg) 

786 if -(2**31) <= arg and arg <= 2**31 - 1: 

787 return "i32" 

788 if 2**63 <= arg and arg <= 2**64 - 1: 

789 return "u64" 

790 return "i64" 

791 

792 spec_key = [spec_arg(arg) for arg in spec_args] 

793 dns_key = [dns_arg(arg) for arg in dns_args] 

794 # const args passed by position 

795 return tuple(spec_key + dns_key + const_args) 

796 

797 def run(self, *args, **kwargs): 

798 grid = kwargs["grid"] 

799 if self._has_flagtune_tuner: 

800 self._apply_flagtune() 

801 

802 # collect all the arguments 

803 spec_args = [] # specialize arguments 

804 dns_args = [] # do not specialize arguments 

805 const_args = [] # constexpr arguments 

806 k_args = OrderedDict() 

807 param_names = list(self.signature.parameters.keys()) 

808 for i, arg in enumerate(args): 

809 hashable_arg = arg 

810 if ( 

811 hasattr(arg, "__class__") 

812 and arg.__class__.__name__ == "TensorDescriptor" 

813 ): 

814 # Create a hashable representation of TensorDescriptor 

815 hashable_arg = ( 

816 "TensorDescriptor", 

817 tuple(arg.shape) if hasattr(arg, "shape") else None, 

818 tuple(arg.strides) if hasattr(arg, "strides") else None, 

819 tuple(arg.block_shape) if hasattr(arg, "block_shape") else None, 

820 arg.padding if hasattr(arg, "padding") else None, 

821 # Add other relevant attributes 

822 ) 

823 if i in self.specialize_indices: 

824 k_args[param_names[i]] = arg 

825 spec_args.append(hashable_arg) 

826 elif i in self.do_not_specialize_indices: 

827 k_args[param_names[i]] = arg 

828 dns_args.append(hashable_arg) 

829 else: 

830 if major_version == 3 and 3 <= minor_version <= 6: 

831 k_args[param_names[i]] = arg 

832 const_args.append(hashable_arg) 

833 for p in self.jit_function.params[len(args) :]: 

834 if p.name in kwargs: 

835 val = kwargs[p.name] 

836 elif p.default is inspect._empty: 

837 continue 

838 else: 

839 val = p.default 

840 

841 if p.is_constexpr: 

842 const_args.append(val) 

843 if major_version == 3 and 3 <= minor_version <= 6: 

844 k_args[p.name] = val 

845 elif p.do_not_specialize: 

846 dns_args.append(val) 

847 k_args[p.name] = val 

848 else: 

849 spec_args.append(val) 

850 k_args[p.name] = val 

851 

852 entry_key = self.key(spec_args, dns_args, const_args) 

853 device = torch_device_fn.current_device() 

854 # CPU has one device per process and `current_device()` returns the 

855 # string "cpu" (can't index into the int-keyed `kernel_cache` tuple). 

856 # This branch is CPU-generic — any future x86 / RISC-V CPU backend 

857 # reuses the same path; no ARM-specific assumption here. 

858 if device == "cpu": 

859 cache = self._cpu_cache 

860 else: 

861 cache = self.kernel_cache[device] 

862 while entry_key not in cache: 

863 # NOTE: we serialize the first run of a jit function regardless of which device to run on 

864 # because Triton runtime is currently not threadsafe. 

865 with self.lock: 

866 if entry_key in cache: 

867 break 

868 kernel = self.fn.run(*args, **kwargs) 

869 fn = self.fn 

870 # collect constexpr arguments for grid computation 

871 constexprs = {} 

872 tune_constexprs = {} 

873 heur_constexprs = {} 

874 launch_pre_hooks = [] 

875 while not isinstance(fn, triton.runtime.JITFunction): 

876 if isinstance(fn, triton.runtime.Autotuner): 

877 config = fn.best_config 

878 constexprs["num_warps"] = config.num_warps 

879 constexprs["num_stages"] = config.num_stages 

880 constexprs["num_ctas"] = config.num_ctas 

881 constexprs = {**constexprs, **config.kwargs} 

882 tune_constexprs = {**tune_constexprs, **config.kwargs} 

883 if config.pre_hook is not None: 

884 launch_pre_hooks.append( 

885 (config.pre_hook, config.all_kwargs()) 

886 ) 

887 elif isinstance(fn, triton.runtime.Heuristics): 

888 for v, heur in fn.values.items(): 

889 heur_constexprs[v] = heur( 

890 { 

891 **dict(zip(fn.arg_names, args)), 

892 **kwargs, 

893 **constexprs, 

894 } 

895 ) 

896 constexprs[v] = heur_constexprs[v] 

897 else: 

898 raise RuntimeError("Invalid Runtime Function") 

899 fn = fn.fn 

900 for p in self.jit_function.params: 

901 if ( 

902 p.is_constexpr 

903 and p.name not in constexprs 

904 and (p.default is not inspect._empty) 

905 ): 

906 constexprs[p.name] = p.default 

907 cache[entry_key] = ( 

908 kernel, 

909 constexprs, 

910 tune_constexprs, 

911 heur_constexprs, 

912 tuple(launch_pre_hooks), 

913 ) 

914 return kernel, constexprs 

915 

916 ( 

917 kernel, 

918 constexprs, 

919 tune_constexprs, 

920 heur_constexprs, 

921 launch_pre_hooks, 

922 ) = cache[entry_key] 

923 

924 if callable(grid): 

925 # collect all arguments to the grid fn,ie: 

926 # 1. args, 

927 # 2. kwargs, 

928 # 3. all all other captured arguments in CompiledKernel from Autotunner & Heuristics 

929 # when kwargs & captured args conflict, captured args have higher priority 

930 meta = {**dict(zip(self.arg_names, args)), **kwargs, **constexprs} 

931 grid = grid(meta) 

932 grid = grid + (1, 1) 

933 

934 if launch_pre_hooks: 

935 hook_nargs = {**dict(zip(self.arg_names, args)), **kwargs} 

936 for pre_hook, hook_kwargs in launch_pre_hooks: 

937 pre_hook({**hook_nargs, **hook_kwargs}) 

938 

939 if major_version == 3 and 3 <= minor_version <= 6: 

940 all_args = [] 

941 missing_keys = [] 

942 for key in list(self.signature.parameters.keys()): 

943 if key in k_args: 

944 all_args.append(k_args[key]) 

945 elif key in tune_constexprs: 

946 all_args.append(tune_constexprs[key]) 

947 elif key in heur_constexprs: 

948 all_args.append(heur_constexprs[key]) 

949 elif key in constexprs: 

950 all_args.append(constexprs[key]) 

951 else: 

952 missing_keys.append(key) 

953 if len(missing_keys): 

954 raise RuntimeError( 

955 f"[libentry]: probably a bug, the following kernel params where not captured: {missing_keys}" 

956 ) 

957 kernel[grid[0:3]](*all_args) 

958 else: 

959 kernel[grid[0:3]](*k_args.values()) 

960 return kernel, constexprs 

961 

962 

963def libentry(): 

964 """Decorator for triton library entries.""" 

965 

966 def decorator(fn): 

967 return LibEntry(fn) 

968 

969 return decorator