TLE 参考#

本页面介绍编写 FlagFFT Triton 内核时使用的 FlagTree/TLE API。

TLE-Lite#

设计理念:一次编写,处处运行。使用高层语义提示(而非硬约束)引导编译器启发式。

内存管理#

tle.load#

tl.load 的扩展,支持异步提示:

x = tle.load(..., is_async=True)

对于稍后在计算密集型区域中重用的全局内存读取,使用 is_async=True。在边界分块上保持 maskother 显式。

示例:尾部块的受保护异步加载

offs = base + tl.arange(0, BLOCK)
mask = offs < n_elements
x = tle.load(x_ptr + offs, mask=mask, other=0.0, is_async=True)

示例:异步加载 + 计算重叠模式

for k in tl.range(0, K, BK, num_stages=2):
    a = tle.load(a_ptr + k * stride_a, is_async=True)
    b = tle.load(b_ptr + k * stride_b, is_async=True)
    acc = tl.dot(a, b, acc)

张量切片#

tle.extract_tile#

将输入张量分割为子块网格,并在指定坐标处提取块。

z = x.extract_tile(index=[0, 0], shape=[2, 2])

tle.insert_tile#

将输入张量分割为子块网格,并在指定坐标处更新块。

z = x.insert_tile(y, index=[0, 0])

示例:寄存器中的逐块后处理

sub = x.extract_tile(index=[1, 0], shape=[2, 2])
sub = tl.maximum(sub, 0.0)  # 对子块执行 ReLU
x = x.insert_tile(sub, index=[1, 0])

流水线#

tle.pipeline_group#

用于显式阶段控制的提示式扩展。

自动阶段划分:

for yoff in tl.range(0, ynumel, YBLOCK, num_stages=2):
    Q = tl.load(...)
    K = tl.load(...)
    KT = tl.trans(K)
    V = tl.dot(Q, KT)

使用 warp 特化的手动阶段划分:

for yoff in tle.range(
    0, ynumel, YBLOCK,
    num_stages=2,
    pipe_stages=[0, 0, 1] if LOAD_TRANS else [0, 1, 1],
    pipe_orders=[0, 1, 2],
    executors=[0, 0, 0] if ONE_CORE else [0, 0, range(1, 31)],
):
    with tle.pipeline_group(0):
        Q = tl.load(...)
        K = tl.load(...)
    with tle.pipeline_group(1):
        KT = tl.trans(K)
    with tle.pipeline_group(2):
        V = tl.dot(Q, KT)

分布式#

tle.device_mesh#

定义物理设备拓扑:

topology = {
    "node": [("node_x", 2), ("node_y", 2)],
    "device": 4,
    "block_cluster": [("cluster_x", 2), ("cluster_y", 2)],
    "block": 4,
}
mesh = tle.device_mesh(topology=topology)
# mesh.shape -> (2, 2, 4, 2, 2, 4),总大小 = 256

tle.sharding#

在设备网格上声明张量分布:

x_shard = tle.sharding(mesh, split=[["cluster_x", "cluster_y"], "device"], partial=["block"])
x = tle.make_sharded_tensor(x_ptr, sharding=x_shard, shape=[4, 4])

符号:tle.S(axis)(分割)、tle.B(广播)、tle.P(axis)(部分)。

tle.reshard#

将张量转换为新的分布状态。编译器自动插入通信原语。

典型转换:Scatter、Gather、Reduce、All-gather、All-reduce。

x_full = tle.reshard(x, spec=tle.sharding(mesh, split=[], partial=[]))

tle.remote#

获取其他设备上张量数据的句柄(点对点或 RDMA/NVLink):

remote_x = tle.remote(x, shard_id=(node_rank, next_device), scope=mesh)

tle.shard_id#

查询当前程序在网格轴上的坐标:

node_rank = tle.shard_id(mesh, "node")
device_rank = tle.shard_id(mesh, "device")

tle.distributed_dot#

线程块簇范围内的跨块矩阵乘法:

def distributed_dot(a, b, c=None):
    """在当前 TBC 范围内执行分布式矩阵乘法。"""

TLE-Struct#

设计理念:架构感知,细粒度调优。按硬件拓扑族分类后端,暴露公共层次结构,让开发者显式定义结构化计算/数据映射。

GPU#

tle.gpu.memory_space#

指定张量内存空间:

x = tle.gpu.memory_space(x, "shared_memory")

tle.gpu.alloc#

分配内存:

a_smem = tle.gpu.alloc([XBLOCK, YBLOCK], dtype=tl.float32, scope=tle.gpu.storage_kind.smem)

tle.gpu.local_ptr#

在共享内存缓冲区上构建指针视图,用于 tl.load/tl.store/tl.atomic*

示例:一维切片

smem = tle.gpu.alloc([BLOCK], dtype=tl.float32, scope=tle.gpu.smem)
idx = offset + tl.arange(0, SLICE)
slice_ptr = tle.gpu.local_ptr(smem, (idx,))
vals = tl.load(slice_ptr)

示例:K 维分块

smem_a = tle.gpu.alloc([BM, BK], dtype=tl.float16, scope=tle.gpu.smem)
rows = tl.broadcast_to(tl.arange(0, BM)[:, None], (BM, BK))
cols = tl.broadcast_to(tl.arange(0, BK)[None, :] + k_start, (BM, BK))
a_slice = tle.gpu.local_ptr(smem_a, (rows, cols))
a_vals = tl.load(a_slice)

示例:对 local_ptr 执行原子操作

smem_i32 = tle.gpu.alloc([BLOCK], dtype=tl.int32, scope=tle.gpu.smem)
ptr = tle.gpu.local_ptr(smem_i32, (tl.arange(0, BLOCK),))
tl.store(ptr, tl.zeros([BLOCK], dtype=tl.int32))
tl.atomic_add(ptr, 1)
vals = tl.load(ptr)

tle.gpu.local_ptr(远程)#

为远程共享/本地缓冲区具现化指针视图:

smem = tle.gpu.alloc([BM, BK], dtype=tl.float16, scope=tle.gpu.storage_kind.smem)
remote_smem = tle.remote(smem, shard_id=(node_rank, next_device), scope=mesh)
remote_ptr = tle.gpu.local_ptr(remote_smem, (rows, cols))
vals = tl.load(remote_ptr)

tle.gpu.copy#

内存复制:

tle.gpu.copy(a_ptrs + ystride_a * yoffs[None, :], a_smem, [XBLOCK, YBLOCK])

DSA#

tle.dsa.alloc#

分配 DSA 本地缓冲区:

a_ub = tle.dsa.alloc([XBLOCK, YBLOCK], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.UB)
b_l1 = tle.dsa.alloc([XBLOCK, YBLOCK], dtype=tl.float32, mem_addr_space=tle.dsa.ascend.L1)

Ascend 内存空间:UBL1L0AL0BL0C

tle.dsa.copy#

GMEM 指针与 DSA 本地缓冲区之间的显式移动:

tle.dsa.copy(x_ptrs, a_ub, [tail_m, tail_n])       # GMEM -> 本地
tle.dsa.copy(a_ub, out_ptrs, [tail_m, tail_n])      # 本地 -> GMEM

tle.dsa.to_tensor / tle.dsa.to_buffer#

在缓冲区和张量视图之间转换:

c_val = tle.dsa.to_tensor(c_ub, writable=True)
result = c_val * 0.5
d_ub = tle.dsa.to_buffer(result, tle.dsa.ascend.UB)

逐元素计算操作#

tle.dsa.addtle.dsa.subtle.dsa.multle.dsa.divtle.dsa.maxtle.dsa.min —— 对 DSA 本地缓冲区的逐元素二元操作。

示例:算术链

tle.dsa.sub(a_ub, b_ub, tmp_ub)      # tmp = a - b
tle.dsa.mul(tmp_ub, b_ub, tmp_ub)    # tmp = tmp * b
tle.dsa.div(tmp_ub, scale_ub, out_ub)  # out = tmp / scale

示例:通过 max + min 进行钳位

tle.dsa.max(x_ub, floor_ub, tmp_ub)  # tmp = max(x, floor)
tle.dsa.min(tmp_ub, ceil_ub, y_ub)   # y = min(tmp, ceil)

示例集#

共享内存暂存#

a_smem = tle.gpu.alloc([BM, BK], dtype=tl.float16, scope=tle.gpu.storage_kind.smem)
tle.gpu.copy(a_ptrs, a_smem, [BM, BK])
rows = tl.broadcast_to(tl.arange(0, BM)[:, None], (BM, BK))
cols = tl.broadcast_to(tl.arange(0, BK)[None, :], (BM, BK))
a_ptr_local = tle.gpu.local_ptr(a_smem, (rows, cols))
a_tile = tl.load(a_ptr_local)

共享内存原子操作#

bins = 256
counts = tle.gpu.alloc([bins], dtype=tl.int32, scope=tle.gpu.storage_kind.smem)
idx = tl.arange(0, BLOCK) % bins
count_ptr = tle.gpu.local_ptr(counts, (idx,))
tl.atomic_add(count_ptr, 1)

DSA 本地缓冲区流程#

a_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)
b_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)
c_ub = tle.dsa.alloc([BM, BK], dtype=tl.float16, mem_addr_space=tle.dsa.ascend.UB)

tle.dsa.copy(a_ptrs, a_ub, [BM, BK])
tle.dsa.copy(b_ptrs, b_ub, [BM, BK])
tle.dsa.add(a_ub, b_ub, c_ub)

c_val = tle.dsa.to_tensor(c_ub, writable=True)
out_ub = tle.dsa.to_buffer(c_val, tle.dsa.ascend.UB)
tle.dsa.copy(out_ub, out_ptrs, [BM, BK])