Coverage for src/flag_gems/runtime/backend/__init__.py: 83%
218 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import ast
2import functools
3import importlib
4import inspect
5import os
6import sys
7from pathlib import Path
9from ..common import vendors
10from . import backend_utils
13class BackendState:
14 """Singleton class to manage backend state variables."""
16 _instance = None
18 def __new__(cls):
19 if cls._instance is None:
20 cls._instance = super().__new__(cls)
21 cls._instance._initialized = False
22 return cls._instance
24 def __init__(self):
25 if self._initialized:
26 return
27 self._initialized = True
28 self.vendor_module = None
29 self.device_name = None
30 self.torch_device_object = None
31 self.torch_device_fn_device = None
32 self.tl_extra_backend_module = None
33 self.ops_module = None
34 self.fused_module = None
35 self.heuristic_config_module = None
36 self.vendor_extra_lib_imported = False
37 self.device_fn_cache = {}
38 self.customized_ops = None
41# Global singleton instance
42_state = BackendState()
45class BackendArchEvent:
46 has_arch: bool = False
47 _instance = None
48 _initialized: bool = False
50 def __new__(cls, *args, **kwargs):
51 if cls._instance is None:
52 cls._instance = super().__new__(cls)
53 return cls._instance
55 def __init__(self, backend=None):
56 if BackendArchEvent._initialized:
57 return
58 BackendArchEvent._initialized = True
59 self.backend = backend
60 self.error_msgs = []
61 self.arch = self.get_arch()
62 if self.has_arch:
63 self.supported_archs = self._get_supported_archs()
64 # current_arch_path is like FlagGems/src/flag_gems/runtime/backend/_nvidia/hopper
65 self.current_arch_path = self.supported_archs.get(self.arch)
66 self.arch_module = self.get_arch_module()
67 self.autotune_configs = self.get_autotune_configs()
68 self.heuristics_configs = self.get_heuristics_configs()
70 def get_functions_from_module(self, module):
71 return inspect.getmembers(module, inspect.isfunction) if module else []
73 def get_heuristics_configs(self):
74 try:
75 heuristic_module = self.arch_module
76 except Exception: # noqa E722
77 sys.path.insert(0, str(self.current_arch_path))
78 heuristic_module = importlib.import_module("heuristics_config_utils")
79 sys.path.remove(str(self.current_arch_path))
80 return getattr(heuristic_module, "HEURISTICS_CONFIGS", None)
82 def get_autotune_configs(self):
83 path = self.current_arch_path
84 return backend_utils.get_tune_config(file_path=path)
86 def get_arch(self, device=0):
87 if not hasattr(_state.vendor_module, "ARCH_MAP"):
88 return
89 arch_map = _state.vendor_module.ARCH_MAP
90 arch_string = os.environ.get("ARCH", "")
91 arch_string_num = arch_string.split("_")[-1][0] if arch_string else arch_string
92 if not arch_string_num:
93 try:
94 if not _state.torch_device_object.is_available():
95 return False
96 props = _state.torch_device_object.get_device_properties(device)
97 arch_string_num = str(props.major)
98 except Exception:
99 self.has_arch = False
100 if arch_string_num not in arch_map:
101 print(
102 f"[INFO] : FlagGems Unsupported GPU arch {arch_string} specialization"
103 )
104 else:
105 self.has_arch = True
106 return arch_map[arch_string_num]
108 def _get_supported_archs(self, path=None):
109 path = Path(path or _state.vendor_module.__path__[0])
110 path = path.parent if path.is_file() else path
111 excluded = ("ops", "fused")
112 return {
113 p.name: str(p)
114 for p in path.iterdir()
115 if p.is_dir() and p.name not in excluded and not p.name.startswith("_")
116 }
118 def get_supported_archs(self):
119 return list(self.supported_archs.keys())
121 def get_arch_module(self):
122 """Load backend.<arch>"""
123 path_dir = os.path.dirname(self.current_arch_path)
124 sys.path.insert(0, str(path_dir))
125 current_arch_module = importlib.import_module(self.arch)
126 sys.path.remove(str(path_dir))
127 return current_arch_module
129 def get_arch_ops(self):
130 arch_specialized_ops = []
131 sys.path.append(self.current_arch_path)
132 ops_module = getattr(self.arch_module, "ops", None)
133 try:
134 if ops_module is None:
135 ops_module = importlib.import_module(f"{self.arch}.ops")
136 except Exception:
137 try:
138 sys.path.append(self.current_arch_path)
139 ops_module = importlib.import_module(f"{self.arch}.ops")
140 arch_specialized_ops.extend(self.get_functions_from_module(ops_module))
141 except Exception as err_msg:
142 self.error_msgs.append(err_msg)
144 return arch_specialized_ops
147def _import_module_safe(module_name, vendor_name, module_type):
148 """Helper to import a module with proper error handling."""
149 try:
150 return importlib.import_module(module_name)
151 except ModuleNotFoundError:
152 print(
153 f"[Note] No specialized {module_type} operators were found for "
154 f"the {vendor_name}, generic {module_type} operators will be used by default."
155 )
156 except Exception as e:
157 raise RuntimeError(f"Failed to import vendor extra lib: {e}")
160def import_vendor_extra_lib(vendor_name=None):
161 if _state.vendor_extra_lib_imported:
162 return
163 _state.ops_module = _import_module_safe(
164 f"_{vendor_name}.ops", vendor_name, "common"
165 )
166 _state.fused_module = _import_module_safe(
167 f"_{vendor_name}.fused", vendor_name, "fused"
168 )
169 _state.vendor_extra_lib_imported = True
172def get_codegen_result(code, result_key):
173 parsed_ast = ast.parse(code)
174 compiled_code = compile(parsed_ast, filename="<ast>", mode="exec")
175 try:
176 exec(compiled_code, globals())
177 except Exception as e:
178 raise e
179 return globals()[result_key]
182@functools.lru_cache(maxsize=32)
183def gen_torch_tensor_attr_res(tensor, attr_name):
184 _state.device_name = _state.device_name or get_vendor_info().device_name
185 code = f"""
186import torch
187res = {tensor}.{attr_name}
188 """
189 return get_codegen_result(code, "res")
192def set_tl_extra_backend_module(vendor_name=None):
193 vendor_info = get_vendor_info(vendor_name)
194 _state.device_name = _state.device_name or vendor_info.device_name
195 extra_name = vendor_info.triton_extra_name or _state.device_name
196 module_str = f"triton.language.extra.{extra_name}.libdevice"
197 _state.tl_extra_backend_module = importlib.import_module(module_str)
200def get_tl_extra_backend_module():
201 return _state.tl_extra_backend_module
204def set_torch_backend_device_fn(vendor_name=None):
205 _state.device_name = _state.device_name or get_vendor_info(vendor_name).device_name
206 module_str = f"torch.backends.{_state.device_name}"
207 if _state.device_name in ("musa", "aipu", "npu", "txda", "ptpu", "gcu"):
208 _state.torch_device_fn_device = None
209 else:
210 _state.torch_device_fn_device = importlib.import_module(module_str)
213def get_torch_backend_device_fn():
214 return _state.torch_device_fn_device
217def gen_torch_device_object(vendor_name=None):
218 if _state.torch_device_object is not None:
219 return _state.torch_device_object
220 _state.device_name = _state.device_name or get_vendor_info(vendor_name).device_name
221 code = f"""
222import torch
223fn = torch.{_state.device_name}
224"""
225 _state.torch_device_object = get_codegen_result(code, "fn")
226 return _state.torch_device_object
229def get_vendor_module(vendor_name, query=False):
230 def get_module(vendor_name):
231 current_file_path = os.path.abspath(__file__)
232 current_dir_path = os.path.dirname(current_file_path)
233 sys.path.append(current_dir_path)
234 return importlib.import_module(vendor_name)
236 if (
237 query
238 ): # The purpose of a query is to provide the user with the instance that he wants to import
239 return get_module(vendor_name)
241 if _state.vendor_module is None:
242 _state.vendor_module = get_module("_" + vendor_name)
243 return _state.vendor_module
246def get_vendor_info(vendor_name=None, query=False):
247 if query:
248 return get_vendor_module(vendor_name, query).vendor_info
249 get_vendor_module(vendor_name)
250 return _state.vendor_module.vendor_info
253def get_vendor_infos():
254 infos = []
255 for vendor_name in vendors.get_all_vendors():
256 try:
257 infos.append(get_vendor_info(f"_{vendor_name}", query=True))
258 except Exception:
259 continue
261 return infos
264def get_customized_ops(vendor_name=None):
265 import_vendor_extra_lib(vendor_name)
266 if _state.customized_ops is not None:
267 return _state.customized_ops
268 _state.customized_ops = []
269 if _state.ops_module is not None:
270 ops = inspect.getmembers(_state.ops_module, inspect.isfunction)
271 _state.customized_ops += ops
272 if _state.fused_module is not None:
273 fused_ops = inspect.getmembers(_state.fused_module, inspect.isfunction)
274 _state.customized_ops += fused_ops
275 return _state.customized_ops
278def get_unused_ops(vendor_name=None):
279 global vendor_module # noqa: F824
280 get_vendor_module(vendor_name)
281 return list(_state.vendor_module.CUSTOMIZED_UNUSED_OPS)
284def get_heuristic_config(vendor_name=None):
285 config_name = "heuristics_config_utils"
286 default_backend = "nvidia"
287 for backend in (vendor_name, default_backend):
288 mod_name = f"_{backend}.{config_name}"
289 try:
290 _state.heuristic_config_module = importlib.import_module(mod_name)
291 except Exception:
292 continue
293 return getattr(_state.heuristic_config_module, "HEURISTICS_CONFIGS", None)
296def get_tune_config(vendor_name=None):
297 global vendor_module # noqa: F824
298 get_vendor_module(vendor_name)
299 return backend_utils.get_tune_config(vendor_name)
302def get_expand_config(op_name=None, file_path=None):
303 return backend_utils.get_expand_config(op_name=op_name, file_path=file_path)
306def get_backend_state() -> BackendState:
307 """Get the global BackendState singleton instance."""
308 return _state
311__all__ = ["*"]