Coverage for src/flag_gems/runtime/backend/_mthreads/fused/sparse_attention.py: 0%

67 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2import os 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry, libtuner 

10 

11logger = logging.getLogger("flag_gems.runtime.backend._mthreads.ops.sparse_attention") 

12EXPAND_CONFIG_FILENAME = os.path.normpath( 

13 os.path.join( 

14 os.path.dirname(__file__), 

15 "..", 

16 "sparse_attention_mthreads_expand.yaml", 

17 ) 

18) 

19 

20 

21def sparse_attention_get_configs(): 

22 return [ 

23 triton.Config({"BLOCK": 32}, num_stages=6, num_warps=16), 

24 ] 

25 

26 

27@libentry() 

28@libtuner( 

29 configs=sparse_attention_get_configs(), 

30 key=["topk", "H_ACTUAL", "D"], 

31 strategy=["align32", "align32", "align32"], 

32 warmup=5, 

33 rep=5, 

34) 

35@triton.jit 

36def sparse_attn_triton_kernel( 

37 Q, # (b, m, h, d) bf16 

38 KV, # (b, n, d) bf16 

39 O, # (b, m, h, d) bf16 

40 attn_sink, # (h,) fp32 

41 topk_idxs, # (b, m, topk) int32 

42 stride_qb, 

43 stride_qm, 

44 stride_qh, 

45 stride_qd, 

46 stride_kvb, 

47 stride_kvn, 

48 stride_kvd, 

49 stride_ob, 

50 stride_om, 

51 stride_oh, 

52 stride_od, 

53 stride_idxb, 

54 stride_idxm, 

55 stride_idxk, 

56 scale, 

57 topk, 

58 H_ACTUAL, 

59 BLOCK: tl.constexpr, 

60 D: tl.constexpr, 

61 H: tl.constexpr, 

62): 

63 pid_m = tl.program_id(0) 

64 pid_b = tl.program_id(1) 

65 

66 q_base = Q + pid_b * stride_qb + pid_m * stride_qm 

67 offs_h = tl.arange(0, H) 

68 offs_d = tl.arange(0, D) 

69 h_mask = offs_h < H_ACTUAL 

70 q_ptrs = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd 

71 q_block = tl.load(q_ptrs, mask=h_mask[:, None], other=0.0) 

72 

73 kv_base = KV + pid_b * stride_kvb 

74 idx_base = topk_idxs + pid_b * stride_idxb + pid_m * stride_idxm 

75 

76 acc_o = tl.zeros([H, D], dtype=tl.float32) 

77 scores_max = tl.full([H], float("-inf"), dtype=tl.float32) 

78 sum_exp = tl.zeros([H], dtype=tl.float32) 

79 

80 num_blocks = (topk + BLOCK - 1) // BLOCK 

81 offs_blk = tl.arange(0, BLOCK) 

82 

83 for t in range(num_blocks): 

84 raw_offs = t * BLOCK + offs_blk 

85 idx_mask = raw_offs < topk 

86 idxs = tl.load( 

87 idx_base + raw_offs * stride_idxk, 

88 mask=idx_mask, 

89 other=-1, 

90 ) 

91 valid_mask = idxs != -1 

92 

93 kv_ptrs = kv_base + idxs[:, None] * stride_kvn + offs_d[None, :] * stride_kvd 

94 kv_block = tl.load(kv_ptrs, mask=valid_mask[:, None], other=0.0) 

95 

96 acc_s = tl.dot(q_block, tl.trans(kv_block)) 

97 acc_s = acc_s * scale 

98 mask_bias = tl.where(valid_mask, 0.0, float("-inf")) 

99 acc_s = acc_s + mask_bias[None, :] 

100 

101 scores_max_prev = scores_max 

102 block_max = tl.max(acc_s, axis=1) 

103 scores_max = tl.maximum(scores_max, block_max) 

104 

105 correction = tl.exp(scores_max_prev - scores_max) 

106 p = tl.exp(acc_s - scores_max[:, None]) 

107 

108 acc_o = acc_o * correction[:, None] 

109 acc_o += tl.dot(p.to(tl.bfloat16), kv_block) 

110 

111 scores_sum = tl.sum(p, axis=1) 

112 sum_exp = sum_exp * correction + scores_sum 

113 

114 sink_vals = tl.load(attn_sink + offs_h, mask=h_mask, other=0.0) 

115 sum_exp = sum_exp + tl.exp(sink_vals - scores_max) 

116 

117 acc_o = acc_o / sum_exp[:, None] 

118 

119 o_base = O + pid_b * stride_ob + pid_m * stride_om 

120 o_ptrs = o_base + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od 

121 tl.store(o_ptrs, acc_o.to(tl.bfloat16), mask=h_mask[:, None]) 

122 

123 

124def sparse_attn_triton( 

125 q: torch.Tensor, 

126 kv: torch.Tensor, 

127 attn_sink: torch.Tensor, 

128 topk_idxs: torch.Tensor, 

129 softmax_scale: float, 

130) -> torch.Tensor: 

131 b, m, h, d = q.shape 

132 _, n, _ = kv.shape 

133 topk = topk_idxs.shape[-1] 

134 o = torch.empty_like(q) 

135 h_padded = max(32, triton.next_power_of_2(h)) 

136 logger.debug( 

137 "GEMS_MTHREADS SPARSE_ATTENTION, [shape info]: [%s, %s, %s, %s, %s, %s](B, M, KV_LEN, TOPK, H, D)", 

138 b, 

139 m, 

140 n, 

141 topk, 

142 h, 

143 d, 

144 ) 

145 grid = (m, b) 

146 with torch_device_fn.device(q.device): 

147 sparse_attn_triton_kernel[grid]( 

148 q, 

149 kv, 

150 o, 

151 attn_sink, 

152 topk_idxs, 

153 q.stride(0), 

154 q.stride(1), 

155 q.stride(2), 

156 q.stride(3), 

157 kv.stride(0), 

158 kv.stride(1), 

159 kv.stride(2), 

160 o.stride(0), 

161 o.stride(1), 

162 o.stride(2), 

163 o.stride(3), 

164 topk_idxs.stride(0), 

165 topk_idxs.stride(1), 

166 topk_idxs.stride(2), 

167 softmax_scale, 

168 topk, 

169 h, 

170 D=d, 

171 H=h_padded, 

172 ) 

173 return o