Coverage for src/flag_gems/runtime/__init__.py: 54%
35 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1from contextlib import contextmanager
3from . import backend, common, error
4from .backend.device import DeviceDetector
5from .configs_loader import TunedConfigLoader
6from .flagtune import flagtune, flagtune_enabled
8config_loader = TunedConfigLoader()
9device = DeviceDetector()
11"""
12The dependency order of the sub-directory is strict, and changing the order arbitrarily may cause errors.
13"""
15# torch_device_fn is like 'torch.cuda' object
16backend.set_torch_backend_device_fn(device.vendor_name)
17torch_device_fn = backend.gen_torch_device_object()
18if device.name == "cpu":
19 if not hasattr(torch_device_fn, "device"):
21 @contextmanager
22 def _noop_device_guard(_device=None):
23 yield
25 torch_device_fn.device = _noop_device_guard
26 if not hasattr(torch_device_fn, "_DeviceGuard"):
28 class _NoOpDeviceGuard:
29 def __init__(self, *args, **kwargs):
30 pass
32 def __enter__(self):
33 return self
35 def __exit__(self, exc_type, exc, tb):
36 return False
38 torch_device_fn._DeviceGuard = _NoOpDeviceGuard
40# torch_backend_device is like 'torch.backend.cuda' object
41torch_backend_device = backend.get_torch_backend_device_fn()
44def get_tuned_config(op_name):
45 return config_loader.get_tuned_config(op_name)
48def get_heuristic_config(op_name):
49 return config_loader.get_heuristics_config(op_name)
52def get_expand_config(op_name, yaml_path=None):
53 return config_loader.get_expand_config(op_name=op_name, yaml_path=yaml_path)
56def ops_get_configs(op_name, pre_hook=None, yaml_path=None):
57 return config_loader.ops_get_configs(
58 op_name=op_name,
59 pre_hook=pre_hook,
60 yaml_path=yaml_path,
61 )
64__all__ = [
65 "TunedConfigLoader",
66 "DeviceDetector",
67 "backend",
68 "common",
69 "config_loader",
70 "device",
71 "error",
72 "flagtune",
73 "flagtune_enabled",
74 "get_expand_config",
75 "get_heuristic_config",
76 "get_tuned_config",
77 "ops_get_configs",
78 "replace_customized_ops",
79 "torch_backend_device",
80 "torch_device_fn",
81]