Coverage for src/flag_gems/runtime/backend/_mthreads/fused/sparse_attention.py: 0%

68 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2import os 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry, libtuner 

11 

12logger = logging.getLogger("flag_gems.runtime.backend._mthreads.ops.sparse_attention") 

13EXPAND_CONFIG_FILENAME = os.path.normpath( 

14 os.path.join( 

15 os.path.dirname(__file__), 

16 "..", 

17 "sparse_attention_mthreads_expand.yaml", 

18 ) 

19) 

20 

21 

22def sparse_attention_get_configs(): 

23 return [ 

24 triton.Config({"BLOCK": 32}, num_stages=6, num_warps=16), 

25 ] 

26 

27 

28@libentry() 

29@libtuner( 

30 configs=runtime.ops_get_configs( 

31 "sparse_attention", yaml_path=EXPAND_CONFIG_FILENAME 

32 ) 

33 if os.environ.get("USE_FLAGTUNE") == "1" 

34 else sparse_attention_get_configs(), 

35 key=["topk", "H_ACTUAL", "D"], 

36 strategy=runtime.get_expand_config( 

37 "sparse_attention", yaml_path=EXPAND_CONFIG_FILENAME 

38 )["strategy"] 

39 if os.environ.get("USE_FLAGTUNE") == "1" 

40 else ["align32", "align32", "align32"], 

41 warmup=5, 

42 rep=5, 

43) 

44@triton.jit 

45def sparse_attn_triton_kernel( 

46 Q, # (b, m, h, d) bf16 

47 KV, # (b, n, d) bf16 

48 O, # (b, m, h, d) bf16 

49 attn_sink, # (h,) fp32 

50 topk_idxs, # (b, m, topk) int32 

51 stride_qb, 

52 stride_qm, 

53 stride_qh, 

54 stride_qd, 

55 stride_kvb, 

56 stride_kvn, 

57 stride_kvd, 

58 stride_ob, 

59 stride_om, 

60 stride_oh, 

61 stride_od, 

62 stride_idxb, 

63 stride_idxm, 

64 stride_idxk, 

65 scale, 

66 topk, 

67 H_ACTUAL, 

68 BLOCK: tl.constexpr, 

69 D: tl.constexpr, 

70 H: tl.constexpr, 

71): 

72 pid_m = tl.program_id(0) 

73 pid_b = tl.program_id(1) 

74 

75 q_base = Q + pid_b * stride_qb + pid_m * stride_qm 

76 offs_h = tl.arange(0, H) 

77 offs_d = tl.arange(0, D) 

78 h_mask = offs_h < H_ACTUAL 

79 q_ptrs = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd 

80 q_block = tl.load(q_ptrs, mask=h_mask[:, None], other=0.0) 

81 

82 kv_base = KV + pid_b * stride_kvb 

83 idx_base = topk_idxs + pid_b * stride_idxb + pid_m * stride_idxm 

84 

85 acc_o = tl.zeros([H, D], dtype=tl.float32) 

86 scores_max = tl.full([H], float("-inf"), dtype=tl.float32) 

87 sum_exp = tl.zeros([H], dtype=tl.float32) 

88 

89 num_blocks = (topk + BLOCK - 1) // BLOCK 

90 offs_blk = tl.arange(0, BLOCK) 

91 

92 for t in range(num_blocks): 

93 raw_offs = t * BLOCK + offs_blk 

94 idx_mask = raw_offs < topk 

95 idxs = tl.load( 

96 idx_base + raw_offs * stride_idxk, 

97 mask=idx_mask, 

98 other=-1, 

99 ) 

100 valid_mask = idxs != -1 

101 

102 kv_ptrs = kv_base + idxs[:, None] * stride_kvn + offs_d[None, :] * stride_kvd 

103 kv_block = tl.load(kv_ptrs, mask=valid_mask[:, None], other=0.0) 

104 

105 acc_s = tl.dot(q_block, tl.trans(kv_block)) 

106 acc_s = acc_s * scale 

107 mask_bias = tl.where(valid_mask, 0.0, float("-inf")) 

108 acc_s = acc_s + mask_bias[None, :] 

109 

110 scores_max_prev = scores_max 

111 block_max = tl.max(acc_s, axis=1) 

112 scores_max = tl.maximum(scores_max, block_max) 

113 

114 correction = tl.exp(scores_max_prev - scores_max) 

115 p = tl.exp(acc_s - scores_max[:, None]) 

116 

117 acc_o = acc_o * correction[:, None] 

118 acc_o += tl.dot(p.to(tl.bfloat16), kv_block) 

119 

120 scores_sum = tl.sum(p, axis=1) 

121 sum_exp = sum_exp * correction + scores_sum 

122 

123 sink_vals = tl.load(attn_sink + offs_h, mask=h_mask, other=0.0) 

124 sum_exp = sum_exp + tl.exp(sink_vals - scores_max) 

125 

126 acc_o = acc_o / sum_exp[:, None] 

127 

128 o_base = O + pid_b * stride_ob + pid_m * stride_om 

129 o_ptrs = o_base + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od 

130 tl.store(o_ptrs, acc_o.to(tl.bfloat16), mask=h_mask[:, None]) 

131 

132 

133def sparse_attn_triton( 

134 q: torch.Tensor, 

135 kv: torch.Tensor, 

136 attn_sink: torch.Tensor, 

137 topk_idxs: torch.Tensor, 

138 softmax_scale: float, 

139) -> torch.Tensor: 

140 b, m, h, d = q.shape 

141 _, n, _ = kv.shape 

142 topk = topk_idxs.shape[-1] 

143 o = torch.empty_like(q) 

144 h_padded = max(32, triton.next_power_of_2(h)) 

145 logger.debug( 

146 "GEMS_MTHREADS SPARSE_ATTENTION, [shape info]: [%s, %s, %s, %s, %s, %s](B, M, KV_LEN, TOPK, H, D)", 

147 b, 

148 m, 

149 n, 

150 topk, 

151 h, 

152 d, 

153 ) 

154 grid = (m, b) 

155 with torch_device_fn.device(q.device): 

156 sparse_attn_triton_kernel[grid]( 

157 q, 

158 kv, 

159 o, 

160 attn_sink, 

161 topk_idxs, 

162 q.stride(0), 

163 q.stride(1), 

164 q.stride(2), 

165 q.stride(3), 

166 kv.stride(0), 

167 kv.stride(1), 

168 kv.stride(2), 

169 o.stride(0), 

170 o.stride(1), 

171 o.stride(2), 

172 o.stride(3), 

173 topk_idxs.stride(0), 

174 topk_idxs.stride(1), 

175 topk_idxs.stride(2), 

176 softmax_scale, 

177 topk, 

178 h, 

179 D=d, 

180 H=h_padded, 

181 ) 

182 return o