Coverage for src/flag_gems/fused/topk_softplus_sqrt.py: 24%
88 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1# SPDX-License-Identifier: Apache-2.0
2# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3#
4# Adapted from the vLLM project (https://github.com/vllm-project/vllm).
5# Source: vllm/model_executor/layers/fused_moe/topk_softplus_sqrt_kernels.cu
6#
7# This Triton implementation is based on the CUDA kernel from vLLM 0.20.0.
8# The kernel fuses softplus, sqrt, top-k selection, and optional renormalization
9# for MoE gating in models like DeepSeek-V3/V4.
11"""TopK Softplus-Sqrt gating kernel in Triton.
13Optimized v27: num_warps=1 + all v19-v26 wins.
14Key insight: For 256 experts, CUDA uses exactly 1 warp (32 threads) per row,
15with each thread holding 8 elements. Using num_warps=4 adds warp scheduling
16overhead without helping the 256-element reduction. Combining num_warps=1
17(matching CUDA's single-warp-per-row) with tensor caching (v19), score-arithmetic
18weight extraction (v20), and the max+compare index recovery (v26) should
19minimize overhead.
21Eliminates the store-load-store pattern for renormalization by storing weights
22during the loop and re-reading with scale at the end.
23"""
25import logging
27import triton
28import triton.language as tl
30logger = logging.getLogger(__name__)
33@triton.jit
34def _fused_topk_kernel(
35 gating_ptr,
36 topk_weights_ptr,
37 topk_indices_ptr,
38 token_expert_indices_ptr,
39 e_score_correction_bias_ptr,
40 num_tokens,
41 num_experts: tl.constexpr,
42 topk: tl.constexpr,
43 renormalize: tl.constexpr,
44 routed_scaling_factor,
45 HAS_BIAS: tl.constexpr,
46 BLOCK_E: tl.constexpr,
47):
48 pid = tl.program_id(0)
49 if pid >= num_tokens:
50 return
52 expert_offsets = tl.arange(0, BLOCK_E)
53 emask = expert_offsets < num_experts
55 row_base = pid * num_experts
56 x = tl.load(gating_ptr + row_base + expert_offsets, mask=emask, other=0.0).to(
57 tl.float32
58 )
60 # Fused softplus + sqrt
61 x = tl.where(x > 20.0, x, tl.log(1.0 + tl.exp(x)))
62 raw = tl.sqrt(x)
64 # Scores for top-k selection (with optional bias)
65 if HAS_BIAS:
66 bias = tl.load(
67 e_score_correction_bias_ptr + expert_offsets, mask=emask, other=0.0
68 ).to(tl.float32)
69 scores = raw + bias
70 else:
71 scores = raw
72 scores = tl.where(emask, scores, -float("inf"))
74 out_base = pid * topk
75 weight_sum = 0.0
77 for k_idx in tl.static_range(topk):
78 max_score = tl.max(scores, axis=0)
79 is_max = scores == max_score
80 match_priority = tl.where(is_max, BLOCK_E - expert_offsets, 0)
81 best_slot = BLOCK_E - tl.max(match_priority, axis=0)
82 eidx = best_slot.to(tl.int32)
84 if HAS_BIAS:
85 bias_at_eidx = tl.load(e_score_correction_bias_ptr + eidx)
86 w = max_score - bias_at_eidx
87 else:
88 w = max_score
90 weight_sum += w
91 tl.store(topk_weights_ptr + out_base + k_idx, w)
92 tl.store(topk_indices_ptr + out_base + k_idx, eidx)
93 tl.store(
94 token_expert_indices_ptr + out_base + k_idx,
95 (pid * topk + k_idx).to(tl.int32),
96 )
98 # Zero out winner
99 scores = tl.where(expert_offsets == eidx, -float("inf"), scores)
101 # Renormalize: re-read weights and apply scale
102 if renormalize:
103 scale = routed_scaling_factor / tl.where(weight_sum > 0.0, weight_sum, 1.0)
104 else:
105 scale = routed_scaling_factor
107 for k_idx in tl.static_range(topk):
108 w = tl.load(topk_weights_ptr + out_base + k_idx)
109 tl.store(topk_weights_ptr + out_base + k_idx, w * scale)
112@triton.jit
113def _hash_kernel(
114 gating_ptr,
115 topk_weights_ptr,
116 topk_indices_ptr,
117 token_expert_indices_ptr,
118 e_score_correction_bias_ptr,
119 input_tokens_ptr,
120 hash_indices_table_ptr,
121 num_tokens,
122 num_experts: tl.constexpr,
123 topk: tl.constexpr,
124 renormalize: tl.constexpr,
125 routed_scaling_factor,
126 HAS_BIAS: tl.constexpr,
127 BLOCK_E: tl.constexpr,
128 BLOCK_K: tl.constexpr,
129):
130 """Hash mode: expert indices come from lookup table."""
131 pid = tl.program_id(0)
132 if pid >= num_tokens:
133 return
135 expert_offsets = tl.arange(0, BLOCK_E)
136 emask = expert_offsets < num_experts
138 row_base = pid * num_experts
139 x = tl.load(gating_ptr + row_base + expert_offsets, mask=emask, other=0.0).to(
140 tl.float32
141 )
143 # Fused softplus + sqrt
144 x = tl.where(x > 20.0, x, tl.log(1.0 + tl.exp(x)))
145 x = tl.sqrt(x)
147 # Get expert indices from lookup table
148 token_id = tl.load(input_tokens_ptr + pid)
149 k_offsets = tl.arange(0, BLOCK_K)
150 kmask = k_offsets < topk
151 expert_ids = tl.load(
152 hash_indices_table_ptr + token_id * topk + k_offsets, mask=kmask, other=0
153 )
155 # Gather weights for each selected expert
156 weight_sum = 0.0
157 weights = tl.zeros([BLOCK_K], dtype=tl.float32)
159 for k_idx in tl.static_range(topk):
160 eidx = tl.sum(tl.where(k_offsets == k_idx, expert_ids, 0))
161 w = tl.sum(tl.where(expert_offsets == eidx, x, 0.0))
162 weight_sum += w
163 weights = tl.where(k_offsets == k_idx, w, weights)
165 # Apply renormalization + scaling
166 if renormalize:
167 scale = routed_scaling_factor / tl.where(weight_sum > 0.0, weight_sum, 1.0)
168 else:
169 scale = routed_scaling_factor
170 weights = weights * scale
172 # Single burst store
173 out_base = pid * topk
174 tl.store(topk_weights_ptr + out_base + k_offsets, weights, mask=kmask)
175 tl.store(topk_indices_ptr + out_base + k_offsets, expert_ids, mask=kmask)
176 tei = (pid * topk + k_offsets).to(tl.int32)
177 tl.store(token_expert_indices_ptr + out_base + k_offsets, tei, mask=kmask)
180def topk_softplus_sqrt(
181 topk_weights,
182 topk_indices,
183 token_expert_indices,
184 gating_output,
185 renormalize,
186 routed_scaling_factor,
187 correction_bias=None,
188 input_ids=None,
189 tid2eid=None,
190):
191 """Fused topk + softplus + sqrt kernel for MoE gating.
193 Interface aligned with vLLM CUDA operator:
194 void topk_softplus_sqrt(Tensor& topk_weights, Tensor& topk_indices,
195 Tensor& token_expert_indices, Tensor& gating_output,
196 bool renormalize, double routed_scaling_factor,
197 const c10::optional<Tensor>& correction_bias,
198 const c10::optional<Tensor>& input_ids,
199 const c10::optional<Tensor>& tid2eid);
201 Args:
202 topk_weights: Output tensor [num_tokens, topk], dtype float32
203 topk_indices: Output tensor [num_tokens, topk], dtype int32
204 token_expert_indices: Output tensor [num_tokens, topk], dtype int32
205 gating_output: Gating logits [num_tokens, num_experts]
206 renormalize: Whether to renormalize weights
207 routed_scaling_factor: Scaling factor for final weights
208 correction_bias: Optional bias for expert scores [num_experts]
209 input_ids: Token IDs for hash mode [num_tokens]
210 tid2eid: Hash table mapping tokens to expert indices
211 """
212 logger.debug("GEMS TOPK_SOFTPLUS_SQRT")
213 num_tokens, num_experts = gating_output.shape
214 topk = topk_weights.shape[1]
216 if num_tokens == 0:
217 return
219 BLOCK_E = triton.next_power_of_2(num_experts)
221 if input_ids is not None and tid2eid is not None:
222 BLOCK_K = triton.next_power_of_2(topk)
223 grid = (num_tokens,)
224 _hash_kernel[grid](
225 gating_output,
226 topk_weights,
227 topk_indices,
228 token_expert_indices,
229 correction_bias if correction_bias is not None else gating_output,
230 input_ids,
231 tid2eid,
232 num_tokens=num_tokens,
233 num_experts=num_experts,
234 topk=topk,
235 renormalize=renormalize,
236 routed_scaling_factor=routed_scaling_factor,
237 HAS_BIAS=correction_bias is not None,
238 BLOCK_E=BLOCK_E,
239 BLOCK_K=BLOCK_K,
240 num_warps=1,
241 num_stages=1,
242 )
243 return
245 grid = (num_tokens,)
246 _fused_topk_kernel[grid](
247 gating_output,
248 topk_weights,
249 topk_indices,
250 token_expert_indices,
251 correction_bias if correction_bias is not None else gating_output,
252 num_tokens=num_tokens,
253 num_experts=num_experts,
254 topk=topk,
255 renormalize=renormalize,
256 routed_scaling_factor=routed_scaling_factor,
257 HAS_BIAS=correction_bias is not None,
258 BLOCK_E=BLOCK_E,
259 num_warps=1,
260 num_stages=1,
261 )