Coverage for src/flag_gems/runtime/backend/_sunrise/fused/sparse_attention.py: 0%

62 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6# --------------------------------------------------------------------------- 

7# Triton kernel: sparse attention with attention-sink 

8# grid = (m, b) — one program per (seq_pos, batch), handles ALL heads 

9# Aligned with tilelang version: uses tl.dot (GEMM) instead of vector dot 

10# --------------------------------------------------------------------------- 

11@triton.jit 

12def sparse_attn_triton_kernel( 

13 Q, # (b, m, h, d) bf16 

14 KV, # (b, n, d) bf16 

15 O, # (b, m, h, d) bf16 

16 attn_sink, # (h,) fp32 

17 topk_idxs, # (b, m, topk) int32 

18 stride_qb, 

19 stride_qm, 

20 stride_qh, 

21 stride_qd, 

22 stride_kvb, 

23 stride_kvn, 

24 stride_kvd, 

25 stride_ob, 

26 stride_om, 

27 stride_oh, 

28 stride_od, 

29 stride_idxb, 

30 stride_idxm, 

31 stride_idxk, 

32 scale, 

33 topk, 

34 BLOCK: tl.constexpr, 

35 D: tl.constexpr, 

36 H: tl.constexpr, 

37): 

38 pid_m = tl.program_id(0) 

39 pid_b = tl.program_id(1) 

40 

41 # ---- load Q matrix: (H, D) — all heads at once ---- 

42 q_base = Q + pid_b * stride_qb + pid_m * stride_qm 

43 offs_h = tl.arange(0, H) 

44 offs_d = tl.arange(0, D) 

45 q_ptrs = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd 

46 q_block = tl.load(q_ptrs) # (H, D) bf16 

47 

48 # ---- base pointers ---- 

49 kv_base = KV + pid_b * stride_kvb 

50 idx_base = topk_idxs + pid_b * stride_idxb + pid_m * stride_idxm 

51 

52 # ---- online softmax state ---- 

53 acc_o = tl.zeros([H, D], dtype=tl.float32) 

54 scores_max = tl.full([H], float("-inf"), dtype=tl.float32) 

55 sum_exp = tl.zeros([H], dtype=tl.float32) 

56 

57 num_blocks = (topk + BLOCK - 1) // BLOCK 

58 offs_blk = tl.arange(0, BLOCK) 

59 

60 for t in range(num_blocks): 

61 # -- gather indices -- 

62 raw_offs = t * BLOCK + offs_blk # (BLOCK,) 

63 idx_mask = raw_offs < topk 

64 idxs = tl.load( 

65 idx_base + raw_offs * stride_idxk, mask=idx_mask, other=-1 

66 ) # (BLOCK,) 

67 valid_mask = idxs != -1 # (BLOCK,) 

68 

69 # -- gather KV block: (BLOCK, D) -- 

70 kv_ptrs = kv_base + idxs[:, None] * stride_kvn + offs_d[None, :] * stride_kvd 

71 kv_block = tl.load( 

72 kv_ptrs, mask=valid_mask[:, None], other=0.0 

73 ) # (BLOCK, D) bf16 

74 

75 # -- scores: Q @ KV^T -> (H, BLOCK) via GEMM -- 

76 acc_s = tl.dot(q_block, tl.trans(kv_block)) # (H, D) @ (D, BLOCK) = (H, BLOCK) 

77 acc_s = acc_s * scale 

78 # mask invalid positions to -inf 

79 mask_bias = tl.where(valid_mask, 0.0, float("-inf")) # (BLOCK,) 

80 acc_s = acc_s + mask_bias[None, :] # broadcast: (H, BLOCK) 

81 

82 # -- online softmax update -- 

83 scores_max_prev = scores_max 

84 block_max = tl.max(acc_s, axis=1) # (H,) 

85 scores_max = tl.maximum(scores_max, block_max) 

86 

87 correction = tl.exp(scores_max_prev - scores_max) # (H,) 

88 p = tl.exp(acc_s - scores_max[:, None]) # (H, BLOCK) 

89 

90 # -- accumulate output: acc_o = acc_o * correction + P @ KV -- 

91 acc_o = acc_o * correction[:, None] 

92 acc_o += tl.dot(p.to(tl.bfloat16), kv_block) # (H, BLOCK) @ (BLOCK, D) = (H, D) 

93 

94 scores_sum = tl.sum(p, axis=1) # (H,) 

95 sum_exp = sum_exp * correction + scores_sum 

96 

97 # ---- incorporate attn_sink ---- 

98 sink_vals = tl.load(attn_sink + offs_h) # (H,) 

99 sum_exp = sum_exp + tl.exp(sink_vals - scores_max) 

100 

101 # ---- normalize ---- 

102 acc_o = acc_o / sum_exp[:, None] 

103 

104 # ---- store output: (H, D) ---- 

105 o_base = O + pid_b * stride_ob + pid_m * stride_om 

106 o_ptrs = o_base + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od 

107 tl.store(o_ptrs, acc_o.to(tl.bfloat16)) 

108 

109 

110# --------------------------------------------------------------------------- 

111# Python wrapper 

112# --------------------------------------------------------------------------- 

113def sparse_attn_triton( 

114 q: torch.Tensor, 

115 kv: torch.Tensor, 

116 attn_sink: torch.Tensor, 

117 topk_idxs: torch.Tensor, 

118 softmax_scale: float, 

119) -> torch.Tensor: 

120 b, m, h, d = q.shape 

121 topk = topk_idxs.shape[-1] 

122 o = torch.empty_like(q) 

123 BLOCK = 64 

124 

125 grid = (m, b) 

126 if h < 8: 

127 q_new = torch.zeros((b, m, 8, d), dtype=q.dtype, device=q.device) 

128 q_new[:, :, :h] = q 

129 attn_sink_new = torch.zeros((8,), dtype=torch.float32, device=attn_sink.device) 

130 attn_sink_new[:h] = attn_sink 

131 o_new = torch.zeros((b, m, 8, d), dtype=q.dtype, device=q.device) 

132 sparse_attn_triton_kernel[grid]( 

133 q_new, 

134 kv, 

135 o_new, 

136 attn_sink_new, 

137 topk_idxs, 

138 q_new.stride(0), 

139 q_new.stride(1), 

140 q_new.stride(2), 

141 q_new.stride(3), 

142 kv.stride(0), 

143 kv.stride(1), 

144 kv.stride(2), 

145 o_new.stride(0), 

146 o_new.stride(1), 

147 o_new.stride(2), 

148 o_new.stride(3), 

149 topk_idxs.stride(0), 

150 topk_idxs.stride(1), 

151 topk_idxs.stride(2), 

152 softmax_scale, 

153 topk, 

154 BLOCK=BLOCK, 

155 D=d, 

156 H=8, 

157 num_warps=8, # 256 threads, matching tilelang 

158 ) 

159 o = o_new[:, :, :h].contiguous() 

160 return o 

161 

162 sparse_attn_triton_kernel[grid]( 

163 q, 

164 kv, 

165 o, 

166 attn_sink, 

167 topk_idxs, 

168 q.stride(0), 

169 q.stride(1), 

170 q.stride(2), 

171 q.stride(3), 

172 kv.stride(0), 

173 kv.stride(1), 

174 kv.stride(2), 

175 o.stride(0), 

176 o.stride(1), 

177 o.stride(2), 

178 o.stride(3), 

179 topk_idxs.stride(0), 

180 topk_idxs.stride(1), 

181 topk_idxs.stride(2), 

182 softmax_scale, 

183 topk, 

184 BLOCK=BLOCK, 

185 D=d, 

186 H=h, 

187 num_warps=8, # 256 threads, matching tilelang 

188 ) 

189 return o