Coverage for src/flag_gems/runtime/register.py: 61%

95 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import warnings 

2 

3from . import backend, common, error 

4from .backend.device import DeviceDetector 

5 

6 

7class Register: 

8 def __init__( 

9 self, 

10 config, 

11 user_include_ops=None, 

12 user_exclude_ops=None, 

13 cpp_patched_ops=None, 

14 lib=None, 

15 full_config_by_func=None, 

16 ): 

17 self.device = DeviceDetector() 

18 

19 # lib is a instance of torch.library.Library 

20 # Some inference chips may not support the backward implementation of operators 

21 self.lib = lib 

22 

23 # reg_key like 'CUDA' 

24 self.reg_key = self.device.dispatch_key 

25 self.all_ops = [] 

26 self.all_keys = [] 

27 if self.device.vendor == common.vendors.CAMBRICON: 

28 # TODO: Cambricon specific, to avoid op deadlock question in libtuner. 

29 # Should remove this in the future. 

30 self.torch_ops_map = {} 

31 

32 # optional mapping func_name -> list of config entries 

33 self.full_config_by_func = full_config_by_func 

34 self.cpp_patched_ops = set(cpp_patched_ops or []) 

35 

36 if user_include_ops: 

37 self.include_ops = list(user_include_ops or []) 

38 self.exclude_ops = [] 

39 self.config = config 

40 self.extract_include_config() 

41 # Use the filtered include config to avoid registering all ops. 

42 self.config = self.include_config 

43 self.for_each() 

44 else: 

45 self.vendor_unused_ops_list = self.get_vendor_unused_op() 

46 self.exclude_ops = ( 

47 list(user_exclude_ops or []) + self.vendor_unused_ops_list 

48 ) 

49 self.config = config 

50 self.config_filter() 

51 self.for_each() 

52 

53 def extract_include_config(self): 

54 # Simple fast path: if we have a full_config_by_func mapping, iterate 

55 # over the requested function names and collect matching config items. 

56 self.include_config = [] 

57 

58 if self.full_config_by_func: 

59 for name in self.include_ops: 

60 for config_item in self.full_config_by_func.get(name, []): 

61 op_name, func = config_item[0], config_item[1] 

62 # respect optional condition functions 

63 if len(config_item) > 2: 

64 condition_func = config_item[2] 

65 if not condition_func(): 

66 continue 

67 if op_name in self.cpp_patched_ops: 

68 continue 

69 self.include_config.append((op_name, func)) 

70 else: 

71 # fallback: scan provided config and match by func name or op name 

72 for config_item in self.config: 

73 op_name, func = config_item[0], config_item[1] 

74 func_name = func.__name__ if hasattr(func, "__name__") else str(func) 

75 if ( 

76 func_name not in self.include_ops 

77 and op_name not in self.include_ops 

78 ): 

79 continue 

80 if len(config_item) > 2: 

81 condition_func = config_item[2] 

82 if not condition_func(): 

83 continue 

84 if op_name in self.cpp_patched_ops: 

85 continue 

86 self.include_config.append((op_name, func)) 

87 

88 if not self.include_config: 

89 warnings.warn( 

90 "only_enable failed: No op to register. Check if include is correct." 

91 ) 

92 return 

93 

94 def config_filter(self): 

95 def enabled(item): 

96 return len(item) < 3 or bool(item[2]()) 

97 

98 self.config = [ 

99 (item[0], item[1]) 

100 for item in self.config 

101 if enabled(item) 

102 and item[1].__name__ not in self.exclude_ops 

103 and item[0] not in self.cpp_patched_ops 

104 ] 

105 

106 def get_vendor_unused_op(self): 

107 if self.device.vendor != common.vendors.NVIDIA: 

108 return backend.get_unused_ops(self.device.vendor_name) 

109 return [] 

110 

111 def register_impl(self, key, fn): 

112 if self.lib is None: 

113 raise ValueError("Library instance is not provided.") 

114 device_key = self.reg_key 

115 self.all_ops.append(fn.__name__) 

116 self.all_keys.append(key) 

117 if self.device.vendor == common.vendors.CAMBRICON: 

118 import torch 

119 

120 try: 

121 self.torch_ops_map["aten::" + key] = torch.library.get_kernel( 

122 "aten::" + key, device_key 

123 ) 

124 except Exception: 

125 pass 

126 try: 

127 self.lib.impl(key, fn, device_key, allow_override=True) 

128 except TypeError: 

129 # Older torch versions don't support allow_override 

130 self.lib.impl(key, fn, device_key) 

131 else: 

132 self.lib.impl(key, fn, device_key) 

133 

134 def for_each(self): 

135 for key, func in self.config: 

136 try: 

137 self.register_impl(key, func) 

138 except Exception as e: 

139 error.register_error(e) 

140 

141 def get_all_ops(self): 

142 return self.all_ops 

143 

144 def get_all_keys(self): 

145 return self.all_keys 

146 

147 def get_unused_ops(self): 

148 return self.exclude_ops 

149 

150 def get_vendor_name(self): 

151 return self.device.vendor_name 

152 

153 def get_current_device(self): 

154 return self.device.name