Coverage for src/flag_gems/runtime/backend/device.py: 49%

83 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import os 

2import shlex 

3import subprocess 

4from concurrent.futures import ThreadPoolExecutor, as_completed 

5 

6import torch # noqa: F401 

7 

8from .. import backend, error 

9from ..common import ( 

10 _VENDOR_TORCH_ATTR, 

11 UNSUPPORT_BF16, 

12 UNSUPPORT_FP64, 

13 UNSUPPORT_INT64, 

14 vendors, 

15) 

16 

17 

18# A singleton class to manage device context. 

19class DeviceDetector: 

20 """Singleton class to manage device context.""" 

21 

22 _instance = None 

23 

24 def __new__(cls, *args, **kargs): 

25 if cls._instance is None: 

26 cls._instance = super(DeviceDetector, cls).__new__(cls) 

27 cls._instance = super().__new__(cls) 

28 return cls._instance 

29 

30 def __init__(self, vendor_name=None): 

31 if not hasattr(self, "initialized"): 

32 self.initialized = True 

33 # A list of all available vendor names. 

34 self.vendor_list = vendors.get_all_vendors().keys() 

35 # A dataclass instance, get the vendor information based on the provided or default vendor name. 

36 self.info = self.get_vendor(vendor_name) 

37 # vendor_name is like 'nvidia', device_name is like 'cuda'. 

38 self.vendor_name = self.info.vendor_name 

39 self.name = self.info.device_name 

40 self.vendor = vendors.get_all_vendors()[self.vendor_name] 

41 self.dispatch_key = ( 

42 self.name.upper() 

43 if self.info.dispatch_key is None 

44 else self.info.dispatch_key 

45 ) 

46 self.device_count = backend.gen_torch_device_object( 

47 self.vendor_name 

48 ).device_count() 

49 self.support_fp64 = self.vendor not in UNSUPPORT_FP64 

50 self.support_bf16 = self.vendor not in UNSUPPORT_BF16 

51 self.support_int64 = self.vendor not in UNSUPPORT_INT64 

52 

53 def get_vendor(self, vendor_name=None) -> tuple: 

54 # Try to get the vendor name from a quick special command like 'torch.mlu'. 

55 vendor_from_env = self._get_vendor_from_env() 

56 if vendor_from_env: 

57 return backend.get_vendor_info(vendor_from_env) 

58 

59 vendor_name = self._get_vendor_from_quick_cmd() 

60 if vendor_name: 

61 return backend.get_vendor_info(vendor_name) 

62 try: 

63 # Obtaining a vendor_info from the methods provided by torch or triton, but is not currently implemented. 

64 return self._get_vendor_from_lib() 

65 except Exception: 

66 return self._get_vendor_from_sys() 

67 

68 def _get_vendor_from_quick_cmd(self): 

69 try: 

70 import torch_npu 

71 

72 torch_module = torch_npu 

73 except ImportError: 

74 torch_module = torch 

75 

76 for vendor_name, attr in _VENDOR_TORCH_ATTR.items(): 

77 if hasattr(torch_module, attr): 

78 return str(vendor_name) 

79 

80 if hasattr(torch_module, "cuda") and hasattr( 

81 torch_module.cuda, "get_device_properties" 

82 ): 

83 try: 

84 prop = torch_module.cuda.get_device_properties(0) 

85 if "NVIDIA" in prop.name.upper(): 

86 return "nvidia" 

87 except Exception: 

88 return False 

89 

90 return False 

91 

92 def _get_vendor_from_env(self): 

93 if "PPU_SDK" in os.environ.keys(): 

94 return "thead" 

95 

96 env_keys = ( 

97 "GEMS_VENDOR", 

98 "FLAGGEMS_VENDOR", 

99 "GEMS_BACKEND", 

100 "FLAGGEMS_BACKEND", 

101 ) 

102 for key in env_keys: 

103 if key in os.environ: 

104 return str(os.environ.get(key).lower()) 

105 

106 return False 

107 

108 def _get_vendor_from_sys(self): 

109 vendor_infos = backend.get_vendor_infos() 

110 

111 def check_vendor(info): 

112 try: 

113 cmd_args = shlex.split(info.device_query_cmd) 

114 result = subprocess.run(cmd_args, capture_output=True, text=True) 

115 return info if result.returncode == 0 else None 

116 except Exception: 

117 return None 

118 

119 with ThreadPoolExecutor() as executor: 

120 futures = { 

121 executor.submit(check_vendor, info): info for info in vendor_infos 

122 } 

123 for future in as_completed(futures): 

124 result = future.result() 

125 if result: 

126 return result 

127 

128 error.device_not_found() 

129 

130 def get_vendor_name(self): 

131 return self.vendor_name 

132 

133 def _get_vendor_from_lib(self): 

134 # Reserve the associated interface for triton or torch 

135 # although they are not implemented yet. 

136 # try: 

137 # return triton.get_vendor_info() 

138 # except Exception: 

139 # return torch.get_vendor_info() 

140 raise RuntimeError("The method is not implemented")