Coverage for src/flag_gems/runtime/backend/__init__.py: 72%
285 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import ast
2import functools
3import importlib
4import inspect
5import os
6import sys
7from pathlib import Path
9from ..common import vendors
10from . import backend_utils
11from .backend_utils import BackendEventBase
14class BackendState:
15 """Singleton class to manage backend state variables."""
17 _instance = None
19 def __new__(cls):
20 if cls._instance is None:
21 cls._instance = super().__new__(cls)
22 cls._instance._initialized = False
23 return cls._instance
25 def __init__(self):
26 if self._initialized:
27 return
28 self._initialized = True
29 self.vendor_module = None
30 self.device_name = None
31 self.torch_device_object = None
32 self.torch_device_fn_device = None
33 self.tl_extra_backend_module = None
34 self.ops_module = None
35 self.fused_module = None
36 self.heuristic_config_module = None
37 self.vendor_extra_lib_imported = False
38 self.device_fn_cache = {}
39 self.customized_ops = None
42# Global singleton instance
43_state = BackendState()
46class TritonVersionEvent(BackendEventBase):
47 _instance = None
48 has_version_spec = False
50 def __new__(cls, *args, **kwargs):
51 if cls._instance is None:
52 cls._instance = super().__new__(cls)
53 return cls._instance
55 def __init__(self, version=None):
56 self.has_version_spec = False
57 self.version = version if version is not None else self.get_version()
58 self.dir = self.get_version_spec_dir()
59 if self.dir and Path(self.dir).exists():
60 self.module = self.get_version_spec_module()
61 self.has_version_spec = True
63 def is_available(self):
64 return self.has_version_spec
66 def get_version_spec_dir(self, path=None):
67 dir_name = f"triton_{self.version}"
68 backend_path = Path(path or _state.vendor_module.__path__[0])
69 backend_path = backend_path.parent if backend_path.is_file() else backend_path
70 excluded = ("ops", "fused")
71 return {
72 p.name: str(p)
73 for p in backend_path.iterdir()
74 if p.is_dir() and p.name not in excluded and not p.name.startswith("_")
75 }.get(dir_name, None)
77 def get_functions_from_module(self, module):
78 return inspect.getmembers(module, inspect.isfunction) if module else []
80 def get_version_spec_module(self):
81 module_name = f"triton_{self.version}"
82 path_dir = os.path.dirname(self.dir)
83 sys.path.insert(0, str(path_dir))
84 version_module = importlib.import_module(module_name)
85 sys.path.remove(str(path_dir))
86 return version_module
88 def get_ops(self):
89 return self.get_version_ops()
91 def get_version_ops(self):
92 pass
94 def get_version(self):
95 try:
96 import triton
97 except ImportError:
98 return None
99 return triton.__version__
102class BackendArchEvent(BackendEventBase):
103 has_arch: bool = False
104 _instance = None
105 _initialized: bool = False
107 def __new__(cls, *args, **kwargs):
108 if cls._instance is None:
109 cls._instance = super().__new__(cls)
110 return cls._instance
112 def __init__(self, backend=None):
113 if BackendArchEvent._initialized:
114 return
115 BackendArchEvent._initialized = True
116 self.backend = backend
117 self.error_msgs = []
118 self.arch = self.get_arch()
119 if self.has_arch:
120 self.supported_archs = self._get_supported_archs()
121 # current_arch_path is like FlagGems/src/flag_gems/runtime/backend/_nvidia/hopper
122 self.current_arch_path = self.supported_archs.get(self.arch)
123 self.arch_module = self.get_arch_module()
124 self.autotune_configs = self.get_autotune_configs()
125 self.heuristics_configs = self.get_heuristics_configs()
127 def is_available(self):
128 return self.has_arch
130 def get_functions_from_module(self, module):
131 return inspect.getmembers(module, inspect.isfunction) if module else []
133 def get_heuristics_configs(self):
134 try:
135 heuristic_module = self.arch_module
136 except Exception: # noqa E722
137 sys.path.insert(0, str(self.current_arch_path))
138 heuristic_module = importlib.import_module("heuristics_config_utils")
139 sys.path.remove(str(self.current_arch_path))
140 return getattr(heuristic_module, "HEURISTICS_CONFIGS", None)
142 def get_autotune_configs(self):
143 path = self.current_arch_path
144 return backend_utils.get_tune_config(file_path=path)
146 def get_arch(self, device=0):
147 if not hasattr(_state.vendor_module, "ARCH_MAP"):
148 return
149 arch_map = _state.vendor_module.ARCH_MAP
150 arch_string = os.environ.get("ARCH", "")
151 arch_string_num = arch_string.split("_")[-1][0] if arch_string else arch_string
152 if not arch_string_num:
153 try:
154 if not _state.torch_device_object.is_available():
155 return False
156 props = _state.torch_device_object.get_device_properties(device)
157 arch_string_num = str(props.major)
158 except Exception:
159 self.has_arch = False
160 if arch_string_num not in arch_map:
161 print(
162 f"[INFO] : FlagGems Unsupported GPU arch {arch_string} specialization"
163 )
164 else:
165 self.has_arch = True
166 return arch_map[arch_string_num]
168 def _get_supported_archs(self, path=None):
169 path = Path(path or _state.vendor_module.__path__[0])
170 path = path.parent if path.is_file() else path
171 excluded = ("ops", "fused")
172 return {
173 p.name: str(p)
174 for p in path.iterdir()
175 if p.is_dir() and p.name not in excluded and not p.name.startswith("_")
176 }
178 def get_supported_archs(self):
179 return list(self.supported_archs.keys())
181 def get_arch_module(self):
182 """Load backend.<arch>"""
183 path_dir = os.path.dirname(self.current_arch_path)
184 sys.path.insert(0, str(path_dir))
185 current_arch_module = importlib.import_module(self.arch)
186 sys.path.remove(str(path_dir))
187 return current_arch_module
189 def get_ops(self):
190 """Provide a unified interface for the upper layer"""
191 return self.get_arch_ops()
193 def get_arch_ops(self):
194 arch_specialized_ops = []
195 sys.path.append(self.current_arch_path)
196 ops_module = getattr(self.arch_module, "ops", None)
197 try:
198 if ops_module is None:
199 ops_module = importlib.import_module(f"{self.arch}.ops")
200 except Exception:
201 try:
202 sys.path.append(self.current_arch_path)
203 ops_module = importlib.import_module(f"{self.arch}.ops")
204 arch_specialized_ops.extend(self.get_functions_from_module(ops_module))
205 except Exception as err_msg:
206 self.error_msgs.append(err_msg)
208 if ops_module is not None:
209 arch_specialized_ops.extend(self.get_functions_from_module(ops_module))
211 return arch_specialized_ops
214class SpecOpRegistrar:
215 def __init__(self, _globals):
216 self._globals = _globals
218 def apply(self):
219 spec_events = self._get_specific_events()
220 for event in spec_events:
221 if not event.is_available():
222 continue
223 operators = event.get_ops()
224 for fn_name, fn in operators:
225 self._globals[fn_name] = fn
227 def _get_specific_events(self):
228 return (BackendArchEvent(), TritonVersionEvent())
231def _import_module_safe(module_name, vendor_name, module_type):
232 """Helper to import a module with proper error handling."""
233 try:
234 return importlib.import_module(module_name)
235 except ModuleNotFoundError:
236 print(
237 f"[Note] No specialized {module_type} operators were found for "
238 f"the {vendor_name}, generic {module_type} operators will be used by default."
239 )
240 except Exception as e:
241 raise RuntimeError(f"Failed to import vendor extra lib: {e}")
244def import_vendor_extra_lib(vendor_name=None):
245 if _state.vendor_extra_lib_imported:
246 return
247 _state.ops_module = _import_module_safe(
248 f"_{vendor_name}.ops", vendor_name, "common"
249 )
250 _state.fused_module = _import_module_safe(
251 f"_{vendor_name}.fused", vendor_name, "fused"
252 )
253 _state.vendor_extra_lib_imported = True
256def get_codegen_result(code, result_key):
257 parsed_ast = ast.parse(code)
258 compiled_code = compile(parsed_ast, filename="<ast>", mode="exec")
259 try:
260 exec(compiled_code, globals())
261 except Exception as e:
262 raise e
263 return globals()[result_key]
266@functools.lru_cache(maxsize=32)
267def gen_torch_tensor_attr_res(tensor, attr_name):
268 _state.device_name = _state.device_name or get_vendor_info().device_name
269 code = f"""
270import torch
271res = {tensor}.{attr_name}
272 """
273 return get_codegen_result(code, "res")
276def set_tl_extra_backend_module(vendor_name=None):
277 vendor_info = get_vendor_info(vendor_name)
278 _state.device_name = _state.device_name or vendor_info.device_name
279 extra_name = vendor_info.triton_extra_name or _state.device_name
280 module_str = f"triton.language.extra.{extra_name}.libdevice"
281 _state.tl_extra_backend_module = importlib.import_module(module_str)
284def get_tl_extra_backend_module():
285 return _state.tl_extra_backend_module
288def set_torch_backend_device_fn(vendor_name=None):
289 _state.device_name = _state.device_name or get_vendor_info(vendor_name).device_name
290 module_str = f"torch.backends.{_state.device_name}"
291 if _state.device_name in ("musa", "aipu", "npu", "txda", "ptpu", "gcu"):
292 _state.torch_device_fn_device = None
293 else:
294 _state.torch_device_fn_device = importlib.import_module(module_str)
297def get_torch_backend_device_fn():
298 return _state.torch_device_fn_device
301def gen_torch_device_object(vendor_name=None):
302 if _state.torch_device_object is not None:
303 return _state.torch_device_object
304 _state.device_name = _state.device_name or get_vendor_info(vendor_name).device_name
305 code = f"""
306import torch
307fn = torch.{_state.device_name}
308"""
309 _state.torch_device_object = get_codegen_result(code, "fn")
311 # SPACEMIT CPU backend needs special device guard handling
312 if vendor_name == "spacemit":
313 backends_module = importlib.import_module("flag_gems.runtime.backend._spacemit")
314 setattr(
315 _state.torch_device_object,
316 "_DeviceGuard",
317 getattr(backends_module, "_DeviceGuard"),
318 )
319 setattr(
320 _state.torch_device_object,
321 "device",
322 getattr(backends_module, "_DeviceWrapper"),
323 )
324 # Override current_device to return integer 0 for kernel cache indexing
325 setattr(_state.torch_device_object, "current_device", lambda: 0)
327 return _state.torch_device_object
330def get_vendor_module(vendor_name, query=False):
331 def get_module(vendor_name):
332 current_file_path = os.path.abspath(__file__)
333 current_dir_path = os.path.dirname(current_file_path)
334 sys.path.append(current_dir_path)
335 return importlib.import_module(vendor_name)
337 if (
338 query
339 ): # The purpose of a query is to provide the user with the instance that he wants to import
340 return get_module(vendor_name)
342 if _state.vendor_module is None:
343 _state.vendor_module = get_module("_" + vendor_name)
344 return _state.vendor_module
347def get_vendor_info(vendor_name=None, query=False):
348 if query:
349 return get_vendor_module(vendor_name, query).vendor_info
350 get_vendor_module(vendor_name)
351 return _state.vendor_module.vendor_info
354def get_vendor_infos():
355 infos = []
356 for vendor_name in vendors.get_all_vendors():
357 try:
358 infos.append(get_vendor_info(f"_{vendor_name}", query=True))
359 except Exception:
360 continue
362 return infos
365def get_customized_ops(vendor_name=None):
366 import_vendor_extra_lib(vendor_name)
367 if _state.customized_ops is not None:
368 return _state.customized_ops
369 _state.customized_ops = []
370 if _state.ops_module is not None:
371 ops = inspect.getmembers(_state.ops_module, inspect.isfunction)
372 _state.customized_ops += ops
373 if _state.fused_module is not None:
374 fused_ops = inspect.getmembers(_state.fused_module, inspect.isfunction)
375 _state.customized_ops += fused_ops
376 return _state.customized_ops
379def get_ops(vendor_name=None):
380 """Provide a unified interface for the upper layer"""
381 return get_customized_ops(vendor_name)
384def get_unused_ops(vendor_name=None):
385 global vendor_module # noqa: F824
386 get_vendor_module(vendor_name)
387 return list(_state.vendor_module.CUSTOMIZED_UNUSED_OPS)
390def get_heuristic_config(vendor_name=None):
391 config_name = "heuristics_config_utils"
392 mod_name = f"_{vendor_name}.{config_name}"
393 try:
394 _state.heuristic_config_module = importlib.import_module(mod_name)
395 except Exception:
396 mod_name = f"_nvidia.{config_name}"
397 _state.heuristic_config_module = importlib.import_module(mod_name)
398 return getattr(_state.heuristic_config_module, "HEURISTICS_CONFIGS", None)
401def get_tune_config(vendor_name=None):
402 global vendor_module # noqa: F824
403 get_vendor_module(vendor_name)
404 return backend_utils.get_tune_config(vendor_name)
407def get_expand_config(op_name=None, file_path=None):
408 return backend_utils.get_expand_config(op_name=op_name, file_path=file_path)
411def get_backend_state() -> BackendState:
412 """Get the global BackendState singleton instance."""
413 return _state
416__all__ = ["*"]