Coverage for src/flag_gems/runtime/backend/_sunrise/fused/sparse_attention.py: 0%
62 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import torch
2import triton
3import triton.language as tl
6# ---------------------------------------------------------------------------
7# Triton kernel: sparse attention with attention-sink
8# grid = (m, b) — one program per (seq_pos, batch), handles ALL heads
9# Aligned with tilelang version: uses tl.dot (GEMM) instead of vector dot
10# ---------------------------------------------------------------------------
11@triton.jit
12def sparse_attn_triton_kernel(
13 Q, # (b, m, h, d) bf16
14 KV, # (b, n, d) bf16
15 O, # (b, m, h, d) bf16
16 attn_sink, # (h,) fp32
17 topk_idxs, # (b, m, topk) int32
18 stride_qb,
19 stride_qm,
20 stride_qh,
21 stride_qd,
22 stride_kvb,
23 stride_kvn,
24 stride_kvd,
25 stride_ob,
26 stride_om,
27 stride_oh,
28 stride_od,
29 stride_idxb,
30 stride_idxm,
31 stride_idxk,
32 scale,
33 topk,
34 BLOCK: tl.constexpr,
35 D: tl.constexpr,
36 H: tl.constexpr,
37):
38 pid_m = tl.program_id(0)
39 pid_b = tl.program_id(1)
41 # ---- load Q matrix: (H, D) — all heads at once ----
42 q_base = Q + pid_b * stride_qb + pid_m * stride_qm
43 offs_h = tl.arange(0, H)
44 offs_d = tl.arange(0, D)
45 q_ptrs = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd
46 q_block = tl.load(q_ptrs) # (H, D) bf16
48 # ---- base pointers ----
49 kv_base = KV + pid_b * stride_kvb
50 idx_base = topk_idxs + pid_b * stride_idxb + pid_m * stride_idxm
52 # ---- online softmax state ----
53 acc_o = tl.zeros([H, D], dtype=tl.float32)
54 scores_max = tl.full([H], float("-inf"), dtype=tl.float32)
55 sum_exp = tl.zeros([H], dtype=tl.float32)
57 num_blocks = (topk + BLOCK - 1) // BLOCK
58 offs_blk = tl.arange(0, BLOCK)
60 for t in range(num_blocks):
61 # -- gather indices --
62 raw_offs = t * BLOCK + offs_blk # (BLOCK,)
63 idx_mask = raw_offs < topk
64 idxs = tl.load(
65 idx_base + raw_offs * stride_idxk, mask=idx_mask, other=-1
66 ) # (BLOCK,)
67 valid_mask = idxs != -1 # (BLOCK,)
69 # -- gather KV block: (BLOCK, D) --
70 kv_ptrs = kv_base + idxs[:, None] * stride_kvn + offs_d[None, :] * stride_kvd
71 kv_block = tl.load(
72 kv_ptrs, mask=valid_mask[:, None], other=0.0
73 ) # (BLOCK, D) bf16
75 # -- scores: Q @ KV^T -> (H, BLOCK) via GEMM --
76 acc_s = tl.dot(q_block, tl.trans(kv_block)) # (H, D) @ (D, BLOCK) = (H, BLOCK)
77 acc_s = acc_s * scale
78 # mask invalid positions to -inf
79 mask_bias = tl.where(valid_mask, 0.0, float("-inf")) # (BLOCK,)
80 acc_s = acc_s + mask_bias[None, :] # broadcast: (H, BLOCK)
82 # -- online softmax update --
83 scores_max_prev = scores_max
84 block_max = tl.max(acc_s, axis=1) # (H,)
85 scores_max = tl.maximum(scores_max, block_max)
87 correction = tl.exp(scores_max_prev - scores_max) # (H,)
88 p = tl.exp(acc_s - scores_max[:, None]) # (H, BLOCK)
90 # -- accumulate output: acc_o = acc_o * correction + P @ KV --
91 acc_o = acc_o * correction[:, None]
92 acc_o += tl.dot(p.to(tl.bfloat16), kv_block) # (H, BLOCK) @ (BLOCK, D) = (H, D)
94 scores_sum = tl.sum(p, axis=1) # (H,)
95 sum_exp = sum_exp * correction + scores_sum
97 # ---- incorporate attn_sink ----
98 sink_vals = tl.load(attn_sink + offs_h) # (H,)
99 sum_exp = sum_exp + tl.exp(sink_vals - scores_max)
101 # ---- normalize ----
102 acc_o = acc_o / sum_exp[:, None]
104 # ---- store output: (H, D) ----
105 o_base = O + pid_b * stride_ob + pid_m * stride_om
106 o_ptrs = o_base + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od
107 tl.store(o_ptrs, acc_o.to(tl.bfloat16))
110# ---------------------------------------------------------------------------
111# Python wrapper
112# ---------------------------------------------------------------------------
113def sparse_attn_triton(
114 q: torch.Tensor,
115 kv: torch.Tensor,
116 attn_sink: torch.Tensor,
117 topk_idxs: torch.Tensor,
118 softmax_scale: float,
119) -> torch.Tensor:
120 b, m, h, d = q.shape
121 topk = topk_idxs.shape[-1]
122 o = torch.empty_like(q)
123 BLOCK = 64
125 grid = (m, b)
126 if h < 8:
127 q_new = torch.zeros((b, m, 8, d), dtype=q.dtype, device=q.device)
128 q_new[:, :, :h] = q
129 attn_sink_new = torch.zeros((8,), dtype=torch.float32, device=attn_sink.device)
130 attn_sink_new[:h] = attn_sink
131 o_new = torch.zeros((b, m, 8, d), dtype=q.dtype, device=q.device)
132 sparse_attn_triton_kernel[grid](
133 q_new,
134 kv,
135 o_new,
136 attn_sink_new,
137 topk_idxs,
138 q_new.stride(0),
139 q_new.stride(1),
140 q_new.stride(2),
141 q_new.stride(3),
142 kv.stride(0),
143 kv.stride(1),
144 kv.stride(2),
145 o_new.stride(0),
146 o_new.stride(1),
147 o_new.stride(2),
148 o_new.stride(3),
149 topk_idxs.stride(0),
150 topk_idxs.stride(1),
151 topk_idxs.stride(2),
152 softmax_scale,
153 topk,
154 BLOCK=BLOCK,
155 D=d,
156 H=8,
157 num_warps=8, # 256 threads, matching tilelang
158 )
159 o = o_new[:, :, :h].contiguous()
160 return o
162 sparse_attn_triton_kernel[grid](
163 q,
164 kv,
165 o,
166 attn_sink,
167 topk_idxs,
168 q.stride(0),
169 q.stride(1),
170 q.stride(2),
171 q.stride(3),
172 kv.stride(0),
173 kv.stride(1),
174 kv.stride(2),
175 o.stride(0),
176 o.stride(1),
177 o.stride(2),
178 o.stride(3),
179 topk_idxs.stride(0),
180 topk_idxs.stride(1),
181 topk_idxs.stride(2),
182 softmax_scale,
183 topk,
184 BLOCK=BLOCK,
185 D=d,
186 H=h,
187 num_warps=8, # 256 threads, matching tilelang
188 )
189 return o