Coverage for src/flag_gems/logging_utils.py: 37%

113 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1"""Logging helpers for flag_gems. 

2 

3Notes 

4----- 

51) When you enter through the public APIs `enable`, `only_enable`, or the 

6 context manager `use_gems`, the `record` flag controls whether op-level 

7 logging is enabled and where it is written. 

82) If you import `flag_gems` and call operators directly (e.g., `flag_gems.mm`) 

9 without those helpers, call `setup_flaggems_logging()` yourself to initialize 

10 the logging mode and file handler. 

11""" 

12 

13import logging 

14import traceback 

15from pathlib import Path 

16 

17import torch 

18 

19 

20class LogOncePerLocationFilter(logging.Filter): 

21 def __init__(self): 

22 super().__init__() 

23 self.logged_locations = set() 

24 

25 def filter(self, record): 

26 key = (record.pathname, record.lineno) 

27 if key in self.logged_locations: 

28 return False 

29 self.logged_locations.add(key) 

30 return True 

31 

32 

33def _remove_file_handlers(logger: logging.Logger): 

34 # Remove and close only the FileHandlers created by setup_flaggems_logging. 

35 # This avoids touching unrelated FileHandlers attached by other modules. 

36 removed = False 

37 for h in list(logger.handlers): 

38 if isinstance(h, logging.FileHandler) and getattr(h, "_flaggems_owned", False): 

39 h.close() 

40 logger.removeHandler(h) 

41 removed = True 

42 return removed 

43 

44 

45def setup_flaggems_logging(path=None, record=True, once=False): 

46 logger = logging.getLogger("flag_gems") 

47 

48 # If caller asks for recording, refresh file handler (new path overwrites old). 

49 if record: 

50 _remove_file_handlers(logger) 

51 else: 

52 return 

53 

54 filename = Path(path or Path.home() / ".flaggems/oplist.log") 

55 handler = logging.FileHandler(filename, mode="w") 

56 handler._flaggems_owned = True 

57 

58 if once: 

59 handler.addFilter(LogOncePerLocationFilter()) 

60 

61 formatter = logging.Formatter("[%(levelname)s] %(name)s.%(funcName)s: %(message)s") 

62 handler.setFormatter(formatter) 

63 

64 logger.setLevel(logging.DEBUG) 

65 logger.addHandler(handler) 

66 logger.propagate = False 

67 

68 

69def teardown_flaggems_logging(logger: logging.Logger | None = None): 

70 """Remove file handlers for the flag_gems logger (used on context exit).""" 

71 

72 logger = logger or logging.getLogger("flag_gems") 

73 _remove_file_handlers(logger) 

74 

75 

76# Precision check data file writer 

77# We intentionally use a plain file rather than the logging framework here 

78# because precision results are structured data, not runtime diagnostics. 

79 

80_precision_file = None 

81precision_config = { 

82 "enabled": False, 

83 "rtol": 1e-4, 

84 "atol": 1e-5, 

85 "log_once": True, 

86 "logged_ops": set(), 

87} 

88 

89 

90def setup_precision_logging(path=None): 

91 """Open (or reopen) the precision results data file for writing.""" 

92 global _precision_file 

93 _close_precision_file() 

94 

95 filename = Path(path or Path.home() / ".flaggems/precision.log") 

96 filename.parent.mkdir(parents=True, exist_ok=True) 

97 _precision_file = open(filename, mode="w") # noqa: SIM115 

98 

99 

100def _close_precision_file(): 

101 """Close the precision data file if open.""" 

102 global _precision_file 

103 if _precision_file is not None and not _precision_file.closed: 

104 _precision_file.close() 

105 _precision_file = None 

106 

107 

108def write_precision_result(record: dict): 

109 """Write a precision check result as a JSON line to the data file. 

110 

111 Each call appends one JSON object (JSONL format) so the output can be 

112 post-processed with standard tools such as ``jq`` or Python's ``json`` 

113 module. 

114 """ 

115 import json 

116 from datetime import datetime, timezone 

117 

118 if _precision_file is not None and not _precision_file.closed: 

119 record["timestamp"] = datetime.now(tz=timezone.utc).isoformat() 

120 _precision_file.write(json.dumps(record, default=str) + "\n") 

121 _precision_file.flush() 

122 

123 

124def get_tensor_info(t): 

125 if isinstance(t, torch.Tensor): 

126 return f"{tuple(t.shape)}:{t.dtype}" 

127 elif isinstance(t, (list, tuple)): 

128 infos = [get_tensor_info(x) for x in t] 

129 return [i for i in infos if i] 

130 return None 

131 

132 

133def get_call_location(): 

134 for frame in traceback.extract_stack(): 

135 if "flag_gems" not in frame.filename and "torch" not in frame.filename: 

136 return f"{frame.filename}:{frame.lineno}" 

137 return "unknown" 

138 

139 

140def compare_outputs(fg_out, pt_out, rtol, atol): 

141 if isinstance(fg_out, torch.Tensor) and isinstance(pt_out, torch.Tensor): 

142 if fg_out.shape != pt_out.shape: 

143 return False, { 

144 "error": "shape_mismatch", 

145 "fg": tuple(fg_out.shape), 

146 "pt": tuple(pt_out.shape), 

147 } 

148 try: 

149 fg = fg_out.detach().float() 

150 pt = pt_out.detach().float() 

151 

152 # Mask out positions where both are NaN or both are the same Inf 

153 # These are not precision errors — they are semantically identical. 

154 both_nan = torch.isnan(fg) & torch.isnan(pt) 

155 both_same_inf = torch.isinf(fg) & torch.isinf(pt) & (fg == pt) 

156 ignore_mask = both_nan | both_same_inf 

157 

158 # If all elements are in the ignore set, they match perfectly 

159 if ignore_mask.all(): 

160 return True, {"max_abs": 0.0, "max_rel": 0.0} 

161 

162 # Check for mismatched NaN/Inf (one side has it, the other doesn't) 

163 fg_special = torch.isnan(fg) | torch.isinf(fg) 

164 pt_special = torch.isnan(pt) | torch.isinf(pt) 

165 mismatch_special = (fg_special != pt_special) & ~ignore_mask 

166 if mismatch_special.any(): 

167 # Find first mismatch for reporting 

168 idx = mismatch_special.nonzero(as_tuple=False)[0] 

169 return False, { 

170 "error": "special_value_mismatch", 

171 "fg": fg[tuple(idx)].item(), 

172 "pt": pt[tuple(idx)].item(), 

173 } 

174 

175 # Compare only finite, non-ignored elements 

176 valid = ~ignore_mask & ~fg_special 

177 if valid.any(): 

178 abs_diff = torch.abs(fg[valid] - pt[valid]) 

179 max_abs = abs_diff.max().item() 

180 denom = torch.abs(pt[valid]) + 1e-12 

181 max_rel = (abs_diff / denom).max().item() 

182 else: 

183 max_abs = 0.0 

184 max_rel = 0.0 

185 

186 is_close = ( 

187 torch.allclose(fg[valid], pt[valid], rtol=rtol, atol=atol) 

188 if valid.any() 

189 else True 

190 ) 

191 return is_close, {"max_abs": max_abs, "max_rel": max_rel} 

192 except Exception as e: 

193 return True, {"error": "exception", "message": str(e)} 

194 elif isinstance(fg_out, (tuple, list)) and isinstance(pt_out, (tuple, list)): 

195 for i, (fg, pt) in enumerate(zip(fg_out, pt_out)): 

196 ok, info = compare_outputs(fg, pt, rtol, atol) 

197 if not ok: 

198 info["index"] = i 

199 return False, info 

200 return True, {} 

201 

202 

203def enable_precision_check( 

204 rtol=1e-4, atol=1e-5, log_once=True, max_checks=10, path=None 

205): 

206 setup_precision_logging(path) 

207 precision_config.update( 

208 { 

209 "enabled": True, 

210 "rtol": rtol, 

211 "atol": atol, 

212 "log_once": log_once, 

213 "max_checks": max_checks, 

214 "logged_ops": set(), 

215 } 

216 ) 

217 

218 

219def disable_precision_check(): 

220 """Close precision data file and disable precision check.""" 

221 _close_precision_file() 

222 precision_config["enabled"] = False