Coverage for src/flag_gems/runtime/backend/device.py: 89%
73 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import os
2import shlex
3import subprocess
4from concurrent.futures import ThreadPoolExecutor, as_completed
6import torch # noqa: F401
8from .. import backend, error
9from ..common import (
10 _VENDOR_TORCH_ATTR,
11 UNSUPPORT_BF16,
12 UNSUPPORT_FP64,
13 UNSUPPORT_INT64,
14 vendors,
15)
18# A singleton class to manage device context.
19class DeviceDetector:
20 """Singleton class to manage device context."""
22 _instance = None
24 def __new__(cls, *args, **kargs):
25 if cls._instance is None:
26 cls._instance = super(DeviceDetector, cls).__new__(cls)
27 cls._instance = super().__new__(cls)
28 return cls._instance
30 def __init__(self, vendor_name=None):
31 if not hasattr(self, "initialized"):
32 self.initialized = True
33 # A list of all available vendor names.
34 self.vendor_list = vendors.get_all_vendors().keys()
35 # A dataclass instance, get the vendor information based on the provided or default vendor name.
36 self.info = self.get_vendor(vendor_name)
37 # vendor_name is like 'nvidia', device_name is like 'cuda'.
38 self.vendor_name = self.info.vendor_name
39 self.name = self.info.device_name
40 self.vendor = vendors.get_all_vendors()[self.vendor_name]
41 self.dispatch_key = (
42 self.name.upper()
43 if self.info.dispatch_key is None
44 else self.info.dispatch_key
45 )
46 self.device_count = backend.gen_torch_device_object(
47 self.vendor_name
48 ).device_count()
49 self.support_fp64 = self.vendor not in UNSUPPORT_FP64
50 self.support_bf16 = self.vendor not in UNSUPPORT_BF16
51 self.support_int64 = self.vendor not in UNSUPPORT_INT64
53 def get_vendor(self, vendor_name=None) -> tuple:
54 # Try to get the vendor name from a quick special command like 'torch.mlu'.
55 vendor_from_env = self._get_vendor_from_env()
56 if vendor_from_env:
57 return backend.get_vendor_info(vendor_from_env)
59 vendor_name = self._get_vendor_from_quick_cmd()
60 if vendor_name:
61 return backend.get_vendor_info(vendor_name)
62 try:
63 # Obtaining a vendor_info from the methods provided by torch or triton, but is not currently implemented.
64 return self._get_vendor_from_lib()
65 except Exception:
66 return self._get_vendor_from_sys()
68 def _get_vendor_from_quick_cmd(self):
69 for vendor_name, attr in _VENDOR_TORCH_ATTR.items():
70 if hasattr(torch, attr):
71 return vendor_name
72 try:
73 import torch_npu
75 for vendor_name, attr in _VENDOR_TORCH_ATTR.items():
76 if hasattr(torch_npu, attr):
77 return vendor_name
78 except ImportError:
79 pass
80 return None
82 def _get_vendor_from_env(self):
83 vendor = os.environ.get("GEMS_VENDOR")
84 return vendor if vendor in self.vendor_list else None
86 def _get_vendor_from_sys(self):
87 vendor_infos = backend.get_vendor_infos()
89 def check_vendor(info):
90 try:
91 cmd_args = shlex.split(info.device_query_cmd)
92 result = subprocess.run(cmd_args, capture_output=True, text=True)
93 return info if result.returncode == 0 else None
94 except Exception:
95 return None
97 with ThreadPoolExecutor() as executor:
98 futures = {
99 executor.submit(check_vendor, info): info for info in vendor_infos
100 }
101 for future in as_completed(futures):
102 result = future.result()
103 if result:
104 return result
106 error.device_not_found()
108 def get_vendor_name(self):
109 return self.vendor_name
111 def _get_vendor_from_lib(self):
112 # Reserve the associated interface for triton or torch
113 # although they are not implemented yet.
114 # try:
115 # return triton.get_vendor_info()
116 # except Exception:
117 # return torch.get_vendor_info()
118 raise RuntimeError("The method is not implemented")