Coverage for src/flag_gems/runtime/backend/_ascend/fused/sparse_attention.py: 0%
66 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import os
3import torch
4import triton
5import triton.language as tl
7# Enable all blocks parallel to avoid coreDim > 65535 issue on NPU
8os.environ["TRITON_ALL_BLOCKS_PARALLEL"] = "1"
11# ---------------------------------------------------------------------------
12# Triton kernel: sparse attention with attention-sink
13# Adapted for Ascend NPU: 1D grid, tiling for UB overflow
14# ---------------------------------------------------------------------------
15@triton.jit
16def sparse_attn_triton_kernel(
17 Q, # (b, m, h, d) bf16
18 KV, # (b, n, d) bf16
19 O, # (b, m, h, d) bf16
20 attn_sink, # (h,) fp32
21 topk_idxs, # (b, m, topk) int32
22 stride_qb,
23 stride_qm,
24 stride_qh,
25 stride_qd,
26 stride_kvb,
27 stride_kvn,
28 stride_kvd,
29 stride_ob,
30 stride_om,
31 stride_oh,
32 stride_od,
33 stride_idxb,
34 stride_idxm,
35 stride_idxk,
36 scale,
37 topk,
38 kv_len,
39 H_ACTUAL,
40 BLOCK: tl.constexpr,
41 BLOCK_SUB: tl.constexpr,
42 D: tl.constexpr,
43 H: tl.constexpr,
44 BATCH_STRIDE: tl.constexpr,
45):
46 # 1D grid: each task handles one (batch, seq_pos)
47 pid = tl.program_id(0)
48 pid_b = pid // BATCH_STRIDE
49 pid_m = pid % BATCH_STRIDE
51 # ---- load Q matrix: (H, D) — all heads at once ----
52 q_base = Q + pid_b * stride_qb + pid_m * stride_qm
53 offs_h = tl.arange(0, H)
54 offs_d = tl.arange(0, D)
55 h_mask = offs_h < H_ACTUAL
56 q_ptrs = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd
57 q_block = tl.load(q_ptrs, mask=h_mask[:, None], other=0.0) # (H, D) bf16
59 # ---- base pointers ----
60 kv_base = KV + pid_b * stride_kvb
61 idx_base = topk_idxs + pid_b * stride_idxb + pid_m * stride_idxm
63 # ---- online softmax state ----
64 acc_o = tl.zeros([H, D], dtype=tl.float32)
65 scores_max = tl.full([H], float("-inf"), dtype=tl.float32)
66 sum_exp = tl.zeros([H], dtype=tl.float32)
68 # Two-level tiling: BLOCK (outer) -> BLOCK_SUB (inner)
69 num_block_iter = (topk + BLOCK - 1) // BLOCK
70 num_sub_iter = (BLOCK + BLOCK_SUB - 1) // BLOCK_SUB
71 offs_blk = tl.arange(0, BLOCK_SUB)
73 for t in range(num_block_iter):
74 block_start = t * BLOCK
75 # Process BLOCK elements in sub-tiles
76 for s in range(num_sub_iter):
77 sub_start = block_start + s * BLOCK_SUB
78 raw_offs = sub_start + offs_blk # (BLOCK_SUB,)
79 idx_mask = raw_offs < topk
80 idxs = tl.load(
81 idx_base + raw_offs * stride_idxk, mask=idx_mask, other=0
82 ) # (BLOCK_SUB,)
84 # Clamp negative indices to 0 (matching PyTorch behavior on NPU)
85 idxs = tl.where(idxs < 0, 0, idxs)
87 # Check index validity: idxs must be >= 0 and < kv_len
88 # Create valid mask based on both position and index value
89 index_valid = (idxs >= 0) & (idxs < kv_len)
90 valid_mask = idx_mask & index_valid # (BLOCK_SUB,)
92 # -- gather KV block: (BLOCK_SUB, D) --
93 kv_ptrs = (
94 kv_base + idxs[:, None] * stride_kvn + offs_d[None, :] * stride_kvd
95 )
96 kv_block = tl.load(
97 kv_ptrs, mask=valid_mask[:, None], other=0.0
98 ) # (BLOCK_SUB, D) bf16
100 # -- scores: Q @ KV^T -> (H, BLOCK_SUB) via GEMM --
101 acc_s = tl.dot(
102 q_block, tl.trans(kv_block)
103 ) # (H, D) @ (D, BLOCK_SUB) = (H, BLOCK_SUB)
104 acc_s = acc_s * scale
105 # mask invalid positions to -inf
106 mask_bias = tl.where(valid_mask, 0.0, float("-inf")) # (BLOCK_SUB,)
107 acc_s = acc_s + mask_bias[None, :] # broadcast: (H, BLOCK_SUB)
109 # -- online softmax update --
110 scores_max_prev = scores_max
111 block_max = tl.max(acc_s, axis=1) # (H,)
112 scores_max = tl.maximum(scores_max, block_max)
114 correction = tl.exp(scores_max_prev - scores_max) # (H,)
115 p = tl.exp(acc_s - scores_max[:, None]) # (H, BLOCK_SUB)
117 # -- accumulate output: acc_o = acc_o * correction + P @ KV --
118 acc_o = acc_o * correction[:, None]
119 acc_o += tl.dot(
120 p.to(tl.bfloat16), kv_block
121 ) # (H, BLOCK_SUB) @ (BLOCK_SUB, D) = (H, D)
123 scores_sum = tl.sum(p, axis=1) # (H,)
124 sum_exp = sum_exp * correction + scores_sum
126 # ---- incorporate attn_sink ----
127 sink_vals = tl.load(attn_sink + offs_h, mask=h_mask, other=0.0) # (H,)
128 sum_exp = sum_exp + tl.exp(sink_vals - scores_max)
130 # ---- normalize ----
131 acc_o = acc_o / sum_exp[:, None]
133 # ---- store output: (H, D) ----
134 o_base = O + pid_b * stride_ob + pid_m * stride_om
135 o_ptrs = o_base + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od
136 tl.store(o_ptrs, acc_o.to(tl.bfloat16), mask=h_mask[:, None])
139# ---------------------------------------------------------------------------
140# Python wrapper
141# ---------------------------------------------------------------------------
142def sparse_attn_triton(
143 q: torch.Tensor,
144 kv: torch.Tensor,
145 attn_sink: torch.Tensor,
146 topk_idxs: torch.Tensor,
147 softmax_scale: float,
148) -> torch.Tensor:
149 b, m, h, d = q.shape
150 topk = topk_idxs.shape[-1]
151 kv_len = kv.shape[1]
152 o = torch.empty_like(q)
154 # NPU optimization: use tiling to avoid UB overflow
155 # BLOCK: number of KV elements per outer loop iteration
156 # BLOCK_SUB: tile size for UB management
157 # UB (192KB) constraint: need to fit q_block + kv_block + acc_o + intermediate buffers
158 # Use fixed BLOCK to avoid edge cases with non-power-of-2 topk
159 BLOCK = 64
160 BLOCK_SUB = 16 # smaller chunks to fit UB (192KB), with multi-buffer overhead
162 # H must be >= 16 for tl.dot; pad to next power of 2
163 H_padded = max(16, triton.next_power_of_2(h))
165 # NPU: use 1D grid, TRITON_ALL_BLOCKS_PARALLEL handles large grid
166 grid = (b * m,)
168 sparse_attn_triton_kernel[grid](
169 q,
170 kv,
171 o,
172 attn_sink,
173 topk_idxs,
174 q.stride(0),
175 q.stride(1),
176 q.stride(2),
177 q.stride(3),
178 kv.stride(0),
179 kv.stride(1),
180 kv.stride(2),
181 o.stride(0),
182 o.stride(1),
183 o.stride(2),
184 o.stride(3),
185 topk_idxs.stride(0),
186 topk_idxs.stride(1),
187 topk_idxs.stride(2),
188 softmax_scale,
189 topk,
190 kv_len,
191 h,
192 BLOCK=BLOCK,
193 BLOCK_SUB=BLOCK_SUB,
194 D=d,
195 H=H_padded,
196 BATCH_STRIDE=m, # for 1D grid: pid = pid_b * m + pid_m
197 num_warps=4, # reduced for NPU
198 )
199 return o