Coverage for src/flag_gems/runtime/precision_register.py: 0%

94 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1"""Precision-checking register – loaded only when precision checking is enabled. 

2 

3This module is NOT imported on the normal execution path. It is lazily 

4imported by ``register.py`` only when the user explicitly requests 

5``PrecisionCheckRegister``. 

6""" 

7 

8import functools 

9 

10import torch 

11 

12from ..config import get_skip_precision_check_ops 

13from ..logging_utils import ( 

14 compare_outputs, 

15 get_tensor_info, 

16 precision_config, 

17 write_precision_result, 

18) 

19from .register import Register 

20 

21# Maximum tensor element count allowed for precision check 

22# (skip if exceeded to avoid large tensor copy overhead) 

23_MAX_NUMEL_FOR_CHECK = 1 * 1024 * 1024 # 1M elements 

24 

25 

26def _get_dtype_tolerance(args, default_rtol, default_atol): 

27 """Automatically adjust tolerance based on the dtype of input tensors.""" 

28 for a in args: 

29 if isinstance(a, torch.Tensor) and a.is_floating_point(): 

30 if a.dtype in (torch.bfloat16, torch.float16): 

31 return (max(default_rtol, 1e-2), max(default_atol, 1e-2)) 

32 break 

33 return (default_rtol, default_atol) 

34 

35 

36def _to_cpu(x): 

37 """Recursively move tensors to CPU.""" 

38 if isinstance(x, torch.Tensor): 

39 return x.detach().cpu() 

40 elif isinstance(x, (list, tuple)): 

41 return type(x)(_to_cpu(i) for i in x) 

42 elif isinstance(x, dict): 

43 return {k: _to_cpu(v) for k, v in x.items()} 

44 return x 

45 

46 

47def _max_tensor_numel(args): 

48 """Return the element count of the largest tensor in the arguments.""" 

49 max_n = 0 

50 for a in args: 

51 if isinstance(a, torch.Tensor): 

52 max_n = max(max_n, a.numel()) 

53 return max_n 

54 

55 

56# Operators that should never be precision-checked – loaded from conf/operators.yaml 

57_SKIP_OPS = get_skip_precision_check_ops() 

58 

59 

60def _parse_op_key(op_key): 

61 """Parse op_key once and return (op_name, overload_name, should_skip). 

62 

63 This avoids repeated string splitting inside the hot wrapper. 

64 """ 

65 # Strip namespace prefix (e.g. "aten::add.Tensor" -> "add.Tensor") 

66 bare_key = op_key.split("::")[-1] if "::" in op_key else op_key 

67 

68 # Split into base op name and overload 

69 dot_pos = bare_key.find(".") 

70 if dot_pos >= 0: 

71 op_name = bare_key[:dot_pos] 

72 overload_name = bare_key[dot_pos + 1 :] 

73 else: 

74 op_name = bare_key 

75 overload_name = "default" 

76 

77 # Determine if this op should be skipped entirely (never checked) 

78 should_skip = ( 

79 overload_name == "out" or op_name.endswith("_out") or op_name in _SKIP_OPS 

80 ) 

81 

82 return op_name, overload_name, should_skip 

83 

84 

85def _wrap_op_with_precision_check(op_key, fn): 

86 """Wrap a FlagGems operator to compare its output against native PyTorch. 

87 

88 Since FlagGems replaces the CUDA dispatch, the native implementation 

89 cannot be called on GPU, so inputs are copied to CPU to compute the 

90 reference result. Performance overhead is controlled by: 

91 - max_checks: only check the first N calls per operator (default 10) 

92 - skip large tensors (over 1M elements) 

93 - once a failure is logged, that operator is no longer checked 

94 """ 

95 # --- Pre-compute everything derivable from op_key at wrap time --- 

96 op_name, overload_name, should_skip = _parse_op_key(op_key) 

97 

98 # If this op can never be checked, return the unwrapped function directly 

99 if should_skip: 

100 return fn 

101 

102 # Pre-resolve the aten overload so we don't do getattr on every call 

103 aten_packet = getattr(torch.ops.aten, op_name, None) 

104 aten_overload = getattr(aten_packet, overload_name, None) if aten_packet else None 

105 

106 # If we can't find the native implementation, no point wrapping 

107 if aten_overload is None: 

108 return fn 

109 

110 # Pre-fetch the set reference for fast membership test 

111 _logged_ops = precision_config["logged_ops"] 

112 _call_count = 0 

113 

114 @functools.wraps(fn) 

115 def wrapper(*args, **kwargs): 

116 nonlocal _call_count 

117 

118 # Skip operators that have already logged a failure 

119 if op_key in _logged_ops: 

120 return fn(*args, **kwargs) 

121 

122 # Sampling: only check the first N calls per operator 

123 _call_count += 1 

124 if _call_count > precision_config.get("max_checks", 10): 

125 return fn(*args, **kwargs) 

126 

127 # Skip large tensors to avoid copy overhead 

128 if _max_tensor_numel(args) > _MAX_NUMEL_FOR_CHECK: 

129 return fn(*args, **kwargs) 

130 

131 # Execute the FlagGems implementation FIRST with no interference 

132 fg_result = fn(*args, **kwargs) 

133 

134 try: 

135 # Copy inputs and output to CPU for comparison. 

136 # The .cpu() call implicitly synchronizes the CUDA stream. 

137 # For in-place ops (op_name ends with '_'), inputs may have been 

138 # modified, but those ops are typically skipped via _SKIP_OPS. 

139 cpu_args = [_to_cpu(a) for a in args] 

140 cpu_kwargs = {k: _to_cpu(v) for k, v in kwargs.items()} 

141 fg_result_cpu = _to_cpu(fg_result) 

142 

143 with torch.no_grad(): 

144 pt_result_cpu = aten_overload(*cpu_args, **cpu_kwargs) 

145 

146 cfg = precision_config 

147 rtol, atol = _get_dtype_tolerance(args, cfg["rtol"], cfg["atol"]) 

148 is_close, info = compare_outputs(fg_result_cpu, pt_result_cpu, rtol, atol) 

149 

150 if not is_close: 

151 _logged_ops.add(op_key) 

152 input_info = [get_tensor_info(a) for a in args if get_tensor_info(a)] 

153 output_info = get_tensor_info(fg_result) 

154 

155 record = { 

156 "op": op_key, 

157 "status": "FAIL", 

158 "inputs": input_info, 

159 "output": output_info, 

160 "rtol": rtol, 

161 "atol": atol, 

162 } 

163 if "error" in info: 

164 record["error"] = info["error"] 

165 record["fg_value"] = info["fg"] 

166 record["pt_value"] = info["pt"] 

167 else: 

168 record["max_abs_diff"] = info["max_abs"] 

169 record["max_rel_diff"] = info["max_rel"] 

170 write_precision_result(record) 

171 

172 except Exception: 

173 pass 

174 

175 return fg_result 

176 

177 return wrapper 

178 

179 

180class PrecisionCheckRegister(Register): 

181 """Register subclass that wraps every operator with precision checking. 

182 

183 This class is only instantiated when the user has explicitly called 

184 ``enable_precision_check()`` before ``enable()`` / ``only_enable()``. 

185 It is never on the normal execution path. 

186 """ 

187 

188 def register_impl(self, key, fn, extra_dispatch_keys=()): 

189 if self.lib is None: 

190 raise ValueError("Library instance is not provided.") 

191 

192 wrapped_fn = _wrap_op_with_precision_check(key, fn) 

193 

194 device_key = self.reg_key 

195 self.all_ops.append(fn.__name__) 

196 self.all_keys.append(key) 

197 self.lib.impl(key, wrapped_fn, device_key) 

198 for dispatch_key in extra_dispatch_keys: 

199 self.lib.impl(key, wrapped_fn, dispatch_key)