Coverage for src/flag_gems/runtime/flagtune.py: 60%
91 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import os
2import warnings
3from dataclasses import dataclass
4from types import MappingProxyType
6USE_FLAGTUNE_ENV = "USE_FLAGTUNE"
7FLAGTUNE_INCLUDE_ENV = "FLAGTUNE_INCLUDE"
9_flagtune_op_registry = {}
10_include_ops = None
13@dataclass(frozen=True)
14class FlagTuneOpSpec:
15 name: str
16 default_enabled: bool = False
17 description: str = ""
20def _normalize_op_name(op_name):
21 if not isinstance(op_name, str):
22 raise TypeError("op_name must be a string")
23 op_name = op_name.strip()
24 if not op_name:
25 raise ValueError("op_name must not be empty")
26 return op_name
29def register_flagtune_op(
30 op_name,
31 *,
32 default=False,
33 description="",
34 replace=False,
35):
36 """Register an operator name that can be selected by flag_gems.flagtune."""
37 name = _normalize_op_name(op_name)
38 spec = FlagTuneOpSpec(
39 name=name,
40 default_enabled=bool(default),
41 description=str(description or ""),
42 )
44 existing = _flagtune_op_registry.get(name)
45 if existing is not None and not replace:
46 if existing == spec:
47 return existing
48 raise ValueError(f"FlagTune op {name!r} is already registered")
50 _flagtune_op_registry[name] = spec
51 return spec
54def get_flagtune_registry():
55 return MappingProxyType(dict(_flagtune_op_registry))
58def get_supported_flagtune_ops():
59 return frozenset(_flagtune_op_registry)
62def get_default_flagtune_include():
63 return frozenset(
64 name for name, spec in _flagtune_op_registry.items() if spec.default_enabled
65 )
68def _split_include(include):
69 if include is None:
70 return get_default_flagtune_include()
71 if isinstance(include, str):
72 include = include.replace(";", ",").split(",")
74 try:
75 ops = [str(op).strip() for op in include]
76 except TypeError as err:
77 raise TypeError(
78 "include must be a comma-separated string or an iterable"
79 ) from err
81 return frozenset(op for op in ops if op)
84def _normalize_include(include):
85 ops = _split_include(include)
86 supported_ops = get_supported_flagtune_ops()
87 unsupported = sorted(ops - supported_ops)
88 if unsupported:
89 supported = ", ".join(sorted(supported_ops)) or "<none>"
90 raise ValueError(
91 f"Unsupported flagtune op(s): {', '.join(unsupported)}. "
92 f"Supported ops: {supported}"
93 )
94 return ops
97def flagtune(include=None):
98 """Enable runtime FlagTune for selected operators.
100 Passing include=None enables the registry's default operators. Passing a
101 string or iterable selects the registered operators that should use
102 expanded tuning spaces when their LibTuner runs. This API only updates the
103 explicit include list; setting USE_FLAGTUNE=1 remains the legacy opt-in for
104 enabling every registered FlagTune operator.
105 """
106 global _include_ops
107 _include_ops = _normalize_include(include)
108 os.environ[FLAGTUNE_INCLUDE_ENV] = ",".join(sorted(_include_ops))
111def _include_from_env():
112 include = os.environ.get(FLAGTUNE_INCLUDE_ENV)
113 if include is None:
114 return frozenset()
115 try:
116 return _normalize_include(include)
117 except (TypeError, ValueError) as err:
118 warnings.warn(f"Invalid {FLAGTUNE_INCLUDE_ENV}: {err}")
119 return frozenset()
122def get_flagtune_include():
123 if _include_ops is not None:
124 return _include_ops
125 return _include_from_env()
128def flagtune_enabled(op_name):
129 try:
130 op_name = _normalize_op_name(op_name)
131 except (TypeError, ValueError):
132 return False
133 if op_name not in get_supported_flagtune_ops():
134 return False
135 return os.environ.get(USE_FLAGTUNE_ENV) == "1" or op_name in get_flagtune_include()
138def __getattr__(name):
139 if name == "SUPPORTED_FLAGTUNE_OPS":
140 return get_supported_flagtune_ops()
141 if name == "DEFAULT_FLAGTUNE_INCLUDE":
142 return get_default_flagtune_include()
143 raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
146register_flagtune_op("mm", default=False, description="matrix multiplication")
147register_flagtune_op("bmm", default=False, description="batched matrix multiplication")
148register_flagtune_op(
149 "addmm",
150 default=False,
151 description="matrix multiplication with bias",
152)
153register_flagtune_op(
154 "baddbmm",
155 default=False,
156 description="batched matrix multiplication with bias",
157)
158register_flagtune_op(
159 "mv",
160 default=False,
161 description="matrix-vector multiplication",
162)
163register_flagtune_op(
164 "w8a8_block_fp8_matmul",
165 default=False,
166 description="W8A8 block FP8 matrix multiplication",
167)
169# DEFAULT_FLAGTUNE_INCLUDE and SUPPORTED_FLAGTUNE_OPS are provided by __getattr__.
170__all__ = [ # noqa: F822
171 "DEFAULT_FLAGTUNE_INCLUDE",
172 "FLAGTUNE_INCLUDE_ENV",
173 "FlagTuneOpSpec",
174 "SUPPORTED_FLAGTUNE_OPS",
175 "USE_FLAGTUNE_ENV",
176 "flagtune",
177 "flagtune_enabled",
178 "get_default_flagtune_include",
179 "get_flagtune_include",
180 "get_flagtune_registry",
181 "get_supported_flagtune_ops",
182 "register_flagtune_op",
183]