Coverage for src/flag_gems/runtime/__init__.py: 65%

34 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1from . import backend, common, error 

2from .backend.device import DeviceDetector 

3from .configloader import ConfigLoader 

4 

5config_loader = ConfigLoader() 

6device = DeviceDetector() 

7 

8""" 

9The dependency order of the sub-directory is strict, and changing the order arbitrarily may cause errors. 

10""" 

11 

12# torch_device_fn is like 'torch.cuda' object 

13backend.set_torch_backend_device_fn(device.vendor_name) 

14torch_device_fn = backend.gen_torch_device_object() 

15 

16# torch_backend_device is like 'torch.backend.cuda' object 

17torch_backend_device = backend.get_torch_backend_device_fn() 

18 

19 

20def get_tuned_config(op_name): 

21 return config_loader.get_tuned_config(op_name) 

22 

23 

24def get_heuristic_config(op_name): 

25 return config_loader.get_heuristics_config(op_name) 

26 

27 

28def replace_customized_ops(_globals): 

29 event = backend.BackendArchEvent() 

30 arch_specific_ops = event.get_arch_ops() if event.has_arch else None 

31 extended_ops = backend.get_customized_ops(device.vendor_name) 

32 if device.vendor != common.vendors.NVIDIA: 

33 try: 

34 for fn_name, fn in extended_ops: 

35 _globals[fn_name] = fn 

36 except RuntimeError as e: 

37 error.customized_op_replace_error(e) 

38 if arch_specific_ops: 

39 try: 

40 for fn_name, fn in arch_specific_ops: 

41 _globals[fn_name] = fn 

42 except RuntimeError as e: 

43 error.customized_op_replace_error(e) 

44 

45 

46def get_expand_config(op_name, yaml_path=None): 

47 return config_loader.get_expand_config(op_name=op_name, yaml_path=yaml_path) 

48 

49 

50def ops_get_configs(op_name, pre_hook=None, yaml_path=None): 

51 return config_loader.ops_get_configs( 

52 op_name=op_name, 

53 pre_hook=pre_hook, 

54 yaml_path=yaml_path, 

55 ) 

56 

57 

58__all__ = ["*"]