精度检查(实验性功能)#
FlagGems 提供了一个实验性的精度检查机制,能够自动将 FlagGems 算子的输出 与原生 PyTorch(CPU)的计算结果进行对比,并将精度不一致的情况记录到日志文件中。 这对于开发过程中验证数值正确性非常有用。
如何启用#
启用精度检查需要两步:
从
flag_gems.logging_utils中调用enable_precision_check()配置精度日志。将
PrecisionCheckRegister作为registrar参数传递给enable()或only_enable(),使算子在注册时被包装上精度检查逻辑。
import flag_gems
from flag_gems.logging_utils import enable_precision_check
from flag_gems.runtime.precision_register import PrecisionCheckRegister
# 第一步:配置精度检查(初始化精度日志)
enable_precision_check()
# 第二步:使用 PrecisionCheckRegister 注册所有算子
flag_gems.enable(registrar=PrecisionCheckRegister)
# 运行你的模型或算子
output = model(input)
也可以配合 only_enable() 仅对特定算子进行精度检查:
from flag_gems.logging_utils import enable_precision_check
from flag_gems.runtime.precision_register import PrecisionCheckRegister
enable_precision_check(rtol=1e-3, atol=1e-4)
flag_gems.only_enable(
include=["mm", "add", "softmax"],
registrar=PrecisionCheckRegister,
)
配置参数#
你可以通过向 enable_precision_check() 传递参数来自定义精度检查的行为。
参数名称 |
数据类型 |
默认值 |
描述 |
|---|---|---|---|
|
|
|
相对误差容忍度 |
|
|
|
绝对误差容忍度 |
|
|
|
每个算子最多检查的调用次数(超过后不再检查以减少开销) |
|
|
|
每个算子仅记录一次失败 |
|
|
|
日志文件路径 |
from flag_gems.logging_utils import enable_precision_check
enable_precision_check(
rtol=1e-3,
atol=1e-5,
max_checks=20,
path="./my_precision.log",
)
关闭精度检查#
如需在运行时关闭精度检查:
from flag_gems.logging_utils import disable_precision_check
disable_precision_check()
日志输出#
精度检查的结果默认写入 ~/.flaggems/precision.log 文件。
只有未通过容忍度检查的算子才会被记录。
日志输出示例:
$ cat ~/.flaggems/precision.log
2025-05-19 10:00:01 [WARNING] Op: add.Tensor | FAIL | in: [(2, 3):torch.float16] | out: (2, 3):torch.float16 | max_abs: 1.200000e-03 | max_rel: 2.500000e-02 | rtol=0.01, atol=0.01
行为细节#
精度检查器内置了多项保护措施以尽量减少对性能的影响:
每个算子仅检查前 N 次调用(由
max_checks控制)超过 100 万元素的张量会被跳过,以避免大张量拷贝的开销
一旦某个算子记录了一次失败,后续不再对其进行检查
纯布局/内存操作(如
clone、view、copy_)会被自动跳过随机采样算子(如
uniform_、normal_)会被自动跳过.out变体的算子也会被跳过对于
float16/bfloat16类型的输入,容忍度会自动放宽到至少1e-2
工作原理#
当使用 PrecisionCheckRegister 作为注册器时,每个算子会被包装上
一个精度检查装饰器。该装饰器的工作流程为:
正常执行 FlagGems(GPU)实现,得到结果。
将输入拷贝到 CPU,调用原生
aten算子计算参考结果。使用配置的容忍度对比两个结果。
如果结果超出容忍度范围,记录一条警告日志。
[!WARNING] 警告
精度检查会将 GPU 张量拷贝到 CPU 并执行原生 PyTorch 计算作为参考, 这会带来显著的性能开销。此功能仅建议在开发调试阶段使用, 不应在生产环境中启用。