Coverage for src/flag_gems/runtime/backend/__init__.py: 78%

224 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import ast 

2import functools 

3import importlib 

4import inspect 

5import os 

6import sys 

7from pathlib import Path 

8 

9from ..common import vendors 

10from . import backend_utils 

11 

12 

13class BackendState: 

14 """Singleton class to manage backend state variables.""" 

15 

16 _instance = None 

17 

18 def __new__(cls): 

19 if cls._instance is None: 

20 cls._instance = super().__new__(cls) 

21 cls._instance._initialized = False 

22 return cls._instance 

23 

24 def __init__(self): 

25 if self._initialized: 

26 return 

27 self._initialized = True 

28 self.vendor_module = None 

29 self.device_name = None 

30 self.torch_device_object = None 

31 self.torch_device_fn_device = None 

32 self.tl_extra_backend_module = None 

33 self.ops_module = None 

34 self.fused_module = None 

35 self.heuristic_config_module = None 

36 self.vendor_extra_lib_imported = False 

37 self.device_fn_cache = {} 

38 self.customized_ops = None 

39 

40 

41# Global singleton instance 

42_state = BackendState() 

43 

44 

45class BackendArchEvent: 

46 has_arch: bool = False 

47 _instance = None 

48 _initialized: bool = False 

49 

50 def __new__(cls, *args, **kwargs): 

51 if cls._instance is None: 

52 cls._instance = super().__new__(cls) 

53 return cls._instance 

54 

55 def __init__(self, backend=None): 

56 if BackendArchEvent._initialized: 

57 return 

58 BackendArchEvent._initialized = True 

59 self.backend = backend 

60 self.error_msgs = [] 

61 self.arch = self.get_arch() 

62 if self.has_arch: 

63 self.supported_archs = self._get_supported_archs() 

64 # current_arch_path is like FlagGems/src/flag_gems/runtime/backend/_nvidia/hopper 

65 self.current_arch_path = self.supported_archs.get(self.arch) 

66 self.arch_module = self.get_arch_module() 

67 self.autotune_configs = self.get_autotune_configs() 

68 self.heuristics_configs = self.get_heuristics_configs() 

69 

70 def get_functions_from_module(self, module): 

71 return inspect.getmembers(module, inspect.isfunction) if module else [] 

72 

73 def get_heuristics_configs(self): 

74 try: 

75 heuristic_module = self.arch_module 

76 except Exception: # noqa E722 

77 sys.path.insert(0, str(self.current_arch_path)) 

78 heuristic_module = importlib.import_module("heuristics_config_utils") 

79 sys.path.remove(str(self.current_arch_path)) 

80 return getattr(heuristic_module, "HEURISTICS_CONFIGS", None) 

81 

82 def get_autotune_configs(self): 

83 path = self.current_arch_path 

84 return backend_utils.get_tune_config(file_path=path) 

85 

86 def get_arch(self, device=0): 

87 if not hasattr(_state.vendor_module, "ARCH_MAP"): 

88 return 

89 arch_map = _state.vendor_module.ARCH_MAP 

90 arch_string = os.environ.get("ARCH", "") 

91 arch_string_num = arch_string.split("_")[-1][0] if arch_string else arch_string 

92 if not arch_string_num: 

93 try: 

94 if not _state.torch_device_object.is_available(): 

95 return False 

96 props = _state.torch_device_object.get_device_properties(device) 

97 arch_string_num = str(props.major) 

98 except Exception: 

99 self.has_arch = False 

100 if arch_string_num not in arch_map: 

101 print( 

102 f"[INFO] : FlagGems Unsupported GPU arch {arch_string} specialization" 

103 ) 

104 else: 

105 self.has_arch = True 

106 return arch_map[arch_string_num] 

107 

108 def _get_supported_archs(self, path=None): 

109 path = Path(path or _state.vendor_module.__path__[0]) 

110 path = path.parent if path.is_file() else path 

111 excluded = ("ops", "fused") 

112 return { 

113 p.name: str(p) 

114 for p in path.iterdir() 

115 if p.is_dir() and p.name not in excluded and not p.name.startswith("_") 

116 } 

117 

118 def get_supported_archs(self): 

119 return list(self.supported_archs.keys()) 

120 

121 def get_arch_module(self): 

122 """Load backend.<arch>""" 

123 path_dir = os.path.dirname(self.current_arch_path) 

124 sys.path.insert(0, str(path_dir)) 

125 current_arch_module = importlib.import_module(self.arch) 

126 sys.path.remove(str(path_dir)) 

127 return current_arch_module 

128 

129 def get_arch_ops(self): 

130 arch_specialized_ops = [] 

131 sys.path.append(self.current_arch_path) 

132 ops_module = getattr(self.arch_module, "ops", None) 

133 try: 

134 if ops_module is None: 

135 ops_module = importlib.import_module(f"{self.arch}.ops") 

136 except Exception: 

137 try: 

138 sys.path.append(self.current_arch_path) 

139 ops_module = importlib.import_module(f"{self.arch}.ops") 

140 arch_specialized_ops.extend(self.get_functions_from_module(ops_module)) 

141 except Exception as err_msg: 

142 self.error_msgs.append(err_msg) 

143 

144 if ops_module is not None: 

145 arch_specialized_ops.extend(self.get_functions_from_module(ops_module)) 

146 

147 return arch_specialized_ops 

148 

149 

150def _import_module_safe(module_name, vendor_name, module_type): 

151 """Helper to import a module with proper error handling.""" 

152 try: 

153 return importlib.import_module(module_name) 

154 except ModuleNotFoundError: 

155 print( 

156 f"[Note] No specialized {module_type} operators were found for " 

157 f"the {vendor_name}, generic {module_type} operators will be used by default." 

158 ) 

159 except Exception as e: 

160 raise RuntimeError(f"Failed to import vendor extra lib: {e}") 

161 

162 

163def import_vendor_extra_lib(vendor_name=None): 

164 if _state.vendor_extra_lib_imported: 

165 return 

166 _state.ops_module = _import_module_safe( 

167 f"_{vendor_name}.ops", vendor_name, "common" 

168 ) 

169 _state.fused_module = _import_module_safe( 

170 f"_{vendor_name}.fused", vendor_name, "fused" 

171 ) 

172 _state.vendor_extra_lib_imported = True 

173 

174 

175def get_codegen_result(code, result_key): 

176 parsed_ast = ast.parse(code) 

177 compiled_code = compile(parsed_ast, filename="<ast>", mode="exec") 

178 try: 

179 exec(compiled_code, globals()) 

180 except Exception as e: 

181 raise e 

182 return globals()[result_key] 

183 

184 

185@functools.lru_cache(maxsize=32) 

186def gen_torch_tensor_attr_res(tensor, attr_name): 

187 _state.device_name = _state.device_name or get_vendor_info().device_name 

188 code = f""" 

189import torch 

190res = {tensor}.{attr_name} 

191 """ 

192 return get_codegen_result(code, "res") 

193 

194 

195def set_tl_extra_backend_module(vendor_name=None): 

196 vendor_info = get_vendor_info(vendor_name) 

197 _state.device_name = _state.device_name or vendor_info.device_name 

198 extra_name = vendor_info.triton_extra_name or _state.device_name 

199 module_str = f"triton.language.extra.{extra_name}.libdevice" 

200 _state.tl_extra_backend_module = importlib.import_module(module_str) 

201 

202 

203def get_tl_extra_backend_module(): 

204 return _state.tl_extra_backend_module 

205 

206 

207def set_torch_backend_device_fn(vendor_name=None): 

208 _state.device_name = _state.device_name or get_vendor_info(vendor_name).device_name 

209 module_str = f"torch.backends.{_state.device_name}" 

210 if _state.device_name in ("musa", "aipu", "npu", "txda", "ptpu", "gcu"): 

211 _state.torch_device_fn_device = None 

212 else: 

213 _state.torch_device_fn_device = importlib.import_module(module_str) 

214 

215 

216def get_torch_backend_device_fn(): 

217 return _state.torch_device_fn_device 

218 

219 

220def gen_torch_device_object(vendor_name=None): 

221 if _state.torch_device_object is not None: 

222 return _state.torch_device_object 

223 _state.device_name = _state.device_name or get_vendor_info(vendor_name).device_name 

224 code = f""" 

225import torch 

226fn = torch.{_state.device_name} 

227""" 

228 _state.torch_device_object = get_codegen_result(code, "fn") 

229 

230 # SPACEMIT CPU backend needs special device guard handling 

231 if vendor_name == "spacemit": 

232 backends_module = importlib.import_module("flag_gems.runtime.backend._spacemit") 

233 setattr( 

234 _state.torch_device_object, 

235 "_DeviceGuard", 

236 getattr(backends_module, "_DeviceGuard"), 

237 ) 

238 setattr( 

239 _state.torch_device_object, 

240 "device", 

241 getattr(backends_module, "_DeviceWrapper"), 

242 ) 

243 # Override current_device to return integer 0 for kernel cache indexing 

244 setattr(_state.torch_device_object, "current_device", lambda: 0) 

245 

246 return _state.torch_device_object 

247 

248 

249def get_vendor_module(vendor_name, query=False): 

250 def get_module(vendor_name): 

251 current_file_path = os.path.abspath(__file__) 

252 current_dir_path = os.path.dirname(current_file_path) 

253 sys.path.append(current_dir_path) 

254 return importlib.import_module(vendor_name) 

255 

256 if ( 

257 query 

258 ): # The purpose of a query is to provide the user with the instance that he wants to import 

259 return get_module(vendor_name) 

260 

261 if _state.vendor_module is None: 

262 _state.vendor_module = get_module("_" + vendor_name) 

263 return _state.vendor_module 

264 

265 

266def get_vendor_info(vendor_name=None, query=False): 

267 if query: 

268 return get_vendor_module(vendor_name, query).vendor_info 

269 get_vendor_module(vendor_name) 

270 return _state.vendor_module.vendor_info 

271 

272 

273def get_vendor_infos(): 

274 infos = [] 

275 for vendor_name in vendors.get_all_vendors(): 

276 try: 

277 infos.append(get_vendor_info(f"_{vendor_name}", query=True)) 

278 except Exception: 

279 continue 

280 

281 return infos 

282 

283 

284def get_customized_ops(vendor_name=None): 

285 import_vendor_extra_lib(vendor_name) 

286 if _state.customized_ops is not None: 

287 return _state.customized_ops 

288 _state.customized_ops = [] 

289 if _state.ops_module is not None: 

290 ops = inspect.getmembers(_state.ops_module, inspect.isfunction) 

291 _state.customized_ops += ops 

292 if _state.fused_module is not None: 

293 fused_ops = inspect.getmembers(_state.fused_module, inspect.isfunction) 

294 _state.customized_ops += fused_ops 

295 return _state.customized_ops 

296 

297 

298def get_unused_ops(vendor_name=None): 

299 global vendor_module # noqa: F824 

300 get_vendor_module(vendor_name) 

301 return list(_state.vendor_module.CUSTOMIZED_UNUSED_OPS) 

302 

303 

304def get_heuristic_config(vendor_name=None): 

305 config_name = "heuristics_config_utils" 

306 mod_name = f"_{vendor_name}.{config_name}" 

307 try: 

308 _state.heuristic_config_module = importlib.import_module(mod_name) 

309 except Exception: 

310 mod_name = f"_nvidia.{config_name}" 

311 _state.heuristic_config_module = importlib.import_module(mod_name) 

312 return getattr(_state.heuristic_config_module, "HEURISTICS_CONFIGS", None) 

313 

314 

315def get_tune_config(vendor_name=None): 

316 global vendor_module # noqa: F824 

317 get_vendor_module(vendor_name) 

318 return backend_utils.get_tune_config(vendor_name) 

319 

320 

321def get_expand_config(op_name=None, file_path=None): 

322 return backend_utils.get_expand_config(op_name=op_name, file_path=file_path) 

323 

324 

325def get_backend_state() -> BackendState: 

326 """Get the global BackendState singleton instance.""" 

327 return _state 

328 

329 

330__all__ = ["*"]