Coverage for src/flag_gems/utils/libentry.py: 84%
466 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1from __future__ import annotations
3import hashlib
4import inspect
5import logging
6import math
7import multiprocessing
8import os
9import time
10from abc import abstractmethod
11from collections import OrderedDict
12from functools import cached_property
13from itertools import starmap
14from pathlib import Path
15from typing import (
16 Any,
17 Callable,
18 Dict,
19 Final,
20 Iterator,
21 List,
22 Optional,
23 Tuple,
24 Type,
25 Union,
26 overload,
27)
29import triton
31from flag_gems import runtime
32from flag_gems.runtime import torch_device_fn
33from flag_gems.runtime.backend import _state
34from flag_gems.utils.code_cache import config_cache_dir
35from flag_gems.utils.models import PersistantModel, SQLPersistantModel
37logger = logging.getLogger(__name__)
39DEVICE_COUNT = runtime.device.device_count
41version = triton.__version__.split(".")
42major_version, minor_version = eval(version[0]), eval(version[1])
45if major_version == 2:
47 def all_kwargs(self):
48 return {
49 **self.kwargs,
50 **{
51 k: getattr(self, k)
52 for k in (
53 "num_warps",
54 "num_ctas",
55 "num_stages",
56 "num_buffers_warp_spec",
57 "num_consumer_groups",
58 "reg_dec_producer",
59 "reg_inc_consumer",
60 "maxnreg",
61 )
62 if hasattr(self, k)
63 },
64 }
66 setattr(triton.Config, "all_kwargs", all_kwargs)
68FLAGGEMS_DB_URL = os.getenv("FLAGGEMS_DB_URL", None)
71class Cache(object):
72 def __init__(
73 self, table_name: str, model: PersistantModel, *args, **kwargs
74 ) -> Cache:
75 super().__init__(*args, **kwargs)
76 self.table_name: Final[str] = table_name
77 self.model: Final[PersistantModel] = model
80class ConfigCache(Cache):
81 """
82 `ConfigCache` is used to store the relationship between keys and their known best configurations.
83 """
85 def __init__(
86 self, table_name: str, model: PersistantModel, *args, **kwargs
87 ) -> ConfigCache:
88 super().__init__(table_name, model, *args, **kwargs)
90 def __contains__(self, key: Tuple[Union[int, float, str], ...]) -> bool:
91 return self.get(key) is not None
93 def __getitem__(self, key: Tuple[Union[int, float, str], ...]) -> triton.Config:
94 ret: Optional[triton.Config] = self.get(key)
95 if ret is None:
96 raise KeyError(f"Key {key} not found in ConfigCache.")
97 return ret
99 def __setitem__(
100 self, key: Tuple[Union[int, float, str], ...], config: triton.Config
101 ) -> None:
102 self.set(key, config)
104 def get(self, key: Tuple[Union[int, float, str], ...]) -> Optional[triton.Config]:
105 return self.model.get_config(self.table_name, key)
107 def set(
108 self, key: Tuple[Union[int, float, str], ...], config: triton.Config
109 ) -> None:
110 return self.model.put_config(self.table_name, key, config)
113class BenchmarkCache(Cache):
114 def __init__(
115 self,
116 table_name: str,
117 model: PersistantModel,
118 key: Tuple[Union[int, float, str], ...],
119 *args,
120 **kwargs,
121 ) -> BenchmarkCache:
122 """
123 `BenchmarkCache` is used to store the benchmark results for the pair of the specific key and configuration.
124 """
125 super().__init__(table_name, model, *args, **kwargs)
126 self.key: Final[Tuple[Union[int, float, str], ...]] = key
128 def __contains__(self, config: triton.Config) -> bool:
129 return self.model.get_benchmark(self.key, config) is not None
131 def __getitem__(self, config: triton.Config) -> Tuple[float]:
132 ret: Optional[Tuple[float, float, float]] = self.get(config)
133 if ret is None:
134 raise KeyError(
135 f"Config {config} not found in BenchmarkCache for key {self.key}."
136 )
137 return ret
139 def __setitem__(self, config: triton.Config, benchmark: Tuple[float]) -> None:
140 return self.set(config, benchmark)
142 def get(self, config: triton.Config) -> Optional[Tuple[float, float, float]]:
143 return self.model.get_benchmark(self.table_name, self.key, config)
145 def set(self, config: triton.Config, benchmark: Tuple[float, float, float]) -> None:
146 return self.model.put_benchmark(self.table_name, self.key, config, benchmark)
149class LibCache(object):
150 _instance = None
152 def __new__(cls, *args, **kwargs):
153 if cls._instance is None:
154 cls._instance = super(LibCache, cls).__new__(cls)
155 return cls._instance
157 def __init__(self, db_url: Optional[str] = None):
158 self.global_cache: Dict = {}
159 self.volumn: Dict = {}
160 vendor_name = _state.vendor_module.vendor_info.vendor_name
161 if db_url is None:
162 cache_file_name: str = (
163 f"TunedConfig_{vendor_name}_triton_{major_version}_{minor_version}.db"
164 )
165 cache_path: Path = config_cache_dir() / cache_file_name
166 self.db_url: str = f"sqlite:///{cache_path}"
167 else:
168 self.db_url: str = db_url
169 self.config_cache_pool: Dict[str, ConfigCache] = {}
170 self.benchmark_cache_pool: Dict[
171 Tuple[str, Tuple[Union[int, float, str], ...]], BenchmarkCache
172 ] = {}
173 self.model: PersistantModel = SQLPersistantModel(self.db_url)
175 @overload
176 def __getitem__(self, key: str) -> ConfigCache:
177 ...
179 @overload
180 def __getitem__(self, key: Tuple[Union[int, float, str]]) -> BenchmarkCache:
181 ...
183 def __getitem__(
184 self, key: Union[str, Tuple[Union[int, float, str], ...]]
185 ) -> Union[BenchmarkCache, ConfigCache]:
186 if isinstance(key, str):
187 return self.get_config(key)
188 elif isinstance(key, tuple):
189 return self.get_benchmark(*key)
190 else:
191 assert False, f"the type of key '{key.__class__.__name__}' is unacceptable"
193 def get_benchmark(
194 self, table: str, key: Tuple[Union[int, float, str], ...]
195 ) -> BenchmarkCache:
196 ret = self.benchmark_cache_pool.get((table, key))
197 if ret is None:
198 ret = BenchmarkCache(table, self.model, key)
199 self.benchmark_cache_pool[(table, key)] = ret
200 return ret
202 def get_config(self, table: str) -> ConfigCache:
203 ret = self.config_cache_pool.get(table)
204 if ret is None:
205 ret = ConfigCache(table, self.model)
206 self.config_cache_pool[table] = ret
207 return ret
210libcache = LibCache(FLAGGEMS_DB_URL)
213class LibTuner(triton.runtime.Autotuner):
214 """`LibTuner` is the base class for `FlagGems` library autotuner.
216 It could be extended in two ways, overriding the `policy` or `run` method in a subclass.
217 For `policy` extension, `LibTuner` provides a decorator `register_policy` to register a policy function quickly.
218 Please refer to the implementation of `default_policy` for an example.
219 """
221 # The dispatch table for `LibTuner` subclasses. It's shared across all instances.
222 _dispatch_table: Dict[str, Type[LibTuner]] = {}
223 _strategy_table: Dict[str, Callable[[Any], Any]] = {}
225 def __init__(
226 self,
227 fn,
228 arg_names,
229 configs,
230 key,
231 reset_to_zero,
232 restore_value,
233 pre_hook=None,
234 post_hook=None,
235 prune_configs_by: Optional[Dict] = None,
236 warmup=None,
237 rep=None,
238 use_cuda_graph=False,
239 do_bench=None,
240 strategy=None,
241 flagtune_op_name=None,
242 flagtune_expand_op_name=None,
243 flagtune_yaml_path=None,
244 flagtune_pre_hook=None,
245 ):
246 # NOTE(zhengyang): See discussion in https://github.com/triton-lang/triton/pull/4496
247 if major_version == 2 or (major_version == 3 and minor_version <= 1):
248 if warmup is None:
249 warmup = 25
250 if rep is None:
251 rep = 100
252 if major_version == 2:
253 super().__init__(
254 fn,
255 arg_names,
256 configs,
257 key,
258 reset_to_zero,
259 restore_value,
260 prune_configs_by,
261 warmup,
262 rep,
263 )
264 self.base_fn = fn
265 while not inspect.isfunction(self.base_fn):
266 self.base_fn = self.base_fn.fn
267 elif major_version == 3 and minor_version <= 1:
268 super().__init__(
269 fn,
270 arg_names,
271 configs,
272 key,
273 reset_to_zero,
274 restore_value,
275 pre_hook,
276 post_hook,
277 prune_configs_by,
278 warmup,
279 rep,
280 use_cuda_graph,
281 )
282 else:
283 # Triton 3.2+ removed warmup/rep/use_cuda_graph positional arguments.
284 # Preserve FlagGems tuning behavior by translating them into do_bench.
285 if do_bench is None:
286 if use_cuda_graph:
287 from triton.testing import do_bench_cudagraph
289 def do_bench(kernel_call, quantiles):
290 return do_bench_cudagraph(
291 kernel_call,
292 rep=rep if rep is not None else 100,
293 quantiles=quantiles,
294 )
296 elif warmup is not None or rep is not None:
298 def do_bench(kernel_call, quantiles):
299 return triton.testing.do_bench(
300 kernel_call,
301 warmup=warmup if warmup is not None else 25,
302 rep=rep if rep is not None else 100,
303 quantiles=quantiles,
304 )
306 super().__init__(
307 fn,
308 arg_names,
309 configs,
310 key,
311 reset_to_zero,
312 restore_value,
313 pre_hook=pre_hook,
314 post_hook=post_hook,
315 prune_configs_by=prune_configs_by,
316 do_bench=do_bench,
317 )
318 self.__name__ = self.base_fn.__name__
319 self.keys = key
320 self.strategy: List[Callable[[Any], Any]] = self._normalize_strategy(strategy)
321 self.config_table_name: str = f"{self.__name__}_{self.kernel_hash}"
322 self.benchmark_table_name: str = f"{self.__name__}_{self.cache_key}_benchmark"
323 self.cache: BenchmarkCache = libcache[self.config_table_name]
324 self._flagtune_default_configs = self.configs
325 self._flagtune_default_strategy = strategy
326 self._flagtune_active = False
327 self._flagtune_warned = False
328 self._flagtune_op_name = flagtune_op_name
329 self._flagtune_expand_op_name = flagtune_expand_op_name or flagtune_op_name
330 self._flagtune_yaml_path = flagtune_yaml_path
331 self._flagtune_pre_hook = flagtune_pre_hook
333 def _normalize_strategy(self, strategy):
334 if isinstance(strategy, str):
335 strategy = LibTuner.get_strategy(strategy)
336 if not isinstance(strategy, (list, tuple)):
337 strategy = [strategy] * len(self.keys)
338 assert len(strategy) == len(
339 self.keys
340 ), f"the length of strategy {len(strategy)} must match the length of keys {len(self.keys)}"
341 return [LibTuner.get_strategy(s) if isinstance(s, str) else s for s in strategy]
343 def _set_configs_and_strategy(self, configs, strategy):
344 self.configs = configs
345 self.strategy = self._normalize_strategy(strategy)
346 self.__dict__.pop("configs_hash", None)
347 self.__dict__.pop("kernel_hash", None)
348 self.config_table_name = f"{self.__name__}_{self.kernel_hash}"
349 self.benchmark_table_name = f"{self.__name__}_{self.cache_key}_benchmark"
350 self.cache = libcache[self.config_table_name]
352 def apply_flagtune(self):
353 if self._flagtune_op_name is None:
354 return False
356 enabled = runtime.flagtune_enabled(self._flagtune_op_name)
357 if enabled == self._flagtune_active:
358 return False
360 if not enabled:
361 self._set_configs_and_strategy(
362 self._flagtune_default_configs,
363 self._flagtune_default_strategy,
364 )
365 self._flagtune_active = False
366 return True
368 expand_config = runtime.get_expand_config(
369 self._flagtune_expand_op_name,
370 yaml_path=self._flagtune_yaml_path,
371 )
372 configs = runtime.ops_get_configs(
373 self._flagtune_expand_op_name,
374 yaml_path=self._flagtune_yaml_path,
375 pre_hook=self._flagtune_pre_hook,
376 )
377 if expand_config == -1 or not configs:
378 if not self._flagtune_warned:
379 logger.warning(
380 "FlagTune expand config is unavailable for %s; using default configs.",
381 self._flagtune_expand_op_name,
382 )
383 self._flagtune_warned = True
384 return False
386 self._set_configs_and_strategy(configs, expand_config["strategy"])
387 self._flagtune_active = True
388 return True
390 @cached_property
391 def cache_key(self) -> str:
392 jit_fn = self.fn
393 while not isinstance(jit_fn, triton.runtime.JITFunction):
394 jit_fn = jit_fn.fn
395 return jit_fn.cache_key
397 @cached_property
398 def kernel_hash(self) -> str:
399 return hashlib.md5(
400 f"{self.cache_key}{self.configs_hash}".encode("utf-8")
401 ).hexdigest()[:32]
403 @cached_property
404 def configs_hash(self) -> str:
405 return hashlib.md5(
406 ",".join(map(lambda config: str(config), self.configs)).encode("utf-8")
407 ).hexdigest()[:32]
409 def get_key(self, args):
410 if self.strategy is None:
411 key = tuple(args[k] for k in self.keys if k in args)
412 else:
413 key = tuple(
414 starmap(
415 lambda idx0, idx1: self.strategy[idx0](args[idx1]),
416 enumerate(self.keys),
417 )
418 )
419 key += tuple(str(arg.dtype) for arg in args.values() if hasattr(arg, "dtype"))
420 return key
422 @staticmethod
423 @abstractmethod
424 def policy(
425 self,
426 fn: Callable[[triton.Config], List[float]],
427 configs: Iterator[triton.Config],
428 args: Tuple[Any],
429 kwargs: Dict[str, Any],
430 ) -> Tuple[triton.Config, Dict[str, float]]:
431 raise NotImplementedError(
432 f"`policy` isn't implemented in {self.__class__.__name__}"
433 )
435 @classmethod
436 def register(cls, name: str):
437 """Register a subclass of `LibTuner` with a name.
439 Args:
440 name: The name of the subclass.
441 Returns:
442 A decorator that registers the subclass with the name.
443 """
445 def decorator(subclass):
446 cls._dispatch_table[name] = subclass
447 return subclass
449 return decorator
451 @classmethod
452 def get(cls, name: str):
453 return cls._dispatch_table[name]
455 @classmethod
456 def get_strategy(cls, name: str):
457 return cls._strategy_table[name]
459 @staticmethod
460 def register_policy(
461 name: str,
462 ) -> Type[LibTuner]:
463 """A decorator to register a policy for `LibTuner`.
465 This decorator allows you to create a new `LibTuner` subclass without defining a new class explicitly.
466 The new subclass will have the `policy` method set to the provided policy function and will be registered under
467 the specified name in the `LibTuner` dispatch table.
468 """
470 def decorator(
471 policy_impl: Callable[
472 [
473 Callable[[triton.Config], List[float]],
474 Iterator[triton.Config],
475 Tuple[Any],
476 Dict[str, Any],
477 ],
478 Tuple[triton.Config, Dict[str, float]],
479 ],
480 ):
481 @LibTuner.register(name)
482 class AnonymousLibTunerImpl(LibTuner):
483 def __init__(self, *args, **kwargs):
484 super().__init__(*args, **kwargs)
486 def policy(
487 self,
488 fn: Callable[[triton.Config], List[float]],
489 configs: Iterator[triton.Config],
490 args: Tuple[Any],
491 kwargs: Dict[str, Any],
492 ) -> Tuple[triton.Config, Dict[str, float]]:
493 return policy_impl(fn, configs, args, kwargs)
495 return AnonymousLibTunerImpl
497 return decorator
499 @staticmethod
500 def register_strategy(name: str):
501 def decorator(
502 strategy: Union[Callable[[Any], Any], List[Callable[[Any], Any]]],
503 ):
504 LibTuner._strategy_table[name] = strategy
505 return strategy
507 return decorator
509 def run(self, *args, **kwargs):
510 if hasattr(self, "seen_tuned_metas"):
511 self.seen_tuned_metas = {} # flagtree aabs: deduplicate tuned meta
512 # `arg_names` corresponds to the arguments of the `JITFunction`'s signature,
513 # so please make sure the orders of `arg_names` and `args` match.
514 self.nargs = dict(zip(self.arg_names, args))
515 used_cached_result = True
516 if len(self.configs) > 1:
517 all_args = {**self.nargs, **kwargs}
518 _args = {k: v for k, v in all_args.items() if k in self.arg_names}
519 key = self.get_key(_args)
520 if key not in self.cache:
521 cache: BenchmarkCache = libcache[self.benchmark_table_name, key]
522 # prune configs
523 used_cached_result = False
524 pruned_configs = self.prune_configs(kwargs)
525 bench_start = time.time()
527 def bench(config: triton.Config) -> List[float]:
528 ret = cache.get(config)
529 if ret is None:
530 ret = self._bench(*args, config=config, **kwargs)
531 cache[config] = tuple(ret)
532 return list(ret)
534 best_config, timings = self.policy(
535 bench,
536 pruned_configs,
537 args,
538 kwargs,
539 )
540 bench_end = time.time()
541 self.bench_time = bench_end - bench_start
542 self.cache[key] = best_config
543 full_nargs = {
544 **self.nargs,
545 **kwargs,
546 **self.cache[key].all_kwargs(),
547 }
548 self.pre_hook(full_nargs, reset_only=True)
549 self.configs_timings = timings
550 config = self.cache[key]
551 if config.pre_hook is None:
552 cached_kwargs = config.all_kwargs()
553 for original_config in self.configs:
554 if original_config.all_kwargs() == cached_kwargs:
555 # Use the original config which has the pre_hook
556 config = original_config
557 break
558 else:
559 config = self.configs[0]
560 self.best_config = config
561 if os.getenv("TRITON_PRINT_AUTOTUNING", None) == "1" and not used_cached_result:
562 print(
563 f"Triton autotuning for function {self.base_fn.__name__} finished after "
564 f"{self.bench_time:.2f}s; key info: {key}, best config selected: {self.best_config};"
565 )
566 full_nargs = {**self.nargs, **kwargs, **config.all_kwargs()}
567 if (
568 hasattr(self, "shared_config_pre_hook")
569 and self.shared_config_pre_hook is not None
570 ):
571 self.shared_config_pre_hook(full_nargs)
572 elif config.pre_hook is not None:
573 config.pre_hook(full_nargs)
574 ret = self.fn.run(
575 *args,
576 **kwargs,
577 **config.all_kwargs(),
578 )
579 self.nargs = None
580 return ret
583@LibTuner.register_strategy(None)
584@LibTuner.register_strategy("default")
585def default_strategy(key: Any) -> Any:
586 return key
589@LibTuner.register_strategy("log")
590def log2_strategy(key: Union[int, float]) -> float:
591 return 2 ** math.ceil(math.log2(key))
594@LibTuner.register_strategy("align32")
595def align32_strategy(key: Union[int, float]) -> int:
596 if key == 0:
597 return 0
598 if key < 32:
599 return 2 ** math.ceil(math.log2(key))
600 return math.ceil(key / 32) * 32
603@LibTuner.register_policy("default")
604def default_policy(
605 bench_fn: Callable[[triton.Config], List[float]],
606 configs: Iterator[triton.Config],
607 args: Tuple[Any],
608 kwargs: Dict[str, Any],
609) -> Tuple[triton.Config, Dict[str, float]]:
610 """Default policy for offline autotuning.
612 Args:
613 bench_fn: The function to benchmark.
614 configs: The collection of the configuration search space.
615 args: Kernel launch arguments.
616 kwargs: Kernel launch arguments.
617 Returns:
618 A tuple containing the best configuration and a dictionary of timings for each configuration.
620 This is one way to implement a default policy for offline autotuning. It's equal to the following
621 ```
622 @LibTuner.register("default")
623 class DefaultLibTunerImpl(LibTuner):
624 def __init__(
625 self,
626 *args,
627 **kwargs,
628 ):
629 super().__init__(
630 *args,
631 **kwargs,
632 )
634 @staticmethod
635 def policy(
636 bench_fn: Callable[[triton.Config], List[float]],
637 configs: Iterator[triton.Config],
638 args: Tuple[Any],
639 kwargs: Dict[str, Any],
640 ) -> Tuple[triton.Config, Dict[str, float]]:
641 timings: Dict[triton.Config, int] = {
642 config: bench_fn(config) for config in configs
643 }
644 best_config: triton.Config = min(timings, key=timings.get)
645 return best_config, timings
646 ```
647 In this way policies could be extended by registering a definition function quickly,
648 or by creating a new subclass of `LibTuner` and overriding the `policy` method to have
649 more control over the autotuning process.
650 """
651 timings: Dict[triton.Config, float] = {
652 config: bench_fn(config) for config in configs
653 }
654 best_config: triton.Config = min(timings, key=timings.get)
655 return best_config, timings
658def libtuner(
659 configs,
660 key,
661 prune_configs_by=None,
662 reset_to_zero=None,
663 restore_value=None,
664 pre_hook=None,
665 post_hook=None,
666 warmup=25,
667 rep=100,
668 use_cuda_graph=False,
669 do_bench=None,
670 strategy: Union[
671 str, Callable[[Any], Any], List[Union[str, Callable[[Any], Any]]]
672 ] = "default",
673 policy: Union[str, Type[LibTuner]] = "default",
674 flagtune_op_name=None,
675 flagtune_expand_op_name=None,
676 flagtune_yaml_path=None,
677 flagtune_pre_hook=None,
678):
679 """Decorator for triton library autotuner.
681 `strategy` is a function that takes a key and returns a value.
682 It accepts a string, which is the name of a registered strategy, or a callable function.
683 In this form it will be applied to each key in the `key` list.
684 If it's a tuple or list, it should have the same length as `key`,
685 and each element should be a string or a callable function that takes a key and returns a value.
686 `policy` accepts a string, which is the name of a registered `LibTuner` subclass, or a `LibTuner` subclass itself.
687 """
689 if isinstance(policy, str):
690 policy = LibTuner.get(policy)
691 assert issubclass(
692 policy, LibTuner
693 ), f"the class of {policy.__name__} is {policy.__class__.__name__}, not a subclass of {LibTuner.__name__}"
695 def decorator(fn):
696 return policy(
697 fn,
698 fn.arg_names,
699 configs,
700 key,
701 reset_to_zero,
702 restore_value,
703 pre_hook=pre_hook,
704 post_hook=post_hook,
705 prune_configs_by=prune_configs_by,
706 warmup=warmup,
707 rep=rep,
708 use_cuda_graph=use_cuda_graph,
709 do_bench=do_bench,
710 strategy=strategy,
711 flagtune_op_name=flagtune_op_name,
712 flagtune_expand_op_name=flagtune_expand_op_name,
713 flagtune_yaml_path=flagtune_yaml_path,
714 flagtune_pre_hook=flagtune_pre_hook,
715 )
717 return decorator
720class LibEntry(triton.KernelInterface):
721 def __init__(
722 self,
723 fn,
724 ):
725 self.fn = fn
726 self.arg_names = fn.arg_names
727 self.divisibility = 16
728 self.kernel_cache = tuple(dict() for _ in range(DEVICE_COUNT))
729 self._has_flagtune_tuner = self._contains_flagtune_tuner(fn)
730 self._cpu_cache = dict()
732 while not isinstance(fn, triton.runtime.JITFunction):
733 fn = fn.fn
734 self.jit_function: triton.runtime.JITFunction = fn
735 self.specialize_indices = [
736 p.num
737 for p in self.jit_function.params
738 if not p.is_constexpr and not p.do_not_specialize
739 ]
740 self.do_not_specialize_indices = [
741 p.num
742 for p in self.jit_function.params
743 if not p.is_constexpr and p.do_not_specialize
744 ]
745 self.lock = multiprocessing.Lock()
746 self.signature = fn.signature
748 @staticmethod
749 def _contains_flagtune_tuner(fn):
750 while not isinstance(fn, triton.runtime.JITFunction):
751 if (
752 getattr(fn, "apply_flagtune", None) is not None
753 and getattr(fn, "_flagtune_op_name", None) is not None
754 ):
755 return True
756 fn = getattr(fn, "fn", None)
757 if fn is None:
758 break
759 return False
761 def _apply_flagtune(self):
762 changed = False
763 fn = self.fn
764 while not isinstance(fn, triton.runtime.JITFunction):
765 apply_flagtune = getattr(fn, "apply_flagtune", None)
766 if apply_flagtune is not None:
767 changed = apply_flagtune() or changed
768 fn = getattr(fn, "fn", None)
769 if fn is None:
770 break
771 if changed:
772 for cache in self.kernel_cache:
773 cache.clear()
775 def key(self, spec_args, dns_args, const_args):
776 def spec_arg(arg):
777 if hasattr(arg, "data_ptr"):
778 return (arg.dtype, arg.data_ptr() % self.divisibility == 0)
779 return (type(arg), arg)
781 def dns_arg(arg):
782 if hasattr(arg, "data_ptr"):
783 return arg.dtype
784 if not isinstance(arg, int):
785 return type(arg)
786 if -(2**31) <= arg and arg <= 2**31 - 1:
787 return "i32"
788 if 2**63 <= arg and arg <= 2**64 - 1:
789 return "u64"
790 return "i64"
792 spec_key = [spec_arg(arg) for arg in spec_args]
793 dns_key = [dns_arg(arg) for arg in dns_args]
794 # const args passed by position
795 return tuple(spec_key + dns_key + const_args)
797 def run(self, *args, **kwargs):
798 grid = kwargs["grid"]
799 if self._has_flagtune_tuner:
800 self._apply_flagtune()
802 # collect all the arguments
803 spec_args = [] # specialize arguments
804 dns_args = [] # do not specialize arguments
805 const_args = [] # constexpr arguments
806 k_args = OrderedDict()
807 param_names = list(self.signature.parameters.keys())
808 for i, arg in enumerate(args):
809 hashable_arg = arg
810 if (
811 hasattr(arg, "__class__")
812 and arg.__class__.__name__ == "TensorDescriptor"
813 ):
814 # Create a hashable representation of TensorDescriptor
815 hashable_arg = (
816 "TensorDescriptor",
817 tuple(arg.shape) if hasattr(arg, "shape") else None,
818 tuple(arg.strides) if hasattr(arg, "strides") else None,
819 tuple(arg.block_shape) if hasattr(arg, "block_shape") else None,
820 arg.padding if hasattr(arg, "padding") else None,
821 # Add other relevant attributes
822 )
823 if i in self.specialize_indices:
824 k_args[param_names[i]] = arg
825 spec_args.append(hashable_arg)
826 elif i in self.do_not_specialize_indices:
827 k_args[param_names[i]] = arg
828 dns_args.append(hashable_arg)
829 else:
830 if major_version == 3 and 3 <= minor_version <= 6:
831 k_args[param_names[i]] = arg
832 const_args.append(hashable_arg)
833 for p in self.jit_function.params[len(args) :]:
834 if p.name in kwargs:
835 val = kwargs[p.name]
836 elif p.default is inspect._empty:
837 continue
838 else:
839 val = p.default
841 if p.is_constexpr:
842 const_args.append(val)
843 if major_version == 3 and 3 <= minor_version <= 6:
844 k_args[p.name] = val
845 elif p.do_not_specialize:
846 dns_args.append(val)
847 k_args[p.name] = val
848 else:
849 spec_args.append(val)
850 k_args[p.name] = val
852 entry_key = self.key(spec_args, dns_args, const_args)
853 device = torch_device_fn.current_device()
854 # CPU has one device per process and `current_device()` returns the
855 # string "cpu" (can't index into the int-keyed `kernel_cache` tuple).
856 # This branch is CPU-generic — any future x86 / RISC-V CPU backend
857 # reuses the same path; no ARM-specific assumption here.
858 if device == "cpu":
859 cache = self._cpu_cache
860 else:
861 cache = self.kernel_cache[device]
862 while entry_key not in cache:
863 # NOTE: we serialize the first run of a jit function regardless of which device to run on
864 # because Triton runtime is currently not threadsafe.
865 with self.lock:
866 if entry_key in cache:
867 break
868 kernel = self.fn.run(*args, **kwargs)
869 fn = self.fn
870 # collect constexpr arguments for grid computation
871 constexprs = {}
872 tune_constexprs = {}
873 heur_constexprs = {}
874 launch_pre_hooks = []
875 while not isinstance(fn, triton.runtime.JITFunction):
876 if isinstance(fn, triton.runtime.Autotuner):
877 config = fn.best_config
878 constexprs["num_warps"] = config.num_warps
879 constexprs["num_stages"] = config.num_stages
880 constexprs["num_ctas"] = config.num_ctas
881 constexprs = {**constexprs, **config.kwargs}
882 tune_constexprs = {**tune_constexprs, **config.kwargs}
883 if config.pre_hook is not None:
884 launch_pre_hooks.append(
885 (config.pre_hook, config.all_kwargs())
886 )
887 elif isinstance(fn, triton.runtime.Heuristics):
888 for v, heur in fn.values.items():
889 heur_constexprs[v] = heur(
890 {
891 **dict(zip(fn.arg_names, args)),
892 **kwargs,
893 **constexprs,
894 }
895 )
896 constexprs[v] = heur_constexprs[v]
897 else:
898 raise RuntimeError("Invalid Runtime Function")
899 fn = fn.fn
900 for p in self.jit_function.params:
901 if (
902 p.is_constexpr
903 and p.name not in constexprs
904 and (p.default is not inspect._empty)
905 ):
906 constexprs[p.name] = p.default
907 cache[entry_key] = (
908 kernel,
909 constexprs,
910 tune_constexprs,
911 heur_constexprs,
912 tuple(launch_pre_hooks),
913 )
914 return kernel, constexprs
916 (
917 kernel,
918 constexprs,
919 tune_constexprs,
920 heur_constexprs,
921 launch_pre_hooks,
922 ) = cache[entry_key]
924 if callable(grid):
925 # collect all arguments to the grid fn,ie:
926 # 1. args,
927 # 2. kwargs,
928 # 3. all all other captured arguments in CompiledKernel from Autotunner & Heuristics
929 # when kwargs & captured args conflict, captured args have higher priority
930 meta = {**dict(zip(self.arg_names, args)), **kwargs, **constexprs}
931 grid = grid(meta)
932 grid = grid + (1, 1)
934 if launch_pre_hooks:
935 hook_nargs = {**dict(zip(self.arg_names, args)), **kwargs}
936 for pre_hook, hook_kwargs in launch_pre_hooks:
937 pre_hook({**hook_nargs, **hook_kwargs})
939 if major_version == 3 and 3 <= minor_version <= 6:
940 all_args = []
941 missing_keys = []
942 for key in list(self.signature.parameters.keys()):
943 if key in k_args:
944 all_args.append(k_args[key])
945 elif key in tune_constexprs:
946 all_args.append(tune_constexprs[key])
947 elif key in heur_constexprs:
948 all_args.append(heur_constexprs[key])
949 elif key in constexprs:
950 all_args.append(constexprs[key])
951 else:
952 missing_keys.append(key)
953 if len(missing_keys):
954 raise RuntimeError(
955 f"[libentry]: probably a bug, the following kernel params where not captured: {missing_keys}"
956 )
957 kernel[grid[0:3]](*all_args)
958 else:
959 kernel[grid[0:3]](*k_args.values())
960 return kernel, constexprs
963def libentry():
964 """Decorator for triton library entries."""
966 def decorator(fn):
967 return LibEntry(fn)
969 return decorator