Coverage for src/flag_gems/runtime/backend/_metax/fused/sparse_attention.py: 0%
55 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import torch
2import triton
3import triton.language as tl
6# ---------------------------------------------------------------------------
7# Triton kernel: sparse attention with attention-sink
8# grid = (m, b) — one program per (seq_pos, batch), handles ALL heads
9# Aligned with tilelang version: uses tl.dot (GEMM) instead of vector dot
10# ---------------------------------------------------------------------------
11@triton.jit
12def sparse_attn_triton_kernel(
13 Q, # (b, m, h, d) bf16
14 KV, # (b, n, d) bf16
15 O, # (b, m, h, d) bf16
16 attn_sink, # (h,) fp32
17 topk_idxs, # (b, m, topk) int32
18 stride_qb,
19 stride_qm,
20 stride_qh,
21 stride_qd,
22 stride_kvb,
23 stride_kvn,
24 stride_kvd,
25 stride_ob,
26 stride_om,
27 stride_oh,
28 stride_od,
29 stride_idxb,
30 stride_idxm,
31 stride_idxk,
32 scale,
33 topk,
34 H_ACTUAL,
35 BLOCK: tl.constexpr,
36 D: tl.constexpr,
37 H: tl.constexpr,
38):
39 pid_m = tl.program_id(0)
40 pid_b = tl.program_id(1)
42 # ---- load Q matrix: (H, D) — all heads at once ----
43 q_base = Q + pid_b * stride_qb + pid_m * stride_qm
44 offs_h = tl.arange(0, H)
45 offs_d = tl.arange(0, D)
46 h_mask = offs_h < H_ACTUAL
47 q_ptrs = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd
48 q_block = tl.load(q_ptrs, mask=h_mask[:, None], other=0.0) # (H, D) bf16
50 # ---- base pointers ----
51 kv_base = KV + pid_b * stride_kvb
52 idx_base = topk_idxs + pid_b * stride_idxb + pid_m * stride_idxm
54 # ---- online softmax state ----
55 acc_o = tl.zeros([H, D], dtype=tl.float32)
56 scores_max = tl.full([H], float("-inf"), dtype=tl.float32)
57 sum_exp = tl.zeros([H], dtype=tl.float32)
59 num_blocks = (topk + BLOCK - 1) // BLOCK
60 offs_blk = tl.arange(0, BLOCK)
62 for t in range(num_blocks):
63 # -- gather indices --
64 raw_offs = t * BLOCK + offs_blk # (BLOCK,)
65 idx_mask = raw_offs < topk
66 idxs = tl.load(
67 idx_base + raw_offs * stride_idxk, mask=idx_mask, other=-1
68 ) # (BLOCK,)
69 valid_mask = idxs != -1 # (BLOCK,)
71 # -- gather KV block: (BLOCK, D) --
72 kv_ptrs = kv_base + idxs[:, None] * stride_kvn + offs_d[None, :] * stride_kvd
73 kv_block = tl.load(
74 kv_ptrs, mask=valid_mask[:, None], other=0.0
75 ) # (BLOCK, D) bf16
77 # -- scores: Q @ KV^T -> (H, BLOCK) via GEMM --
78 acc_s = tl.dot(q_block, tl.trans(kv_block)) # (H, D) @ (D, BLOCK) = (H, BLOCK)
79 acc_s = acc_s * scale
80 # mask invalid positions to -inf
81 mask_bias = tl.where(valid_mask, 0.0, float("-inf")) # (BLOCK,)
82 acc_s = acc_s + mask_bias[None, :] # broadcast: (H, BLOCK)
84 # -- online softmax update --
85 scores_max_prev = scores_max
86 block_max = tl.max(acc_s, axis=1) # (H,)
87 scores_max = tl.maximum(scores_max, block_max)
89 correction = tl.exp(scores_max_prev - scores_max) # (H,)
90 p = tl.exp(acc_s - scores_max[:, None]) # (H, BLOCK)
92 # -- accumulate output: acc_o = acc_o * correction + P @ KV --
93 acc_o = acc_o * correction[:, None]
94 acc_o += tl.dot(p.to(tl.bfloat16), kv_block) # (H, BLOCK) @ (BLOCK, D) = (H, D)
96 scores_sum = tl.sum(p, axis=1) # (H,)
97 sum_exp = sum_exp * correction + scores_sum
99 # ---- incorporate attn_sink ----
100 sink_vals = tl.load(attn_sink + offs_h, mask=h_mask, other=0.0) # (H,)
101 sum_exp = sum_exp + tl.exp(sink_vals - scores_max)
103 # ---- normalize ----
104 acc_o = acc_o / sum_exp[:, None]
106 # ---- store output: (H, D) ----
107 o_base = O + pid_b * stride_ob + pid_m * stride_om
108 o_ptrs = o_base + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od
109 tl.store(o_ptrs, acc_o.to(tl.bfloat16), mask=h_mask[:, None])
112# ---------------------------------------------------------------------------
113# Python wrapper
114# ---------------------------------------------------------------------------
115def sparse_attn_triton(
116 q: torch.Tensor,
117 kv: torch.Tensor,
118 attn_sink: torch.Tensor,
119 topk_idxs: torch.Tensor,
120 softmax_scale: float,
121) -> torch.Tensor:
122 b, m, h, d = q.shape
123 topk = topk_idxs.shape[-1]
124 o = torch.empty_like(q)
125 BLOCK = 16
127 # H must be >= 16 for tl.dot; pad to next power of 2
128 H_padded = max(16, triton.next_power_of_2(h))
130 grid = (m, b) # each program handles ALL h heads
131 sparse_attn_triton_kernel[grid](
132 q,
133 kv,
134 o,
135 attn_sink,
136 topk_idxs,
137 q.stride(0),
138 q.stride(1),
139 q.stride(2),
140 q.stride(3),
141 kv.stride(0),
142 kv.stride(1),
143 kv.stride(2),
144 o.stride(0),
145 o.stride(1),
146 o.stride(2),
147 o.stride(3),
148 topk_idxs.stride(0),
149 topk_idxs.stride(1),
150 topk_idxs.stride(2),
151 softmax_scale,
152 topk,
153 h,
154 BLOCK=BLOCK,
155 D=d,
156 H=H_padded,
157 num_warps=8, # 256 threads, matching tilelang
158 )
159 return o