Coverage for src/flag_gems/utils/libentry.py: 89%

387 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1from __future__ import annotations 

2 

3import hashlib 

4import inspect 

5import logging 

6import math 

7import multiprocessing 

8import os 

9import time 

10from abc import abstractmethod 

11from collections import OrderedDict 

12from functools import cached_property 

13from itertools import starmap 

14from pathlib import Path 

15from typing import ( 

16 Any, 

17 Callable, 

18 Dict, 

19 Final, 

20 Iterator, 

21 List, 

22 Optional, 

23 Tuple, 

24 Type, 

25 Union, 

26 overload, 

27) 

28 

29import triton 

30 

31from flag_gems import runtime 

32from flag_gems.runtime import torch_device_fn 

33from flag_gems.runtime.backend import _state 

34from flag_gems.utils.code_cache import config_cache_dir 

35from flag_gems.utils.models import PersistantModel, SQLPersistantModel 

36 

37logger = logging.getLogger(__name__) 

38 

39DEVICE_COUNT = runtime.device.device_count 

40 

41version = triton.__version__.split(".") 

42major_version, minor_version = eval(version[0]), eval(version[1]) 

43 

44 

45if major_version == 2: 

46 

47 def all_kwargs(self): 

48 return { 

49 **self.kwargs, 

50 **{ 

51 k: getattr(self, k) 

52 for k in ( 

53 "num_warps", 

54 "num_ctas", 

55 "num_stages", 

56 "num_buffers_warp_spec", 

57 "num_consumer_groups", 

58 "reg_dec_producer", 

59 "reg_inc_consumer", 

60 "maxnreg", 

61 ) 

62 if hasattr(self, k) 

63 }, 

64 } 

65 

66 setattr(triton.Config, "all_kwargs", all_kwargs) 

67 

68FLAGGEMS_DB_URL = os.getenv("FLAGGEMS_DB_URL", None) 

69 

70 

71class Cache(object): 

72 def __init__( 

73 self, table_name: str, model: PersistantModel, *args, **kwargs 

74 ) -> Cache: 

75 super().__init__(*args, **kwargs) 

76 self.table_name: Final[str] = table_name 

77 self.model: Final[PersistantModel] = model 

78 

79 

80class ConfigCache(Cache): 

81 """ 

82 `ConfigCache` is used to store the relationship between keys and their known best configurations. 

83 """ 

84 

85 def __init__( 

86 self, table_name: str, model: PersistantModel, *args, **kwargs 

87 ) -> ConfigCache: 

88 super().__init__(table_name, model, *args, **kwargs) 

89 

90 def __contains__(self, key: Tuple[Union[int, float, str], ...]) -> bool: 

91 return self.get(key) is not None 

92 

93 def __getitem__(self, key: Tuple[Union[int, float, str], ...]) -> triton.Config: 

94 ret: Optional[triton.Config] = self.get(key) 

95 if ret is None: 

96 raise KeyError(f"Key {key} not found in ConfigCache.") 

97 return ret 

98 

99 def __setitem__( 

100 self, key: Tuple[Union[int, float, str], ...], config: triton.Config 

101 ) -> None: 

102 self.set(key, config) 

103 

104 def get(self, key: Tuple[Union[int, float, str], ...]) -> Optional[triton.Config]: 

105 return self.model.get_config(self.table_name, key) 

106 

107 def set( 

108 self, key: Tuple[Union[int, float, str], ...], config: triton.Config 

109 ) -> None: 

110 return self.model.put_config(self.table_name, key, config) 

111 

112 

113class BenchmarkCache(Cache): 

114 def __init__( 

115 self, 

116 table_name: str, 

117 model: PersistantModel, 

118 key: Tuple[Union[int, float, str], ...], 

119 *args, 

120 **kwargs, 

121 ) -> BenchmarkCache: 

122 """ 

123 `BenchmarkCache` is used to store the benchmark results for the pair of the specific key and configuration. 

124 """ 

125 super().__init__(table_name, model, *args, **kwargs) 

126 self.key: Final[Tuple[Union[int, float, str], ...]] = key 

127 

128 def __contains__(self, config: triton.Config) -> bool: 

129 return self.model.get_benchmark(self.key, config) is not None 

130 

131 def __getitem__(self, config: triton.Config) -> Tuple[float]: 

132 ret: Optional[Tuple[float, float, float]] = self.get(config) 

133 if ret is None: 

134 raise KeyError( 

135 f"Config {config} not found in BenchmarkCache for key {self.key}." 

136 ) 

137 return ret 

138 

139 def __setitem__(self, config: triton.Config, benchmark: Tuple[float]) -> None: 

140 return self.set(config, benchmark) 

141 

142 def get(self, config: triton.Config) -> Optional[Tuple[float, float, float]]: 

143 return self.model.get_benchmark(self.table_name, self.key, config) 

144 

145 def set(self, config: triton.Config, benchmark: Tuple[float, float, float]) -> None: 

146 return self.model.put_benchmark(self.table_name, self.key, config, benchmark) 

147 

148 

149class LibCache(object): 

150 _instance = None 

151 

152 def __new__(cls, *args, **kwargs): 

153 if cls._instance is None: 

154 cls._instance = super(LibCache, cls).__new__(cls) 

155 return cls._instance 

156 

157 def __init__(self, db_url: Optional[str] = None): 

158 self.global_cache: Dict = {} 

159 self.volumn: Dict = {} 

160 device_name = _state.vendor_module.vendor_info.device_name 

161 if db_url is None: 

162 try: 

163 device_name: str = torch_device_fn.get_device_name().replace(" ", "_") 

164 except AttributeError: 

165 device_name: str = device_name 

166 cache_file_name: str = ( 

167 f"TunedConfig_{device_name}_triton_{major_version}_{minor_version}.db" 

168 if device_name == "nvidia" 

169 else f"TunedConfig_{device_name}_triton_{major_version}_{minor_version}.db" 

170 ) 

171 cache_path: Path = config_cache_dir() / cache_file_name 

172 self.db_url: str = f"sqlite:///{cache_path}" 

173 else: 

174 self.db_url: str = db_url 

175 self.config_cache_pool: Dict[str, ConfigCache] = {} 

176 self.benchmark_cache_pool: Dict[ 

177 Tuple[str, Tuple[Union[int, float, str], ...]], BenchmarkCache 

178 ] = {} 

179 self.model: PersistantModel = SQLPersistantModel(self.db_url) 

180 

181 @overload 

182 def __getitem__(self, key: str) -> ConfigCache: 

183 ... 

184 

185 @overload 

186 def __getitem__(self, key: Tuple[Union[int, float, str]]) -> BenchmarkCache: 

187 ... 

188 

189 def __getitem__( 

190 self, key: Union[str, Tuple[Union[int, float, str], ...]] 

191 ) -> Union[BenchmarkCache, ConfigCache]: 

192 if isinstance(key, str): 

193 return self.get_config(key) 

194 elif isinstance(key, tuple): 

195 return self.get_benchmark(*key) 

196 else: 

197 assert False, f"the type of key '{key.__class__.__name__}' is unacceptable" 

198 

199 def get_benchmark( 

200 self, table: str, key: Tuple[Union[int, float, str], ...] 

201 ) -> BenchmarkCache: 

202 ret = self.benchmark_cache_pool.get((table, key)) 

203 if ret is None: 

204 ret = BenchmarkCache(table, self.model, key) 

205 self.benchmark_cache_pool[(table, key)] = ret 

206 return ret 

207 

208 def get_config(self, table: str) -> ConfigCache: 

209 ret = self.config_cache_pool.get(table) 

210 if ret is None: 

211 ret = ConfigCache(table, self.model) 

212 self.config_cache_pool[table] = ret 

213 return ret 

214 

215 

216libcache = LibCache(FLAGGEMS_DB_URL) 

217 

218 

219class LibTuner(triton.runtime.Autotuner): 

220 """`LibTuner` is the base class for `FlagGems` library autotuner. 

221 

222 It could be extended in two ways, overriding the `policy` or `run` method in a subclass. 

223 For `policy` extension, `LibTuner` provides a decorator `register_policy` to register a policy function quickly. 

224 Please refer to the implementation of `default_policy` for an example. 

225 """ 

226 

227 # The dispatch table for `LibTuner` subclasses. It's shared across all instances. 

228 _dispatch_table: Dict[str, Type[LibTuner]] = {} 

229 _strategy_table: Dict[str, Callable[[Any], Any]] = {} 

230 

231 def __init__( 

232 self, 

233 fn, 

234 arg_names, 

235 configs, 

236 key, 

237 reset_to_zero, 

238 restore_value, 

239 pre_hook=None, 

240 post_hook=None, 

241 prune_configs_by: Optional[Dict] = None, 

242 warmup=None, 

243 rep=None, 

244 use_cuda_graph=False, 

245 do_bench=None, 

246 strategy=None, 

247 ): 

248 # NOTE(zhengyang): See discussion in https://github.com/triton-lang/triton/pull/4496 

249 if major_version == 2 or (major_version == 3 and minor_version <= 1): 

250 if warmup is None: 

251 warmup = 25 

252 if rep is None: 

253 rep = 100 

254 if major_version == 2: 

255 super().__init__( 

256 fn, 

257 arg_names, 

258 configs, 

259 key, 

260 reset_to_zero, 

261 restore_value, 

262 prune_configs_by, 

263 warmup, 

264 rep, 

265 ) 

266 self.base_fn = fn 

267 while not inspect.isfunction(self.base_fn): 

268 self.base_fn = self.base_fn.fn 

269 else: 

270 super().__init__( 

271 fn, 

272 arg_names, 

273 configs, 

274 key, 

275 reset_to_zero, 

276 restore_value, 

277 pre_hook, 

278 post_hook, 

279 prune_configs_by, 

280 warmup, 

281 rep, 

282 use_cuda_graph, 

283 ) 

284 self.__name__ = self.base_fn.__name__ 

285 self.keys = key 

286 if isinstance(strategy, str): 

287 strategy = LibTuner.get_strategy(strategy) 

288 if not isinstance(strategy, (list, tuple)): 

289 strategy = [strategy] * len(self.keys) 

290 assert len(strategy) == len( 

291 self.keys 

292 ), f"the length of strategy {len(strategy)} must match the length of keys {len(self.keys)}" 

293 strategy: List[Callable[[Any], Any]] = [ 

294 LibTuner.get_strategy(s) if isinstance(s, str) else s for s in strategy 

295 ] 

296 self.strategy: List[Callable[[Any], Any]] = strategy 

297 self.config_table_name: str = f"{self.__name__}_{self.kernel_hash}" 

298 self.benchmark_table_name: str = f"{self.__name__}_{self.cache_key}_benchmark" 

299 self.cache: BenchmarkCache = libcache[self.config_table_name] 

300 

301 @cached_property 

302 def cache_key(self) -> str: 

303 jit_fn = self.fn 

304 while not isinstance(jit_fn, triton.runtime.JITFunction): 

305 jit_fn = jit_fn.fn 

306 return jit_fn.cache_key 

307 

308 @cached_property 

309 def kernel_hash(self) -> str: 

310 return hashlib.md5( 

311 f"{self.cache_key}{self.configs_hash}".encode("utf-8") 

312 ).hexdigest()[:32] 

313 

314 @cached_property 

315 def configs_hash(self) -> str: 

316 return hashlib.md5( 

317 ",".join(map(lambda config: str(config), self.configs)).encode("utf-8") 

318 ).hexdigest()[:32] 

319 

320 def get_key(self, args): 

321 if self.strategy is None: 

322 key = tuple(args[k] for k in self.keys if k in args) 

323 else: 

324 key = tuple( 

325 starmap( 

326 lambda idx0, idx1: self.strategy[idx0](args[idx1]), 

327 enumerate(self.keys), 

328 ) 

329 ) 

330 key += tuple(str(arg.dtype) for arg in args.values() if hasattr(arg, "dtype")) 

331 return key 

332 

333 @staticmethod 

334 @abstractmethod 

335 def policy( 

336 self, 

337 fn: Callable[[triton.Config], List[float]], 

338 configs: Iterator[triton.Config], 

339 args: Tuple[Any], 

340 kwargs: Dict[str, Any], 

341 ) -> Tuple[triton.Config, Dict[str, float]]: 

342 raise NotImplementedError( 

343 f"`policy` isn't implemented in {self.__class__.__name__}" 

344 ) 

345 

346 @classmethod 

347 def register(cls, name: str): 

348 """Register a subclass of `LibTuner` with a name. 

349 

350 Args: 

351 name: The name of the subclass. 

352 Returns: 

353 A decorator that registers the subclass with the name. 

354 """ 

355 

356 def decorator(subclass): 

357 cls._dispatch_table[name] = subclass 

358 return subclass 

359 

360 return decorator 

361 

362 @classmethod 

363 def get(cls, name: str): 

364 return cls._dispatch_table[name] 

365 

366 @classmethod 

367 def get_strategy(cls, name: str): 

368 return cls._strategy_table[name] 

369 

370 @staticmethod 

371 def register_policy( 

372 name: str, 

373 ) -> Type[LibTuner]: 

374 """A decorator to register a policy for `LibTuner`. 

375 

376 This decorator allows you to create a new `LibTuner` subclass without defining a new class explicitly. 

377 The new subclass will have the `policy` method set to the provided policy function and will be registered under 

378 the specified name in the `LibTuner` dispatch table. 

379 """ 

380 

381 def decorator( 

382 policy_impl: Callable[ 

383 [ 

384 Callable[[triton.Config], List[float]], 

385 Iterator[triton.Config], 

386 Tuple[Any], 

387 Dict[str, Any], 

388 ], 

389 Tuple[triton.Config, Dict[str, float]], 

390 ], 

391 ): 

392 @LibTuner.register(name) 

393 class AnonymousLibTunerImpl(LibTuner): 

394 def __init__(self, *args, **kwargs): 

395 super().__init__(*args, **kwargs) 

396 

397 def policy( 

398 self, 

399 fn: Callable[[triton.Config], List[float]], 

400 configs: Iterator[triton.Config], 

401 args: Tuple[Any], 

402 kwargs: Dict[str, Any], 

403 ) -> Tuple[triton.Config, Dict[str, float]]: 

404 return policy_impl(fn, configs, args, kwargs) 

405 

406 return AnonymousLibTunerImpl 

407 

408 return decorator 

409 

410 @staticmethod 

411 def register_strategy(name: str): 

412 def decorator( 

413 strategy: Union[Callable[[Any], Any], List[Callable[[Any], Any]]], 

414 ): 

415 LibTuner._strategy_table[name] = strategy 

416 return strategy 

417 

418 return decorator 

419 

420 def run(self, *args, **kwargs): 

421 # `arg_names` corresponds to the arguments of the `JITFunction`'s signature, 

422 # so please make sure the orders of `arg_names` and `args` match. 

423 self.nargs = dict(zip(self.arg_names, args)) 

424 used_cached_result = True 

425 if len(self.configs) > 1: 

426 all_args = {**self.nargs, **kwargs} 

427 _args = {k: v for k, v in all_args.items() if k in self.arg_names} 

428 key = self.get_key(_args) 

429 if key not in self.cache: 

430 cache: BenchmarkCache = libcache[self.benchmark_table_name, key] 

431 # prune configs 

432 used_cached_result = False 

433 pruned_configs = self.prune_configs(kwargs) 

434 bench_start = time.time() 

435 

436 def bench(config: triton.Config) -> List[float]: 

437 ret = cache.get(config) 

438 if ret is None: 

439 ret = self._bench(*args, config=config, **kwargs) 

440 cache[config] = tuple(ret) 

441 return list(ret) 

442 

443 best_config, timings = self.policy( 

444 bench, 

445 pruned_configs, 

446 args, 

447 kwargs, 

448 ) 

449 bench_end = time.time() 

450 self.bench_time = bench_end - bench_start 

451 self.cache[key] = best_config 

452 full_nargs = { 

453 **self.nargs, 

454 **kwargs, 

455 **self.cache[key].all_kwargs(), 

456 } 

457 self.pre_hook(full_nargs, reset_only=True) 

458 self.configs_timings = timings 

459 config = self.cache[key] 

460 if config.pre_hook is None: 

461 cached_kwargs = config.all_kwargs() 

462 for original_config in self.configs: 

463 if original_config.all_kwargs() == cached_kwargs: 

464 # Use the original config which has the pre_hook 

465 config = original_config 

466 break 

467 else: 

468 config = self.configs[0] 

469 self.best_config = config 

470 if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" and not used_cached_result: 

471 print( 

472 f"Triton autotuning for function {self.base_fn.__name__} finished after " 

473 f"{self.bench_time:.2f}s; key info: {key}, best config selected: {self.best_config};" 

474 ) 

475 if config.pre_hook is not None: 

476 full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()} 

477 config.pre_hook(full_nargs) 

478 ret = self.fn.run( 

479 *args, 

480 **kwargs, 

481 **config.all_kwargs(), 

482 ) 

483 self.nargs = None 

484 return ret 

485 

486 

487@LibTuner.register_strategy(None) 

488@LibTuner.register_strategy("default") 

489def default_strategy(key: Any) -> Any: 

490 return key 

491 

492 

493@LibTuner.register_strategy("log") 

494def log2_strategy(key: Union[int, float]) -> float: 

495 return 2 ** math.ceil(math.log2(key)) 

496 

497 

498@LibTuner.register_strategy("align32") 

499def align32_strategy(key: Union[int, float]) -> int: 

500 return math.ceil(key / 32) * 32 

501 

502 

503@LibTuner.register_policy("default") 

504def default_policy( 

505 bench_fn: Callable[[triton.Config], List[float]], 

506 configs: Iterator[triton.Config], 

507 args: Tuple[Any], 

508 kwargs: Dict[str, Any], 

509) -> Tuple[triton.Config, Dict[str, float]]: 

510 """Default policy for offline autotuning. 

511 

512 Args: 

513 bench_fn: The function to benchmark. 

514 configs: The collection of the configuration search space. 

515 args: Kernel launch arguments. 

516 kwargs: Kernel launch arguments. 

517 Returns: 

518 A tuple containing the best configuration and a dictionary of timings for each configuration. 

519 

520 This is one way to implement a default policy for offline autotuning. It's equal to the following 

521 ``` 

522 @LibTuner.register("default") 

523 class DefaultLibTunerImpl(LibTuner): 

524 def __init__( 

525 self, 

526 *args, 

527 **kwargs, 

528 ): 

529 super().__init__( 

530 *args, 

531 **kwargs, 

532 ) 

533 

534 @staticmethod 

535 def policy( 

536 bench_fn: Callable[[triton.Config], List[float]], 

537 configs: Iterator[triton.Config], 

538 args: Tuple[Any], 

539 kwargs: Dict[str, Any], 

540 ) -> Tuple[triton.Config, Dict[str, float]]: 

541 timings: Dict[triton.Config, int] = { 

542 config: bench_fn(config) for config in configs 

543 } 

544 best_config: triton.Config = min(timings, key=timings.get) 

545 return best_config, timings 

546 ``` 

547 In this way policies could be extended by registering a definition function quickly, 

548 or by creating a new subclass of `LibTuner` and overriding the `policy` method to have 

549 more control over the autotuning process. 

550 """ 

551 timings: Dict[triton.Config, float] = { 

552 config: bench_fn(config) for config in configs 

553 } 

554 best_config: triton.Config = min(timings, key=timings.get) 

555 return best_config, timings 

556 

557 

558def libtuner( 

559 configs, 

560 key, 

561 prune_configs_by=None, 

562 reset_to_zero=None, 

563 restore_value=None, 

564 pre_hook=None, 

565 post_hook=None, 

566 warmup=25, 

567 rep=100, 

568 use_cuda_graph=False, 

569 do_bench=None, 

570 strategy: Union[ 

571 str, Callable[[Any], Any], List[Union[str, Callable[[Any], Any]]] 

572 ] = "default", 

573 policy: Union[str, Type[LibTuner]] = "default", 

574): 

575 """Decorator for triton library autotuner. 

576 

577 `strategy` is a function that takes a key and returns a value. 

578 It accepts a string, which is the name of a registered strategy, or a callable function. 

579 In this form it will be applied to each key in the `key` list. 

580 If it's a tuple or list, it should have the same length as `key`, 

581 and each element should be a string or a callable function that takes a key and returns a value. 

582 `policy` accepts a string, which is the name of a registered `LibTuner` subclass, or a `LibTuner` subclass itself. 

583 """ 

584 

585 if isinstance(policy, str): 

586 policy = LibTuner.get(policy) 

587 assert issubclass( 

588 policy, LibTuner 

589 ), f"the class of {policy.__name__} is {policy.__class__.__name__}, not a subclass of {LibTuner.__name__}" 

590 

591 def decorator(fn): 

592 return policy( 

593 fn, 

594 fn.arg_names, 

595 configs, 

596 key, 

597 reset_to_zero, 

598 restore_value, 

599 pre_hook=pre_hook, 

600 post_hook=post_hook, 

601 prune_configs_by=prune_configs_by, 

602 warmup=warmup, 

603 rep=rep, 

604 use_cuda_graph=use_cuda_graph, 

605 do_bench=do_bench, 

606 strategy=strategy, 

607 ) 

608 

609 return decorator 

610 

611 

612class LibEntry(triton.KernelInterface): 

613 def __init__( 

614 self, 

615 fn, 

616 ): 

617 self.fn = fn 

618 self.arg_names = fn.arg_names 

619 self.divisibility = 16 

620 self.kernel_cache = tuple(dict() for _ in range(DEVICE_COUNT)) 

621 

622 while not isinstance(fn, triton.runtime.JITFunction): 

623 fn = fn.fn 

624 self.jit_function: triton.runtime.JITFunction = fn 

625 self.specialize_indices = [ 

626 p.num 

627 for p in self.jit_function.params 

628 if not p.is_constexpr and not p.do_not_specialize 

629 ] 

630 self.do_not_specialize_indices = [ 

631 p.num 

632 for p in self.jit_function.params 

633 if not p.is_constexpr and p.do_not_specialize 

634 ] 

635 self.lock = multiprocessing.Lock() 

636 self.signature = fn.signature 

637 

638 def key(self, spec_args, dns_args, const_args): 

639 def spec_arg(arg): 

640 if hasattr(arg, "data_ptr"): 

641 return (arg.dtype, arg.data_ptr() % self.divisibility == 0) 

642 return (type(arg), arg) 

643 

644 def dns_arg(arg): 

645 if hasattr(arg, "data_ptr"): 

646 return arg.dtype 

647 if not isinstance(arg, int): 

648 return type(arg) 

649 if -(2**31) <= arg and arg <= 2**31 - 1: 

650 return "i32" 

651 if 2**63 <= arg and arg <= 2**64 - 1: 

652 return "u64" 

653 return "i64" 

654 

655 spec_key = [spec_arg(arg) for arg in spec_args] 

656 dns_key = [dns_arg(arg) for arg in dns_args] 

657 # const args passed by position 

658 return tuple(spec_key + dns_key + const_args) 

659 

660 def run(self, *args, **kwargs): 

661 grid = kwargs["grid"] 

662 

663 # collect all the arguments 

664 spec_args = [] # specialize arguments 

665 dns_args = [] # do not specialize arguments 

666 const_args = [] # constexpr arguments 

667 k_args = OrderedDict() 

668 param_names = list(self.signature.parameters.keys()) 

669 for i, arg in enumerate(args): 

670 hashable_arg = arg 

671 if ( 

672 hasattr(arg, "__class__") 

673 and arg.__class__.__name__ == "TensorDescriptor" 

674 ): 

675 # Create a hashable representation of TensorDescriptor 

676 hashable_arg = ( 

677 "TensorDescriptor", 

678 tuple(arg.shape) if hasattr(arg, "shape") else None, 

679 tuple(arg.strides) if hasattr(arg, "strides") else None, 

680 tuple(arg.block_shape) if hasattr(arg, "block_shape") else None, 

681 arg.padding if hasattr(arg, "padding") else None, 

682 # Add other relevant attributes 

683 ) 

684 if i in self.specialize_indices: 

685 k_args[param_names[i]] = arg 

686 spec_args.append(hashable_arg) 

687 elif i in self.do_not_specialize_indices: 

688 k_args[param_names[i]] = arg 

689 dns_args.append(hashable_arg) 

690 else: 

691 if major_version == 3 and 3 <= minor_version <= 6: 

692 k_args[param_names[i]] = arg 

693 const_args.append(hashable_arg) 

694 for p in self.jit_function.params[len(args) :]: 

695 if p.name in kwargs: 

696 val = kwargs[p.name] 

697 elif p.default is inspect._empty: 

698 continue 

699 else: 

700 val = p.default 

701 

702 if p.is_constexpr: 

703 const_args.append(val) 

704 if major_version == 3 and 3 <= minor_version <= 6: 

705 k_args[p.name] = val 

706 elif p.do_not_specialize: 

707 dns_args.append(val) 

708 k_args[p.name] = val 

709 else: 

710 spec_args.append(val) 

711 k_args[p.name] = val 

712 

713 entry_key = self.key(spec_args, dns_args, const_args) 

714 device = torch_device_fn.current_device() 

715 cache = self.kernel_cache[device] 

716 while entry_key not in cache: 

717 # NOTE: we serialize the first run of a jit function regardless of which device to run on 

718 # because Triton runtime is currently not threadsafe. 

719 with self.lock: 

720 if entry_key in cache: 

721 break 

722 kernel = self.fn.run(*args, **kwargs) 

723 fn = self.fn 

724 # collect constexpr arguments for grid computation 

725 constexprs = {} 

726 tune_constexprs = {} 

727 heur_constexprs = {} 

728 launch_pre_hooks = [] 

729 while not isinstance(fn, triton.runtime.JITFunction): 

730 if isinstance(fn, triton.runtime.Autotuner): 

731 config = fn.best_config 

732 constexprs["num_warps"] = config.num_warps 

733 constexprs["num_stages"] = config.num_stages 

734 constexprs["num_ctas"] = config.num_ctas 

735 constexprs = {**constexprs, **config.kwargs} 

736 tune_constexprs = {**tune_constexprs, **config.kwargs} 

737 if config.pre_hook is not None: 

738 launch_pre_hooks.append( 

739 (config.pre_hook, config.all_kwargs()) 

740 ) 

741 elif isinstance(fn, triton.runtime.Heuristics): 

742 for v, heur in fn.values.items(): 

743 heur_constexprs[v] = heur( 

744 { 

745 **dict(zip(fn.arg_names, args)), 

746 **kwargs, 

747 **constexprs, 

748 } 

749 ) 

750 constexprs[v] = heur_constexprs[v] 

751 else: 

752 raise RuntimeError("Invalid Runtime Function") 

753 fn = fn.fn 

754 for p in self.jit_function.params: 

755 if ( 

756 p.is_constexpr 

757 and p.name not in constexprs 

758 and (p.default is not inspect._empty) 

759 ): 

760 constexprs[p.name] = p.default 

761 cache[entry_key] = ( 

762 kernel, 

763 constexprs, 

764 tune_constexprs, 

765 heur_constexprs, 

766 tuple(launch_pre_hooks), 

767 ) 

768 return kernel, constexprs 

769 

770 ( 

771 kernel, 

772 constexprs, 

773 tune_constexprs, 

774 heur_constexprs, 

775 launch_pre_hooks, 

776 ) = cache[entry_key] 

777 

778 if callable(grid): 

779 # collect all arguments to the grid fn,ie: 

780 # 1. args, 

781 # 2. kwargs, 

782 # 3. all all other captured arguments in CompiledKernel from Autotunner & Heuristics 

783 # when kwargs & captured args conflict, captured args have higher priority 

784 meta = {**dict(zip(self.arg_names, args)), **kwargs, **constexprs} 

785 grid = grid(meta) 

786 grid = grid + (1, 1) 

787 

788 if launch_pre_hooks: 

789 hook_nargs = {**dict(zip(self.arg_names, args)), **kwargs} 

790 for pre_hook, hook_kwargs in launch_pre_hooks: 

791 pre_hook({**hook_nargs, **hook_kwargs}) 

792 

793 if major_version == 3 and 3 <= minor_version <= 6: 

794 all_args = [] 

795 missing_keys = [] 

796 for key in list(self.signature.parameters.keys()): 

797 if key in k_args: 

798 all_args.append(k_args[key]) 

799 elif key in tune_constexprs: 

800 all_args.append(tune_constexprs[key]) 

801 elif key in heur_constexprs: 

802 all_args.append(heur_constexprs[key]) 

803 elif key in constexprs: 

804 all_args.append(constexprs[key]) 

805 else: 

806 missing_keys.append(key) 

807 if len(missing_keys): 

808 raise RuntimeError( 

809 f"[libentry]: probably a bug, the following kernel params where not captured: {missing_keys}" 

810 ) 

811 kernel[grid[0:3]](*all_args) 

812 else: 

813 kernel[grid[0:3]](*k_args.values()) 

814 return kernel, constexprs 

815 

816 

817def libentry(): 

818 """Decorator for triton library entries.""" 

819 

820 def decorator(fn): 

821 return LibEntry(fn) 

822 

823 return decorator