Examples

Examples#

This section includes FlagTree examples.

Sparse MLA Forward#

This module implements a Triton kernel for the forward pass of a sparse MLA (Multi-Headed Attention) mechanism. It demonstrates the use of tle.load for efficient memory access and computation.

import torch
import triton
import triton.language as tl
import triton.experimental.tle.language as tle

spar_mla_fwd_configs = [
    triton.Config({'num_stages': 4, 'num_warps': 8}),
    # triton.Config({'num_stages': 2, 'num_warps': 4}),
]


@triton.autotune(  # Decorate the kernel
    configs=spar_mla_fwd_configs,
    key=['K', 'is_causal'],
)
@triton.jit
def triton_sparse_mla_fwd(q, kv, indices, sm_scale: tl.constexpr, output, lse, stride_qb, stride_qh, stride_qm,
                          stride_qd, stride_kvb, stride_kvg, stride_kvn, stride_kvd, stride_tb, stride_tg, stride_tm,
                          stride_tt,  # topk,for indices
                          stride_ob, stride_oh, stride_om, stride_od, stride_lb, stride_lh, stride_lm, B: tl.constexpr,
                          SQ: tl.constexpr,  # seqlen
                          SKV: tl.constexpr, K: tl.constexpr,  # topk
                          D: tl.constexpr,  # QKV dim
                          TD: tl.constexpr,  # tail dim
                          DP: tl.constexpr, TDP: tl.constexpr, H: tl.constexpr,  # q_head_dim
                          G: tl.constexpr,  # group_size
                          VG: tl.constexpr,  # H/G KV groups
                          BK: tl.constexpr, BH: tl.constexpr, is_causal: tl.constexpr):
    i_b, i_sq, i_gbh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    i_g, i_bh = i_gbh // G, i_gbh % G
    q_base = q + i_b * stride_qb + i_sq * stride_qm + i_gbh * (BH * stride_qh)
    tq_base = q_base + D * stride_qd
    kv_base = kv + i_b * stride_kvb + i_g * stride_kvg
    tkv_base = kv_base + D * stride_kvd
    t_base = indices + i_b * stride_tb + i_sq * stride_tm + i_g * stride_tg
    o_base = output + i_b * stride_ob + i_sq * stride_om + i_gbh * (BH * stride_oh)
    l_base = lse + i_b * stride_lb + i_sq * stride_lm + i_gbh * (BH * stride_lh)

    offs_h = tl.arange(0, BH)
    offs_d = tl.arange(0, DP)
    offs_td = tl.arange(0, TDP)
    offs_od = tl.arange(0, DP)
    offs_t = tl.arange(0, BK)
    mask_h = i_bh * BH + offs_h < G
    mask_d = offs_d < D
    mask_td = offs_td < TD
    mask_od = mask_d

    q_ptr = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd
    q_msk = mask_h[:, None] & mask_d[None, :]
    q_blk = tl.load(q_ptr, q_msk, other=0.0)

    tq_ptr = tq_base + offs_h[:, None] * stride_qh + offs_td[None, :] * stride_qd
    tq_msk = mask_h[:, None] & mask_td[None, :]
    tq_blk = tl.load(tq_ptr, tq_msk, other=0.0)

    max_prev = tl.full([BH], float('-inf'), dtype=tl.float32)
    sum_exp = tl.full([BH], 1.0, dtype=tl.float32)
    acc = tl.zeros([BH, DP], dtype=tl.float32)

    log_scale: tl.constexpr = sm_scale * 1.44269504

    max_col = i_sq if is_causal else SQ - 1

    NK = tl.cdiv(K, BK)
    for ck in tl.range(NK, num_stages=0):
        if ck * BK <= max_col:
            t_ptr = (BK * ck + offs_t) * stride_tt
            t_msk = t_ptr < K
            t_ptr += t_base
            kv_ids = tl.load(t_ptr, t_msk, other=-1)
            mask_ids = (kv_ids <= max_col) & (kv_ids >= 0)

            kv_ptr = kv_base + offs_d[:, None] * stride_kvd + kv_ids[None, :] * stride_kvn
            kv_msk = mask_d[:, None] & mask_ids[None, :]
            kv_blk = tle.load(kv_ptr, kv_msk, other=0.0, is_async=True)  # [DP, BK]

            tkv_ptr = tkv_base + offs_td[:, None] * stride_kvd + kv_ids[None, :] * stride_kvn
            tkv_msk = mask_td[:, None] & mask_ids[None, :]
            tkv_blk = tle.load(tkv_ptr, tkv_msk, other=0.0, is_async=False)  # [TDP, BK]

            qk = tl.dot(tq_blk, tkv_blk, out_dtype=tl.float32)
            qk = tl.dot(q_blk, kv_blk, qk, out_dtype=tl.float32)

            qk = tl.where(mask_ids[None, :], qk, float('-inf'))  # [BH, BK]

            new_max = tl.maximum(max_prev, tl.max(qk, axis=1))
            alpha = tl.math.exp2((max_prev - new_max) * log_scale)
            exp_qk = tl.math.exp2(qk * log_scale - new_max[:, None] * log_scale)
            sum_qk = tl.sum(exp_qk, axis=1)
            sum_exp = sum_exp * alpha + sum_qk
            acc = acc * alpha[:, None]
            exp_qk = exp_qk.to(tl.bfloat16)
            acc = tl.dot(exp_qk, tl.trans(kv_blk), acc, out_dtype=tl.float32)  # [BH, BK] @ [BK, DP] = [BH, DP]

            max_prev = new_max

    out_vals = acc / sum_exp[:, None]
    o_ptr = o_base + offs_h[:, None] * stride_oh + offs_od[None, :] * stride_od
    o_msk = mask_h[:, None] & mask_od[None, :]
    tl.store(o_ptr, out_vals.to(q_blk.dtype), o_msk)

    fin_log = max_prev * log_scale + tl.math.log2(sum_exp.to(tl.float32))  # lse / ln2
    l_ptr = l_base + offs_h * stride_lh
    l_msk = mask_h
    tl.store(l_ptr, fin_log.to(q_blk.dtype), l_msk)


def triton_sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, return_p_sum: bool = False, d_v=512):
    is_causal = True
    assert not return_p_sum, "This kernel file is for fwd only"
    assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous()
    B, SQ, H, DT = q.shape
    _, S, VG, _ = kv.shape

    # assert DT == 576, "you should assign dim otherwise"
    D = d_v

    assert kv.shape[-1] == DT
    TD = DT - D
    DP = triton.next_power_of_2(D)
    TDP = triton.next_power_of_2(TD)
    assert kv.shape[0] == B
    _, _, _, K = indices.shape
    assert indices.shape == (B, SQ, VG, K)
    G = H // VG
    if sm_scale is None:
        sm_scale = DT**-0.5
    BH = 32
    NH = triton.cdiv(G, BH)
    BK = 32
    output = torch.zeros((B, SQ, H, D), device=q.device, dtype=q.dtype)
    lse = torch.full((B, SQ, H), float('-inf'), device=q.device, dtype=q.dtype)
    grid = (B, SQ, VG * NH)  # (SQ//BQ, B*H)
    triton_sparse_mla_fwd[grid](
        q, kv, indices, sm_scale, output, lse, q.stride(0), q.stride(2), q.stride(1), q.stride(3),  # [B, H, SQ, DT]
        kv.stride(0), kv.stride(2), kv.stride(1), kv.stride(3),  # [B, VG, SKV, DT]
        indices.stride(0), indices.stride(2), indices.stride(1), indices.stride(3),  # [B, VG, SQ, K]
        output.stride(0), output.stride(2), output.stride(1), output.stride(3),  # [B, H, SQ, D]
        lse.stride(0), lse.stride(2), lse.stride(1),  # [B, H, SQ]
        B, SQ, S, K, D, TD, DP, TDP, H, G, VG, BK, BH,
        # BD,
        is_causal)
    # sparse_mla_fwd[grid](q, kv, indices, output)
    return output, lse


def ref_sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, is_casual=True, d_v=512):
    q = q.float()
    kv = kv.float()
    indices = indices.transpose(1, 2)
    b, sq, h, dim_q = q.shape
    b, sk, g, _ = kv.shape

    dim = d_v
    # assert kv.shape[-1] == 576, "you should assign dim otherwise"
    # dim = 512
    k = kv
    v = kv[..., :dim]

    b, _, _, dim_v = v.shape
    g_index = g
    h_index = h // g
    compressed_casual_mask = torch.arange(0, sq, dtype=torch.int32,
                                          device="cuda").view(-1,
                                                              1) >= torch.arange(1 - 1, sk * 1, 1, dtype=torch.int32,
                                                                                 device="cuda").view(1, -1)

    mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1)
    mask = mask[..., :-1]
    mask = mask & compressed_casual_mask.view(1, 1, sq, sk)
    mask[:, :, :1 - 1, 0] = True
    mask = mask.view(b, g_index, 1, sq, sk)

    q = q.view(b, sq, g, -1, dim_q)
    score = torch.einsum("bmghd,bngd->bghmn", q, k)
    sm_scale = dim_q**-0.5 if sm_scale is None else sm_scale
    score = score.masked_fill(~mask, float("-inf")).mul(sm_scale)
    p = score.softmax(dim=-1)
    p = p.view(b, g_index, h_index, -1, sq, sk)
    p = p.view(b, g, -1, sq, sk)
    o = torch.einsum("bghmn,bngd->bmghd", p.type(v.dtype), v)
    o = o.reshape(b, sq, h, dim_v)
    return o.to(torch.bfloat16)


def test_sparse_mla_fwd(B=1, S=4096, SKV=4096, H=128, HKV=1, DQK=576, DV=512, topk=2048, dtype=torch.bfloat16):
    torch.random.manual_seed(0)
    q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True)
    kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True)

    indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda")
    for b in range(B):
        for t in range(S):
            for h in range(HKV):
                i_i = torch.randperm(max(1, t))[:topk]
                indices[b, t, h, :len(i_i)] = i_i

    ref_bf16_out = ref_sparse_mla_fwd_interface(q, kv, indices, d_v=DV)

    triton_bf16_out, triton_bf16_lse = triton_sparse_mla_fwd_interface(q, kv, indices, d_v=DV)
    print("triton bf16 done \n triton lse tensor: \n", triton_bf16_lse)
    print()

    assert torch.allclose(
        triton_bf16_out.float(),
        ref_bf16_out.float(),
        atol=1e-1,
        rtol=1e-1,
    ), "Triton sparse MLA fwd bf16 does not match reference"
    print("Triton sparse MLA fwd bf16 matches reference!")


if __name__ == "__main__":
    test_sparse_mla_fwd(B=1, S=128, SKV=1024, H=32, HKV=1, DQK=256 + 32, DV=256, topk=64, dtype=torch.bfloat16)