Coverage for src/flag_gems/runtime/backend/_mthreads/fused/sparse_attention.py: 0%
67 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
2import os
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry, libtuner
11logger = logging.getLogger("flag_gems.runtime.backend._mthreads.ops.sparse_attention")
12EXPAND_CONFIG_FILENAME = os.path.normpath(
13 os.path.join(
14 os.path.dirname(__file__),
15 "..",
16 "sparse_attention_mthreads_expand.yaml",
17 )
18)
21def sparse_attention_get_configs():
22 return [
23 triton.Config({"BLOCK": 32}, num_stages=6, num_warps=16),
24 ]
27@libentry()
28@libtuner(
29 configs=sparse_attention_get_configs(),
30 key=["topk", "H_ACTUAL", "D"],
31 strategy=["align32", "align32", "align32"],
32 warmup=5,
33 rep=5,
34)
35@triton.jit
36def sparse_attn_triton_kernel(
37 Q, # (b, m, h, d) bf16
38 KV, # (b, n, d) bf16
39 O, # (b, m, h, d) bf16
40 attn_sink, # (h,) fp32
41 topk_idxs, # (b, m, topk) int32
42 stride_qb,
43 stride_qm,
44 stride_qh,
45 stride_qd,
46 stride_kvb,
47 stride_kvn,
48 stride_kvd,
49 stride_ob,
50 stride_om,
51 stride_oh,
52 stride_od,
53 stride_idxb,
54 stride_idxm,
55 stride_idxk,
56 scale,
57 topk,
58 H_ACTUAL,
59 BLOCK: tl.constexpr,
60 D: tl.constexpr,
61 H: tl.constexpr,
62):
63 pid_m = tl.program_id(0)
64 pid_b = tl.program_id(1)
66 q_base = Q + pid_b * stride_qb + pid_m * stride_qm
67 offs_h = tl.arange(0, H)
68 offs_d = tl.arange(0, D)
69 h_mask = offs_h < H_ACTUAL
70 q_ptrs = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd
71 q_block = tl.load(q_ptrs, mask=h_mask[:, None], other=0.0)
73 kv_base = KV + pid_b * stride_kvb
74 idx_base = topk_idxs + pid_b * stride_idxb + pid_m * stride_idxm
76 acc_o = tl.zeros([H, D], dtype=tl.float32)
77 scores_max = tl.full([H], float("-inf"), dtype=tl.float32)
78 sum_exp = tl.zeros([H], dtype=tl.float32)
80 num_blocks = (topk + BLOCK - 1) // BLOCK
81 offs_blk = tl.arange(0, BLOCK)
83 for t in range(num_blocks):
84 raw_offs = t * BLOCK + offs_blk
85 idx_mask = raw_offs < topk
86 idxs = tl.load(
87 idx_base + raw_offs * stride_idxk,
88 mask=idx_mask,
89 other=-1,
90 )
91 valid_mask = idxs != -1
93 kv_ptrs = kv_base + idxs[:, None] * stride_kvn + offs_d[None, :] * stride_kvd
94 kv_block = tl.load(kv_ptrs, mask=valid_mask[:, None], other=0.0)
96 acc_s = tl.dot(q_block, tl.trans(kv_block))
97 acc_s = acc_s * scale
98 mask_bias = tl.where(valid_mask, 0.0, float("-inf"))
99 acc_s = acc_s + mask_bias[None, :]
101 scores_max_prev = scores_max
102 block_max = tl.max(acc_s, axis=1)
103 scores_max = tl.maximum(scores_max, block_max)
105 correction = tl.exp(scores_max_prev - scores_max)
106 p = tl.exp(acc_s - scores_max[:, None])
108 acc_o = acc_o * correction[:, None]
109 acc_o += tl.dot(p.to(tl.bfloat16), kv_block)
111 scores_sum = tl.sum(p, axis=1)
112 sum_exp = sum_exp * correction + scores_sum
114 sink_vals = tl.load(attn_sink + offs_h, mask=h_mask, other=0.0)
115 sum_exp = sum_exp + tl.exp(sink_vals - scores_max)
117 acc_o = acc_o / sum_exp[:, None]
119 o_base = O + pid_b * stride_ob + pid_m * stride_om
120 o_ptrs = o_base + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od
121 tl.store(o_ptrs, acc_o.to(tl.bfloat16), mask=h_mask[:, None])
124def sparse_attn_triton(
125 q: torch.Tensor,
126 kv: torch.Tensor,
127 attn_sink: torch.Tensor,
128 topk_idxs: torch.Tensor,
129 softmax_scale: float,
130) -> torch.Tensor:
131 b, m, h, d = q.shape
132 _, n, _ = kv.shape
133 topk = topk_idxs.shape[-1]
134 o = torch.empty_like(q)
135 h_padded = max(32, triton.next_power_of_2(h))
136 logger.debug(
137 "GEMS_MTHREADS SPARSE_ATTENTION, [shape info]: [%s, %s, %s, %s, %s, %s](B, M, KV_LEN, TOPK, H, D)",
138 b,
139 m,
140 n,
141 topk,
142 h,
143 d,
144 )
145 grid = (m, b)
146 with torch_device_fn.device(q.device):
147 sparse_attn_triton_kernel[grid](
148 q,
149 kv,
150 o,
151 attn_sink,
152 topk_idxs,
153 q.stride(0),
154 q.stride(1),
155 q.stride(2),
156 q.stride(3),
157 kv.stride(0),
158 kv.stride(1),
159 kv.stride(2),
160 o.stride(0),
161 o.stride(1),
162 o.stride(2),
163 o.stride(3),
164 topk_idxs.stride(0),
165 topk_idxs.stride(1),
166 topk_idxs.stride(2),
167 softmax_scale,
168 topk,
169 h,
170 D=d,
171 H=h_padded,
172 )
173 return o