Coverage for src/flag_gems/fused/mhc/mhc_pre.py: 21%
346 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1"""
2Triton implementation of mHC Pre operator (optimized v2).
4Key optimizations:
5- GEMM: torch.mm in bf16 (cuBLAS tensor cores)
6- sqrsum + norm + mix + sinkhorn + weighted sum: single fused Triton kernel
7 Two passes over residual: pass 1 computes sqrsum, pass 2 does weighted sum
8"""
10import logging
11import weakref
13import torch
14import triton
15import triton.language as tl
17logger = logging.getLogger(__name__)
20_FN_BF16_CACHE: weakref.WeakKeyDictionary[
21 torch.Tensor, tuple[int, torch.Tensor]
22] = weakref.WeakKeyDictionary()
25def _get_fn_bf16_cached(fn: torch.Tensor) -> torch.Tensor:
26 if fn.requires_grad or torch.is_grad_enabled():
27 return fn.to(dtype=torch.bfloat16)
28 version = fn._version
29 cached = _FN_BF16_CACHE.get(fn)
30 if cached is not None:
31 cached_version, cached_bf16 = cached
32 if cached_version == version:
33 return cached_bf16
34 fn_bf16 = fn.to(dtype=torch.bfloat16)
35 _FN_BF16_CACHE[fn] = (version, fn_bf16)
36 return fn_bf16
39@triton.jit
40def _mhc_pre_fused_kernel_hc_mult_4_impl(
41 gemm_out_ptr, # (num_tokens, hc_mult3), float32
42 hc_scale_ptr, # (3,), float32
43 hc_base_ptr, # (hc_mult3,), float32
44 residual_ptr, # (num_tokens, hc_mult, hidden_size), bfloat16
45 post_mix_ptr, # (num_tokens, hc_mult), float32
46 comb_mix_ptr, # (num_tokens, hc_mult*hc_mult), float32
47 layer_input_ptr, # (num_tokens, hidden_size), bfloat16
48 num_tokens,
49 num_tokens_bucket,
50 res_stride_n,
51 res_stride_i,
52 res_stride_h,
53 li_stride_n,
54 li_stride_h,
55 hidden_size,
56 hc_hidden_size,
57 rms_eps: tl.constexpr,
58 hc_pre_eps: tl.constexpr,
59 hc_sinkhorn_eps: tl.constexpr,
60 hc_post_mult_value: tl.constexpr,
61 sinkhorn_repeat: tl.constexpr,
62 HC_MULT3: tl.constexpr,
63 BLOCK_H: tl.constexpr,
64):
65 """Fully fused: sqrsum + RMS norm + sigmoid + Sinkhorn + weighted sum. One token per program."""
66 pid_n = tl.program_id(0)
67 if pid_n >= num_tokens:
68 return
70 # ══ Pass 1: compute sqrsum over all 4 heads ══
71 sq = 0.0
72 res_base = pid_n * res_stride_n
73 for k in tl.static_range(4):
74 head_base = res_base + k * res_stride_i
75 for h_start in range(0, hidden_size, BLOCK_H):
76 h_offsets = h_start + tl.arange(0, BLOCK_H)
77 h_mask = h_offsets < hidden_size
78 v = tl.load(
79 residual_ptr + head_base + h_offsets * res_stride_h,
80 mask=h_mask,
81 other=0.0,
82 ).to(tl.float32)
83 sq += tl.sum(v * v)
85 rms_inv = tl.rsqrt(sq / hc_hidden_size + rms_eps)
87 # ══ Load scales ══
88 scale_0 = tl.load(hc_scale_ptr + 0)
89 scale_1 = tl.load(hc_scale_ptr + 1)
90 scale_2 = tl.load(hc_scale_ptr + 2)
92 go_base = pid_n * HC_MULT3
94 # ══ pre_mix: indices 0..3 ══
95 pre_mix_0 = (
96 tl.sigmoid(
97 tl.load(gemm_out_ptr + go_base + 0) * rms_inv * scale_0
98 + tl.load(hc_base_ptr + 0)
99 )
100 + hc_pre_eps
101 )
102 pre_mix_1 = (
103 tl.sigmoid(
104 tl.load(gemm_out_ptr + go_base + 1) * rms_inv * scale_0
105 + tl.load(hc_base_ptr + 1)
106 )
107 + hc_pre_eps
108 )
109 pre_mix_2 = (
110 tl.sigmoid(
111 tl.load(gemm_out_ptr + go_base + 2) * rms_inv * scale_0
112 + tl.load(hc_base_ptr + 2)
113 )
114 + hc_pre_eps
115 )
116 pre_mix_3 = (
117 tl.sigmoid(
118 tl.load(gemm_out_ptr + go_base + 3) * rms_inv * scale_0
119 + tl.load(hc_base_ptr + 3)
120 )
121 + hc_pre_eps
122 )
124 # ══ post_mix: indices 4..7 ══
125 post_0 = (
126 tl.sigmoid(
127 tl.load(gemm_out_ptr + go_base + 4) * rms_inv * scale_1
128 + tl.load(hc_base_ptr + 4)
129 )
130 * hc_post_mult_value
131 )
132 tl.store(post_mix_ptr + pid_n * 4 + 0, post_0)
133 post_1 = (
134 tl.sigmoid(
135 tl.load(gemm_out_ptr + go_base + 5) * rms_inv * scale_1
136 + tl.load(hc_base_ptr + 5)
137 )
138 * hc_post_mult_value
139 )
140 tl.store(post_mix_ptr + pid_n * 4 + 1, post_1)
141 post_2 = (
142 tl.sigmoid(
143 tl.load(gemm_out_ptr + go_base + 6) * rms_inv * scale_1
144 + tl.load(hc_base_ptr + 6)
145 )
146 * hc_post_mult_value
147 )
148 tl.store(post_mix_ptr + pid_n * 4 + 2, post_2)
149 post_3 = (
150 tl.sigmoid(
151 tl.load(gemm_out_ptr + go_base + 7) * rms_inv * scale_1
152 + tl.load(hc_base_ptr + 7)
153 )
154 * hc_post_mult_value
155 )
156 tl.store(post_mix_ptr + pid_n * 4 + 3, post_3)
158 # ══ comb_mix: indices 8..23 → 4x4 Sinkhorn ══
159 cb = 8
160 cm_00 = tl.load(gemm_out_ptr + go_base + cb + 0) * rms_inv * scale_2 + tl.load(
161 hc_base_ptr + cb + 0
162 )
163 cm_01 = tl.load(gemm_out_ptr + go_base + cb + 1) * rms_inv * scale_2 + tl.load(
164 hc_base_ptr + cb + 1
165 )
166 cm_02 = tl.load(gemm_out_ptr + go_base + cb + 2) * rms_inv * scale_2 + tl.load(
167 hc_base_ptr + cb + 2
168 )
169 cm_03 = tl.load(gemm_out_ptr + go_base + cb + 3) * rms_inv * scale_2 + tl.load(
170 hc_base_ptr + cb + 3
171 )
172 cm_10 = tl.load(gemm_out_ptr + go_base + cb + 4) * rms_inv * scale_2 + tl.load(
173 hc_base_ptr + cb + 4
174 )
175 cm_11 = tl.load(gemm_out_ptr + go_base + cb + 5) * rms_inv * scale_2 + tl.load(
176 hc_base_ptr + cb + 5
177 )
178 cm_12 = tl.load(gemm_out_ptr + go_base + cb + 6) * rms_inv * scale_2 + tl.load(
179 hc_base_ptr + cb + 6
180 )
181 cm_13 = tl.load(gemm_out_ptr + go_base + cb + 7) * rms_inv * scale_2 + tl.load(
182 hc_base_ptr + cb + 7
183 )
184 cm_20 = tl.load(gemm_out_ptr + go_base + cb + 8) * rms_inv * scale_2 + tl.load(
185 hc_base_ptr + cb + 8
186 )
187 cm_21 = tl.load(gemm_out_ptr + go_base + cb + 9) * rms_inv * scale_2 + tl.load(
188 hc_base_ptr + cb + 9
189 )
190 cm_22 = tl.load(gemm_out_ptr + go_base + cb + 10) * rms_inv * scale_2 + tl.load(
191 hc_base_ptr + cb + 10
192 )
193 cm_23 = tl.load(gemm_out_ptr + go_base + cb + 11) * rms_inv * scale_2 + tl.load(
194 hc_base_ptr + cb + 11
195 )
196 cm_30 = tl.load(gemm_out_ptr + go_base + cb + 12) * rms_inv * scale_2 + tl.load(
197 hc_base_ptr + cb + 12
198 )
199 cm_31 = tl.load(gemm_out_ptr + go_base + cb + 13) * rms_inv * scale_2 + tl.load(
200 hc_base_ptr + cb + 13
201 )
202 cm_32 = tl.load(gemm_out_ptr + go_base + cb + 14) * rms_inv * scale_2 + tl.load(
203 hc_base_ptr + cb + 14
204 )
205 cm_33 = tl.load(gemm_out_ptr + go_base + cb + 15) * rms_inv * scale_2 + tl.load(
206 hc_base_ptr + cb + 15
207 )
209 # ── Sinkhorn iteration ──
210 rm = tl.maximum(tl.maximum(cm_00, cm_01), tl.maximum(cm_02, cm_03))
211 cm_00 = tl.exp(cm_00 - rm)
212 cm_01 = tl.exp(cm_01 - rm)
213 cm_02 = tl.exp(cm_02 - rm)
214 cm_03 = tl.exp(cm_03 - rm)
215 rs = cm_00 + cm_01 + cm_02 + cm_03
216 inv_rs = 1.0 / rs
217 cm_00 = cm_00 * inv_rs + hc_sinkhorn_eps
218 cm_01 = cm_01 * inv_rs + hc_sinkhorn_eps
219 cm_02 = cm_02 * inv_rs + hc_sinkhorn_eps
220 cm_03 = cm_03 * inv_rs + hc_sinkhorn_eps
222 rm = tl.maximum(tl.maximum(cm_10, cm_11), tl.maximum(cm_12, cm_13))
223 cm_10 = tl.exp(cm_10 - rm)
224 cm_11 = tl.exp(cm_11 - rm)
225 cm_12 = tl.exp(cm_12 - rm)
226 cm_13 = tl.exp(cm_13 - rm)
227 rs = cm_10 + cm_11 + cm_12 + cm_13
228 inv_rs = 1.0 / rs
229 cm_10 = cm_10 * inv_rs + hc_sinkhorn_eps
230 cm_11 = cm_11 * inv_rs + hc_sinkhorn_eps
231 cm_12 = cm_12 * inv_rs + hc_sinkhorn_eps
232 cm_13 = cm_13 * inv_rs + hc_sinkhorn_eps
234 rm = tl.maximum(tl.maximum(cm_20, cm_21), tl.maximum(cm_22, cm_23))
235 cm_20 = tl.exp(cm_20 - rm)
236 cm_21 = tl.exp(cm_21 - rm)
237 cm_22 = tl.exp(cm_22 - rm)
238 cm_23 = tl.exp(cm_23 - rm)
239 rs = cm_20 + cm_21 + cm_22 + cm_23
240 inv_rs = 1.0 / rs
241 cm_20 = cm_20 * inv_rs + hc_sinkhorn_eps
242 cm_21 = cm_21 * inv_rs + hc_sinkhorn_eps
243 cm_22 = cm_22 * inv_rs + hc_sinkhorn_eps
244 cm_23 = cm_23 * inv_rs + hc_sinkhorn_eps
246 rm = tl.maximum(tl.maximum(cm_30, cm_31), tl.maximum(cm_32, cm_33))
247 cm_30 = tl.exp(cm_30 - rm)
248 cm_31 = tl.exp(cm_31 - rm)
249 cm_32 = tl.exp(cm_32 - rm)
250 cm_33 = tl.exp(cm_33 - rm)
251 rs = cm_30 + cm_31 + cm_32 + cm_33
252 inv_rs = 1.0 / rs
253 cm_30 = cm_30 * inv_rs + hc_sinkhorn_eps
254 cm_31 = cm_31 * inv_rs + hc_sinkhorn_eps
255 cm_32 = cm_32 * inv_rs + hc_sinkhorn_eps
256 cm_33 = cm_33 * inv_rs + hc_sinkhorn_eps
258 cs0 = cm_00 + cm_10 + cm_20 + cm_30
259 cs1 = cm_01 + cm_11 + cm_21 + cm_31
260 cs2 = cm_02 + cm_12 + cm_22 + cm_32
261 cs3 = cm_03 + cm_13 + cm_23 + cm_33
262 inv_cs0 = 1.0 / (cs0 + hc_sinkhorn_eps)
263 inv_cs1 = 1.0 / (cs1 + hc_sinkhorn_eps)
264 inv_cs2 = 1.0 / (cs2 + hc_sinkhorn_eps)
265 inv_cs3 = 1.0 / (cs3 + hc_sinkhorn_eps)
266 cm_00 *= inv_cs0
267 cm_10 *= inv_cs0
268 cm_20 *= inv_cs0
269 cm_30 *= inv_cs0
270 cm_01 *= inv_cs1
271 cm_11 *= inv_cs1
272 cm_21 *= inv_cs1
273 cm_31 *= inv_cs1
274 cm_02 *= inv_cs2
275 cm_12 *= inv_cs2
276 cm_22 *= inv_cs2
277 cm_32 *= inv_cs2
278 cm_03 *= inv_cs3
279 cm_13 *= inv_cs3
280 cm_23 *= inv_cs3
281 cm_33 *= inv_cs3
283 for _ in tl.static_range(sinkhorn_repeat - 1):
284 rs0 = cm_00 + cm_01 + cm_02 + cm_03
285 rs1 = cm_10 + cm_11 + cm_12 + cm_13
286 rs2 = cm_20 + cm_21 + cm_22 + cm_23
287 rs3 = cm_30 + cm_31 + cm_32 + cm_33
288 inv_rs0 = 1.0 / (rs0 + hc_sinkhorn_eps)
289 inv_rs1 = 1.0 / (rs1 + hc_sinkhorn_eps)
290 inv_rs2 = 1.0 / (rs2 + hc_sinkhorn_eps)
291 inv_rs3 = 1.0 / (rs3 + hc_sinkhorn_eps)
292 cm_00 *= inv_rs0
293 cm_01 *= inv_rs0
294 cm_02 *= inv_rs0
295 cm_03 *= inv_rs0
296 cm_10 *= inv_rs1
297 cm_11 *= inv_rs1
298 cm_12 *= inv_rs1
299 cm_13 *= inv_rs1
300 cm_20 *= inv_rs2
301 cm_21 *= inv_rs2
302 cm_22 *= inv_rs2
303 cm_23 *= inv_rs2
304 cm_30 *= inv_rs3
305 cm_31 *= inv_rs3
306 cm_32 *= inv_rs3
307 cm_33 *= inv_rs3
308 cs0 = cm_00 + cm_10 + cm_20 + cm_30
309 cs1 = cm_01 + cm_11 + cm_21 + cm_31
310 cs2 = cm_02 + cm_12 + cm_22 + cm_32
311 cs3 = cm_03 + cm_13 + cm_23 + cm_33
312 inv_cs0 = 1.0 / (cs0 + hc_sinkhorn_eps)
313 inv_cs1 = 1.0 / (cs1 + hc_sinkhorn_eps)
314 inv_cs2 = 1.0 / (cs2 + hc_sinkhorn_eps)
315 inv_cs3 = 1.0 / (cs3 + hc_sinkhorn_eps)
316 cm_00 *= inv_cs0
317 cm_01 *= inv_cs1
318 cm_02 *= inv_cs2
319 cm_03 *= inv_cs3
320 cm_10 *= inv_cs0
321 cm_11 *= inv_cs1
322 cm_12 *= inv_cs2
323 cm_13 *= inv_cs3
324 cm_20 *= inv_cs0
325 cm_21 *= inv_cs1
326 cm_22 *= inv_cs2
327 cm_23 *= inv_cs3
328 cm_30 *= inv_cs0
329 cm_31 *= inv_cs1
330 cm_32 *= inv_cs2
331 cm_33 *= inv_cs3
333 co = pid_n * 16
334 tl.store(comb_mix_ptr + co + 0, cm_00)
335 tl.store(comb_mix_ptr + co + 1, cm_01)
336 tl.store(comb_mix_ptr + co + 2, cm_02)
337 tl.store(comb_mix_ptr + co + 3, cm_03)
338 tl.store(comb_mix_ptr + co + 4, cm_10)
339 tl.store(comb_mix_ptr + co + 5, cm_11)
340 tl.store(comb_mix_ptr + co + 6, cm_12)
341 tl.store(comb_mix_ptr + co + 7, cm_13)
342 tl.store(comb_mix_ptr + co + 8, cm_20)
343 tl.store(comb_mix_ptr + co + 9, cm_21)
344 tl.store(comb_mix_ptr + co + 10, cm_22)
345 tl.store(comb_mix_ptr + co + 11, cm_23)
346 tl.store(comb_mix_ptr + co + 12, cm_30)
347 tl.store(comb_mix_ptr + co + 13, cm_31)
348 tl.store(comb_mix_ptr + co + 14, cm_32)
349 tl.store(comb_mix_ptr + co + 15, cm_33)
351 # ══ Pass 2: weighted sum layer_input = sum_k(pre_mix_k * residual[n, k, :]) ══
352 for h_start in range(0, hidden_size, BLOCK_H):
353 h_offsets = h_start + tl.arange(0, BLOCK_H)
354 h_mask = h_offsets < hidden_size
355 r0 = tl.load(
356 residual_ptr + res_base + 0 * res_stride_i + h_offsets * res_stride_h,
357 mask=h_mask,
358 other=0.0,
359 ).to(tl.float32)
360 r1 = tl.load(
361 residual_ptr + res_base + 1 * res_stride_i + h_offsets * res_stride_h,
362 mask=h_mask,
363 other=0.0,
364 ).to(tl.float32)
365 acc = pre_mix_0 * r0 + pre_mix_1 * r1
366 r2 = tl.load(
367 residual_ptr + res_base + 2 * res_stride_i + h_offsets * res_stride_h,
368 mask=h_mask,
369 other=0.0,
370 ).to(tl.float32)
371 r3 = tl.load(
372 residual_ptr + res_base + 3 * res_stride_i + h_offsets * res_stride_h,
373 mask=h_mask,
374 other=0.0,
375 ).to(tl.float32)
376 acc += pre_mix_2 * r2 + pre_mix_3 * r3
377 tl.store(
378 layer_input_ptr + pid_n * li_stride_n + h_offsets * li_stride_h,
379 acc.to(tl.bfloat16),
380 mask=h_mask,
381 )
384@triton.autotune(
385 configs=[
386 triton.Config({"BLOCK_H": 256}, num_warps=4, num_stages=1),
387 triton.Config({"BLOCK_H": 256}, num_warps=8, num_stages=1),
388 triton.Config({"BLOCK_H": 512}, num_warps=4, num_stages=1),
389 triton.Config({"BLOCK_H": 512}, num_warps=8, num_stages=1),
390 triton.Config({"BLOCK_H": 1024}, num_warps=4, num_stages=1),
391 triton.Config({"BLOCK_H": 1024}, num_warps=8, num_stages=1),
392 triton.Config({"BLOCK_H": 1024}, num_warps=16, num_stages=1),
393 triton.Config({"BLOCK_H": 256}, num_warps=4, num_stages=2),
394 triton.Config({"BLOCK_H": 512}, num_warps=4, num_stages=2),
395 triton.Config({"BLOCK_H": 512}, num_warps=8, num_stages=2),
396 ],
397 key=["hidden_size", "num_tokens_bucket"],
398)
399@triton.jit
400def mhc_pre_fused_kernel_hc_mult_4(
401 gemm_out_ptr, # (num_tokens, hc_mult3), float32
402 hc_scale_ptr, # (3,), float32
403 hc_base_ptr, # (hc_mult3,), float32
404 residual_ptr, # (num_tokens, hc_mult, hidden_size), bfloat16
405 post_mix_ptr, # (num_tokens, hc_mult), float32
406 comb_mix_ptr, # (num_tokens, hc_mult*hc_mult), float32
407 layer_input_ptr, # (num_tokens, hidden_size), bfloat16
408 num_tokens,
409 num_tokens_bucket,
410 res_stride_n,
411 res_stride_i,
412 res_stride_h,
413 li_stride_n,
414 li_stride_h,
415 hidden_size,
416 hc_hidden_size,
417 rms_eps: tl.constexpr,
418 hc_pre_eps: tl.constexpr,
419 hc_sinkhorn_eps: tl.constexpr,
420 hc_post_mult_value: tl.constexpr,
421 sinkhorn_repeat: tl.constexpr,
422 HC_MULT3: tl.constexpr,
423 BLOCK_H: tl.constexpr,
424):
425 _mhc_pre_fused_kernel_hc_mult_4_impl(
426 gemm_out_ptr,
427 hc_scale_ptr,
428 hc_base_ptr,
429 residual_ptr,
430 post_mix_ptr,
431 comb_mix_ptr,
432 layer_input_ptr,
433 num_tokens,
434 num_tokens_bucket,
435 res_stride_n,
436 res_stride_i,
437 res_stride_h,
438 li_stride_n,
439 li_stride_h,
440 hidden_size,
441 hc_hidden_size,
442 rms_eps,
443 hc_pre_eps,
444 hc_sinkhorn_eps,
445 hc_post_mult_value,
446 sinkhorn_repeat,
447 HC_MULT3,
448 BLOCK_H,
449 )
452@triton.autotune(
453 configs=[
454 triton.Config({"BLOCK_H": 256}, num_warps=4, num_stages=1),
455 triton.Config({"BLOCK_H": 256}, num_warps=8, num_stages=1),
456 triton.Config({"BLOCK_H": 512}, num_warps=4, num_stages=1),
457 triton.Config({"BLOCK_H": 512}, num_warps=8, num_stages=1),
458 triton.Config({"BLOCK_H": 1024}, num_warps=4, num_stages=1),
459 triton.Config({"BLOCK_H": 1024}, num_warps=8, num_stages=1),
460 ],
461 key=["hidden_size", "num_tokens_bucket", "HC"],
462)
463@triton.jit
464def mhc_pre_generic_kernel(
465 gemm_out_ptr, # (num_tokens, hc_mult3), float32
466 hc_scale_ptr, # (3,), float32
467 hc_base_ptr, # (hc_mult3,), float32
468 residual_ptr, # (num_tokens, HC, hidden_size), bfloat16
469 post_mix_ptr, # (num_tokens, HC), float32
470 comb_mix_ptr, # (num_tokens, HC*HC), float32
471 layer_input_ptr, # (num_tokens, hidden_size), bfloat16
472 num_tokens,
473 num_tokens_bucket,
474 res_stride_n,
475 res_stride_i,
476 res_stride_h,
477 li_stride_n,
478 li_stride_h,
479 hidden_size,
480 hc_hidden_size,
481 rms_eps: tl.constexpr,
482 hc_pre_eps: tl.constexpr,
483 hc_sinkhorn_eps: tl.constexpr,
484 hc_post_mult_value: tl.constexpr,
485 sinkhorn_repeat: tl.constexpr,
486 HC: tl.constexpr,
487 BLOCK_H: tl.constexpr,
488):
489 pid_n = tl.program_id(0)
490 if pid_n >= num_tokens:
491 return
493 res_base = pid_n * res_stride_n
494 go_base = pid_n * (HC * 2 + HC * HC)
495 comb_base = pid_n * (HC * HC)
497 sq = 0.0
498 for k in tl.static_range(HC):
499 head_base = res_base + k * res_stride_i
500 for h_start in range(0, hidden_size, BLOCK_H):
501 h_offsets = h_start + tl.arange(0, BLOCK_H)
502 h_mask = h_offsets < hidden_size
503 v = tl.load(
504 residual_ptr + head_base + h_offsets * res_stride_h,
505 mask=h_mask,
506 other=0.0,
507 ).to(tl.float32)
508 sq += tl.sum(v * v)
510 rms_inv = tl.rsqrt(sq / hc_hidden_size + rms_eps)
512 scale_0 = tl.load(hc_scale_ptr + 0)
513 scale_1 = tl.load(hc_scale_ptr + 1)
514 scale_2 = tl.load(hc_scale_ptr + 2)
516 for i in tl.static_range(HC):
517 post_i = (
518 tl.sigmoid(
519 tl.load(gemm_out_ptr + go_base + HC + i) * rms_inv * scale_1
520 + tl.load(hc_base_ptr + HC + i)
521 )
522 * hc_post_mult_value
523 )
524 tl.store(post_mix_ptr + pid_n * HC + i, post_i)
526 cb = 2 * HC
527 for i in tl.static_range(HC):
528 for j in tl.static_range(HC):
529 idx = i * HC + j
530 v = tl.load(
531 gemm_out_ptr + go_base + cb + idx
532 ) * rms_inv * scale_2 + tl.load(hc_base_ptr + cb + idx)
533 tl.store(comb_mix_ptr + comb_base + idx, v)
535 for i in tl.static_range(HC):
536 row_max = tl.load(comb_mix_ptr + comb_base + i * HC + 0)
537 for j in tl.static_range(1, HC):
538 row_max = tl.maximum(
539 row_max, tl.load(comb_mix_ptr + comb_base + i * HC + j)
540 )
542 row_sum = 0.0
543 for j in tl.static_range(HC):
544 e = tl.exp(tl.load(comb_mix_ptr + comb_base + i * HC + j) - row_max)
545 tl.store(comb_mix_ptr + comb_base + i * HC + j, e)
546 row_sum += e
548 inv_row_sum = 1.0 / row_sum
549 for j in tl.static_range(HC):
550 v = tl.load(comb_mix_ptr + comb_base + i * HC + j)
551 tl.store(
552 comb_mix_ptr + comb_base + i * HC + j, v * inv_row_sum + hc_sinkhorn_eps
553 )
555 for j in tl.static_range(HC):
556 col_sum = 0.0
557 for i in tl.static_range(HC):
558 col_sum += tl.load(comb_mix_ptr + comb_base + i * HC + j)
559 inv_col_sum = 1.0 / (col_sum + hc_sinkhorn_eps)
560 for i in tl.static_range(HC):
561 v = tl.load(comb_mix_ptr + comb_base + i * HC + j)
562 tl.store(comb_mix_ptr + comb_base + i * HC + j, v * inv_col_sum)
564 for _ in tl.static_range(sinkhorn_repeat - 1):
565 for i in tl.static_range(HC):
566 row_sum = 0.0
567 for j in tl.static_range(HC):
568 row_sum += tl.load(comb_mix_ptr + comb_base + i * HC + j)
569 inv_row_sum = 1.0 / (row_sum + hc_sinkhorn_eps)
570 for j in tl.static_range(HC):
571 v = tl.load(comb_mix_ptr + comb_base + i * HC + j)
572 tl.store(comb_mix_ptr + comb_base + i * HC + j, v * inv_row_sum)
574 for j in tl.static_range(HC):
575 col_sum = 0.0
576 for i in tl.static_range(HC):
577 col_sum += tl.load(comb_mix_ptr + comb_base + i * HC + j)
578 inv_col_sum = 1.0 / (col_sum + hc_sinkhorn_eps)
579 for i in tl.static_range(HC):
580 v = tl.load(comb_mix_ptr + comb_base + i * HC + j)
581 tl.store(comb_mix_ptr + comb_base + i * HC + j, v * inv_col_sum)
583 for h_start in range(0, hidden_size, BLOCK_H):
584 h_offsets = h_start + tl.arange(0, BLOCK_H)
585 h_mask = h_offsets < hidden_size
586 acc = tl.zeros([BLOCK_H], dtype=tl.float32)
588 for k in tl.static_range(HC):
589 pre_k = (
590 tl.sigmoid(
591 tl.load(gemm_out_ptr + go_base + k) * rms_inv * scale_0
592 + tl.load(hc_base_ptr + k)
593 )
594 + hc_pre_eps
595 )
596 rk = tl.load(
597 residual_ptr + res_base + k * res_stride_i + h_offsets * res_stride_h,
598 mask=h_mask,
599 other=0.0,
600 ).to(tl.float32)
601 acc += pre_k * rk
603 tl.store(
604 layer_input_ptr + pid_n * li_stride_n + h_offsets * li_stride_h,
605 acc.to(tl.bfloat16),
606 mask=h_mask,
607 )
610def mhc_pre(
611 residual: torch.Tensor,
612 fn: torch.Tensor,
613 hc_scale: torch.Tensor,
614 hc_base: torch.Tensor,
615 rms_eps: float,
616 hc_pre_eps: float,
617 hc_sinkhorn_eps: float,
618 hc_post_mult_value: float,
619 sinkhorn_repeat: int,
620 n_splits: int = 1,
621) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
622 """
623 Optimized mHC pre block.
625 - hc_mult == 4: specialized fused Triton kernel
626 - hc_mult != 4: generic Triton kernel aligned to reference math
627 """
628 assert residual.dtype == torch.bfloat16
629 assert fn.dtype == torch.float32
631 hc_mult = residual.shape[-2]
632 hidden_size = residual.shape[-1]
633 hc_mult3 = hc_mult * 2 + hc_mult * hc_mult
634 hc_hidden_size = hc_mult * hidden_size
636 assert fn.shape == (hc_mult3, hc_hidden_size)
638 outer_shape = residual.shape[:-2]
639 residual_flat = residual.reshape(-1, hc_mult, hidden_size).contiguous()
640 num_tokens = residual_flat.shape[0]
641 device = residual.device
642 if num_tokens <= 512:
643 num_tokens_bucket = 1
644 elif num_tokens <= 1024:
645 num_tokens_bucket = 2
646 elif num_tokens <= 2048:
647 num_tokens_bucket = 3
648 elif num_tokens <= 4096:
649 num_tokens_bucket = 4
650 else:
651 num_tokens_bucket = 5
653 # ── Step 1: GEMM via cuBLAS (bf16 tensor cores) ──
654 x_flat = residual_flat.reshape(num_tokens, hc_hidden_size)
655 fn_bf16 = _get_fn_bf16_cached(fn)
656 gemm_out = torch.mm(x_flat, fn_bf16.t()).float()
658 # ── Step 2: Fused sqrsum + norm + mix + sinkhorn + weighted sum ──
659 post_mix = torch.empty(num_tokens, hc_mult, dtype=torch.float32, device=device)
660 comb_mix = torch.empty(
661 num_tokens, hc_mult * hc_mult, dtype=torch.float32, device=device
662 )
663 layer_input = torch.empty(
664 num_tokens, hidden_size, dtype=torch.bfloat16, device=device
665 )
667 if hc_mult == 4:
668 mhc_pre_fused_kernel_hc_mult_4[(num_tokens,)](
669 gemm_out,
670 hc_scale,
671 hc_base,
672 residual_flat,
673 post_mix,
674 comb_mix,
675 layer_input,
676 num_tokens,
677 num_tokens_bucket,
678 residual_flat.stride(0),
679 residual_flat.stride(1),
680 residual_flat.stride(2),
681 layer_input.stride(0),
682 layer_input.stride(1),
683 hidden_size,
684 hc_hidden_size,
685 rms_eps=rms_eps,
686 hc_pre_eps=hc_pre_eps,
687 hc_sinkhorn_eps=hc_sinkhorn_eps,
688 hc_post_mult_value=hc_post_mult_value,
689 sinkhorn_repeat=sinkhorn_repeat,
690 HC_MULT3=hc_mult3,
691 )
692 else:
693 mhc_pre_generic_kernel[(num_tokens,)](
694 gemm_out,
695 hc_scale,
696 hc_base,
697 residual_flat,
698 post_mix,
699 comb_mix,
700 layer_input,
701 num_tokens,
702 num_tokens_bucket,
703 residual_flat.stride(0),
704 residual_flat.stride(1),
705 residual_flat.stride(2),
706 layer_input.stride(0),
707 layer_input.stride(1),
708 hidden_size,
709 hc_hidden_size,
710 rms_eps=rms_eps,
711 hc_pre_eps=hc_pre_eps,
712 hc_sinkhorn_eps=hc_sinkhorn_eps,
713 hc_post_mult_value=hc_post_mult_value,
714 sinkhorn_repeat=sinkhorn_repeat,
715 HC=hc_mult,
716 )
718 post_mix = post_mix.view(*outer_shape, hc_mult, 1)
719 comb_mix = comb_mix.view(*outer_shape, hc_mult, hc_mult)
720 layer_input = layer_input.view(*outer_shape, hidden_size)
722 return post_mix, comb_mix, layer_input
725# ───────────────────────── Reference implementations ─────────────────────────
728def sinkhorn_normalize_ref(x: torch.Tensor, repeat: int, eps: float) -> torch.Tensor:
729 x = x.softmax(-1) + eps
730 x = x / (x.sum(-2, keepdim=True) + eps)
731 for _ in range(repeat - 1):
732 x = x / (x.sum(-1, keepdim=True) + eps)
733 x = x / (x.sum(-2, keepdim=True) + eps)
734 return x
737def mhc_pre_ref(
738 residual: torch.Tensor,
739 fn: torch.Tensor,
740 hc_scale: torch.Tensor,
741 hc_base: torch.Tensor,
742 rms_eps: float,
743 hc_pre_eps: float,
744 hc_sinkhorn_eps: float,
745 hc_post_mult_value: float,
746 sinkhorn_repeat: int,
747) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
748 """PyTorch reference."""
749 hc_mult = residual.shape[-2]
750 residual_flat = residual.flatten(-2, -1).float()
751 sqrsum = residual_flat.square().sum(-1)
752 mixes = (
753 residual_flat @ fn.T * (sqrsum.unsqueeze(-1) / fn.shape[-1] + rms_eps).rsqrt()
754 )
755 hc_scale_expanded = torch.cat(
756 [
757 hc_scale[0].expand(hc_mult),
758 hc_scale[1].expand(hc_mult),
759 hc_scale[2].expand(hc_mult * hc_mult),
760 ]
761 )
762 mixes = mixes * hc_scale_expanded + hc_base
763 pre_mix = mixes[:, :hc_mult].sigmoid().unsqueeze(-1) + hc_pre_eps
764 post_mix = (
765 mixes[:, hc_mult : 2 * hc_mult].sigmoid() * hc_post_mult_value
766 ).unsqueeze(-1)
767 res_mix = mixes[:, 2 * hc_mult :].view(-1, hc_mult, hc_mult)
768 res_mix = sinkhorn_normalize_ref(
769 res_mix, repeat=sinkhorn_repeat, eps=hc_sinkhorn_eps
770 )
771 layer_input = (residual * pre_mix).sum(-2).bfloat16()
772 return post_mix, res_mix, layer_input