Coverage for src/flag_gems/runtime/op_registrar.py: 67%
100 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import warnings
3from . import backend, common, error
4from .backend.device import DeviceDetector
7class GeneralOpRegistrar:
8 def __init__(
9 self,
10 config,
11 user_include_ops=None,
12 user_exclude_ops=None,
13 cpp_patched_ops=None,
14 lib=None,
15 full_config_by_func=None,
16 ):
17 self.device = DeviceDetector()
19 # lib is a instance of torch.library.Library
20 # Some inference chips may not support the backward implementation of operators
21 self.lib = lib
23 # reg_key like 'CUDA'
24 self.reg_key = self.device.dispatch_key
25 self.all_ops = []
26 self.all_keys = []
27 if self.device.vendor == common.vendors.CAMBRICON:
28 # TODO: Cambricon specific, to avoid op deadlock question in libtuner.
29 # Should remove this in the future.
30 self.torch_ops_map = {}
32 # optional mapping func_name -> list of config entries
33 self.full_config_by_func = full_config_by_func
34 self.cpp_patched_ops = set(cpp_patched_ops or [])
36 if user_include_ops:
37 self.include_ops = list(user_include_ops or [])
38 self.exclude_ops = []
39 self.config = config
40 self.extract_include_config()
41 # Use the filtered include config to avoid registering all ops.
42 self.config = self.include_config
43 self.for_each()
44 else:
45 self.vendor_unused_ops_list = self.get_vendor_unused_op()
46 self.exclude_ops = (
47 list(user_exclude_ops or []) + self.vendor_unused_ops_list
48 )
49 self.config = config
50 self.config_filter()
51 self.for_each()
53 def extract_include_config(self):
54 # Simple fast path: if we have a full_config_by_func mapping, iterate
55 # over the requested function names and collect matching config items.
56 self.include_config = []
58 if self.full_config_by_func:
59 for name in self.include_ops:
60 for config_item in self.full_config_by_func.get(name, []):
61 op_name, func = config_item[0], config_item[1]
62 # respect optional condition functions
63 if not self._config_enabled(config_item):
64 continue
65 if op_name in self.cpp_patched_ops:
66 continue
67 self.include_config.append(self._normalized_config(config_item))
68 else:
69 # fallback: scan provided config and match by func name or op name
70 for config_item in self.config:
71 op_name, func = config_item[0], config_item[1]
72 func_name = func.__name__ if hasattr(func, "__name__") else str(func)
73 if (
74 func_name not in self.include_ops
75 and op_name not in self.include_ops
76 ):
77 continue
78 if not self._config_enabled(config_item):
79 continue
80 if op_name in self.cpp_patched_ops:
81 continue
82 self.include_config.append(self._normalized_config(config_item))
84 if not self.include_config:
85 warnings.warn(
86 "only_enable failed: No op to register. Check if include is correct."
87 )
88 return
90 @staticmethod
91 def _config_enabled(item):
92 condition_func = item[2] if len(item) > 2 else None
93 return condition_func is None or bool(condition_func())
95 @staticmethod
96 def _extra_dispatch_keys(item):
97 return tuple(item[3]) if len(item) > 3 else ()
99 def _normalized_config(self, item):
100 return item[0], item[1], self._extra_dispatch_keys(item)
102 def config_filter(self):
103 self.config = [
104 self._normalized_config(item)
105 for item in self.config
106 if self._config_enabled(item)
107 and item[1].__name__ not in self.exclude_ops
108 and item[0] not in self.cpp_patched_ops
109 ]
111 def get_vendor_unused_op(self):
112 if self.device.vendor != common.vendors.NVIDIA:
113 return backend.get_unused_ops(self.device.vendor_name)
114 return []
116 def register_impl(self, key, fn, extra_dispatch_keys=()):
117 if self.lib is None:
118 raise ValueError("Library instance is not provided.")
119 device_key = self.reg_key
120 self.all_ops.append(fn.__name__)
121 self.all_keys.append(key)
122 if self.device.vendor == common.vendors.CAMBRICON:
123 import torch
125 try:
126 self.torch_ops_map["aten::" + key] = torch.library.get_kernel(
127 "aten::" + key, device_key
128 )
129 except Exception:
130 pass
131 try:
132 self.lib.impl(key, fn, device_key, allow_override=True)
133 except TypeError:
134 # Older torch versions don't support allow_override
135 self.lib.impl(key, fn, device_key)
136 else:
137 self.lib.impl(key, fn, device_key)
139 for dispatch_key in extra_dispatch_keys:
140 self.lib.impl(key, fn, dispatch_key)
142 def for_each(self):
143 for key, func, extra_dispatch_keys in self.config:
144 try:
145 self.register_impl(key, func, extra_dispatch_keys)
146 except Exception as e:
147 error.register_error(e)
149 def get_all_ops(self):
150 return self.all_ops
152 def get_all_keys(self):
153 return self.all_keys
155 def get_unused_ops(self):
156 return self.exclude_ops
158 def get_vendor_name(self):
159 return self.device.vendor_name
161 def get_current_device(self):
162 return self.device.name