Coverage for src/flag_gems/runtime/register.py: 67%

100 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import warnings 

2 

3from . import backend, common, error 

4from .backend.device import DeviceDetector 

5 

6 

7class Register: 

8 def __init__( 

9 self, 

10 config, 

11 user_include_ops=None, 

12 user_exclude_ops=None, 

13 cpp_patched_ops=None, 

14 lib=None, 

15 full_config_by_func=None, 

16 ): 

17 self.device = DeviceDetector() 

18 

19 # lib is a instance of torch.library.Library 

20 # Some inference chips may not support the backward implementation of operators 

21 self.lib = lib 

22 

23 # reg_key like 'CUDA' 

24 self.reg_key = self.device.dispatch_key 

25 self.all_ops = [] 

26 self.all_keys = [] 

27 if self.device.vendor == common.vendors.CAMBRICON: 

28 # TODO: Cambricon specific, to avoid op deadlock question in libtuner. 

29 # Should remove this in the future. 

30 self.torch_ops_map = {} 

31 

32 # optional mapping func_name -> list of config entries 

33 self.full_config_by_func = full_config_by_func 

34 self.cpp_patched_ops = set(cpp_patched_ops or []) 

35 

36 if user_include_ops: 

37 self.include_ops = list(user_include_ops or []) 

38 self.exclude_ops = [] 

39 self.config = config 

40 self.extract_include_config() 

41 # Use the filtered include config to avoid registering all ops. 

42 self.config = self.include_config 

43 self.for_each() 

44 else: 

45 self.vendor_unused_ops_list = self.get_vendor_unused_op() 

46 self.exclude_ops = ( 

47 list(user_exclude_ops or []) + self.vendor_unused_ops_list 

48 ) 

49 self.config = config 

50 self.config_filter() 

51 self.for_each() 

52 

53 def extract_include_config(self): 

54 # Simple fast path: if we have a full_config_by_func mapping, iterate 

55 # over the requested function names and collect matching config items. 

56 self.include_config = [] 

57 

58 if self.full_config_by_func: 

59 for name in self.include_ops: 

60 for config_item in self.full_config_by_func.get(name, []): 

61 op_name, func = config_item[0], config_item[1] 

62 # respect optional condition functions 

63 if not self._config_enabled(config_item): 

64 continue 

65 if op_name in self.cpp_patched_ops: 

66 continue 

67 self.include_config.append(self._normalized_config(config_item)) 

68 else: 

69 # fallback: scan provided config and match by func name or op name 

70 for config_item in self.config: 

71 op_name, func = config_item[0], config_item[1] 

72 func_name = func.__name__ if hasattr(func, "__name__") else str(func) 

73 if ( 

74 func_name not in self.include_ops 

75 and op_name not in self.include_ops 

76 ): 

77 continue 

78 if not self._config_enabled(config_item): 

79 continue 

80 if op_name in self.cpp_patched_ops: 

81 continue 

82 self.include_config.append(self._normalized_config(config_item)) 

83 

84 if not self.include_config: 

85 warnings.warn( 

86 "only_enable failed: No op to register. Check if include is correct." 

87 ) 

88 return 

89 

90 @staticmethod 

91 def _config_enabled(item): 

92 condition_func = item[2] if len(item) > 2 else None 

93 return condition_func is None or bool(condition_func()) 

94 

95 @staticmethod 

96 def _extra_dispatch_keys(item): 

97 return tuple(item[3]) if len(item) > 3 else () 

98 

99 def _normalized_config(self, item): 

100 return item[0], item[1], self._extra_dispatch_keys(item) 

101 

102 def config_filter(self): 

103 self.config = [ 

104 self._normalized_config(item) 

105 for item in self.config 

106 if self._config_enabled(item) 

107 and item[1].__name__ not in self.exclude_ops 

108 and item[0] not in self.cpp_patched_ops 

109 ] 

110 

111 def get_vendor_unused_op(self): 

112 if self.device.vendor != common.vendors.NVIDIA: 

113 return backend.get_unused_ops(self.device.vendor_name) 

114 return [] 

115 

116 def register_impl(self, key, fn, extra_dispatch_keys=()): 

117 if self.lib is None: 

118 raise ValueError("Library instance is not provided.") 

119 device_key = self.reg_key 

120 self.all_ops.append(fn.__name__) 

121 self.all_keys.append(key) 

122 if self.device.vendor == common.vendors.CAMBRICON: 

123 import torch 

124 

125 try: 

126 self.torch_ops_map["aten::" + key] = torch.library.get_kernel( 

127 "aten::" + key, device_key 

128 ) 

129 except Exception: 

130 pass 

131 try: 

132 self.lib.impl(key, fn, device_key, allow_override=True) 

133 except TypeError: 

134 # Older torch versions don't support allow_override 

135 self.lib.impl(key, fn, device_key) 

136 else: 

137 self.lib.impl(key, fn, device_key) 

138 

139 for dispatch_key in extra_dispatch_keys: 

140 self.lib.impl(key, fn, dispatch_key) 

141 

142 def for_each(self): 

143 for key, func, extra_dispatch_keys in self.config: 

144 try: 

145 self.register_impl(key, func, extra_dispatch_keys) 

146 except Exception as e: 

147 error.register_error(e) 

148 

149 def get_all_ops(self): 

150 return self.all_ops 

151 

152 def get_all_keys(self): 

153 return self.all_keys 

154 

155 def get_unused_ops(self): 

156 return self.exclude_ops 

157 

158 def get_vendor_name(self): 

159 return self.device.vendor_name 

160 

161 def get_current_device(self): 

162 return self.device.name