Coverage for src/flag_gems/runtime/backend/_hygon/fused/sparse_attention.py: 0%

55 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6# --------------------------------------------------------------------------- 

7# Triton kernel: sparse attention with attention-sink 

8# grid = (m, b) — one program per (seq_pos, batch), handles ALL heads 

9# Aligned with tilelang version: uses tl.dot (GEMM) instead of vector dot 

10# --------------------------------------------------------------------------- 

11@triton.jit 

12def sparse_attn_triton_kernel( 

13 Q, # (b, m, h, d) bf16 

14 KV, # (b, n, d) bf16 

15 O, # (b, m, h, d) bf16 

16 attn_sink, # (h,) fp32 

17 topk_idxs, # (b, m, topk) int32 

18 stride_qb, 

19 stride_qm, 

20 stride_qh, 

21 stride_qd, 

22 stride_kvb, 

23 stride_kvn, 

24 stride_kvd, 

25 stride_ob, 

26 stride_om, 

27 stride_oh, 

28 stride_od, 

29 stride_idxb, 

30 stride_idxm, 

31 stride_idxk, 

32 scale, 

33 topk, 

34 H_ACTUAL, 

35 BLOCK: tl.constexpr, 

36 D: tl.constexpr, 

37 H: tl.constexpr, 

38): 

39 pid_m = tl.program_id(0) 

40 pid_b = tl.program_id(1) 

41 

42 # ---- load Q matrix: (H, D) — all heads at once ---- 

43 q_base = Q + pid_b * stride_qb + pid_m * stride_qm 

44 offs_h = tl.arange(0, H) 

45 offs_d = tl.arange(0, D) 

46 h_mask = offs_h < H_ACTUAL 

47 q_ptrs = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd 

48 q_block = tl.load(q_ptrs, mask=h_mask[:, None], other=0.0) # (H, D) bf16 

49 

50 # ---- base pointers ---- 

51 kv_base = KV + pid_b * stride_kvb 

52 idx_base = topk_idxs + pid_b * stride_idxb + pid_m * stride_idxm 

53 

54 # ---- online softmax state ---- 

55 acc_o = tl.zeros([H, D], dtype=tl.float32) 

56 scores_max = tl.full([H], float("-inf"), dtype=tl.float32) 

57 sum_exp = tl.zeros([H], dtype=tl.float32) 

58 

59 num_blocks = (topk + BLOCK - 1) // BLOCK 

60 offs_blk = tl.arange(0, BLOCK) 

61 

62 for t in range(num_blocks): 

63 # -- gather indices -- 

64 raw_offs = t * BLOCK + offs_blk # (BLOCK,) 

65 idx_mask = raw_offs < topk 

66 idxs = tl.load( 

67 idx_base + raw_offs * stride_idxk, mask=idx_mask, other=-1 

68 ) # (BLOCK,) 

69 valid_mask = idxs != -1 # (BLOCK,) 

70 

71 # -- gather KV block: (BLOCK, D) -- 

72 kv_ptrs = kv_base + idxs[:, None] * stride_kvn + offs_d[None, :] * stride_kvd 

73 kv_block = tl.load( 

74 kv_ptrs, mask=valid_mask[:, None], other=0.0 

75 ) # (BLOCK, D) bf16 

76 

77 # -- scores: Q @ KV^T -> (H, BLOCK) via GEMM -- 

78 acc_s = tl.dot(q_block, tl.trans(kv_block)) # (H, D) @ (D, BLOCK) = (H, BLOCK) 

79 acc_s = acc_s * scale 

80 # mask invalid positions to -inf 

81 mask_bias = tl.where(valid_mask, 0.0, float("-inf")) # (BLOCK,) 

82 acc_s = acc_s + mask_bias[None, :] # broadcast: (H, BLOCK) 

83 

84 # -- online softmax update -- 

85 scores_max_prev = scores_max 

86 block_max = tl.max(acc_s, axis=1) # (H,) 

87 scores_max = tl.maximum(scores_max, block_max) 

88 

89 correction = tl.exp(scores_max_prev - scores_max) # (H,) 

90 p = tl.exp(acc_s - scores_max[:, None]) # (H, BLOCK) 

91 

92 # -- accumulate output: acc_o = acc_o * correction + P @ KV -- 

93 acc_o = acc_o * correction[:, None] 

94 acc_o += tl.dot(p.to(tl.bfloat16), kv_block) # (H, BLOCK) @ (BLOCK, D) = (H, D) 

95 

96 scores_sum = tl.sum(p, axis=1) # (H,) 

97 sum_exp = sum_exp * correction + scores_sum 

98 

99 # ---- incorporate attn_sink ---- 

100 sink_vals = tl.load(attn_sink + offs_h, mask=h_mask, other=0.0) # (H,) 

101 sum_exp = sum_exp + tl.exp(sink_vals - scores_max) 

102 

103 # ---- normalize ---- 

104 acc_o = acc_o / sum_exp[:, None] 

105 

106 # ---- store output: (H, D) ---- 

107 o_base = O + pid_b * stride_ob + pid_m * stride_om 

108 o_ptrs = o_base + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od 

109 tl.store(o_ptrs, acc_o.to(tl.bfloat16), mask=h_mask[:, None]) 

110 

111 

112# --------------------------------------------------------------------------- 

113# Python wrapper 

114# --------------------------------------------------------------------------- 

115def sparse_attn_triton( 

116 q: torch.Tensor, 

117 kv: torch.Tensor, 

118 attn_sink: torch.Tensor, 

119 topk_idxs: torch.Tensor, 

120 softmax_scale: float, 

121) -> torch.Tensor: 

122 b, m, h, d = q.shape 

123 topk = topk_idxs.shape[-1] 

124 o = torch.empty_like(q) 

125 BLOCK = 16 

126 

127 # H must be >= 16 for tl.dot; pad to next power of 2 

128 H_padded = max(16, triton.next_power_of_2(h)) 

129 

130 grid = (m, b) # each program handles ALL h heads 

131 sparse_attn_triton_kernel[grid]( 

132 q, 

133 kv, 

134 o, 

135 attn_sink, 

136 topk_idxs, 

137 q.stride(0), 

138 q.stride(1), 

139 q.stride(2), 

140 q.stride(3), 

141 kv.stride(0), 

142 kv.stride(1), 

143 kv.stride(2), 

144 o.stride(0), 

145 o.stride(1), 

146 o.stride(2), 

147 o.stride(3), 

148 topk_idxs.stride(0), 

149 topk_idxs.stride(1), 

150 topk_idxs.stride(2), 

151 softmax_scale, 

152 topk, 

153 h, 

154 BLOCK=BLOCK, 

155 D=d, 

156 H=H_padded, 

157 num_warps=8, # 256 threads, matching tilelang 

158 ) 

159 return o