精度检查(实验性功能)#

FlagGems 提供了一个实验性的精度检查机制,能够自动将 FlagGems 算子的输出 与原生 PyTorch(CPU)的计算结果进行对比,并将精度不一致的情况记录到日志文件中。 这对于开发过程中验证数值正确性非常有用。

如何启用#

启用精度检查需要两步:

  1. flag_gems.logging_utils 中调用 enable_precision_check() 配置精度日志。

  2. PrecisionCheckRegister 作为 registrar 参数传递给 enable()only_enable(),使算子在注册时被包装上精度检查逻辑。

import flag_gems
from flag_gems.logging_utils import enable_precision_check
from flag_gems.runtime.precision_register import PrecisionCheckRegister

# 第一步:配置精度检查(初始化精度日志)
enable_precision_check()

# 第二步:使用 PrecisionCheckRegister 注册所有算子
flag_gems.enable(registrar=PrecisionCheckRegister)

# 运行你的模型或算子
output = model(input)

也可以配合 only_enable() 仅对特定算子进行精度检查:

from flag_gems.logging_utils import enable_precision_check
from flag_gems.runtime.precision_register import PrecisionCheckRegister

enable_precision_check(rtol=1e-3, atol=1e-4)
flag_gems.only_enable(
    include=["mm", "add", "softmax"],
    registrar=PrecisionCheckRegister,
)

配置参数#

你可以通过向 enable_precision_check() 传递参数来自定义精度检查的行为。

参数名称

数据类型

默认值

描述

rtol

float

1e-4

相对误差容忍度

atol

float

1e-5

绝对误差容忍度

max_checks

int

10

每个算子最多检查的调用次数(超过后不再检查以减少开销)

log_once

bool

True

每个算子仅记录一次失败

path

str

~/.flaggems/precision.log

日志文件路径

from flag_gems.logging_utils import enable_precision_check

enable_precision_check(
    rtol=1e-3,
    atol=1e-5,
    max_checks=20,
    path="./my_precision.log",
)

关闭精度检查#

如需在运行时关闭精度检查:

from flag_gems.logging_utils import disable_precision_check

disable_precision_check()

日志输出#

精度检查的结果默认写入 ~/.flaggems/precision.log 文件。 只有未通过容忍度检查的算子才会被记录。

日志输出示例:

$ cat ~/.flaggems/precision.log
2025-05-19 10:00:01 [WARNING] Op: add.Tensor | FAIL | in: [(2, 3):torch.float16] | out: (2, 3):torch.float16 | max_abs: 1.200000e-03 | max_rel: 2.500000e-02 | rtol=0.01, atol=0.01

行为细节#

精度检查器内置了多项保护措施以尽量减少对性能的影响:

  • 每个算子仅检查前 N 次调用(由 max_checks 控制)

  • 超过 100 万元素的张量会被跳过,以避免大张量拷贝的开销

  • 一旦某个算子记录了一次失败,后续不再对其进行检查

  • 纯布局/内存操作(如 cloneviewcopy_)会被自动跳过

  • 随机采样算子(如 uniform_normal_)会被自动跳过

  • .out 变体的算子也会被跳过

  • 对于 float16 / bfloat16 类型的输入,容忍度会自动放宽到至少 1e-2

工作原理#

当使用 PrecisionCheckRegister 作为注册器时,每个算子会被包装上 一个精度检查装饰器。该装饰器的工作流程为:

  1. 正常执行 FlagGems(GPU)实现,得到结果。

  2. 将输入拷贝到 CPU,调用原生 aten 算子计算参考结果。

  3. 使用配置的容忍度对比两个结果。

  4. 如果结果超出容忍度范围,记录一条警告日志。

[!WARNING] 警告

精度检查会将 GPU 张量拷贝到 CPU 并执行原生 PyTorch 计算作为参考, 这会带来显著的性能开销。此功能仅建议在开发调试阶段使用, 不应在生产环境中启用。