Coverage for src/flag_gems/runtime/backend/device.py: 52%

73 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import os 

2import shlex 

3import subprocess 

4from concurrent.futures import ThreadPoolExecutor, as_completed 

5 

6import torch # noqa: F401 

7 

8from .. import backend, error 

9from ..common import ( 

10 _VENDOR_TORCH_ATTR, 

11 UNSUPPORT_BF16, 

12 UNSUPPORT_FP64, 

13 UNSUPPORT_INT64, 

14 vendors, 

15) 

16 

17 

18# A singleton class to manage device context. 

19class DeviceDetector: 

20 """Singleton class to manage device context.""" 

21 

22 _instance = None 

23 

24 def __new__(cls, *args, **kargs): 

25 if cls._instance is None: 

26 cls._instance = super(DeviceDetector, cls).__new__(cls) 

27 cls._instance = super().__new__(cls) 

28 return cls._instance 

29 

30 def __init__(self, vendor_name=None): 

31 if not hasattr(self, "initialized"): 

32 self.initialized = True 

33 # A list of all available vendor names. 

34 self.vendor_list = vendors.get_all_vendors().keys() 

35 # A dataclass instance, get the vendor information based on the provided or default vendor name. 

36 self.info = self.get_vendor(vendor_name) 

37 # vendor_name is like 'nvidia', device_name is like 'cuda'. 

38 self.vendor_name = self.info.vendor_name 

39 self.name = self.info.device_name 

40 self.vendor = vendors.get_all_vendors()[self.vendor_name] 

41 self.dispatch_key = ( 

42 self.name.upper() 

43 if self.info.dispatch_key is None 

44 else self.info.dispatch_key 

45 ) 

46 self.device_count = backend.gen_torch_device_object( 

47 self.vendor_name 

48 ).device_count() 

49 self.support_fp64 = self.vendor not in UNSUPPORT_FP64 

50 self.support_bf16 = self.vendor not in UNSUPPORT_BF16 

51 self.support_int64 = self.vendor not in UNSUPPORT_INT64 

52 

53 def get_vendor(self, vendor_name=None) -> tuple: 

54 # Try to get the vendor name from a quick special command like 'torch.mlu'. 

55 vendor_from_env = self._get_vendor_from_env() 

56 if vendor_from_env: 

57 return backend.get_vendor_info(vendor_from_env) 

58 

59 vendor_name = self._get_vendor_from_quick_cmd() 

60 if vendor_name: 

61 return backend.get_vendor_info(vendor_name) 

62 try: 

63 # Obtaining a vendor_info from the methods provided by torch or triton, but is not currently implemented. 

64 return self._get_vendor_from_lib() 

65 except Exception: 

66 return self._get_vendor_from_sys() 

67 

68 def _get_vendor_from_quick_cmd(self): 

69 for vendor_name, attr in _VENDOR_TORCH_ATTR.items(): 

70 if hasattr(torch, attr): 

71 return vendor_name 

72 try: 

73 import torch_npu 

74 

75 for vendor_name, attr in _VENDOR_TORCH_ATTR.items(): 

76 if hasattr(torch_npu, attr): 

77 return vendor_name 

78 except ImportError: 

79 pass 

80 return None 

81 

82 def _get_vendor_from_env(self): 

83 vendor = os.environ.get("GEMS_VENDOR") 

84 return vendor if vendor in self.vendor_list else None 

85 

86 def _get_vendor_from_sys(self): 

87 vendor_infos = backend.get_vendor_infos() 

88 

89 def check_vendor(info): 

90 try: 

91 cmd_args = shlex.split(info.device_query_cmd) 

92 result = subprocess.run(cmd_args, capture_output=True, text=True) 

93 return info if result.returncode == 0 else None 

94 except Exception: 

95 return None 

96 

97 with ThreadPoolExecutor() as executor: 

98 futures = { 

99 executor.submit(check_vendor, info): info for info in vendor_infos 

100 } 

101 for future in as_completed(futures): 

102 result = future.result() 

103 if result: 

104 return result 

105 

106 error.device_not_found() 

107 

108 def get_vendor_name(self): 

109 return self.vendor_name 

110 

111 def _get_vendor_from_lib(self): 

112 # Reserve the associated interface for triton or torch 

113 # although they are not implemented yet. 

114 # try: 

115 # return triton.get_vendor_info() 

116 # except Exception: 

117 # return torch.get_vendor_info() 

118 raise RuntimeError("The method is not implemented")