Coverage for src/flag_gems/runtime/backend/_mthreads/fused/sparse_attention.py: 0%
68 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
2import os
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry, libtuner
12logger = logging.getLogger("flag_gems.runtime.backend._mthreads.ops.sparse_attention")
13EXPAND_CONFIG_FILENAME = os.path.normpath(
14 os.path.join(
15 os.path.dirname(__file__),
16 "..",
17 "sparse_attention_mthreads_expand.yaml",
18 )
19)
22def sparse_attention_get_configs():
23 return [
24 triton.Config({"BLOCK": 32}, num_stages=6, num_warps=16),
25 ]
28@libentry()
29@libtuner(
30 configs=runtime.ops_get_configs(
31 "sparse_attention", yaml_path=EXPAND_CONFIG_FILENAME
32 )
33 if os.environ.get("USE_FLAGTUNE") == "1"
34 else sparse_attention_get_configs(),
35 key=["topk", "H_ACTUAL", "D"],
36 strategy=runtime.get_expand_config(
37 "sparse_attention", yaml_path=EXPAND_CONFIG_FILENAME
38 )["strategy"]
39 if os.environ.get("USE_FLAGTUNE") == "1"
40 else ["align32", "align32", "align32"],
41 warmup=5,
42 rep=5,
43)
44@triton.jit
45def sparse_attn_triton_kernel(
46 Q, # (b, m, h, d) bf16
47 KV, # (b, n, d) bf16
48 O, # (b, m, h, d) bf16
49 attn_sink, # (h,) fp32
50 topk_idxs, # (b, m, topk) int32
51 stride_qb,
52 stride_qm,
53 stride_qh,
54 stride_qd,
55 stride_kvb,
56 stride_kvn,
57 stride_kvd,
58 stride_ob,
59 stride_om,
60 stride_oh,
61 stride_od,
62 stride_idxb,
63 stride_idxm,
64 stride_idxk,
65 scale,
66 topk,
67 H_ACTUAL,
68 BLOCK: tl.constexpr,
69 D: tl.constexpr,
70 H: tl.constexpr,
71):
72 pid_m = tl.program_id(0)
73 pid_b = tl.program_id(1)
75 q_base = Q + pid_b * stride_qb + pid_m * stride_qm
76 offs_h = tl.arange(0, H)
77 offs_d = tl.arange(0, D)
78 h_mask = offs_h < H_ACTUAL
79 q_ptrs = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd
80 q_block = tl.load(q_ptrs, mask=h_mask[:, None], other=0.0)
82 kv_base = KV + pid_b * stride_kvb
83 idx_base = topk_idxs + pid_b * stride_idxb + pid_m * stride_idxm
85 acc_o = tl.zeros([H, D], dtype=tl.float32)
86 scores_max = tl.full([H], float("-inf"), dtype=tl.float32)
87 sum_exp = tl.zeros([H], dtype=tl.float32)
89 num_blocks = (topk + BLOCK - 1) // BLOCK
90 offs_blk = tl.arange(0, BLOCK)
92 for t in range(num_blocks):
93 raw_offs = t * BLOCK + offs_blk
94 idx_mask = raw_offs < topk
95 idxs = tl.load(
96 idx_base + raw_offs * stride_idxk,
97 mask=idx_mask,
98 other=-1,
99 )
100 valid_mask = idxs != -1
102 kv_ptrs = kv_base + idxs[:, None] * stride_kvn + offs_d[None, :] * stride_kvd
103 kv_block = tl.load(kv_ptrs, mask=valid_mask[:, None], other=0.0)
105 acc_s = tl.dot(q_block, tl.trans(kv_block))
106 acc_s = acc_s * scale
107 mask_bias = tl.where(valid_mask, 0.0, float("-inf"))
108 acc_s = acc_s + mask_bias[None, :]
110 scores_max_prev = scores_max
111 block_max = tl.max(acc_s, axis=1)
112 scores_max = tl.maximum(scores_max, block_max)
114 correction = tl.exp(scores_max_prev - scores_max)
115 p = tl.exp(acc_s - scores_max[:, None])
117 acc_o = acc_o * correction[:, None]
118 acc_o += tl.dot(p.to(tl.bfloat16), kv_block)
120 scores_sum = tl.sum(p, axis=1)
121 sum_exp = sum_exp * correction + scores_sum
123 sink_vals = tl.load(attn_sink + offs_h, mask=h_mask, other=0.0)
124 sum_exp = sum_exp + tl.exp(sink_vals - scores_max)
126 acc_o = acc_o / sum_exp[:, None]
128 o_base = O + pid_b * stride_ob + pid_m * stride_om
129 o_ptrs = o_base + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od
130 tl.store(o_ptrs, acc_o.to(tl.bfloat16), mask=h_mask[:, None])
133def sparse_attn_triton(
134 q: torch.Tensor,
135 kv: torch.Tensor,
136 attn_sink: torch.Tensor,
137 topk_idxs: torch.Tensor,
138 softmax_scale: float,
139) -> torch.Tensor:
140 b, m, h, d = q.shape
141 _, n, _ = kv.shape
142 topk = topk_idxs.shape[-1]
143 o = torch.empty_like(q)
144 h_padded = max(32, triton.next_power_of_2(h))
145 logger.debug(
146 "GEMS_MTHREADS SPARSE_ATTENTION, [shape info]: [%s, %s, %s, %s, %s, %s](B, M, KV_LEN, TOPK, H, D)",
147 b,
148 m,
149 n,
150 topk,
151 h,
152 d,
153 )
154 grid = (m, b)
155 with torch_device_fn.device(q.device):
156 sparse_attn_triton_kernel[grid](
157 q,
158 kv,
159 o,
160 attn_sink,
161 topk_idxs,
162 q.stride(0),
163 q.stride(1),
164 q.stride(2),
165 q.stride(3),
166 kv.stride(0),
167 kv.stride(1),
168 kv.stride(2),
169 o.stride(0),
170 o.stride(1),
171 o.stride(2),
172 o.stride(3),
173 topk_idxs.stride(0),
174 topk_idxs.stride(1),
175 topk_idxs.stride(2),
176 softmax_scale,
177 topk,
178 h,
179 D=d,
180 H=h_padded,
181 )
182 return o