Coverage for src/flag_gems/runtime/backend/_iluvatar/fused/sparse_attention.py: 0%

66 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6# --------------------------------------------------------------------------- 

7# Triton kernel: sparse attention with attention-sink 

8# grid = (m, b) — one program per (seq_pos, batch), handles ALL heads 

9# Aligned with tilelang version: uses tl.dot (GEMM) instead of vector dot 

10# 

11# Iluvatar-compatible: no tl.load mask/other, no tl.where 

12# --------------------------------------------------------------------------- 

13@triton.jit 

14def sparse_attn_triton_kernel( 

15 Q, # (b, m, h, d) bf16 

16 KV, # (b, n, d) bf16 

17 O, # (b, m, h, d) bf16 

18 attn_sink, # (h,) fp32 

19 topk_idxs, # (b, m, topk) int32 

20 stride_qb, 

21 stride_qm, 

22 stride_qh, 

23 stride_qd, 

24 stride_kvb, 

25 stride_kvn, 

26 stride_kvd, 

27 stride_ob, 

28 stride_om, 

29 stride_oh, 

30 stride_od, 

31 stride_idxb, 

32 stride_idxm, 

33 stride_idxk, 

34 scale, 

35 topk, 

36 H_ACTUAL, 

37 BLOCK: tl.constexpr, 

38 D: tl.constexpr, 

39 H: tl.constexpr, 

40): 

41 pid_m = tl.program_id(0) 

42 pid_b = tl.program_id(1) 

43 

44 # ---- load Q matrix: (H, D) — all heads at once ---- 

45 q_base = Q + pid_b * stride_qb + pid_m * stride_qm 

46 offs_h = tl.arange(0, H) 

47 offs_d = tl.arange(0, D) 

48 h_mask = offs_h < H_ACTUAL 

49 h_mask_f = h_mask.to(tl.float32) 

50 # Use offs_h directly, will load OOB for h >= H_ACTUAL but we mask later 

51 q_ptrs = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd 

52 q_block = tl.load(q_ptrs) # (H, D) bf16 

53 # zero padded heads via arithmetic (avoid tl.where) 

54 q_block = (q_block.to(tl.float32) * h_mask_f[:, None]).to(tl.bfloat16) 

55 

56 # ---- base pointers ---- 

57 kv_base = KV + pid_b * stride_kvb 

58 idx_base = topk_idxs + pid_b * stride_idxb + pid_m * stride_idxm 

59 

60 # ---- online softmax state ---- 

61 acc_o = tl.zeros([H, D], dtype=tl.float32) 

62 scores_max = tl.full([H], float("-inf"), dtype=tl.float32) 

63 sum_exp = tl.zeros([H], dtype=tl.float32) 

64 

65 num_blocks = (topk + BLOCK - 1) // BLOCK 

66 offs_blk = tl.arange(0, BLOCK) 

67 

68 for t in range(num_blocks): 

69 # -- gather indices (clamp to avoid OOB, mask via score bias) -- 

70 raw_offs = t * BLOCK + offs_blk # (BLOCK,) 

71 idx_mask = raw_offs < topk 

72 safe_raw_offs = tl.minimum(raw_offs, topk - 1) 

73 idxs = tl.load(idx_base + safe_raw_offs * stride_idxk) # (BLOCK,) 

74 

75 # -- gather KV block: (BLOCK, D) -- 

76 safe_idxs = tl.maximum(idxs, 0) 

77 kv_ptrs = ( 

78 kv_base + safe_idxs[:, None] * stride_kvn + offs_d[None, :] * stride_kvd 

79 ) 

80 kv_block = tl.load(kv_ptrs) # (BLOCK, D) bf16 

81 

82 # -- scores: Q @ KV^T -> (H, BLOCK) via GEMM -- 

83 acc_s = tl.dot(q_block, tl.trans(kv_block)) # (H, D) @ (D, BLOCK) = (H, BLOCK) 

84 acc_s = acc_s * scale 

85 # mask invalid positions to -large via arithmetic (avoid tl.where) 

86 mask_bias = ( 

87 idx_mask.to(tl.float32) - 1.0 

88 ) * 1e30 # 0 for valid, -1e30 for invalid 

89 acc_s = acc_s + mask_bias[None, :] # broadcast: (H, BLOCK) 

90 

91 # -- online softmax update -- 

92 scores_max_prev = scores_max 

93 block_max = tl.max(acc_s, axis=1) # (H,) 

94 scores_max = tl.maximum(scores_max, block_max) 

95 

96 correction = tl.exp(scores_max_prev - scores_max) # (H,) 

97 p = tl.exp(acc_s - scores_max[:, None]) # (H, BLOCK) 

98 

99 # -- accumulate output: acc_o = acc_o * correction + P @ KV -- 

100 acc_o = acc_o * correction[:, None] 

101 acc_o += tl.dot(p.to(tl.bfloat16), kv_block) # (H, BLOCK) @ (BLOCK, D) = (H, D) 

102 

103 scores_sum = tl.sum(p, axis=1) # (H,) 

104 sum_exp = sum_exp * correction + scores_sum 

105 

106 # ---- incorporate attn_sink ---- 

107 # attn_sink is now padded to H elements, safe to load with offs_h 

108 sink_vals = tl.load(attn_sink + offs_h) # (H,) 

109 # zero padded heads' sink via arithmetic 

110 sink_vals = sink_vals * h_mask_f 

111 sum_exp = sum_exp + tl.exp(sink_vals - scores_max) 

112 

113 # ---- normalize ---- 

114 acc_o = acc_o / sum_exp[:, None] 

115 

116 # ---- store output: (H, D) ---- 

117 o_base = O + pid_b * stride_ob + pid_m * stride_om 

118 o_ptrs = o_base + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od 

119 tl.store(o_ptrs, acc_o.to(tl.bfloat16), mask=h_mask[:, None]) 

120 

121 

122# --------------------------------------------------------------------------- 

123# Python wrapper 

124# --------------------------------------------------------------------------- 

125def sparse_attn_triton( 

126 q: torch.Tensor, 

127 kv: torch.Tensor, 

128 attn_sink: torch.Tensor, 

129 topk_idxs: torch.Tensor, 

130 softmax_scale: float, 

131) -> torch.Tensor: 

132 b, m, h, d = q.shape 

133 topk = topk_idxs.shape[-1] 

134 o = torch.empty_like(q) 

135 

136 # H must be >= 16 for tl.dot; pad to next power of 2 

137 H_padded = max(16, triton.next_power_of_2(h)) 

138 

139 # Pad attn_sink to H_padded elements for safe kernel indexing 

140 if attn_sink.shape[0] < H_padded: 

141 attn_sink_padded = torch.zeros( 

142 H_padded, dtype=attn_sink.dtype, device=attn_sink.device 

143 ) 

144 attn_sink_padded[: attn_sink.shape[0]] = attn_sink 

145 else: 

146 attn_sink_padded = attn_sink 

147 

148 # Reduce BLOCK for large D to stay within resource limits 

149 if d >= 256: 

150 BLOCK = 16 

151 else: 

152 BLOCK = 64 

153 

154 # Reduce warps for large D to lower register pressure 

155 num_warps = 2 if d >= 256 else 8 

156 

157 grid = (m, b) # each program handles ALL h heads 

158 sparse_attn_triton_kernel[grid]( 

159 q, 

160 kv, 

161 o, 

162 attn_sink_padded, 

163 topk_idxs, 

164 q.stride(0), 

165 q.stride(1), 

166 q.stride(2), 

167 q.stride(3), 

168 kv.stride(0), 

169 kv.stride(1), 

170 kv.stride(2), 

171 o.stride(0), 

172 o.stride(1), 

173 o.stride(2), 

174 o.stride(3), 

175 topk_idxs.stride(0), 

176 topk_idxs.stride(1), 

177 topk_idxs.stride(2), 

178 softmax_scale, 

179 topk, 

180 h, 

181 BLOCK=BLOCK, 

182 D=d, 

183 H=H_padded, 

184 num_warps=num_warps, 

185 num_stages=1, 

186 ) 

187 return o