Coverage for src/flag_gems/fused/sparse_attention.py: 11%

53 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6# --------------------------------------------------------------------------- 

7# Triton kernel: sparse attention with attention-sink 

8# grid = (m, b) — one program per (seq_pos, batch), handles ALL heads 

9# Aligned with tilelang version: uses tl.dot (GEMM) instead of vector dot 

10# --------------------------------------------------------------------------- 

11@triton.jit 

12def sparse_attn_triton_kernel( 

13 Q, # (b, m, h, d) bf16 

14 KV, # (b, n, d) bf16 

15 O, # (b, m, h, d) bf16 

16 attn_sink, # (h,) fp32 

17 topk_idxs, # (b, m, topk) int32 

18 stride_qb, 

19 stride_qm, 

20 stride_qh, 

21 stride_qd, 

22 stride_kvb, 

23 stride_kvn, 

24 stride_kvd, 

25 stride_ob, 

26 stride_om, 

27 stride_oh, 

28 stride_od, 

29 stride_idxb, 

30 stride_idxm, 

31 stride_idxk, 

32 scale, 

33 topk, 

34 BLOCK: tl.constexpr, 

35 D: tl.constexpr, 

36 H: tl.constexpr, 

37): 

38 pid_m = tl.program_id(0) 

39 pid_b = tl.program_id(1) 

40 

41 # ---- load Q matrix: (H, D) — all heads at once ---- 

42 q_base = Q + pid_b * stride_qb + pid_m * stride_qm 

43 offs_h = tl.arange(0, H) 

44 offs_d = tl.arange(0, D) 

45 q_ptrs = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd 

46 q_block = tl.load(q_ptrs) # (H, D) bf16 

47 

48 # ---- base pointers ---- 

49 kv_base = KV + pid_b * stride_kvb 

50 idx_base = topk_idxs + pid_b * stride_idxb + pid_m * stride_idxm 

51 

52 # ---- online softmax state ---- 

53 acc_o = tl.zeros([H, D], dtype=tl.float32) 

54 scores_max = tl.full([H], float("-inf"), dtype=tl.float32) 

55 sum_exp = tl.zeros([H], dtype=tl.float32) 

56 

57 num_blocks = (topk + BLOCK - 1) // BLOCK 

58 offs_blk = tl.arange(0, BLOCK) 

59 

60 for t in range(num_blocks): 

61 # -- gather indices -- 

62 raw_offs = t * BLOCK + offs_blk # (BLOCK,) 

63 idx_mask = raw_offs < topk 

64 idxs = tl.load( 

65 idx_base + raw_offs * stride_idxk, mask=idx_mask, other=-1 

66 ) # (BLOCK,) 

67 valid_mask = idxs != -1 # (BLOCK,) 

68 

69 # -- gather KV block: (BLOCK, D) -- 

70 kv_ptrs = kv_base + idxs[:, None] * stride_kvn + offs_d[None, :] * stride_kvd 

71 kv_block = tl.load( 

72 kv_ptrs, mask=valid_mask[:, None], other=0.0 

73 ) # (BLOCK, D) bf16 

74 

75 # -- scores: Q @ KV^T -> (H, BLOCK) via GEMM -- 

76 acc_s = tl.dot(q_block, tl.trans(kv_block)) # (H, D) @ (D, BLOCK) = (H, BLOCK) 

77 acc_s = acc_s * scale 

78 # mask invalid positions to -inf 

79 mask_bias = tl.where(valid_mask, 0.0, float("-inf")) # (BLOCK,) 

80 acc_s = acc_s + mask_bias[None, :] # broadcast: (H, BLOCK) 

81 

82 # -- online softmax update -- 

83 scores_max_prev = scores_max 

84 block_max = tl.max(acc_s, axis=1) # (H,) 

85 scores_max = tl.maximum(scores_max, block_max) 

86 

87 correction = tl.exp(scores_max_prev - scores_max) # (H,) 

88 p = tl.exp(acc_s - scores_max[:, None]) # (H, BLOCK) 

89 

90 # -- accumulate output: acc_o = acc_o * correction + P @ KV -- 

91 acc_o = acc_o * correction[:, None] 

92 acc_o += tl.dot(p.to(tl.bfloat16), kv_block) # (H, BLOCK) @ (BLOCK, D) = (H, D) 

93 

94 scores_sum = tl.sum(p, axis=1) # (H,) 

95 sum_exp = sum_exp * correction + scores_sum 

96 

97 # ---- incorporate attn_sink ---- 

98 sink_vals = tl.load(attn_sink + offs_h) # (H,) 

99 sum_exp = sum_exp + tl.exp(sink_vals - scores_max) 

100 

101 # ---- normalize ---- 

102 acc_o = acc_o / sum_exp[:, None] 

103 

104 # ---- store output: (H, D) ---- 

105 o_base = O + pid_b * stride_ob + pid_m * stride_om 

106 o_ptrs = o_base + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od 

107 tl.store(o_ptrs, acc_o.to(tl.bfloat16)) 

108 

109 

110# --------------------------------------------------------------------------- 

111# Python wrapper 

112# --------------------------------------------------------------------------- 

113def sparse_attn_triton( 

114 q: torch.Tensor, 

115 kv: torch.Tensor, 

116 attn_sink: torch.Tensor, 

117 topk_idxs: torch.Tensor, 

118 softmax_scale: float, 

119) -> torch.Tensor: 

120 b, m, h, d = q.shape 

121 topk = topk_idxs.shape[-1] 

122 o = torch.empty_like(q) 

123 BLOCK = 64 

124 

125 grid = (m, b) # each program handles ALL h heads 

126 sparse_attn_triton_kernel[grid]( 

127 q, 

128 kv, 

129 o, 

130 attn_sink, 

131 topk_idxs, 

132 q.stride(0), 

133 q.stride(1), 

134 q.stride(2), 

135 q.stride(3), 

136 kv.stride(0), 

137 kv.stride(1), 

138 kv.stride(2), 

139 o.stride(0), 

140 o.stride(1), 

141 o.stride(2), 

142 o.stride(3), 

143 topk_idxs.stride(0), 

144 topk_idxs.stride(1), 

145 topk_idxs.stride(2), 

146 softmax_scale, 

147 topk, 

148 BLOCK=BLOCK, 

149 D=d, 

150 H=h, 

151 num_warps=8, # 256 threads, matching tilelang 

152 ) 

153 return o