Coverage for src/flag_gems/runtime/__init__.py: 54%

35 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1from contextlib import contextmanager 

2 

3from . import backend, common, error 

4from .backend.device import DeviceDetector 

5from .configs_loader import TunedConfigLoader 

6from .flagtune import flagtune, flagtune_enabled 

7 

8config_loader = TunedConfigLoader() 

9device = DeviceDetector() 

10 

11""" 

12The dependency order of the sub-directory is strict, and changing the order arbitrarily may cause errors. 

13""" 

14 

15# torch_device_fn is like 'torch.cuda' object 

16backend.set_torch_backend_device_fn(device.vendor_name) 

17torch_device_fn = backend.gen_torch_device_object() 

18if device.name == "cpu": 

19 if not hasattr(torch_device_fn, "device"): 

20 

21 @contextmanager 

22 def _noop_device_guard(_device=None): 

23 yield 

24 

25 torch_device_fn.device = _noop_device_guard 

26 if not hasattr(torch_device_fn, "_DeviceGuard"): 

27 

28 class _NoOpDeviceGuard: 

29 def __init__(self, *args, **kwargs): 

30 pass 

31 

32 def __enter__(self): 

33 return self 

34 

35 def __exit__(self, exc_type, exc, tb): 

36 return False 

37 

38 torch_device_fn._DeviceGuard = _NoOpDeviceGuard 

39 

40# torch_backend_device is like 'torch.backend.cuda' object 

41torch_backend_device = backend.get_torch_backend_device_fn() 

42 

43 

44def get_tuned_config(op_name): 

45 return config_loader.get_tuned_config(op_name) 

46 

47 

48def get_heuristic_config(op_name): 

49 return config_loader.get_heuristics_config(op_name) 

50 

51 

52def get_expand_config(op_name, yaml_path=None): 

53 return config_loader.get_expand_config(op_name=op_name, yaml_path=yaml_path) 

54 

55 

56def ops_get_configs(op_name, pre_hook=None, yaml_path=None): 

57 return config_loader.ops_get_configs( 

58 op_name=op_name, 

59 pre_hook=pre_hook, 

60 yaml_path=yaml_path, 

61 ) 

62 

63 

64__all__ = [ 

65 "TunedConfigLoader", 

66 "DeviceDetector", 

67 "backend", 

68 "common", 

69 "config_loader", 

70 "device", 

71 "error", 

72 "flagtune", 

73 "flagtune_enabled", 

74 "get_expand_config", 

75 "get_heuristic_config", 

76 "get_tuned_config", 

77 "ops_get_configs", 

78 "replace_customized_ops", 

79 "torch_backend_device", 

80 "torch_device_fn", 

81]