Coverage for src/flag_gems/runtime/register.py: 61%
95 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import warnings
3from . import backend, common, error
4from .backend.device import DeviceDetector
7class Register:
8 def __init__(
9 self,
10 config,
11 user_include_ops=None,
12 user_exclude_ops=None,
13 cpp_patched_ops=None,
14 lib=None,
15 full_config_by_func=None,
16 ):
17 self.device = DeviceDetector()
19 # lib is a instance of torch.library.Library
20 # Some inference chips may not support the backward implementation of operators
21 self.lib = lib
23 # reg_key like 'CUDA'
24 self.reg_key = self.device.dispatch_key
25 self.all_ops = []
26 self.all_keys = []
27 if self.device.vendor == common.vendors.CAMBRICON:
28 # TODO: Cambricon specific, to avoid op deadlock question in libtuner.
29 # Should remove this in the future.
30 self.torch_ops_map = {}
32 # optional mapping func_name -> list of config entries
33 self.full_config_by_func = full_config_by_func
34 self.cpp_patched_ops = set(cpp_patched_ops or [])
36 if user_include_ops:
37 self.include_ops = list(user_include_ops or [])
38 self.exclude_ops = []
39 self.config = config
40 self.extract_include_config()
41 # Use the filtered include config to avoid registering all ops.
42 self.config = self.include_config
43 self.for_each()
44 else:
45 self.vendor_unused_ops_list = self.get_vendor_unused_op()
46 self.exclude_ops = (
47 list(user_exclude_ops or []) + self.vendor_unused_ops_list
48 )
49 self.config = config
50 self.config_filter()
51 self.for_each()
53 def extract_include_config(self):
54 # Simple fast path: if we have a full_config_by_func mapping, iterate
55 # over the requested function names and collect matching config items.
56 self.include_config = []
58 if self.full_config_by_func:
59 for name in self.include_ops:
60 for config_item in self.full_config_by_func.get(name, []):
61 op_name, func = config_item[0], config_item[1]
62 # respect optional condition functions
63 if len(config_item) > 2:
64 condition_func = config_item[2]
65 if not condition_func():
66 continue
67 if op_name in self.cpp_patched_ops:
68 continue
69 self.include_config.append((op_name, func))
70 else:
71 # fallback: scan provided config and match by func name or op name
72 for config_item in self.config:
73 op_name, func = config_item[0], config_item[1]
74 func_name = func.__name__ if hasattr(func, "__name__") else str(func)
75 if (
76 func_name not in self.include_ops
77 and op_name not in self.include_ops
78 ):
79 continue
80 if len(config_item) > 2:
81 condition_func = config_item[2]
82 if not condition_func():
83 continue
84 if op_name in self.cpp_patched_ops:
85 continue
86 self.include_config.append((op_name, func))
88 if not self.include_config:
89 warnings.warn(
90 "only_enable failed: No op to register. Check if include is correct."
91 )
92 return
94 def config_filter(self):
95 def enabled(item):
96 return len(item) < 3 or bool(item[2]())
98 self.config = [
99 (item[0], item[1])
100 for item in self.config
101 if enabled(item)
102 and item[1].__name__ not in self.exclude_ops
103 and item[0] not in self.cpp_patched_ops
104 ]
106 def get_vendor_unused_op(self):
107 if self.device.vendor != common.vendors.NVIDIA:
108 return backend.get_unused_ops(self.device.vendor_name)
109 return []
111 def register_impl(self, key, fn):
112 if self.lib is None:
113 raise ValueError("Library instance is not provided.")
114 device_key = self.reg_key
115 self.all_ops.append(fn.__name__)
116 self.all_keys.append(key)
117 if self.device.vendor == common.vendors.CAMBRICON:
118 import torch
120 try:
121 self.torch_ops_map["aten::" + key] = torch.library.get_kernel(
122 "aten::" + key, device_key
123 )
124 except Exception:
125 pass
126 try:
127 self.lib.impl(key, fn, device_key, allow_override=True)
128 except TypeError:
129 # Older torch versions don't support allow_override
130 self.lib.impl(key, fn, device_key)
131 else:
132 self.lib.impl(key, fn, device_key)
134 def for_each(self):
135 for key, func in self.config:
136 try:
137 self.register_impl(key, func)
138 except Exception as e:
139 error.register_error(e)
141 def get_all_ops(self):
142 return self.all_ops
144 def get_all_keys(self):
145 return self.all_keys
147 def get_unused_ops(self):
148 return self.exclude_ops
150 def get_vendor_name(self):
151 return self.device.vendor_name
153 def get_current_device(self):
154 return self.device.name