Coverage for src/flag_gems/runtime/backend/device.py: 49%
83 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import os
2import shlex
3import subprocess
4from concurrent.futures import ThreadPoolExecutor, as_completed
6import torch # noqa: F401
8from .. import backend, error
9from ..common import (
10 _VENDOR_TORCH_ATTR,
11 UNSUPPORT_BF16,
12 UNSUPPORT_FP64,
13 UNSUPPORT_INT64,
14 vendors,
15)
18# A singleton class to manage device context.
19class DeviceDetector:
20 """Singleton class to manage device context."""
22 _instance = None
24 def __new__(cls, *args, **kargs):
25 if cls._instance is None:
26 cls._instance = super(DeviceDetector, cls).__new__(cls)
27 cls._instance = super().__new__(cls)
28 return cls._instance
30 def __init__(self, vendor_name=None):
31 if not hasattr(self, "initialized"):
32 self.initialized = True
33 # A list of all available vendor names.
34 self.vendor_list = vendors.get_all_vendors().keys()
35 # A dataclass instance, get the vendor information based on the provided or default vendor name.
36 self.info = self.get_vendor(vendor_name)
37 # vendor_name is like 'nvidia', device_name is like 'cuda'.
38 self.vendor_name = self.info.vendor_name
39 self.name = self.info.device_name
40 self.vendor = vendors.get_all_vendors()[self.vendor_name]
41 self.dispatch_key = (
42 self.name.upper()
43 if self.info.dispatch_key is None
44 else self.info.dispatch_key
45 )
46 self.device_count = backend.gen_torch_device_object(
47 self.vendor_name
48 ).device_count()
49 self.support_fp64 = self.vendor not in UNSUPPORT_FP64
50 self.support_bf16 = self.vendor not in UNSUPPORT_BF16
51 self.support_int64 = self.vendor not in UNSUPPORT_INT64
53 def get_vendor(self, vendor_name=None) -> tuple:
54 # Try to get the vendor name from a quick special command like 'torch.mlu'.
55 vendor_from_env = self._get_vendor_from_env()
56 if vendor_from_env:
57 return backend.get_vendor_info(vendor_from_env)
59 vendor_name = self._get_vendor_from_quick_cmd()
60 if vendor_name:
61 return backend.get_vendor_info(vendor_name)
62 try:
63 # Obtaining a vendor_info from the methods provided by torch or triton, but is not currently implemented.
64 return self._get_vendor_from_lib()
65 except Exception:
66 return self._get_vendor_from_sys()
68 def _get_vendor_from_quick_cmd(self):
69 try:
70 import torch_npu
72 torch_module = torch_npu
73 except ImportError:
74 torch_module = torch
76 for vendor_name, attr in _VENDOR_TORCH_ATTR.items():
77 if hasattr(torch_module, attr):
78 return str(vendor_name)
80 if hasattr(torch_module, "cuda") and hasattr(
81 torch_module.cuda, "get_device_properties"
82 ):
83 try:
84 prop = torch_module.cuda.get_device_properties(0)
85 if "NVIDIA" in prop.name.upper():
86 return "nvidia"
87 except Exception:
88 return False
90 return False
92 def _get_vendor_from_env(self):
93 if "PPU_SDK" in os.environ.keys():
94 return "thead"
96 env_keys = (
97 "GEMS_VENDOR",
98 "FLAGGEMS_VENDOR",
99 "GEMS_BACKEND",
100 "FLAGGEMS_BACKEND",
101 )
102 for key in env_keys:
103 if key in os.environ:
104 return str(os.environ.get(key).lower())
106 return False
108 def _get_vendor_from_sys(self):
109 vendor_infos = backend.get_vendor_infos()
111 def check_vendor(info):
112 try:
113 cmd_args = shlex.split(info.device_query_cmd)
114 result = subprocess.run(cmd_args, capture_output=True, text=True)
115 return info if result.returncode == 0 else None
116 except Exception:
117 return None
119 with ThreadPoolExecutor() as executor:
120 futures = {
121 executor.submit(check_vendor, info): info for info in vendor_infos
122 }
123 for future in as_completed(futures):
124 result = future.result()
125 if result:
126 return result
128 error.device_not_found()
130 def get_vendor_name(self):
131 return self.vendor_name
133 def _get_vendor_from_lib(self):
134 # Reserve the associated interface for triton or torch
135 # although they are not implemented yet.
136 # try:
137 # return triton.get_vendor_info()
138 # except Exception:
139 # return torch.get_vendor_info()
140 raise RuntimeError("The method is not implemented")