Coverage for src/flag_gems/utils/libentry.py: 89%
387 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1from __future__ import annotations
3import hashlib
4import inspect
5import logging
6import math
7import multiprocessing
8import os
9import time
10from abc import abstractmethod
11from collections import OrderedDict
12from functools import cached_property
13from itertools import starmap
14from pathlib import Path
15from typing import (
16 Any,
17 Callable,
18 Dict,
19 Final,
20 Iterator,
21 List,
22 Optional,
23 Tuple,
24 Type,
25 Union,
26 overload,
27)
29import triton
31from flag_gems import runtime
32from flag_gems.runtime import torch_device_fn
33from flag_gems.runtime.backend import _state
34from flag_gems.utils.code_cache import config_cache_dir
35from flag_gems.utils.models import PersistantModel, SQLPersistantModel
37logger = logging.getLogger(__name__)
39DEVICE_COUNT = runtime.device.device_count
41version = triton.__version__.split(".")
42major_version, minor_version = eval(version[0]), eval(version[1])
45if major_version == 2:
47 def all_kwargs(self):
48 return {
49 **self.kwargs,
50 **{
51 k: getattr(self, k)
52 for k in (
53 "num_warps",
54 "num_ctas",
55 "num_stages",
56 "num_buffers_warp_spec",
57 "num_consumer_groups",
58 "reg_dec_producer",
59 "reg_inc_consumer",
60 "maxnreg",
61 )
62 if hasattr(self, k)
63 },
64 }
66 setattr(triton.Config, "all_kwargs", all_kwargs)
68FLAGGEMS_DB_URL = os.getenv("FLAGGEMS_DB_URL", None)
71class Cache(object):
72 def __init__(
73 self, table_name: str, model: PersistantModel, *args, **kwargs
74 ) -> Cache:
75 super().__init__(*args, **kwargs)
76 self.table_name: Final[str] = table_name
77 self.model: Final[PersistantModel] = model
80class ConfigCache(Cache):
81 """
82 `ConfigCache` is used to store the relationship between keys and their known best configurations.
83 """
85 def __init__(
86 self, table_name: str, model: PersistantModel, *args, **kwargs
87 ) -> ConfigCache:
88 super().__init__(table_name, model, *args, **kwargs)
90 def __contains__(self, key: Tuple[Union[int, float, str], ...]) -> bool:
91 return self.get(key) is not None
93 def __getitem__(self, key: Tuple[Union[int, float, str], ...]) -> triton.Config:
94 ret: Optional[triton.Config] = self.get(key)
95 if ret is None:
96 raise KeyError(f"Key {key} not found in ConfigCache.")
97 return ret
99 def __setitem__(
100 self, key: Tuple[Union[int, float, str], ...], config: triton.Config
101 ) -> None:
102 self.set(key, config)
104 def get(self, key: Tuple[Union[int, float, str], ...]) -> Optional[triton.Config]:
105 return self.model.get_config(self.table_name, key)
107 def set(
108 self, key: Tuple[Union[int, float, str], ...], config: triton.Config
109 ) -> None:
110 return self.model.put_config(self.table_name, key, config)
113class BenchmarkCache(Cache):
114 def __init__(
115 self,
116 table_name: str,
117 model: PersistantModel,
118 key: Tuple[Union[int, float, str], ...],
119 *args,
120 **kwargs,
121 ) -> BenchmarkCache:
122 """
123 `BenchmarkCache` is used to store the benchmark results for the pair of the specific key and configuration.
124 """
125 super().__init__(table_name, model, *args, **kwargs)
126 self.key: Final[Tuple[Union[int, float, str], ...]] = key
128 def __contains__(self, config: triton.Config) -> bool:
129 return self.model.get_benchmark(self.key, config) is not None
131 def __getitem__(self, config: triton.Config) -> Tuple[float]:
132 ret: Optional[Tuple[float, float, float]] = self.get(config)
133 if ret is None:
134 raise KeyError(
135 f"Config {config} not found in BenchmarkCache for key {self.key}."
136 )
137 return ret
139 def __setitem__(self, config: triton.Config, benchmark: Tuple[float]) -> None:
140 return self.set(config, benchmark)
142 def get(self, config: triton.Config) -> Optional[Tuple[float, float, float]]:
143 return self.model.get_benchmark(self.table_name, self.key, config)
145 def set(self, config: triton.Config, benchmark: Tuple[float, float, float]) -> None:
146 return self.model.put_benchmark(self.table_name, self.key, config, benchmark)
149class LibCache(object):
150 _instance = None
152 def __new__(cls, *args, **kwargs):
153 if cls._instance is None:
154 cls._instance = super(LibCache, cls).__new__(cls)
155 return cls._instance
157 def __init__(self, db_url: Optional[str] = None):
158 self.global_cache: Dict = {}
159 self.volumn: Dict = {}
160 device_name = _state.vendor_module.vendor_info.device_name
161 if db_url is None:
162 try:
163 device_name: str = torch_device_fn.get_device_name().replace(" ", "_")
164 except AttributeError:
165 device_name: str = device_name
166 cache_file_name: str = (
167 f"TunedConfig_{device_name}_triton_{major_version}_{minor_version}.db"
168 if device_name == "nvidia"
169 else f"TunedConfig_{device_name}_triton_{major_version}_{minor_version}.db"
170 )
171 cache_path: Path = config_cache_dir() / cache_file_name
172 self.db_url: str = f"sqlite:///{cache_path}"
173 else:
174 self.db_url: str = db_url
175 self.config_cache_pool: Dict[str, ConfigCache] = {}
176 self.benchmark_cache_pool: Dict[
177 Tuple[str, Tuple[Union[int, float, str], ...]], BenchmarkCache
178 ] = {}
179 self.model: PersistantModel = SQLPersistantModel(self.db_url)
181 @overload
182 def __getitem__(self, key: str) -> ConfigCache:
183 ...
185 @overload
186 def __getitem__(self, key: Tuple[Union[int, float, str]]) -> BenchmarkCache:
187 ...
189 def __getitem__(
190 self, key: Union[str, Tuple[Union[int, float, str], ...]]
191 ) -> Union[BenchmarkCache, ConfigCache]:
192 if isinstance(key, str):
193 return self.get_config(key)
194 elif isinstance(key, tuple):
195 return self.get_benchmark(*key)
196 else:
197 assert False, f"the type of key '{key.__class__.__name__}' is unacceptable"
199 def get_benchmark(
200 self, table: str, key: Tuple[Union[int, float, str], ...]
201 ) -> BenchmarkCache:
202 ret = self.benchmark_cache_pool.get((table, key))
203 if ret is None:
204 ret = BenchmarkCache(table, self.model, key)
205 self.benchmark_cache_pool[(table, key)] = ret
206 return ret
208 def get_config(self, table: str) -> ConfigCache:
209 ret = self.config_cache_pool.get(table)
210 if ret is None:
211 ret = ConfigCache(table, self.model)
212 self.config_cache_pool[table] = ret
213 return ret
216libcache = LibCache(FLAGGEMS_DB_URL)
219class LibTuner(triton.runtime.Autotuner):
220 """`LibTuner` is the base class for `FlagGems` library autotuner.
222 It could be extended in two ways, overriding the `policy` or `run` method in a subclass.
223 For `policy` extension, `LibTuner` provides a decorator `register_policy` to register a policy function quickly.
224 Please refer to the implementation of `default_policy` for an example.
225 """
227 # The dispatch table for `LibTuner` subclasses. It's shared across all instances.
228 _dispatch_table: Dict[str, Type[LibTuner]] = {}
229 _strategy_table: Dict[str, Callable[[Any], Any]] = {}
231 def __init__(
232 self,
233 fn,
234 arg_names,
235 configs,
236 key,
237 reset_to_zero,
238 restore_value,
239 pre_hook=None,
240 post_hook=None,
241 prune_configs_by: Optional[Dict] = None,
242 warmup=None,
243 rep=None,
244 use_cuda_graph=False,
245 do_bench=None,
246 strategy=None,
247 ):
248 # NOTE(zhengyang): See discussion in https://github.com/triton-lang/triton/pull/4496
249 if major_version == 2 or (major_version == 3 and minor_version <= 1):
250 if warmup is None:
251 warmup = 25
252 if rep is None:
253 rep = 100
254 if major_version == 2:
255 super().__init__(
256 fn,
257 arg_names,
258 configs,
259 key,
260 reset_to_zero,
261 restore_value,
262 prune_configs_by,
263 warmup,
264 rep,
265 )
266 self.base_fn = fn
267 while not inspect.isfunction(self.base_fn):
268 self.base_fn = self.base_fn.fn
269 else:
270 super().__init__(
271 fn,
272 arg_names,
273 configs,
274 key,
275 reset_to_zero,
276 restore_value,
277 pre_hook,
278 post_hook,
279 prune_configs_by,
280 warmup,
281 rep,
282 use_cuda_graph,
283 )
284 self.__name__ = self.base_fn.__name__
285 self.keys = key
286 if isinstance(strategy, str):
287 strategy = LibTuner.get_strategy(strategy)
288 if not isinstance(strategy, (list, tuple)):
289 strategy = [strategy] * len(self.keys)
290 assert len(strategy) == len(
291 self.keys
292 ), f"the length of strategy {len(strategy)} must match the length of keys {len(self.keys)}"
293 strategy: List[Callable[[Any], Any]] = [
294 LibTuner.get_strategy(s) if isinstance(s, str) else s for s in strategy
295 ]
296 self.strategy: List[Callable[[Any], Any]] = strategy
297 self.config_table_name: str = f"{self.__name__}_{self.kernel_hash}"
298 self.benchmark_table_name: str = f"{self.__name__}_{self.cache_key}_benchmark"
299 self.cache: BenchmarkCache = libcache[self.config_table_name]
301 @cached_property
302 def cache_key(self) -> str:
303 jit_fn = self.fn
304 while not isinstance(jit_fn, triton.runtime.JITFunction):
305 jit_fn = jit_fn.fn
306 return jit_fn.cache_key
308 @cached_property
309 def kernel_hash(self) -> str:
310 return hashlib.md5(
311 f"{self.cache_key}{self.configs_hash}".encode("utf-8")
312 ).hexdigest()[:32]
314 @cached_property
315 def configs_hash(self) -> str:
316 return hashlib.md5(
317 ",".join(map(lambda config: str(config), self.configs)).encode("utf-8")
318 ).hexdigest()[:32]
320 def get_key(self, args):
321 if self.strategy is None:
322 key = tuple(args[k] for k in self.keys if k in args)
323 else:
324 key = tuple(
325 starmap(
326 lambda idx0, idx1: self.strategy[idx0](args[idx1]),
327 enumerate(self.keys),
328 )
329 )
330 key += tuple(str(arg.dtype) for arg in args.values() if hasattr(arg, "dtype"))
331 return key
333 @staticmethod
334 @abstractmethod
335 def policy(
336 self,
337 fn: Callable[[triton.Config], List[float]],
338 configs: Iterator[triton.Config],
339 args: Tuple[Any],
340 kwargs: Dict[str, Any],
341 ) -> Tuple[triton.Config, Dict[str, float]]:
342 raise NotImplementedError(
343 f"`policy` isn't implemented in {self.__class__.__name__}"
344 )
346 @classmethod
347 def register(cls, name: str):
348 """Register a subclass of `LibTuner` with a name.
350 Args:
351 name: The name of the subclass.
352 Returns:
353 A decorator that registers the subclass with the name.
354 """
356 def decorator(subclass):
357 cls._dispatch_table[name] = subclass
358 return subclass
360 return decorator
362 @classmethod
363 def get(cls, name: str):
364 return cls._dispatch_table[name]
366 @classmethod
367 def get_strategy(cls, name: str):
368 return cls._strategy_table[name]
370 @staticmethod
371 def register_policy(
372 name: str,
373 ) -> Type[LibTuner]:
374 """A decorator to register a policy for `LibTuner`.
376 This decorator allows you to create a new `LibTuner` subclass without defining a new class explicitly.
377 The new subclass will have the `policy` method set to the provided policy function and will be registered under
378 the specified name in the `LibTuner` dispatch table.
379 """
381 def decorator(
382 policy_impl: Callable[
383 [
384 Callable[[triton.Config], List[float]],
385 Iterator[triton.Config],
386 Tuple[Any],
387 Dict[str, Any],
388 ],
389 Tuple[triton.Config, Dict[str, float]],
390 ],
391 ):
392 @LibTuner.register(name)
393 class AnonymousLibTunerImpl(LibTuner):
394 def __init__(self, *args, **kwargs):
395 super().__init__(*args, **kwargs)
397 def policy(
398 self,
399 fn: Callable[[triton.Config], List[float]],
400 configs: Iterator[triton.Config],
401 args: Tuple[Any],
402 kwargs: Dict[str, Any],
403 ) -> Tuple[triton.Config, Dict[str, float]]:
404 return policy_impl(fn, configs, args, kwargs)
406 return AnonymousLibTunerImpl
408 return decorator
410 @staticmethod
411 def register_strategy(name: str):
412 def decorator(
413 strategy: Union[Callable[[Any], Any], List[Callable[[Any], Any]]],
414 ):
415 LibTuner._strategy_table[name] = strategy
416 return strategy
418 return decorator
420 def run(self, *args, **kwargs):
421 # `arg_names` corresponds to the arguments of the `JITFunction`'s signature,
422 # so please make sure the orders of `arg_names` and `args` match.
423 self.nargs = dict(zip(self.arg_names, args))
424 used_cached_result = True
425 if len(self.configs) > 1:
426 all_args = {**self.nargs, **kwargs}
427 _args = {k: v for k, v in all_args.items() if k in self.arg_names}
428 key = self.get_key(_args)
429 if key not in self.cache:
430 cache: BenchmarkCache = libcache[self.benchmark_table_name, key]
431 # prune configs
432 used_cached_result = False
433 pruned_configs = self.prune_configs(kwargs)
434 bench_start = time.time()
436 def bench(config: triton.Config) -> List[float]:
437 ret = cache.get(config)
438 if ret is None:
439 ret = self._bench(*args, config=config, **kwargs)
440 cache[config] = tuple(ret)
441 return list(ret)
443 best_config, timings = self.policy(
444 bench,
445 pruned_configs,
446 args,
447 kwargs,
448 )
449 bench_end = time.time()
450 self.bench_time = bench_end - bench_start
451 self.cache[key] = best_config
452 full_nargs = {
453 **self.nargs,
454 **kwargs,
455 **self.cache[key].all_kwargs(),
456 }
457 self.pre_hook(full_nargs, reset_only=True)
458 self.configs_timings = timings
459 config = self.cache[key]
460 if config.pre_hook is None:
461 cached_kwargs = config.all_kwargs()
462 for original_config in self.configs:
463 if original_config.all_kwargs() == cached_kwargs:
464 # Use the original config which has the pre_hook
465 config = original_config
466 break
467 else:
468 config = self.configs[0]
469 self.best_config = config
470 if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" and not used_cached_result:
471 print(
472 f"Triton autotuning for function {self.base_fn.__name__} finished after "
473 f"{self.bench_time:.2f}s; key info: {key}, best config selected: {self.best_config};"
474 )
475 if config.pre_hook is not None:
476 full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()}
477 config.pre_hook(full_nargs)
478 ret = self.fn.run(
479 *args,
480 **kwargs,
481 **config.all_kwargs(),
482 )
483 self.nargs = None
484 return ret
487@LibTuner.register_strategy(None)
488@LibTuner.register_strategy("default")
489def default_strategy(key: Any) -> Any:
490 return key
493@LibTuner.register_strategy("log")
494def log2_strategy(key: Union[int, float]) -> float:
495 return 2 ** math.ceil(math.log2(key))
498@LibTuner.register_strategy("align32")
499def align32_strategy(key: Union[int, float]) -> int:
500 return math.ceil(key / 32) * 32
503@LibTuner.register_policy("default")
504def default_policy(
505 bench_fn: Callable[[triton.Config], List[float]],
506 configs: Iterator[triton.Config],
507 args: Tuple[Any],
508 kwargs: Dict[str, Any],
509) -> Tuple[triton.Config, Dict[str, float]]:
510 """Default policy for offline autotuning.
512 Args:
513 bench_fn: The function to benchmark.
514 configs: The collection of the configuration search space.
515 args: Kernel launch arguments.
516 kwargs: Kernel launch arguments.
517 Returns:
518 A tuple containing the best configuration and a dictionary of timings for each configuration.
520 This is one way to implement a default policy for offline autotuning. It's equal to the following
521 ```
522 @LibTuner.register("default")
523 class DefaultLibTunerImpl(LibTuner):
524 def __init__(
525 self,
526 *args,
527 **kwargs,
528 ):
529 super().__init__(
530 *args,
531 **kwargs,
532 )
534 @staticmethod
535 def policy(
536 bench_fn: Callable[[triton.Config], List[float]],
537 configs: Iterator[triton.Config],
538 args: Tuple[Any],
539 kwargs: Dict[str, Any],
540 ) -> Tuple[triton.Config, Dict[str, float]]:
541 timings: Dict[triton.Config, int] = {
542 config: bench_fn(config) for config in configs
543 }
544 best_config: triton.Config = min(timings, key=timings.get)
545 return best_config, timings
546 ```
547 In this way policies could be extended by registering a definition function quickly,
548 or by creating a new subclass of `LibTuner` and overriding the `policy` method to have
549 more control over the autotuning process.
550 """
551 timings: Dict[triton.Config, float] = {
552 config: bench_fn(config) for config in configs
553 }
554 best_config: triton.Config = min(timings, key=timings.get)
555 return best_config, timings
558def libtuner(
559 configs,
560 key,
561 prune_configs_by=None,
562 reset_to_zero=None,
563 restore_value=None,
564 pre_hook=None,
565 post_hook=None,
566 warmup=25,
567 rep=100,
568 use_cuda_graph=False,
569 do_bench=None,
570 strategy: Union[
571 str, Callable[[Any], Any], List[Union[str, Callable[[Any], Any]]]
572 ] = "default",
573 policy: Union[str, Type[LibTuner]] = "default",
574):
575 """Decorator for triton library autotuner.
577 `strategy` is a function that takes a key and returns a value.
578 It accepts a string, which is the name of a registered strategy, or a callable function.
579 In this form it will be applied to each key in the `key` list.
580 If it's a tuple or list, it should have the same length as `key`,
581 and each element should be a string or a callable function that takes a key and returns a value.
582 `policy` accepts a string, which is the name of a registered `LibTuner` subclass, or a `LibTuner` subclass itself.
583 """
585 if isinstance(policy, str):
586 policy = LibTuner.get(policy)
587 assert issubclass(
588 policy, LibTuner
589 ), f"the class of {policy.__name__} is {policy.__class__.__name__}, not a subclass of {LibTuner.__name__}"
591 def decorator(fn):
592 return policy(
593 fn,
594 fn.arg_names,
595 configs,
596 key,
597 reset_to_zero,
598 restore_value,
599 pre_hook=pre_hook,
600 post_hook=post_hook,
601 prune_configs_by=prune_configs_by,
602 warmup=warmup,
603 rep=rep,
604 use_cuda_graph=use_cuda_graph,
605 do_bench=do_bench,
606 strategy=strategy,
607 )
609 return decorator
612class LibEntry(triton.KernelInterface):
613 def __init__(
614 self,
615 fn,
616 ):
617 self.fn = fn
618 self.arg_names = fn.arg_names
619 self.divisibility = 16
620 self.kernel_cache = tuple(dict() for _ in range(DEVICE_COUNT))
622 while not isinstance(fn, triton.runtime.JITFunction):
623 fn = fn.fn
624 self.jit_function: triton.runtime.JITFunction = fn
625 self.specialize_indices = [
626 p.num
627 for p in self.jit_function.params
628 if not p.is_constexpr and not p.do_not_specialize
629 ]
630 self.do_not_specialize_indices = [
631 p.num
632 for p in self.jit_function.params
633 if not p.is_constexpr and p.do_not_specialize
634 ]
635 self.lock = multiprocessing.Lock()
636 self.signature = fn.signature
638 def key(self, spec_args, dns_args, const_args):
639 def spec_arg(arg):
640 if hasattr(arg, "data_ptr"):
641 return (arg.dtype, arg.data_ptr() % self.divisibility == 0)
642 return (type(arg), arg)
644 def dns_arg(arg):
645 if hasattr(arg, "data_ptr"):
646 return arg.dtype
647 if not isinstance(arg, int):
648 return type(arg)
649 if -(2**31) <= arg and arg <= 2**31 - 1:
650 return "i32"
651 if 2**63 <= arg and arg <= 2**64 - 1:
652 return "u64"
653 return "i64"
655 spec_key = [spec_arg(arg) for arg in spec_args]
656 dns_key = [dns_arg(arg) for arg in dns_args]
657 # const args passed by position
658 return tuple(spec_key + dns_key + const_args)
660 def run(self, *args, **kwargs):
661 grid = kwargs["grid"]
663 # collect all the arguments
664 spec_args = [] # specialize arguments
665 dns_args = [] # do not specialize arguments
666 const_args = [] # constexpr arguments
667 k_args = OrderedDict()
668 param_names = list(self.signature.parameters.keys())
669 for i, arg in enumerate(args):
670 hashable_arg = arg
671 if (
672 hasattr(arg, "__class__")
673 and arg.__class__.__name__ == "TensorDescriptor"
674 ):
675 # Create a hashable representation of TensorDescriptor
676 hashable_arg = (
677 "TensorDescriptor",
678 tuple(arg.shape) if hasattr(arg, "shape") else None,
679 tuple(arg.strides) if hasattr(arg, "strides") else None,
680 tuple(arg.block_shape) if hasattr(arg, "block_shape") else None,
681 arg.padding if hasattr(arg, "padding") else None,
682 # Add other relevant attributes
683 )
684 if i in self.specialize_indices:
685 k_args[param_names[i]] = arg
686 spec_args.append(hashable_arg)
687 elif i in self.do_not_specialize_indices:
688 k_args[param_names[i]] = arg
689 dns_args.append(hashable_arg)
690 else:
691 if major_version == 3 and 3 <= minor_version <= 6:
692 k_args[param_names[i]] = arg
693 const_args.append(hashable_arg)
694 for p in self.jit_function.params[len(args) :]:
695 if p.name in kwargs:
696 val = kwargs[p.name]
697 elif p.default is inspect._empty:
698 continue
699 else:
700 val = p.default
702 if p.is_constexpr:
703 const_args.append(val)
704 if major_version == 3 and 3 <= minor_version <= 6:
705 k_args[p.name] = val
706 elif p.do_not_specialize:
707 dns_args.append(val)
708 k_args[p.name] = val
709 else:
710 spec_args.append(val)
711 k_args[p.name] = val
713 entry_key = self.key(spec_args, dns_args, const_args)
714 device = torch_device_fn.current_device()
715 cache = self.kernel_cache[device]
716 while entry_key not in cache:
717 # NOTE: we serialize the first run of a jit function regardless of which device to run on
718 # because Triton runtime is currently not threadsafe.
719 with self.lock:
720 if entry_key in cache:
721 break
722 kernel = self.fn.run(*args, **kwargs)
723 fn = self.fn
724 # collect constexpr arguments for grid computation
725 constexprs = {}
726 tune_constexprs = {}
727 heur_constexprs = {}
728 launch_pre_hooks = []
729 while not isinstance(fn, triton.runtime.JITFunction):
730 if isinstance(fn, triton.runtime.Autotuner):
731 config = fn.best_config
732 constexprs["num_warps"] = config.num_warps
733 constexprs["num_stages"] = config.num_stages
734 constexprs["num_ctas"] = config.num_ctas
735 constexprs = {**constexprs, **config.kwargs}
736 tune_constexprs = {**tune_constexprs, **config.kwargs}
737 if config.pre_hook is not None:
738 launch_pre_hooks.append(
739 (config.pre_hook, config.all_kwargs())
740 )
741 elif isinstance(fn, triton.runtime.Heuristics):
742 for v, heur in fn.values.items():
743 heur_constexprs[v] = heur(
744 {
745 **dict(zip(fn.arg_names, args)),
746 **kwargs,
747 **constexprs,
748 }
749 )
750 constexprs[v] = heur_constexprs[v]
751 else:
752 raise RuntimeError("Invalid Runtime Function")
753 fn = fn.fn
754 for p in self.jit_function.params:
755 if (
756 p.is_constexpr
757 and p.name not in constexprs
758 and (p.default is not inspect._empty)
759 ):
760 constexprs[p.name] = p.default
761 cache[entry_key] = (
762 kernel,
763 constexprs,
764 tune_constexprs,
765 heur_constexprs,
766 tuple(launch_pre_hooks),
767 )
768 return kernel, constexprs
770 (
771 kernel,
772 constexprs,
773 tune_constexprs,
774 heur_constexprs,
775 launch_pre_hooks,
776 ) = cache[entry_key]
778 if callable(grid):
779 # collect all arguments to the grid fn,ie:
780 # 1. args,
781 # 2. kwargs,
782 # 3. all all other captured arguments in CompiledKernel from Autotunner & Heuristics
783 # when kwargs & captured args conflict, captured args have higher priority
784 meta = {**dict(zip(self.arg_names, args)), **kwargs, **constexprs}
785 grid = grid(meta)
786 grid = grid + (1, 1)
788 if launch_pre_hooks:
789 hook_nargs = {**dict(zip(self.arg_names, args)), **kwargs}
790 for pre_hook, hook_kwargs in launch_pre_hooks:
791 pre_hook({**hook_nargs, **hook_kwargs})
793 if major_version == 3 and 3 <= minor_version <= 6:
794 all_args = []
795 missing_keys = []
796 for key in list(self.signature.parameters.keys()):
797 if key in k_args:
798 all_args.append(k_args[key])
799 elif key in tune_constexprs:
800 all_args.append(tune_constexprs[key])
801 elif key in heur_constexprs:
802 all_args.append(heur_constexprs[key])
803 elif key in constexprs:
804 all_args.append(constexprs[key])
805 else:
806 missing_keys.append(key)
807 if len(missing_keys):
808 raise RuntimeError(
809 f"[libentry]: probably a bug, the following kernel params where not captured: {missing_keys}"
810 )
811 kernel[grid[0:3]](*all_args)
812 else:
813 kernel[grid[0:3]](*k_args.values())
814 return kernel, constexprs
817def libentry():
818 """Decorator for triton library entries."""
820 def decorator(fn):
821 return LibEntry(fn)
823 return decorator