Coverage for src/flag_gems/runtime/backend/_iluvatar/fused/sparse_attention.py: 0%
66 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import torch
2import triton
3import triton.language as tl
6# ---------------------------------------------------------------------------
7# Triton kernel: sparse attention with attention-sink
8# grid = (m, b) — one program per (seq_pos, batch), handles ALL heads
9# Aligned with tilelang version: uses tl.dot (GEMM) instead of vector dot
10#
11# Iluvatar-compatible: no tl.load mask/other, no tl.where
12# ---------------------------------------------------------------------------
13@triton.jit
14def sparse_attn_triton_kernel(
15 Q, # (b, m, h, d) bf16
16 KV, # (b, n, d) bf16
17 O, # (b, m, h, d) bf16
18 attn_sink, # (h,) fp32
19 topk_idxs, # (b, m, topk) int32
20 stride_qb,
21 stride_qm,
22 stride_qh,
23 stride_qd,
24 stride_kvb,
25 stride_kvn,
26 stride_kvd,
27 stride_ob,
28 stride_om,
29 stride_oh,
30 stride_od,
31 stride_idxb,
32 stride_idxm,
33 stride_idxk,
34 scale,
35 topk,
36 H_ACTUAL,
37 BLOCK: tl.constexpr,
38 D: tl.constexpr,
39 H: tl.constexpr,
40):
41 pid_m = tl.program_id(0)
42 pid_b = tl.program_id(1)
44 # ---- load Q matrix: (H, D) — all heads at once ----
45 q_base = Q + pid_b * stride_qb + pid_m * stride_qm
46 offs_h = tl.arange(0, H)
47 offs_d = tl.arange(0, D)
48 h_mask = offs_h < H_ACTUAL
49 h_mask_f = h_mask.to(tl.float32)
50 # Use offs_h directly, will load OOB for h >= H_ACTUAL but we mask later
51 q_ptrs = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd
52 q_block = tl.load(q_ptrs) # (H, D) bf16
53 # zero padded heads via arithmetic (avoid tl.where)
54 q_block = (q_block.to(tl.float32) * h_mask_f[:, None]).to(tl.bfloat16)
56 # ---- base pointers ----
57 kv_base = KV + pid_b * stride_kvb
58 idx_base = topk_idxs + pid_b * stride_idxb + pid_m * stride_idxm
60 # ---- online softmax state ----
61 acc_o = tl.zeros([H, D], dtype=tl.float32)
62 scores_max = tl.full([H], float("-inf"), dtype=tl.float32)
63 sum_exp = tl.zeros([H], dtype=tl.float32)
65 num_blocks = (topk + BLOCK - 1) // BLOCK
66 offs_blk = tl.arange(0, BLOCK)
68 for t in range(num_blocks):
69 # -- gather indices (clamp to avoid OOB, mask via score bias) --
70 raw_offs = t * BLOCK + offs_blk # (BLOCK,)
71 idx_mask = raw_offs < topk
72 safe_raw_offs = tl.minimum(raw_offs, topk - 1)
73 idxs = tl.load(idx_base + safe_raw_offs * stride_idxk) # (BLOCK,)
75 # -- gather KV block: (BLOCK, D) --
76 safe_idxs = tl.maximum(idxs, 0)
77 kv_ptrs = (
78 kv_base + safe_idxs[:, None] * stride_kvn + offs_d[None, :] * stride_kvd
79 )
80 kv_block = tl.load(kv_ptrs) # (BLOCK, D) bf16
82 # -- scores: Q @ KV^T -> (H, BLOCK) via GEMM --
83 acc_s = tl.dot(q_block, tl.trans(kv_block)) # (H, D) @ (D, BLOCK) = (H, BLOCK)
84 acc_s = acc_s * scale
85 # mask invalid positions to -large via arithmetic (avoid tl.where)
86 mask_bias = (
87 idx_mask.to(tl.float32) - 1.0
88 ) * 1e30 # 0 for valid, -1e30 for invalid
89 acc_s = acc_s + mask_bias[None, :] # broadcast: (H, BLOCK)
91 # -- online softmax update --
92 scores_max_prev = scores_max
93 block_max = tl.max(acc_s, axis=1) # (H,)
94 scores_max = tl.maximum(scores_max, block_max)
96 correction = tl.exp(scores_max_prev - scores_max) # (H,)
97 p = tl.exp(acc_s - scores_max[:, None]) # (H, BLOCK)
99 # -- accumulate output: acc_o = acc_o * correction + P @ KV --
100 acc_o = acc_o * correction[:, None]
101 acc_o += tl.dot(p.to(tl.bfloat16), kv_block) # (H, BLOCK) @ (BLOCK, D) = (H, D)
103 scores_sum = tl.sum(p, axis=1) # (H,)
104 sum_exp = sum_exp * correction + scores_sum
106 # ---- incorporate attn_sink ----
107 # attn_sink is now padded to H elements, safe to load with offs_h
108 sink_vals = tl.load(attn_sink + offs_h) # (H,)
109 # zero padded heads' sink via arithmetic
110 sink_vals = sink_vals * h_mask_f
111 sum_exp = sum_exp + tl.exp(sink_vals - scores_max)
113 # ---- normalize ----
114 acc_o = acc_o / sum_exp[:, None]
116 # ---- store output: (H, D) ----
117 o_base = O + pid_b * stride_ob + pid_m * stride_om
118 o_ptrs = o_base + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od
119 tl.store(o_ptrs, acc_o.to(tl.bfloat16), mask=h_mask[:, None])
122# ---------------------------------------------------------------------------
123# Python wrapper
124# ---------------------------------------------------------------------------
125def sparse_attn_triton(
126 q: torch.Tensor,
127 kv: torch.Tensor,
128 attn_sink: torch.Tensor,
129 topk_idxs: torch.Tensor,
130 softmax_scale: float,
131) -> torch.Tensor:
132 b, m, h, d = q.shape
133 topk = topk_idxs.shape[-1]
134 o = torch.empty_like(q)
136 # H must be >= 16 for tl.dot; pad to next power of 2
137 H_padded = max(16, triton.next_power_of_2(h))
139 # Pad attn_sink to H_padded elements for safe kernel indexing
140 if attn_sink.shape[0] < H_padded:
141 attn_sink_padded = torch.zeros(
142 H_padded, dtype=attn_sink.dtype, device=attn_sink.device
143 )
144 attn_sink_padded[: attn_sink.shape[0]] = attn_sink
145 else:
146 attn_sink_padded = attn_sink
148 # Reduce BLOCK for large D to stay within resource limits
149 if d >= 256:
150 BLOCK = 16
151 else:
152 BLOCK = 64
154 # Reduce warps for large D to lower register pressure
155 num_warps = 2 if d >= 256 else 8
157 grid = (m, b) # each program handles ALL h heads
158 sparse_attn_triton_kernel[grid](
159 q,
160 kv,
161 o,
162 attn_sink_padded,
163 topk_idxs,
164 q.stride(0),
165 q.stride(1),
166 q.stride(2),
167 q.stride(3),
168 kv.stride(0),
169 kv.stride(1),
170 kv.stride(2),
171 o.stride(0),
172 o.stride(1),
173 o.stride(2),
174 o.stride(3),
175 topk_idxs.stride(0),
176 topk_idxs.stride(1),
177 topk_idxs.stride(2),
178 softmax_scale,
179 topk,
180 h,
181 BLOCK=BLOCK,
182 D=d,
183 H=H_padded,
184 num_warps=num_warps,
185 num_stages=1,
186 )
187 return o