Coverage for src/flag_gems/fused/topk_softplus_sqrt.py: 24%

88 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1# SPDX-License-Identifier: Apache-2.0 

2# SPDX-FileCopyrightText: Copyright contributors to the vLLM project 

3# 

4# Adapted from the vLLM project (https://github.com/vllm-project/vllm). 

5# Source: vllm/model_executor/layers/fused_moe/topk_softplus_sqrt_kernels.cu 

6# 

7# This Triton implementation is based on the CUDA kernel from vLLM 0.20.0. 

8# The kernel fuses softplus, sqrt, top-k selection, and optional renormalization 

9# for MoE gating in models like DeepSeek-V3/V4. 

10 

11"""TopK Softplus-Sqrt gating kernel in Triton. 

12 

13Optimized v27: num_warps=1 + all v19-v26 wins. 

14Key insight: For 256 experts, CUDA uses exactly 1 warp (32 threads) per row, 

15with each thread holding 8 elements. Using num_warps=4 adds warp scheduling 

16overhead without helping the 256-element reduction. Combining num_warps=1 

17(matching CUDA's single-warp-per-row) with tensor caching (v19), score-arithmetic 

18weight extraction (v20), and the max+compare index recovery (v26) should 

19minimize overhead. 

20 

21Eliminates the store-load-store pattern for renormalization by storing weights 

22during the loop and re-reading with scale at the end. 

23""" 

24 

25import logging 

26 

27import triton 

28import triton.language as tl 

29 

30logger = logging.getLogger(__name__) 

31 

32 

33@triton.jit 

34def _fused_topk_kernel( 

35 gating_ptr, 

36 topk_weights_ptr, 

37 topk_indices_ptr, 

38 token_expert_indices_ptr, 

39 e_score_correction_bias_ptr, 

40 num_tokens, 

41 num_experts: tl.constexpr, 

42 topk: tl.constexpr, 

43 renormalize: tl.constexpr, 

44 routed_scaling_factor, 

45 HAS_BIAS: tl.constexpr, 

46 BLOCK_E: tl.constexpr, 

47): 

48 pid = tl.program_id(0) 

49 if pid >= num_tokens: 

50 return 

51 

52 expert_offsets = tl.arange(0, BLOCK_E) 

53 emask = expert_offsets < num_experts 

54 

55 row_base = pid * num_experts 

56 x = tl.load(gating_ptr + row_base + expert_offsets, mask=emask, other=0.0).to( 

57 tl.float32 

58 ) 

59 

60 # Fused softplus + sqrt 

61 x = tl.where(x > 20.0, x, tl.log(1.0 + tl.exp(x))) 

62 raw = tl.sqrt(x) 

63 

64 # Scores for top-k selection (with optional bias) 

65 if HAS_BIAS: 

66 bias = tl.load( 

67 e_score_correction_bias_ptr + expert_offsets, mask=emask, other=0.0 

68 ).to(tl.float32) 

69 scores = raw + bias 

70 else: 

71 scores = raw 

72 scores = tl.where(emask, scores, -float("inf")) 

73 

74 out_base = pid * topk 

75 weight_sum = 0.0 

76 

77 for k_idx in tl.static_range(topk): 

78 max_score = tl.max(scores, axis=0) 

79 is_max = scores == max_score 

80 match_priority = tl.where(is_max, BLOCK_E - expert_offsets, 0) 

81 best_slot = BLOCK_E - tl.max(match_priority, axis=0) 

82 eidx = best_slot.to(tl.int32) 

83 

84 if HAS_BIAS: 

85 bias_at_eidx = tl.load(e_score_correction_bias_ptr + eidx) 

86 w = max_score - bias_at_eidx 

87 else: 

88 w = max_score 

89 

90 weight_sum += w 

91 tl.store(topk_weights_ptr + out_base + k_idx, w) 

92 tl.store(topk_indices_ptr + out_base + k_idx, eidx) 

93 tl.store( 

94 token_expert_indices_ptr + out_base + k_idx, 

95 (pid * topk + k_idx).to(tl.int32), 

96 ) 

97 

98 # Zero out winner 

99 scores = tl.where(expert_offsets == eidx, -float("inf"), scores) 

100 

101 # Renormalize: re-read weights and apply scale 

102 if renormalize: 

103 scale = routed_scaling_factor / tl.where(weight_sum > 0.0, weight_sum, 1.0) 

104 else: 

105 scale = routed_scaling_factor 

106 

107 for k_idx in tl.static_range(topk): 

108 w = tl.load(topk_weights_ptr + out_base + k_idx) 

109 tl.store(topk_weights_ptr + out_base + k_idx, w * scale) 

110 

111 

112@triton.jit 

113def _hash_kernel( 

114 gating_ptr, 

115 topk_weights_ptr, 

116 topk_indices_ptr, 

117 token_expert_indices_ptr, 

118 e_score_correction_bias_ptr, 

119 input_tokens_ptr, 

120 hash_indices_table_ptr, 

121 num_tokens, 

122 num_experts: tl.constexpr, 

123 topk: tl.constexpr, 

124 renormalize: tl.constexpr, 

125 routed_scaling_factor, 

126 HAS_BIAS: tl.constexpr, 

127 BLOCK_E: tl.constexpr, 

128 BLOCK_K: tl.constexpr, 

129): 

130 """Hash mode: expert indices come from lookup table.""" 

131 pid = tl.program_id(0) 

132 if pid >= num_tokens: 

133 return 

134 

135 expert_offsets = tl.arange(0, BLOCK_E) 

136 emask = expert_offsets < num_experts 

137 

138 row_base = pid * num_experts 

139 x = tl.load(gating_ptr + row_base + expert_offsets, mask=emask, other=0.0).to( 

140 tl.float32 

141 ) 

142 

143 # Fused softplus + sqrt 

144 x = tl.where(x > 20.0, x, tl.log(1.0 + tl.exp(x))) 

145 x = tl.sqrt(x) 

146 

147 # Get expert indices from lookup table 

148 token_id = tl.load(input_tokens_ptr + pid) 

149 k_offsets = tl.arange(0, BLOCK_K) 

150 kmask = k_offsets < topk 

151 expert_ids = tl.load( 

152 hash_indices_table_ptr + token_id * topk + k_offsets, mask=kmask, other=0 

153 ) 

154 

155 # Gather weights for each selected expert 

156 weight_sum = 0.0 

157 weights = tl.zeros([BLOCK_K], dtype=tl.float32) 

158 

159 for k_idx in tl.static_range(topk): 

160 eidx = tl.sum(tl.where(k_offsets == k_idx, expert_ids, 0)) 

161 w = tl.sum(tl.where(expert_offsets == eidx, x, 0.0)) 

162 weight_sum += w 

163 weights = tl.where(k_offsets == k_idx, w, weights) 

164 

165 # Apply renormalization + scaling 

166 if renormalize: 

167 scale = routed_scaling_factor / tl.where(weight_sum > 0.0, weight_sum, 1.0) 

168 else: 

169 scale = routed_scaling_factor 

170 weights = weights * scale 

171 

172 # Single burst store 

173 out_base = pid * topk 

174 tl.store(topk_weights_ptr + out_base + k_offsets, weights, mask=kmask) 

175 tl.store(topk_indices_ptr + out_base + k_offsets, expert_ids, mask=kmask) 

176 tei = (pid * topk + k_offsets).to(tl.int32) 

177 tl.store(token_expert_indices_ptr + out_base + k_offsets, tei, mask=kmask) 

178 

179 

180def topk_softplus_sqrt( 

181 topk_weights, 

182 topk_indices, 

183 token_expert_indices, 

184 gating_output, 

185 renormalize, 

186 routed_scaling_factor, 

187 correction_bias=None, 

188 input_ids=None, 

189 tid2eid=None, 

190): 

191 """Fused topk + softplus + sqrt kernel for MoE gating. 

192 

193 Interface aligned with vLLM CUDA operator: 

194 void topk_softplus_sqrt(Tensor& topk_weights, Tensor& topk_indices, 

195 Tensor& token_expert_indices, Tensor& gating_output, 

196 bool renormalize, double routed_scaling_factor, 

197 const c10::optional<Tensor>& correction_bias, 

198 const c10::optional<Tensor>& input_ids, 

199 const c10::optional<Tensor>& tid2eid); 

200 

201 Args: 

202 topk_weights: Output tensor [num_tokens, topk], dtype float32 

203 topk_indices: Output tensor [num_tokens, topk], dtype int32 

204 token_expert_indices: Output tensor [num_tokens, topk], dtype int32 

205 gating_output: Gating logits [num_tokens, num_experts] 

206 renormalize: Whether to renormalize weights 

207 routed_scaling_factor: Scaling factor for final weights 

208 correction_bias: Optional bias for expert scores [num_experts] 

209 input_ids: Token IDs for hash mode [num_tokens] 

210 tid2eid: Hash table mapping tokens to expert indices 

211 """ 

212 logger.debug("GEMS TOPK_SOFTPLUS_SQRT") 

213 num_tokens, num_experts = gating_output.shape 

214 topk = topk_weights.shape[1] 

215 

216 if num_tokens == 0: 

217 return 

218 

219 BLOCK_E = triton.next_power_of_2(num_experts) 

220 

221 if input_ids is not None and tid2eid is not None: 

222 BLOCK_K = triton.next_power_of_2(topk) 

223 grid = (num_tokens,) 

224 _hash_kernel[grid]( 

225 gating_output, 

226 topk_weights, 

227 topk_indices, 

228 token_expert_indices, 

229 correction_bias if correction_bias is not None else gating_output, 

230 input_ids, 

231 tid2eid, 

232 num_tokens=num_tokens, 

233 num_experts=num_experts, 

234 topk=topk, 

235 renormalize=renormalize, 

236 routed_scaling_factor=routed_scaling_factor, 

237 HAS_BIAS=correction_bias is not None, 

238 BLOCK_E=BLOCK_E, 

239 BLOCK_K=BLOCK_K, 

240 num_warps=1, 

241 num_stages=1, 

242 ) 

243 return 

244 

245 grid = (num_tokens,) 

246 _fused_topk_kernel[grid]( 

247 gating_output, 

248 topk_weights, 

249 topk_indices, 

250 token_expert_indices, 

251 correction_bias if correction_bias is not None else gating_output, 

252 num_tokens=num_tokens, 

253 num_experts=num_experts, 

254 topk=topk, 

255 renormalize=renormalize, 

256 routed_scaling_factor=routed_scaling_factor, 

257 HAS_BIAS=correction_bias is not None, 

258 BLOCK_E=BLOCK_E, 

259 num_warps=1, 

260 num_stages=1, 

261 )