Coverage for src/flag_gems/runtime/precision_register.py: 0%
94 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1"""Precision-checking register – loaded only when precision checking is enabled.
3This module is NOT imported on the normal execution path. It is lazily
4imported by ``register.py`` only when the user explicitly requests
5``PrecisionCheckRegister``.
6"""
8import functools
10import torch
12from ..config import get_skip_precision_check_ops
13from ..logging_utils import (
14 compare_outputs,
15 get_tensor_info,
16 precision_config,
17 write_precision_result,
18)
19from .op_registrar import GeneralOpRegistrar
21# Maximum tensor element count allowed for precision check
22# (skip if exceeded to avoid large tensor copy overhead)
23_MAX_NUMEL_FOR_CHECK = 1 * 1024 * 1024 # 1M elements
26def _get_dtype_tolerance(args, default_rtol, default_atol):
27 """Automatically adjust tolerance based on the dtype of input tensors."""
28 for a in args:
29 if isinstance(a, torch.Tensor) and a.is_floating_point():
30 if a.dtype in (torch.bfloat16, torch.float16):
31 return (max(default_rtol, 1e-2), max(default_atol, 1e-2))
32 break
33 return (default_rtol, default_atol)
36def _to_cpu(x):
37 """Recursively move tensors to CPU."""
38 if isinstance(x, torch.Tensor):
39 return x.detach().cpu()
40 elif isinstance(x, (list, tuple)):
41 return type(x)(_to_cpu(i) for i in x)
42 elif isinstance(x, dict):
43 return {k: _to_cpu(v) for k, v in x.items()}
44 return x
47def _max_tensor_numel(args):
48 """Return the element count of the largest tensor in the arguments."""
49 max_n = 0
50 for a in args:
51 if isinstance(a, torch.Tensor):
52 max_n = max(max_n, a.numel())
53 return max_n
56# Operators that should never be precision-checked – loaded from conf/operators.yaml
57_SKIP_OPS = get_skip_precision_check_ops()
60def _parse_op_key(op_key):
61 """Parse op_key once and return (op_name, overload_name, should_skip).
63 This avoids repeated string splitting inside the hot wrapper.
64 """
65 # Strip namespace prefix (e.g. "aten::add.Tensor" -> "add.Tensor")
66 bare_key = op_key.split("::")[-1] if "::" in op_key else op_key
68 # Split into base op name and overload
69 dot_pos = bare_key.find(".")
70 if dot_pos >= 0:
71 op_name = bare_key[:dot_pos]
72 overload_name = bare_key[dot_pos + 1 :]
73 else:
74 op_name = bare_key
75 overload_name = "default"
77 # Determine if this op should be skipped entirely (never checked)
78 should_skip = (
79 overload_name == "out" or op_name.endswith("_out") or op_name in _SKIP_OPS
80 )
82 return op_name, overload_name, should_skip
85def _wrap_op_with_precision_check(op_key, fn):
86 """Wrap a FlagGems operator to compare its output against native PyTorch.
88 Since FlagGems replaces the CUDA dispatch, the native implementation
89 cannot be called on GPU, so inputs are copied to CPU to compute the
90 reference result. Performance overhead is controlled by:
91 - max_checks: only check the first N calls per operator (default 10)
92 - skip large tensors (over 1M elements)
93 - once a failure is logged, that operator is no longer checked
94 """
95 # --- Pre-compute everything derivable from op_key at wrap time ---
96 op_name, overload_name, should_skip = _parse_op_key(op_key)
98 # If this op can never be checked, return the unwrapped function directly
99 if should_skip:
100 return fn
102 # Pre-resolve the aten overload so we don't do getattr on every call
103 aten_packet = getattr(torch.ops.aten, op_name, None)
104 aten_overload = getattr(aten_packet, overload_name, None) if aten_packet else None
106 # If we can't find the native implementation, no point wrapping
107 if aten_overload is None:
108 return fn
110 # Pre-fetch the set reference for fast membership test
111 _logged_ops = precision_config["logged_ops"]
112 _call_count = 0
114 @functools.wraps(fn)
115 def wrapper(*args, **kwargs):
116 nonlocal _call_count
118 # Skip operators that have already logged a failure
119 if op_key in _logged_ops:
120 return fn(*args, **kwargs)
122 # Sampling: only check the first N calls per operator
123 _call_count += 1
124 if _call_count > precision_config.get("max_checks", 10):
125 return fn(*args, **kwargs)
127 # Skip large tensors to avoid copy overhead
128 if _max_tensor_numel(args) > _MAX_NUMEL_FOR_CHECK:
129 return fn(*args, **kwargs)
131 # Execute the FlagGems implementation FIRST with no interference
132 fg_result = fn(*args, **kwargs)
134 try:
135 # Copy inputs and output to CPU for comparison.
136 # The .cpu() call implicitly synchronizes the CUDA stream.
137 # For in-place ops (op_name ends with '_'), inputs may have been
138 # modified, but those ops are typically skipped via _SKIP_OPS.
139 cpu_args = [_to_cpu(a) for a in args]
140 cpu_kwargs = {k: _to_cpu(v) for k, v in kwargs.items()}
141 fg_result_cpu = _to_cpu(fg_result)
143 with torch.no_grad():
144 pt_result_cpu = aten_overload(*cpu_args, **cpu_kwargs)
146 cfg = precision_config
147 rtol, atol = _get_dtype_tolerance(args, cfg["rtol"], cfg["atol"])
148 is_close, info = compare_outputs(fg_result_cpu, pt_result_cpu, rtol, atol)
150 if not is_close:
151 _logged_ops.add(op_key)
152 input_info = [get_tensor_info(a) for a in args if get_tensor_info(a)]
153 output_info = get_tensor_info(fg_result)
155 record = {
156 "op": op_key,
157 "status": "FAIL",
158 "inputs": input_info,
159 "output": output_info,
160 "rtol": rtol,
161 "atol": atol,
162 }
163 if "error" in info:
164 record["error"] = info["error"]
165 record["fg_value"] = info["fg"]
166 record["pt_value"] = info["pt"]
167 else:
168 record["max_abs_diff"] = info["max_abs"]
169 record["max_rel_diff"] = info["max_rel"]
170 write_precision_result(record)
172 except Exception:
173 pass
175 return fg_result
177 return wrapper
180class PrecisionCheckRegister(GeneralOpRegistrar):
181 """Register subclass that wraps every operator with precision checking.
183 This class is only instantiated when the user has explicitly called
184 ``enable_precision_check()`` before ``enable()`` / ``only_enable()``.
185 It is never on the normal execution path.
186 """
188 def register_impl(self, key, fn, extra_dispatch_keys=()):
189 if self.lib is None:
190 raise ValueError("Library instance is not provided.")
192 wrapped_fn = _wrap_op_with_precision_check(key, fn)
194 device_key = self.reg_key
195 self.all_ops.append(fn.__name__)
196 self.all_keys.append(key)
197 self.lib.impl(key, wrapped_fn, device_key)
198 for dispatch_key in extra_dispatch_keys:
199 self.lib.impl(key, wrapped_fn, dispatch_key)