Coverage for src/flag_gems/runtime/backend/__init__.py: 72%

285 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import ast 

2import functools 

3import importlib 

4import inspect 

5import os 

6import sys 

7from pathlib import Path 

8 

9from ..common import vendors 

10from . import backend_utils 

11from .backend_utils import BackendEventBase 

12 

13 

14class BackendState: 

15 """Singleton class to manage backend state variables.""" 

16 

17 _instance = None 

18 

19 def __new__(cls): 

20 if cls._instance is None: 

21 cls._instance = super().__new__(cls) 

22 cls._instance._initialized = False 

23 return cls._instance 

24 

25 def __init__(self): 

26 if self._initialized: 

27 return 

28 self._initialized = True 

29 self.vendor_module = None 

30 self.device_name = None 

31 self.torch_device_object = None 

32 self.torch_device_fn_device = None 

33 self.tl_extra_backend_module = None 

34 self.ops_module = None 

35 self.fused_module = None 

36 self.heuristic_config_module = None 

37 self.vendor_extra_lib_imported = False 

38 self.device_fn_cache = {} 

39 self.customized_ops = None 

40 

41 

42# Global singleton instance 

43_state = BackendState() 

44 

45 

46class TritonVersionEvent(BackendEventBase): 

47 _instance = None 

48 has_version_spec = False 

49 

50 def __new__(cls, *args, **kwargs): 

51 if cls._instance is None: 

52 cls._instance = super().__new__(cls) 

53 return cls._instance 

54 

55 def __init__(self, version=None): 

56 self.has_version_spec = False 

57 self.version = version if version is not None else self.get_version() 

58 self.dir = self.get_version_spec_dir() 

59 if self.dir and Path(self.dir).exists(): 

60 self.module = self.get_version_spec_module() 

61 self.has_version_spec = True 

62 

63 def is_available(self): 

64 return self.has_version_spec 

65 

66 def get_version_spec_dir(self, path=None): 

67 dir_name = f"triton_{self.version}" 

68 backend_path = Path(path or _state.vendor_module.__path__[0]) 

69 backend_path = backend_path.parent if backend_path.is_file() else backend_path 

70 excluded = ("ops", "fused") 

71 return { 

72 p.name: str(p) 

73 for p in backend_path.iterdir() 

74 if p.is_dir() and p.name not in excluded and not p.name.startswith("_") 

75 }.get(dir_name, None) 

76 

77 def get_functions_from_module(self, module): 

78 return inspect.getmembers(module, inspect.isfunction) if module else [] 

79 

80 def get_version_spec_module(self): 

81 module_name = f"triton_{self.version}" 

82 path_dir = os.path.dirname(self.dir) 

83 sys.path.insert(0, str(path_dir)) 

84 version_module = importlib.import_module(module_name) 

85 sys.path.remove(str(path_dir)) 

86 return version_module 

87 

88 def get_ops(self): 

89 return self.get_version_ops() 

90 

91 def get_version_ops(self): 

92 pass 

93 

94 def get_version(self): 

95 try: 

96 import triton 

97 except ImportError: 

98 return None 

99 return triton.__version__ 

100 

101 

102class BackendArchEvent(BackendEventBase): 

103 has_arch: bool = False 

104 _instance = None 

105 _initialized: bool = False 

106 

107 def __new__(cls, *args, **kwargs): 

108 if cls._instance is None: 

109 cls._instance = super().__new__(cls) 

110 return cls._instance 

111 

112 def __init__(self, backend=None): 

113 if BackendArchEvent._initialized: 

114 return 

115 BackendArchEvent._initialized = True 

116 self.backend = backend 

117 self.error_msgs = [] 

118 self.arch = self.get_arch() 

119 if self.has_arch: 

120 self.supported_archs = self._get_supported_archs() 

121 # current_arch_path is like FlagGems/src/flag_gems/runtime/backend/_nvidia/hopper 

122 self.current_arch_path = self.supported_archs.get(self.arch) 

123 self.arch_module = self.get_arch_module() 

124 self.autotune_configs = self.get_autotune_configs() 

125 self.heuristics_configs = self.get_heuristics_configs() 

126 

127 def is_available(self): 

128 return self.has_arch 

129 

130 def get_functions_from_module(self, module): 

131 return inspect.getmembers(module, inspect.isfunction) if module else [] 

132 

133 def get_heuristics_configs(self): 

134 try: 

135 heuristic_module = self.arch_module 

136 except Exception: # noqa E722 

137 sys.path.insert(0, str(self.current_arch_path)) 

138 heuristic_module = importlib.import_module("heuristics_config_utils") 

139 sys.path.remove(str(self.current_arch_path)) 

140 return getattr(heuristic_module, "HEURISTICS_CONFIGS", None) 

141 

142 def get_autotune_configs(self): 

143 path = self.current_arch_path 

144 return backend_utils.get_tune_config(file_path=path) 

145 

146 def get_arch(self, device=0): 

147 if not hasattr(_state.vendor_module, "ARCH_MAP"): 

148 return 

149 arch_map = _state.vendor_module.ARCH_MAP 

150 arch_string = os.environ.get("ARCH", "") 

151 arch_string_num = arch_string.split("_")[-1][0] if arch_string else arch_string 

152 if not arch_string_num: 

153 try: 

154 if not _state.torch_device_object.is_available(): 

155 return False 

156 props = _state.torch_device_object.get_device_properties(device) 

157 arch_string_num = str(props.major) 

158 except Exception: 

159 self.has_arch = False 

160 if arch_string_num not in arch_map: 

161 print( 

162 f"[INFO] : FlagGems Unsupported GPU arch {arch_string} specialization" 

163 ) 

164 else: 

165 self.has_arch = True 

166 return arch_map[arch_string_num] 

167 

168 def _get_supported_archs(self, path=None): 

169 path = Path(path or _state.vendor_module.__path__[0]) 

170 path = path.parent if path.is_file() else path 

171 excluded = ("ops", "fused") 

172 return { 

173 p.name: str(p) 

174 for p in path.iterdir() 

175 if p.is_dir() and p.name not in excluded and not p.name.startswith("_") 

176 } 

177 

178 def get_supported_archs(self): 

179 return list(self.supported_archs.keys()) 

180 

181 def get_arch_module(self): 

182 """Load backend.<arch>""" 

183 path_dir = os.path.dirname(self.current_arch_path) 

184 sys.path.insert(0, str(path_dir)) 

185 current_arch_module = importlib.import_module(self.arch) 

186 sys.path.remove(str(path_dir)) 

187 return current_arch_module 

188 

189 def get_ops(self): 

190 """Provide a unified interface for the upper layer""" 

191 return self.get_arch_ops() 

192 

193 def get_arch_ops(self): 

194 arch_specialized_ops = [] 

195 sys.path.append(self.current_arch_path) 

196 ops_module = getattr(self.arch_module, "ops", None) 

197 try: 

198 if ops_module is None: 

199 ops_module = importlib.import_module(f"{self.arch}.ops") 

200 except Exception: 

201 try: 

202 sys.path.append(self.current_arch_path) 

203 ops_module = importlib.import_module(f"{self.arch}.ops") 

204 arch_specialized_ops.extend(self.get_functions_from_module(ops_module)) 

205 except Exception as err_msg: 

206 self.error_msgs.append(err_msg) 

207 

208 if ops_module is not None: 

209 arch_specialized_ops.extend(self.get_functions_from_module(ops_module)) 

210 

211 return arch_specialized_ops 

212 

213 

214class SpecOpRegistrar: 

215 def __init__(self, _globals): 

216 self._globals = _globals 

217 

218 def apply(self): 

219 spec_events = self._get_specific_events() 

220 for event in spec_events: 

221 if not event.is_available(): 

222 continue 

223 operators = event.get_ops() 

224 for fn_name, fn in operators: 

225 self._globals[fn_name] = fn 

226 

227 def _get_specific_events(self): 

228 return (BackendArchEvent(), TritonVersionEvent()) 

229 

230 

231def _import_module_safe(module_name, vendor_name, module_type): 

232 """Helper to import a module with proper error handling.""" 

233 try: 

234 return importlib.import_module(module_name) 

235 except ModuleNotFoundError: 

236 print( 

237 f"[Note] No specialized {module_type} operators were found for " 

238 f"the {vendor_name}, generic {module_type} operators will be used by default." 

239 ) 

240 except Exception as e: 

241 raise RuntimeError(f"Failed to import vendor extra lib: {e}") 

242 

243 

244def import_vendor_extra_lib(vendor_name=None): 

245 if _state.vendor_extra_lib_imported: 

246 return 

247 _state.ops_module = _import_module_safe( 

248 f"_{vendor_name}.ops", vendor_name, "common" 

249 ) 

250 _state.fused_module = _import_module_safe( 

251 f"_{vendor_name}.fused", vendor_name, "fused" 

252 ) 

253 _state.vendor_extra_lib_imported = True 

254 

255 

256def get_codegen_result(code, result_key): 

257 parsed_ast = ast.parse(code) 

258 compiled_code = compile(parsed_ast, filename="<ast>", mode="exec") 

259 try: 

260 exec(compiled_code, globals()) 

261 except Exception as e: 

262 raise e 

263 return globals()[result_key] 

264 

265 

266@functools.lru_cache(maxsize=32) 

267def gen_torch_tensor_attr_res(tensor, attr_name): 

268 _state.device_name = _state.device_name or get_vendor_info().device_name 

269 code = f""" 

270import torch 

271res = {tensor}.{attr_name} 

272 """ 

273 return get_codegen_result(code, "res") 

274 

275 

276def set_tl_extra_backend_module(vendor_name=None): 

277 vendor_info = get_vendor_info(vendor_name) 

278 _state.device_name = _state.device_name or vendor_info.device_name 

279 extra_name = vendor_info.triton_extra_name or _state.device_name 

280 module_str = f"triton.language.extra.{extra_name}.libdevice" 

281 _state.tl_extra_backend_module = importlib.import_module(module_str) 

282 

283 

284def get_tl_extra_backend_module(): 

285 return _state.tl_extra_backend_module 

286 

287 

288def set_torch_backend_device_fn(vendor_name=None): 

289 _state.device_name = _state.device_name or get_vendor_info(vendor_name).device_name 

290 module_str = f"torch.backends.{_state.device_name}" 

291 if _state.device_name in ("musa", "aipu", "npu", "txda", "ptpu", "gcu"): 

292 _state.torch_device_fn_device = None 

293 else: 

294 _state.torch_device_fn_device = importlib.import_module(module_str) 

295 

296 

297def get_torch_backend_device_fn(): 

298 return _state.torch_device_fn_device 

299 

300 

301def gen_torch_device_object(vendor_name=None): 

302 if _state.torch_device_object is not None: 

303 return _state.torch_device_object 

304 _state.device_name = _state.device_name or get_vendor_info(vendor_name).device_name 

305 code = f""" 

306import torch 

307fn = torch.{_state.device_name} 

308""" 

309 _state.torch_device_object = get_codegen_result(code, "fn") 

310 

311 # SPACEMIT CPU backend needs special device guard handling 

312 if vendor_name == "spacemit": 

313 backends_module = importlib.import_module("flag_gems.runtime.backend._spacemit") 

314 setattr( 

315 _state.torch_device_object, 

316 "_DeviceGuard", 

317 getattr(backends_module, "_DeviceGuard"), 

318 ) 

319 setattr( 

320 _state.torch_device_object, 

321 "device", 

322 getattr(backends_module, "_DeviceWrapper"), 

323 ) 

324 # Override current_device to return integer 0 for kernel cache indexing 

325 setattr(_state.torch_device_object, "current_device", lambda: 0) 

326 

327 return _state.torch_device_object 

328 

329 

330def get_vendor_module(vendor_name, query=False): 

331 def get_module(vendor_name): 

332 current_file_path = os.path.abspath(__file__) 

333 current_dir_path = os.path.dirname(current_file_path) 

334 sys.path.append(current_dir_path) 

335 return importlib.import_module(vendor_name) 

336 

337 if ( 

338 query 

339 ): # The purpose of a query is to provide the user with the instance that he wants to import 

340 return get_module(vendor_name) 

341 

342 if _state.vendor_module is None: 

343 _state.vendor_module = get_module("_" + vendor_name) 

344 return _state.vendor_module 

345 

346 

347def get_vendor_info(vendor_name=None, query=False): 

348 if query: 

349 return get_vendor_module(vendor_name, query).vendor_info 

350 get_vendor_module(vendor_name) 

351 return _state.vendor_module.vendor_info 

352 

353 

354def get_vendor_infos(): 

355 infos = [] 

356 for vendor_name in vendors.get_all_vendors(): 

357 try: 

358 infos.append(get_vendor_info(f"_{vendor_name}", query=True)) 

359 except Exception: 

360 continue 

361 

362 return infos 

363 

364 

365def get_customized_ops(vendor_name=None): 

366 import_vendor_extra_lib(vendor_name) 

367 if _state.customized_ops is not None: 

368 return _state.customized_ops 

369 _state.customized_ops = [] 

370 if _state.ops_module is not None: 

371 ops = inspect.getmembers(_state.ops_module, inspect.isfunction) 

372 _state.customized_ops += ops 

373 if _state.fused_module is not None: 

374 fused_ops = inspect.getmembers(_state.fused_module, inspect.isfunction) 

375 _state.customized_ops += fused_ops 

376 return _state.customized_ops 

377 

378 

379def get_ops(vendor_name=None): 

380 """Provide a unified interface for the upper layer""" 

381 return get_customized_ops(vendor_name) 

382 

383 

384def get_unused_ops(vendor_name=None): 

385 global vendor_module # noqa: F824 

386 get_vendor_module(vendor_name) 

387 return list(_state.vendor_module.CUSTOMIZED_UNUSED_OPS) 

388 

389 

390def get_heuristic_config(vendor_name=None): 

391 config_name = "heuristics_config_utils" 

392 mod_name = f"_{vendor_name}.{config_name}" 

393 try: 

394 _state.heuristic_config_module = importlib.import_module(mod_name) 

395 except Exception: 

396 mod_name = f"_nvidia.{config_name}" 

397 _state.heuristic_config_module = importlib.import_module(mod_name) 

398 return getattr(_state.heuristic_config_module, "HEURISTICS_CONFIGS", None) 

399 

400 

401def get_tune_config(vendor_name=None): 

402 global vendor_module # noqa: F824 

403 get_vendor_module(vendor_name) 

404 return backend_utils.get_tune_config(vendor_name) 

405 

406 

407def get_expand_config(op_name=None, file_path=None): 

408 return backend_utils.get_expand_config(op_name=op_name, file_path=file_path) 

409 

410 

411def get_backend_state() -> BackendState: 

412 """Get the global BackendState singleton instance.""" 

413 return _state 

414 

415 

416__all__ = ["*"]