Coverage for src/flag_gems/runtime/backend/__init__.py: 78%
224 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import ast
2import functools
3import importlib
4import inspect
5import os
6import sys
7from pathlib import Path
9from ..common import vendors
10from . import backend_utils
13class BackendState:
14 """Singleton class to manage backend state variables."""
16 _instance = None
18 def __new__(cls):
19 if cls._instance is None:
20 cls._instance = super().__new__(cls)
21 cls._instance._initialized = False
22 return cls._instance
24 def __init__(self):
25 if self._initialized:
26 return
27 self._initialized = True
28 self.vendor_module = None
29 self.device_name = None
30 self.torch_device_object = None
31 self.torch_device_fn_device = None
32 self.tl_extra_backend_module = None
33 self.ops_module = None
34 self.fused_module = None
35 self.heuristic_config_module = None
36 self.vendor_extra_lib_imported = False
37 self.device_fn_cache = {}
38 self.customized_ops = None
41# Global singleton instance
42_state = BackendState()
45class BackendArchEvent:
46 has_arch: bool = False
47 _instance = None
48 _initialized: bool = False
50 def __new__(cls, *args, **kwargs):
51 if cls._instance is None:
52 cls._instance = super().__new__(cls)
53 return cls._instance
55 def __init__(self, backend=None):
56 if BackendArchEvent._initialized:
57 return
58 BackendArchEvent._initialized = True
59 self.backend = backend
60 self.error_msgs = []
61 self.arch = self.get_arch()
62 if self.has_arch:
63 self.supported_archs = self._get_supported_archs()
64 # current_arch_path is like FlagGems/src/flag_gems/runtime/backend/_nvidia/hopper
65 self.current_arch_path = self.supported_archs.get(self.arch)
66 self.arch_module = self.get_arch_module()
67 self.autotune_configs = self.get_autotune_configs()
68 self.heuristics_configs = self.get_heuristics_configs()
70 def get_functions_from_module(self, module):
71 return inspect.getmembers(module, inspect.isfunction) if module else []
73 def get_heuristics_configs(self):
74 try:
75 heuristic_module = self.arch_module
76 except Exception: # noqa E722
77 sys.path.insert(0, str(self.current_arch_path))
78 heuristic_module = importlib.import_module("heuristics_config_utils")
79 sys.path.remove(str(self.current_arch_path))
80 return getattr(heuristic_module, "HEURISTICS_CONFIGS", None)
82 def get_autotune_configs(self):
83 path = self.current_arch_path
84 return backend_utils.get_tune_config(file_path=path)
86 def get_arch(self, device=0):
87 if not hasattr(_state.vendor_module, "ARCH_MAP"):
88 return
89 arch_map = _state.vendor_module.ARCH_MAP
90 arch_string = os.environ.get("ARCH", "")
91 arch_string_num = arch_string.split("_")[-1][0] if arch_string else arch_string
92 if not arch_string_num:
93 try:
94 if not _state.torch_device_object.is_available():
95 return False
96 props = _state.torch_device_object.get_device_properties(device)
97 arch_string_num = str(props.major)
98 except Exception:
99 self.has_arch = False
100 if arch_string_num not in arch_map:
101 print(
102 f"[INFO] : FlagGems Unsupported GPU arch {arch_string} specialization"
103 )
104 else:
105 self.has_arch = True
106 return arch_map[arch_string_num]
108 def _get_supported_archs(self, path=None):
109 path = Path(path or _state.vendor_module.__path__[0])
110 path = path.parent if path.is_file() else path
111 excluded = ("ops", "fused")
112 return {
113 p.name: str(p)
114 for p in path.iterdir()
115 if p.is_dir() and p.name not in excluded and not p.name.startswith("_")
116 }
118 def get_supported_archs(self):
119 return list(self.supported_archs.keys())
121 def get_arch_module(self):
122 """Load backend.<arch>"""
123 path_dir = os.path.dirname(self.current_arch_path)
124 sys.path.insert(0, str(path_dir))
125 current_arch_module = importlib.import_module(self.arch)
126 sys.path.remove(str(path_dir))
127 return current_arch_module
129 def get_arch_ops(self):
130 arch_specialized_ops = []
131 sys.path.append(self.current_arch_path)
132 ops_module = getattr(self.arch_module, "ops", None)
133 try:
134 if ops_module is None:
135 ops_module = importlib.import_module(f"{self.arch}.ops")
136 except Exception:
137 try:
138 sys.path.append(self.current_arch_path)
139 ops_module = importlib.import_module(f"{self.arch}.ops")
140 arch_specialized_ops.extend(self.get_functions_from_module(ops_module))
141 except Exception as err_msg:
142 self.error_msgs.append(err_msg)
144 if ops_module is not None:
145 arch_specialized_ops.extend(self.get_functions_from_module(ops_module))
147 return arch_specialized_ops
150def _import_module_safe(module_name, vendor_name, module_type):
151 """Helper to import a module with proper error handling."""
152 try:
153 return importlib.import_module(module_name)
154 except ModuleNotFoundError:
155 print(
156 f"[Note] No specialized {module_type} operators were found for "
157 f"the {vendor_name}, generic {module_type} operators will be used by default."
158 )
159 except Exception as e:
160 raise RuntimeError(f"Failed to import vendor extra lib: {e}")
163def import_vendor_extra_lib(vendor_name=None):
164 if _state.vendor_extra_lib_imported:
165 return
166 _state.ops_module = _import_module_safe(
167 f"_{vendor_name}.ops", vendor_name, "common"
168 )
169 _state.fused_module = _import_module_safe(
170 f"_{vendor_name}.fused", vendor_name, "fused"
171 )
172 _state.vendor_extra_lib_imported = True
175def get_codegen_result(code, result_key):
176 parsed_ast = ast.parse(code)
177 compiled_code = compile(parsed_ast, filename="<ast>", mode="exec")
178 try:
179 exec(compiled_code, globals())
180 except Exception as e:
181 raise e
182 return globals()[result_key]
185@functools.lru_cache(maxsize=32)
186def gen_torch_tensor_attr_res(tensor, attr_name):
187 _state.device_name = _state.device_name or get_vendor_info().device_name
188 code = f"""
189import torch
190res = {tensor}.{attr_name}
191 """
192 return get_codegen_result(code, "res")
195def set_tl_extra_backend_module(vendor_name=None):
196 vendor_info = get_vendor_info(vendor_name)
197 _state.device_name = _state.device_name or vendor_info.device_name
198 extra_name = vendor_info.triton_extra_name or _state.device_name
199 module_str = f"triton.language.extra.{extra_name}.libdevice"
200 _state.tl_extra_backend_module = importlib.import_module(module_str)
203def get_tl_extra_backend_module():
204 return _state.tl_extra_backend_module
207def set_torch_backend_device_fn(vendor_name=None):
208 _state.device_name = _state.device_name or get_vendor_info(vendor_name).device_name
209 module_str = f"torch.backends.{_state.device_name}"
210 if _state.device_name in ("musa", "aipu", "npu", "txda", "ptpu", "gcu"):
211 _state.torch_device_fn_device = None
212 else:
213 _state.torch_device_fn_device = importlib.import_module(module_str)
216def get_torch_backend_device_fn():
217 return _state.torch_device_fn_device
220def gen_torch_device_object(vendor_name=None):
221 if _state.torch_device_object is not None:
222 return _state.torch_device_object
223 _state.device_name = _state.device_name or get_vendor_info(vendor_name).device_name
224 code = f"""
225import torch
226fn = torch.{_state.device_name}
227"""
228 _state.torch_device_object = get_codegen_result(code, "fn")
230 # SPACEMIT CPU backend needs special device guard handling
231 if vendor_name == "spacemit":
232 backends_module = importlib.import_module("flag_gems.runtime.backend._spacemit")
233 setattr(
234 _state.torch_device_object,
235 "_DeviceGuard",
236 getattr(backends_module, "_DeviceGuard"),
237 )
238 setattr(
239 _state.torch_device_object,
240 "device",
241 getattr(backends_module, "_DeviceWrapper"),
242 )
243 # Override current_device to return integer 0 for kernel cache indexing
244 setattr(_state.torch_device_object, "current_device", lambda: 0)
246 return _state.torch_device_object
249def get_vendor_module(vendor_name, query=False):
250 def get_module(vendor_name):
251 current_file_path = os.path.abspath(__file__)
252 current_dir_path = os.path.dirname(current_file_path)
253 sys.path.append(current_dir_path)
254 return importlib.import_module(vendor_name)
256 if (
257 query
258 ): # The purpose of a query is to provide the user with the instance that he wants to import
259 return get_module(vendor_name)
261 if _state.vendor_module is None:
262 _state.vendor_module = get_module("_" + vendor_name)
263 return _state.vendor_module
266def get_vendor_info(vendor_name=None, query=False):
267 if query:
268 return get_vendor_module(vendor_name, query).vendor_info
269 get_vendor_module(vendor_name)
270 return _state.vendor_module.vendor_info
273def get_vendor_infos():
274 infos = []
275 for vendor_name in vendors.get_all_vendors():
276 try:
277 infos.append(get_vendor_info(f"_{vendor_name}", query=True))
278 except Exception:
279 continue
281 return infos
284def get_customized_ops(vendor_name=None):
285 import_vendor_extra_lib(vendor_name)
286 if _state.customized_ops is not None:
287 return _state.customized_ops
288 _state.customized_ops = []
289 if _state.ops_module is not None:
290 ops = inspect.getmembers(_state.ops_module, inspect.isfunction)
291 _state.customized_ops += ops
292 if _state.fused_module is not None:
293 fused_ops = inspect.getmembers(_state.fused_module, inspect.isfunction)
294 _state.customized_ops += fused_ops
295 return _state.customized_ops
298def get_unused_ops(vendor_name=None):
299 global vendor_module # noqa: F824
300 get_vendor_module(vendor_name)
301 return list(_state.vendor_module.CUSTOMIZED_UNUSED_OPS)
304def get_heuristic_config(vendor_name=None):
305 config_name = "heuristics_config_utils"
306 mod_name = f"_{vendor_name}.{config_name}"
307 try:
308 _state.heuristic_config_module = importlib.import_module(mod_name)
309 except Exception:
310 mod_name = f"_nvidia.{config_name}"
311 _state.heuristic_config_module = importlib.import_module(mod_name)
312 return getattr(_state.heuristic_config_module, "HEURISTICS_CONFIGS", None)
315def get_tune_config(vendor_name=None):
316 global vendor_module # noqa: F824
317 get_vendor_module(vendor_name)
318 return backend_utils.get_tune_config(vendor_name)
321def get_expand_config(op_name=None, file_path=None):
322 return backend_utils.get_expand_config(op_name=op_name, file_path=file_path)
325def get_backend_state() -> BackendState:
326 """Get the global BackendState singleton instance."""
327 return _state
330__all__ = ["*"]