TE-FL 自定义后端示例#
本文档介绍添加自定义后端的两种方式。
两种方式#
方式 |
适用场景 |
示例文件 |
|---|---|---|
In-tree |
开源贡献,直接集成 |
|
Out-of-tree |
闭源/第三方插件,独立包 |
|
快速开始#
cd transformer_engine/plugin/examples
# In-tree 方式
python example_intree.py
# Out-of-tree 方式
python example_outtree.py
In-tree 方式(3 步)#
from transformer_engine.plugin.core import (
OpRegistry, OpManager, OpImpl, BackendImplKind
)
# 1. 定义算子实现
def my_rmsnorm(input, weight, eps=1e-5, **kwargs):
variance = input.pow(2).mean(-1, keepdim=True)
return input * torch.rsqrt(variance + eps) * weight, torch.rsqrt(variance + eps)
# 2. 注册到注册表
registry = OpRegistry()
registry.register_impl(OpImpl(
op_name="rmsnorm_fwd",
impl_id="vendor.mybackend",
kind=BackendImplKind.VENDOR,
vendor="mybackend",
fn=my_rmsnorm,
priority=200,
))
# 3. 通过管理器调用
manager = OpManager(registry)
output, rsigma = manager.call("rmsnorm_fwd", input, weight)
Out-of-tree 方式(插件包)#
插件包结构#
my_vendor_plugin/
├── __init__.py # 包含 register(registry) 函数
└── setup.py # 或 pyproject.toml
\\init\\.py#
from transformer_engine.plugin.core import OpImpl, BackendImplKind
def my_rmsnorm(input, weight, eps=1e-5, **kwargs):
# 您的实现
...
def register(registry):
"""由 TE-FL 自动调用"""
registry.register_impl(OpImpl(
op_name="rmsnorm_fwd",
impl_id="vendor.myvendor",
kind=BackendImplKind.VENDOR,
vendor="myvendor",
fn=my_rmsnorm,
priority=200,
))
加载方式#
# 方式 1:环境变量
export TE_FL_PLUGIN_MODULES=my_vendor_plugin
python your_script.py
# 方式 2:pip install(需要配置 entry_points)
pip install my-vendor-plugin
python your_script.py
环境变量#
后端选择#
变量 |
描述 |
值 |
默认值 |
|---|---|---|---|
|
首选后端类型(最高优先级) |
|
|
|
优先使用供应商后端(旧版,优先级低于 |
|
|
|
严格模式 — 如果首选实现失败则报错,不回退 |
|
|
供应商过滤#
变量 |
描述 |
示例 |
|---|---|---|
|
允许的供应商白名单(逗号分隔) |
|
|
拒绝的供应商黑名单(逗号分隔) |
|
逐算子配置#
变量 |
描述 |
示例 |
|---|---|---|
|
逐算子后端排序 |
|
格式:op_name=backend1|backend2;op_name2=backend3|backend4
插件发现#
变量 |
描述 |
示例 |
|---|---|---|
|
要加载的插件模块(逗号分隔) |
|
构建配置#
变量 |
描述 |
值 |
默认值 |
|---|---|---|---|
|
跳过 CUDA 后端(构建时和运行时) |
|
|
|
CUDA 安装路径 |
|
自动检测 |
|
替代 CUDA 路径变量 |
|
自动检测 |
日志#
变量 |
描述 |
值 |
默认值 |
|---|---|---|---|
|
TE-FL 日志级别 |
|
|
示例#
优先使用供应商后端#
export TE_FL_PREFER=vendor
python your_script.py
仅允许特定供应商#
export TE_FL_ALLOW_VENDORS=nvidia,acme
python your_script.py
自定义逐算子排序#
# rmsnorm 使用 acme 供应商,其他使用 flagos
export TE_FL_PER_OP="rmsnorm_fwd=vendor:acme|flagos"
python your_script.py
跳过 CUDA,仅使用 FlagOS#
export TE_FL_SKIP_CUDA=1
export TE_FL_PREFER=flagos
python your_script.py
启用调试日志#
export TEFL_LOG_LEVEL=DEBUG
python your_script.py
预期输出#
运行时,您应该看到类似以下的日志:
[TE-FL manager.py:133 INFO] Registered impl_ids: ['default.flagos', 'reference.torch', 'vendor.mybackend']
[TE-FL manager.py:390 INFO] Op 'rmsnorm_fwd' using 'vendor.mybackend' (kind=vendor, vendor=mybackend)