Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/sparse_attention.py: 0%
42 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import torch
2import triton
3import triton.language as tl
6# ---------------------------------------------------------------------------
7# Triton kernel: sparse attention with attention-sink
8# grid = (m, b, h) — one program per (seq_pos, batch, head)
9# 昆仑芯适配版本
10# ---------------------------------------------------------------------------
11@triton.jit
12def sparse_attn_triton_kernel(
13 Q, # (b, m, h, d) bf16
14 KV, # (b, n, d) bf16
15 O, # (b, m, h, d) bf16
16 attn_sink, # (h,) fp32
17 topk_idxs, # (b, m, topk) int32
18 stride_qb,
19 stride_qm,
20 stride_qh,
21 stride_qd,
22 stride_kvb,
23 stride_kvn,
24 stride_kvd,
25 stride_ob,
26 stride_om,
27 stride_oh,
28 stride_od,
29 stride_idxb,
30 stride_idxm,
31 stride_idxk,
32 scale,
33 topk,
34 D: tl.constexpr,
35):
36 pid_m = tl.program_id(0)
37 pid_b = tl.program_id(1)
38 pid_h = tl.program_id(2)
40 # ---- load Q vector: (D,) for this head ----
41 q_base = Q + pid_b * stride_qb + pid_m * stride_qm + pid_h * stride_qh
42 offs_d = tl.arange(0, D)
43 q_vec = tl.load(q_base + offs_d * stride_qd) # (D,) bf16
45 # ---- base pointers ----
46 kv_base = KV + pid_b * stride_kvb
47 idx_base = topk_idxs + pid_b * stride_idxb + pid_m * stride_idxm
49 # ---- online softmax state ----
50 acc_o = tl.zeros([D], dtype=tl.float32)
51 score_max = float("-inf")
52 sum_exp = 0.0
54 # Process each topk element one by one
55 for k in range(topk):
56 # -- gather KV vector --
57 idx = tl.load(idx_base + k * stride_idxk) # scalar
58 # Handle negative indices (padding values like -1): clamp to 0
59 idx = tl.where(idx < 0, 0, idx)
61 # Load KV for this index: (D,)
62 kv_ptrs = kv_base + idx * stride_kvn + offs_d * stride_kvd
63 kv_vec = tl.load(kv_ptrs) # (D,)
65 # -- compute score using element-wise multiply then sum --
66 # This is equivalent to dot product for 1D vectors
67 score = tl.sum(q_vec * kv_vec)
69 score = score * scale
71 # -- online softmax update --
72 score_max_prev = score_max
73 score_max = tl.maximum(score_max, score)
75 correction = tl.exp(score_max_prev - score_max)
76 p = tl.exp(score - score_max)
78 # -- accumulate output: acc_o = acc_o * correction + p * kv_vec --
79 acc_o = acc_o * correction + p * kv_vec.to(tl.float32)
81 sum_exp = sum_exp * correction + p
83 # ---- incorporate attn_sink ----
84 sink_val = tl.load(attn_sink + pid_h) # scalar
85 sum_exp = sum_exp + tl.exp(sink_val - score_max)
87 # ---- normalize ----
88 acc_o = acc_o / sum_exp
90 # ---- store output: (D,) ----
91 o_base = O + pid_b * stride_ob + pid_m * stride_om + pid_h * stride_oh
92 o_ptrs = o_base + offs_d * stride_od
93 tl.store(o_ptrs, acc_o.to(tl.bfloat16))
96# ---------------------------------------------------------------------------
97# Python wrapper
98# ---------------------------------------------------------------------------
99def sparse_attn_triton(
100 q: torch.Tensor,
101 kv: torch.Tensor,
102 attn_sink: torch.Tensor,
103 topk_idxs: torch.Tensor,
104 softmax_scale: float,
105) -> torch.Tensor:
106 b, m, h, d = q.shape
107 topk = topk_idxs.shape[-1]
108 o = torch.empty_like(q)
110 grid = (m, b, h) # each program handles one (seq_pos, batch, head)
111 sparse_attn_triton_kernel[grid](
112 q,
113 kv,
114 o,
115 attn_sink,
116 topk_idxs,
117 q.stride(0),
118 q.stride(1),
119 q.stride(2),
120 q.stride(3),
121 kv.stride(0),
122 kv.stride(1),
123 kv.stride(2),
124 o.stride(0),
125 o.stride(1),
126 o.stride(2),
127 o.stride(3),
128 topk_idxs.stride(0),
129 topk_idxs.stride(1),
130 topk_idxs.stride(2),
131 softmax_scale,
132 topk,
133 D=d,
134 num_warps=2,
135 )
136 return o