Coverage for src/flag_gems/logging_utils.py: 37%
113 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1"""Logging helpers for flag_gems.
3Notes
4-----
51) When you enter through the public APIs `enable`, `only_enable`, or the
6 context manager `use_gems`, the `record` flag controls whether op-level
7 logging is enabled and where it is written.
82) If you import `flag_gems` and call operators directly (e.g., `flag_gems.mm`)
9 without those helpers, call `setup_flaggems_logging()` yourself to initialize
10 the logging mode and file handler.
11"""
13import logging
14import traceback
15from pathlib import Path
17import torch
20class LogOncePerLocationFilter(logging.Filter):
21 def __init__(self):
22 super().__init__()
23 self.logged_locations = set()
25 def filter(self, record):
26 key = (record.pathname, record.lineno)
27 if key in self.logged_locations:
28 return False
29 self.logged_locations.add(key)
30 return True
33def _remove_file_handlers(logger: logging.Logger):
34 # Remove and close only the FileHandlers created by setup_flaggems_logging.
35 # This avoids touching unrelated FileHandlers attached by other modules.
36 removed = False
37 for h in list(logger.handlers):
38 if isinstance(h, logging.FileHandler) and getattr(h, "_flaggems_owned", False):
39 h.close()
40 logger.removeHandler(h)
41 removed = True
42 return removed
45def setup_flaggems_logging(path=None, record=True, once=False):
46 logger = logging.getLogger("flag_gems")
48 # If caller asks for recording, refresh file handler (new path overwrites old).
49 if record:
50 _remove_file_handlers(logger)
51 else:
52 return
54 filename = Path(path or Path.home() / ".flaggems/oplist.log")
55 handler = logging.FileHandler(filename, mode="w")
56 handler._flaggems_owned = True
58 if once:
59 handler.addFilter(LogOncePerLocationFilter())
61 formatter = logging.Formatter("[%(levelname)s] %(name)s.%(funcName)s: %(message)s")
62 handler.setFormatter(formatter)
64 logger.setLevel(logging.DEBUG)
65 logger.addHandler(handler)
66 logger.propagate = False
69def teardown_flaggems_logging(logger: logging.Logger | None = None):
70 """Remove file handlers for the flag_gems logger (used on context exit)."""
72 logger = logger or logging.getLogger("flag_gems")
73 _remove_file_handlers(logger)
76# Precision check data file writer
77# We intentionally use a plain file rather than the logging framework here
78# because precision results are structured data, not runtime diagnostics.
80_precision_file = None
81precision_config = {
82 "enabled": False,
83 "rtol": 1e-4,
84 "atol": 1e-5,
85 "log_once": True,
86 "logged_ops": set(),
87}
90def setup_precision_logging(path=None):
91 """Open (or reopen) the precision results data file for writing."""
92 global _precision_file
93 _close_precision_file()
95 filename = Path(path or Path.home() / ".flaggems/precision.log")
96 filename.parent.mkdir(parents=True, exist_ok=True)
97 _precision_file = open(filename, mode="w") # noqa: SIM115
100def _close_precision_file():
101 """Close the precision data file if open."""
102 global _precision_file
103 if _precision_file is not None and not _precision_file.closed:
104 _precision_file.close()
105 _precision_file = None
108def write_precision_result(record: dict):
109 """Write a precision check result as a JSON line to the data file.
111 Each call appends one JSON object (JSONL format) so the output can be
112 post-processed with standard tools such as ``jq`` or Python's ``json``
113 module.
114 """
115 import json
116 from datetime import datetime, timezone
118 if _precision_file is not None and not _precision_file.closed:
119 record["timestamp"] = datetime.now(tz=timezone.utc).isoformat()
120 _precision_file.write(json.dumps(record, default=str) + "\n")
121 _precision_file.flush()
124def get_tensor_info(t):
125 if isinstance(t, torch.Tensor):
126 return f"{tuple(t.shape)}:{t.dtype}"
127 elif isinstance(t, (list, tuple)):
128 infos = [get_tensor_info(x) for x in t]
129 return [i for i in infos if i]
130 return None
133def get_call_location():
134 for frame in traceback.extract_stack():
135 if "flag_gems" not in frame.filename and "torch" not in frame.filename:
136 return f"{frame.filename}:{frame.lineno}"
137 return "unknown"
140def compare_outputs(fg_out, pt_out, rtol, atol):
141 if isinstance(fg_out, torch.Tensor) and isinstance(pt_out, torch.Tensor):
142 if fg_out.shape != pt_out.shape:
143 return False, {
144 "error": "shape_mismatch",
145 "fg": tuple(fg_out.shape),
146 "pt": tuple(pt_out.shape),
147 }
148 try:
149 fg = fg_out.detach().float()
150 pt = pt_out.detach().float()
152 # Mask out positions where both are NaN or both are the same Inf
153 # These are not precision errors — they are semantically identical.
154 both_nan = torch.isnan(fg) & torch.isnan(pt)
155 both_same_inf = torch.isinf(fg) & torch.isinf(pt) & (fg == pt)
156 ignore_mask = both_nan | both_same_inf
158 # If all elements are in the ignore set, they match perfectly
159 if ignore_mask.all():
160 return True, {"max_abs": 0.0, "max_rel": 0.0}
162 # Check for mismatched NaN/Inf (one side has it, the other doesn't)
163 fg_special = torch.isnan(fg) | torch.isinf(fg)
164 pt_special = torch.isnan(pt) | torch.isinf(pt)
165 mismatch_special = (fg_special != pt_special) & ~ignore_mask
166 if mismatch_special.any():
167 # Find first mismatch for reporting
168 idx = mismatch_special.nonzero(as_tuple=False)[0]
169 return False, {
170 "error": "special_value_mismatch",
171 "fg": fg[tuple(idx)].item(),
172 "pt": pt[tuple(idx)].item(),
173 }
175 # Compare only finite, non-ignored elements
176 valid = ~ignore_mask & ~fg_special
177 if valid.any():
178 abs_diff = torch.abs(fg[valid] - pt[valid])
179 max_abs = abs_diff.max().item()
180 denom = torch.abs(pt[valid]) + 1e-12
181 max_rel = (abs_diff / denom).max().item()
182 else:
183 max_abs = 0.0
184 max_rel = 0.0
186 is_close = (
187 torch.allclose(fg[valid], pt[valid], rtol=rtol, atol=atol)
188 if valid.any()
189 else True
190 )
191 return is_close, {"max_abs": max_abs, "max_rel": max_rel}
192 except Exception as e:
193 return True, {"error": "exception", "message": str(e)}
194 elif isinstance(fg_out, (tuple, list)) and isinstance(pt_out, (tuple, list)):
195 for i, (fg, pt) in enumerate(zip(fg_out, pt_out)):
196 ok, info = compare_outputs(fg, pt, rtol, atol)
197 if not ok:
198 info["index"] = i
199 return False, info
200 return True, {}
203def enable_precision_check(
204 rtol=1e-4, atol=1e-5, log_once=True, max_checks=10, path=None
205):
206 setup_precision_logging(path)
207 precision_config.update(
208 {
209 "enabled": True,
210 "rtol": rtol,
211 "atol": atol,
212 "log_once": log_once,
213 "max_checks": max_checks,
214 "logged_ops": set(),
215 }
216 )
219def disable_precision_check():
220 """Close precision data file and disable precision check."""
221 _close_precision_file()
222 precision_config["enabled"] = False