Coverage for src/flag_gems/fused/mhc/mhc_pre.py: 21%

346 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1""" 

2Triton implementation of mHC Pre operator (optimized v2). 

3 

4Key optimizations: 

5- GEMM: torch.mm in bf16 (cuBLAS tensor cores) 

6- sqrsum + norm + mix + sinkhorn + weighted sum: single fused Triton kernel 

7 Two passes over residual: pass 1 computes sqrsum, pass 2 does weighted sum 

8""" 

9 

10import logging 

11import weakref 

12 

13import torch 

14import triton 

15import triton.language as tl 

16 

17logger = logging.getLogger(__name__) 

18 

19 

20_FN_BF16_CACHE: weakref.WeakKeyDictionary[ 

21 torch.Tensor, tuple[int, torch.Tensor] 

22] = weakref.WeakKeyDictionary() 

23 

24 

25def _get_fn_bf16_cached(fn: torch.Tensor) -> torch.Tensor: 

26 if fn.requires_grad or torch.is_grad_enabled(): 

27 return fn.to(dtype=torch.bfloat16) 

28 version = fn._version 

29 cached = _FN_BF16_CACHE.get(fn) 

30 if cached is not None: 

31 cached_version, cached_bf16 = cached 

32 if cached_version == version: 

33 return cached_bf16 

34 fn_bf16 = fn.to(dtype=torch.bfloat16) 

35 _FN_BF16_CACHE[fn] = (version, fn_bf16) 

36 return fn_bf16 

37 

38 

39@triton.jit 

40def _mhc_pre_fused_kernel_hc_mult_4_impl( 

41 gemm_out_ptr, # (num_tokens, hc_mult3), float32 

42 hc_scale_ptr, # (3,), float32 

43 hc_base_ptr, # (hc_mult3,), float32 

44 residual_ptr, # (num_tokens, hc_mult, hidden_size), bfloat16 

45 post_mix_ptr, # (num_tokens, hc_mult), float32 

46 comb_mix_ptr, # (num_tokens, hc_mult*hc_mult), float32 

47 layer_input_ptr, # (num_tokens, hidden_size), bfloat16 

48 num_tokens, 

49 num_tokens_bucket, 

50 res_stride_n, 

51 res_stride_i, 

52 res_stride_h, 

53 li_stride_n, 

54 li_stride_h, 

55 hidden_size, 

56 hc_hidden_size, 

57 rms_eps: tl.constexpr, 

58 hc_pre_eps: tl.constexpr, 

59 hc_sinkhorn_eps: tl.constexpr, 

60 hc_post_mult_value: tl.constexpr, 

61 sinkhorn_repeat: tl.constexpr, 

62 HC_MULT3: tl.constexpr, 

63 BLOCK_H: tl.constexpr, 

64): 

65 """Fully fused: sqrsum + RMS norm + sigmoid + Sinkhorn + weighted sum. One token per program.""" 

66 pid_n = tl.program_id(0) 

67 if pid_n >= num_tokens: 

68 return 

69 

70 # ══ Pass 1: compute sqrsum over all 4 heads ══ 

71 sq = 0.0 

72 res_base = pid_n * res_stride_n 

73 for k in tl.static_range(4): 

74 head_base = res_base + k * res_stride_i 

75 for h_start in range(0, hidden_size, BLOCK_H): 

76 h_offsets = h_start + tl.arange(0, BLOCK_H) 

77 h_mask = h_offsets < hidden_size 

78 v = tl.load( 

79 residual_ptr + head_base + h_offsets * res_stride_h, 

80 mask=h_mask, 

81 other=0.0, 

82 ).to(tl.float32) 

83 sq += tl.sum(v * v) 

84 

85 rms_inv = tl.rsqrt(sq / hc_hidden_size + rms_eps) 

86 

87 # ══ Load scales ══ 

88 scale_0 = tl.load(hc_scale_ptr + 0) 

89 scale_1 = tl.load(hc_scale_ptr + 1) 

90 scale_2 = tl.load(hc_scale_ptr + 2) 

91 

92 go_base = pid_n * HC_MULT3 

93 

94 # ══ pre_mix: indices 0..3 ══ 

95 pre_mix_0 = ( 

96 tl.sigmoid( 

97 tl.load(gemm_out_ptr + go_base + 0) * rms_inv * scale_0 

98 + tl.load(hc_base_ptr + 0) 

99 ) 

100 + hc_pre_eps 

101 ) 

102 pre_mix_1 = ( 

103 tl.sigmoid( 

104 tl.load(gemm_out_ptr + go_base + 1) * rms_inv * scale_0 

105 + tl.load(hc_base_ptr + 1) 

106 ) 

107 + hc_pre_eps 

108 ) 

109 pre_mix_2 = ( 

110 tl.sigmoid( 

111 tl.load(gemm_out_ptr + go_base + 2) * rms_inv * scale_0 

112 + tl.load(hc_base_ptr + 2) 

113 ) 

114 + hc_pre_eps 

115 ) 

116 pre_mix_3 = ( 

117 tl.sigmoid( 

118 tl.load(gemm_out_ptr + go_base + 3) * rms_inv * scale_0 

119 + tl.load(hc_base_ptr + 3) 

120 ) 

121 + hc_pre_eps 

122 ) 

123 

124 # ══ post_mix: indices 4..7 ══ 

125 post_0 = ( 

126 tl.sigmoid( 

127 tl.load(gemm_out_ptr + go_base + 4) * rms_inv * scale_1 

128 + tl.load(hc_base_ptr + 4) 

129 ) 

130 * hc_post_mult_value 

131 ) 

132 tl.store(post_mix_ptr + pid_n * 4 + 0, post_0) 

133 post_1 = ( 

134 tl.sigmoid( 

135 tl.load(gemm_out_ptr + go_base + 5) * rms_inv * scale_1 

136 + tl.load(hc_base_ptr + 5) 

137 ) 

138 * hc_post_mult_value 

139 ) 

140 tl.store(post_mix_ptr + pid_n * 4 + 1, post_1) 

141 post_2 = ( 

142 tl.sigmoid( 

143 tl.load(gemm_out_ptr + go_base + 6) * rms_inv * scale_1 

144 + tl.load(hc_base_ptr + 6) 

145 ) 

146 * hc_post_mult_value 

147 ) 

148 tl.store(post_mix_ptr + pid_n * 4 + 2, post_2) 

149 post_3 = ( 

150 tl.sigmoid( 

151 tl.load(gemm_out_ptr + go_base + 7) * rms_inv * scale_1 

152 + tl.load(hc_base_ptr + 7) 

153 ) 

154 * hc_post_mult_value 

155 ) 

156 tl.store(post_mix_ptr + pid_n * 4 + 3, post_3) 

157 

158 # ══ comb_mix: indices 8..23 → 4x4 Sinkhorn ══ 

159 cb = 8 

160 cm_00 = tl.load(gemm_out_ptr + go_base + cb + 0) * rms_inv * scale_2 + tl.load( 

161 hc_base_ptr + cb + 0 

162 ) 

163 cm_01 = tl.load(gemm_out_ptr + go_base + cb + 1) * rms_inv * scale_2 + tl.load( 

164 hc_base_ptr + cb + 1 

165 ) 

166 cm_02 = tl.load(gemm_out_ptr + go_base + cb + 2) * rms_inv * scale_2 + tl.load( 

167 hc_base_ptr + cb + 2 

168 ) 

169 cm_03 = tl.load(gemm_out_ptr + go_base + cb + 3) * rms_inv * scale_2 + tl.load( 

170 hc_base_ptr + cb + 3 

171 ) 

172 cm_10 = tl.load(gemm_out_ptr + go_base + cb + 4) * rms_inv * scale_2 + tl.load( 

173 hc_base_ptr + cb + 4 

174 ) 

175 cm_11 = tl.load(gemm_out_ptr + go_base + cb + 5) * rms_inv * scale_2 + tl.load( 

176 hc_base_ptr + cb + 5 

177 ) 

178 cm_12 = tl.load(gemm_out_ptr + go_base + cb + 6) * rms_inv * scale_2 + tl.load( 

179 hc_base_ptr + cb + 6 

180 ) 

181 cm_13 = tl.load(gemm_out_ptr + go_base + cb + 7) * rms_inv * scale_2 + tl.load( 

182 hc_base_ptr + cb + 7 

183 ) 

184 cm_20 = tl.load(gemm_out_ptr + go_base + cb + 8) * rms_inv * scale_2 + tl.load( 

185 hc_base_ptr + cb + 8 

186 ) 

187 cm_21 = tl.load(gemm_out_ptr + go_base + cb + 9) * rms_inv * scale_2 + tl.load( 

188 hc_base_ptr + cb + 9 

189 ) 

190 cm_22 = tl.load(gemm_out_ptr + go_base + cb + 10) * rms_inv * scale_2 + tl.load( 

191 hc_base_ptr + cb + 10 

192 ) 

193 cm_23 = tl.load(gemm_out_ptr + go_base + cb + 11) * rms_inv * scale_2 + tl.load( 

194 hc_base_ptr + cb + 11 

195 ) 

196 cm_30 = tl.load(gemm_out_ptr + go_base + cb + 12) * rms_inv * scale_2 + tl.load( 

197 hc_base_ptr + cb + 12 

198 ) 

199 cm_31 = tl.load(gemm_out_ptr + go_base + cb + 13) * rms_inv * scale_2 + tl.load( 

200 hc_base_ptr + cb + 13 

201 ) 

202 cm_32 = tl.load(gemm_out_ptr + go_base + cb + 14) * rms_inv * scale_2 + tl.load( 

203 hc_base_ptr + cb + 14 

204 ) 

205 cm_33 = tl.load(gemm_out_ptr + go_base + cb + 15) * rms_inv * scale_2 + tl.load( 

206 hc_base_ptr + cb + 15 

207 ) 

208 

209 # ── Sinkhorn iteration ── 

210 rm = tl.maximum(tl.maximum(cm_00, cm_01), tl.maximum(cm_02, cm_03)) 

211 cm_00 = tl.exp(cm_00 - rm) 

212 cm_01 = tl.exp(cm_01 - rm) 

213 cm_02 = tl.exp(cm_02 - rm) 

214 cm_03 = tl.exp(cm_03 - rm) 

215 rs = cm_00 + cm_01 + cm_02 + cm_03 

216 inv_rs = 1.0 / rs 

217 cm_00 = cm_00 * inv_rs + hc_sinkhorn_eps 

218 cm_01 = cm_01 * inv_rs + hc_sinkhorn_eps 

219 cm_02 = cm_02 * inv_rs + hc_sinkhorn_eps 

220 cm_03 = cm_03 * inv_rs + hc_sinkhorn_eps 

221 

222 rm = tl.maximum(tl.maximum(cm_10, cm_11), tl.maximum(cm_12, cm_13)) 

223 cm_10 = tl.exp(cm_10 - rm) 

224 cm_11 = tl.exp(cm_11 - rm) 

225 cm_12 = tl.exp(cm_12 - rm) 

226 cm_13 = tl.exp(cm_13 - rm) 

227 rs = cm_10 + cm_11 + cm_12 + cm_13 

228 inv_rs = 1.0 / rs 

229 cm_10 = cm_10 * inv_rs + hc_sinkhorn_eps 

230 cm_11 = cm_11 * inv_rs + hc_sinkhorn_eps 

231 cm_12 = cm_12 * inv_rs + hc_sinkhorn_eps 

232 cm_13 = cm_13 * inv_rs + hc_sinkhorn_eps 

233 

234 rm = tl.maximum(tl.maximum(cm_20, cm_21), tl.maximum(cm_22, cm_23)) 

235 cm_20 = tl.exp(cm_20 - rm) 

236 cm_21 = tl.exp(cm_21 - rm) 

237 cm_22 = tl.exp(cm_22 - rm) 

238 cm_23 = tl.exp(cm_23 - rm) 

239 rs = cm_20 + cm_21 + cm_22 + cm_23 

240 inv_rs = 1.0 / rs 

241 cm_20 = cm_20 * inv_rs + hc_sinkhorn_eps 

242 cm_21 = cm_21 * inv_rs + hc_sinkhorn_eps 

243 cm_22 = cm_22 * inv_rs + hc_sinkhorn_eps 

244 cm_23 = cm_23 * inv_rs + hc_sinkhorn_eps 

245 

246 rm = tl.maximum(tl.maximum(cm_30, cm_31), tl.maximum(cm_32, cm_33)) 

247 cm_30 = tl.exp(cm_30 - rm) 

248 cm_31 = tl.exp(cm_31 - rm) 

249 cm_32 = tl.exp(cm_32 - rm) 

250 cm_33 = tl.exp(cm_33 - rm) 

251 rs = cm_30 + cm_31 + cm_32 + cm_33 

252 inv_rs = 1.0 / rs 

253 cm_30 = cm_30 * inv_rs + hc_sinkhorn_eps 

254 cm_31 = cm_31 * inv_rs + hc_sinkhorn_eps 

255 cm_32 = cm_32 * inv_rs + hc_sinkhorn_eps 

256 cm_33 = cm_33 * inv_rs + hc_sinkhorn_eps 

257 

258 cs0 = cm_00 + cm_10 + cm_20 + cm_30 

259 cs1 = cm_01 + cm_11 + cm_21 + cm_31 

260 cs2 = cm_02 + cm_12 + cm_22 + cm_32 

261 cs3 = cm_03 + cm_13 + cm_23 + cm_33 

262 inv_cs0 = 1.0 / (cs0 + hc_sinkhorn_eps) 

263 inv_cs1 = 1.0 / (cs1 + hc_sinkhorn_eps) 

264 inv_cs2 = 1.0 / (cs2 + hc_sinkhorn_eps) 

265 inv_cs3 = 1.0 / (cs3 + hc_sinkhorn_eps) 

266 cm_00 *= inv_cs0 

267 cm_10 *= inv_cs0 

268 cm_20 *= inv_cs0 

269 cm_30 *= inv_cs0 

270 cm_01 *= inv_cs1 

271 cm_11 *= inv_cs1 

272 cm_21 *= inv_cs1 

273 cm_31 *= inv_cs1 

274 cm_02 *= inv_cs2 

275 cm_12 *= inv_cs2 

276 cm_22 *= inv_cs2 

277 cm_32 *= inv_cs2 

278 cm_03 *= inv_cs3 

279 cm_13 *= inv_cs3 

280 cm_23 *= inv_cs3 

281 cm_33 *= inv_cs3 

282 

283 for _ in tl.static_range(sinkhorn_repeat - 1): 

284 rs0 = cm_00 + cm_01 + cm_02 + cm_03 

285 rs1 = cm_10 + cm_11 + cm_12 + cm_13 

286 rs2 = cm_20 + cm_21 + cm_22 + cm_23 

287 rs3 = cm_30 + cm_31 + cm_32 + cm_33 

288 inv_rs0 = 1.0 / (rs0 + hc_sinkhorn_eps) 

289 inv_rs1 = 1.0 / (rs1 + hc_sinkhorn_eps) 

290 inv_rs2 = 1.0 / (rs2 + hc_sinkhorn_eps) 

291 inv_rs3 = 1.0 / (rs3 + hc_sinkhorn_eps) 

292 cm_00 *= inv_rs0 

293 cm_01 *= inv_rs0 

294 cm_02 *= inv_rs0 

295 cm_03 *= inv_rs0 

296 cm_10 *= inv_rs1 

297 cm_11 *= inv_rs1 

298 cm_12 *= inv_rs1 

299 cm_13 *= inv_rs1 

300 cm_20 *= inv_rs2 

301 cm_21 *= inv_rs2 

302 cm_22 *= inv_rs2 

303 cm_23 *= inv_rs2 

304 cm_30 *= inv_rs3 

305 cm_31 *= inv_rs3 

306 cm_32 *= inv_rs3 

307 cm_33 *= inv_rs3 

308 cs0 = cm_00 + cm_10 + cm_20 + cm_30 

309 cs1 = cm_01 + cm_11 + cm_21 + cm_31 

310 cs2 = cm_02 + cm_12 + cm_22 + cm_32 

311 cs3 = cm_03 + cm_13 + cm_23 + cm_33 

312 inv_cs0 = 1.0 / (cs0 + hc_sinkhorn_eps) 

313 inv_cs1 = 1.0 / (cs1 + hc_sinkhorn_eps) 

314 inv_cs2 = 1.0 / (cs2 + hc_sinkhorn_eps) 

315 inv_cs3 = 1.0 / (cs3 + hc_sinkhorn_eps) 

316 cm_00 *= inv_cs0 

317 cm_01 *= inv_cs1 

318 cm_02 *= inv_cs2 

319 cm_03 *= inv_cs3 

320 cm_10 *= inv_cs0 

321 cm_11 *= inv_cs1 

322 cm_12 *= inv_cs2 

323 cm_13 *= inv_cs3 

324 cm_20 *= inv_cs0 

325 cm_21 *= inv_cs1 

326 cm_22 *= inv_cs2 

327 cm_23 *= inv_cs3 

328 cm_30 *= inv_cs0 

329 cm_31 *= inv_cs1 

330 cm_32 *= inv_cs2 

331 cm_33 *= inv_cs3 

332 

333 co = pid_n * 16 

334 tl.store(comb_mix_ptr + co + 0, cm_00) 

335 tl.store(comb_mix_ptr + co + 1, cm_01) 

336 tl.store(comb_mix_ptr + co + 2, cm_02) 

337 tl.store(comb_mix_ptr + co + 3, cm_03) 

338 tl.store(comb_mix_ptr + co + 4, cm_10) 

339 tl.store(comb_mix_ptr + co + 5, cm_11) 

340 tl.store(comb_mix_ptr + co + 6, cm_12) 

341 tl.store(comb_mix_ptr + co + 7, cm_13) 

342 tl.store(comb_mix_ptr + co + 8, cm_20) 

343 tl.store(comb_mix_ptr + co + 9, cm_21) 

344 tl.store(comb_mix_ptr + co + 10, cm_22) 

345 tl.store(comb_mix_ptr + co + 11, cm_23) 

346 tl.store(comb_mix_ptr + co + 12, cm_30) 

347 tl.store(comb_mix_ptr + co + 13, cm_31) 

348 tl.store(comb_mix_ptr + co + 14, cm_32) 

349 tl.store(comb_mix_ptr + co + 15, cm_33) 

350 

351 # ══ Pass 2: weighted sum layer_input = sum_k(pre_mix_k * residual[n, k, :]) ══ 

352 for h_start in range(0, hidden_size, BLOCK_H): 

353 h_offsets = h_start + tl.arange(0, BLOCK_H) 

354 h_mask = h_offsets < hidden_size 

355 r0 = tl.load( 

356 residual_ptr + res_base + 0 * res_stride_i + h_offsets * res_stride_h, 

357 mask=h_mask, 

358 other=0.0, 

359 ).to(tl.float32) 

360 r1 = tl.load( 

361 residual_ptr + res_base + 1 * res_stride_i + h_offsets * res_stride_h, 

362 mask=h_mask, 

363 other=0.0, 

364 ).to(tl.float32) 

365 acc = pre_mix_0 * r0 + pre_mix_1 * r1 

366 r2 = tl.load( 

367 residual_ptr + res_base + 2 * res_stride_i + h_offsets * res_stride_h, 

368 mask=h_mask, 

369 other=0.0, 

370 ).to(tl.float32) 

371 r3 = tl.load( 

372 residual_ptr + res_base + 3 * res_stride_i + h_offsets * res_stride_h, 

373 mask=h_mask, 

374 other=0.0, 

375 ).to(tl.float32) 

376 acc += pre_mix_2 * r2 + pre_mix_3 * r3 

377 tl.store( 

378 layer_input_ptr + pid_n * li_stride_n + h_offsets * li_stride_h, 

379 acc.to(tl.bfloat16), 

380 mask=h_mask, 

381 ) 

382 

383 

384@triton.autotune( 

385 configs=[ 

386 triton.Config({"BLOCK_H": 256}, num_warps=4, num_stages=1), 

387 triton.Config({"BLOCK_H": 256}, num_warps=8, num_stages=1), 

388 triton.Config({"BLOCK_H": 512}, num_warps=4, num_stages=1), 

389 triton.Config({"BLOCK_H": 512}, num_warps=8, num_stages=1), 

390 triton.Config({"BLOCK_H": 1024}, num_warps=4, num_stages=1), 

391 triton.Config({"BLOCK_H": 1024}, num_warps=8, num_stages=1), 

392 triton.Config({"BLOCK_H": 1024}, num_warps=16, num_stages=1), 

393 triton.Config({"BLOCK_H": 256}, num_warps=4, num_stages=2), 

394 triton.Config({"BLOCK_H": 512}, num_warps=4, num_stages=2), 

395 triton.Config({"BLOCK_H": 512}, num_warps=8, num_stages=2), 

396 ], 

397 key=["hidden_size", "num_tokens_bucket"], 

398) 

399@triton.jit 

400def mhc_pre_fused_kernel_hc_mult_4( 

401 gemm_out_ptr, # (num_tokens, hc_mult3), float32 

402 hc_scale_ptr, # (3,), float32 

403 hc_base_ptr, # (hc_mult3,), float32 

404 residual_ptr, # (num_tokens, hc_mult, hidden_size), bfloat16 

405 post_mix_ptr, # (num_tokens, hc_mult), float32 

406 comb_mix_ptr, # (num_tokens, hc_mult*hc_mult), float32 

407 layer_input_ptr, # (num_tokens, hidden_size), bfloat16 

408 num_tokens, 

409 num_tokens_bucket, 

410 res_stride_n, 

411 res_stride_i, 

412 res_stride_h, 

413 li_stride_n, 

414 li_stride_h, 

415 hidden_size, 

416 hc_hidden_size, 

417 rms_eps: tl.constexpr, 

418 hc_pre_eps: tl.constexpr, 

419 hc_sinkhorn_eps: tl.constexpr, 

420 hc_post_mult_value: tl.constexpr, 

421 sinkhorn_repeat: tl.constexpr, 

422 HC_MULT3: tl.constexpr, 

423 BLOCK_H: tl.constexpr, 

424): 

425 _mhc_pre_fused_kernel_hc_mult_4_impl( 

426 gemm_out_ptr, 

427 hc_scale_ptr, 

428 hc_base_ptr, 

429 residual_ptr, 

430 post_mix_ptr, 

431 comb_mix_ptr, 

432 layer_input_ptr, 

433 num_tokens, 

434 num_tokens_bucket, 

435 res_stride_n, 

436 res_stride_i, 

437 res_stride_h, 

438 li_stride_n, 

439 li_stride_h, 

440 hidden_size, 

441 hc_hidden_size, 

442 rms_eps, 

443 hc_pre_eps, 

444 hc_sinkhorn_eps, 

445 hc_post_mult_value, 

446 sinkhorn_repeat, 

447 HC_MULT3, 

448 BLOCK_H, 

449 ) 

450 

451 

452@triton.autotune( 

453 configs=[ 

454 triton.Config({"BLOCK_H": 256}, num_warps=4, num_stages=1), 

455 triton.Config({"BLOCK_H": 256}, num_warps=8, num_stages=1), 

456 triton.Config({"BLOCK_H": 512}, num_warps=4, num_stages=1), 

457 triton.Config({"BLOCK_H": 512}, num_warps=8, num_stages=1), 

458 triton.Config({"BLOCK_H": 1024}, num_warps=4, num_stages=1), 

459 triton.Config({"BLOCK_H": 1024}, num_warps=8, num_stages=1), 

460 ], 

461 key=["hidden_size", "num_tokens_bucket", "HC"], 

462) 

463@triton.jit 

464def mhc_pre_generic_kernel( 

465 gemm_out_ptr, # (num_tokens, hc_mult3), float32 

466 hc_scale_ptr, # (3,), float32 

467 hc_base_ptr, # (hc_mult3,), float32 

468 residual_ptr, # (num_tokens, HC, hidden_size), bfloat16 

469 post_mix_ptr, # (num_tokens, HC), float32 

470 comb_mix_ptr, # (num_tokens, HC*HC), float32 

471 layer_input_ptr, # (num_tokens, hidden_size), bfloat16 

472 num_tokens, 

473 num_tokens_bucket, 

474 res_stride_n, 

475 res_stride_i, 

476 res_stride_h, 

477 li_stride_n, 

478 li_stride_h, 

479 hidden_size, 

480 hc_hidden_size, 

481 rms_eps: tl.constexpr, 

482 hc_pre_eps: tl.constexpr, 

483 hc_sinkhorn_eps: tl.constexpr, 

484 hc_post_mult_value: tl.constexpr, 

485 sinkhorn_repeat: tl.constexpr, 

486 HC: tl.constexpr, 

487 BLOCK_H: tl.constexpr, 

488): 

489 pid_n = tl.program_id(0) 

490 if pid_n >= num_tokens: 

491 return 

492 

493 res_base = pid_n * res_stride_n 

494 go_base = pid_n * (HC * 2 + HC * HC) 

495 comb_base = pid_n * (HC * HC) 

496 

497 sq = 0.0 

498 for k in tl.static_range(HC): 

499 head_base = res_base + k * res_stride_i 

500 for h_start in range(0, hidden_size, BLOCK_H): 

501 h_offsets = h_start + tl.arange(0, BLOCK_H) 

502 h_mask = h_offsets < hidden_size 

503 v = tl.load( 

504 residual_ptr + head_base + h_offsets * res_stride_h, 

505 mask=h_mask, 

506 other=0.0, 

507 ).to(tl.float32) 

508 sq += tl.sum(v * v) 

509 

510 rms_inv = tl.rsqrt(sq / hc_hidden_size + rms_eps) 

511 

512 scale_0 = tl.load(hc_scale_ptr + 0) 

513 scale_1 = tl.load(hc_scale_ptr + 1) 

514 scale_2 = tl.load(hc_scale_ptr + 2) 

515 

516 for i in tl.static_range(HC): 

517 post_i = ( 

518 tl.sigmoid( 

519 tl.load(gemm_out_ptr + go_base + HC + i) * rms_inv * scale_1 

520 + tl.load(hc_base_ptr + HC + i) 

521 ) 

522 * hc_post_mult_value 

523 ) 

524 tl.store(post_mix_ptr + pid_n * HC + i, post_i) 

525 

526 cb = 2 * HC 

527 for i in tl.static_range(HC): 

528 for j in tl.static_range(HC): 

529 idx = i * HC + j 

530 v = tl.load( 

531 gemm_out_ptr + go_base + cb + idx 

532 ) * rms_inv * scale_2 + tl.load(hc_base_ptr + cb + idx) 

533 tl.store(comb_mix_ptr + comb_base + idx, v) 

534 

535 for i in tl.static_range(HC): 

536 row_max = tl.load(comb_mix_ptr + comb_base + i * HC + 0) 

537 for j in tl.static_range(1, HC): 

538 row_max = tl.maximum( 

539 row_max, tl.load(comb_mix_ptr + comb_base + i * HC + j) 

540 ) 

541 

542 row_sum = 0.0 

543 for j in tl.static_range(HC): 

544 e = tl.exp(tl.load(comb_mix_ptr + comb_base + i * HC + j) - row_max) 

545 tl.store(comb_mix_ptr + comb_base + i * HC + j, e) 

546 row_sum += e 

547 

548 inv_row_sum = 1.0 / row_sum 

549 for j in tl.static_range(HC): 

550 v = tl.load(comb_mix_ptr + comb_base + i * HC + j) 

551 tl.store( 

552 comb_mix_ptr + comb_base + i * HC + j, v * inv_row_sum + hc_sinkhorn_eps 

553 ) 

554 

555 for j in tl.static_range(HC): 

556 col_sum = 0.0 

557 for i in tl.static_range(HC): 

558 col_sum += tl.load(comb_mix_ptr + comb_base + i * HC + j) 

559 inv_col_sum = 1.0 / (col_sum + hc_sinkhorn_eps) 

560 for i in tl.static_range(HC): 

561 v = tl.load(comb_mix_ptr + comb_base + i * HC + j) 

562 tl.store(comb_mix_ptr + comb_base + i * HC + j, v * inv_col_sum) 

563 

564 for _ in tl.static_range(sinkhorn_repeat - 1): 

565 for i in tl.static_range(HC): 

566 row_sum = 0.0 

567 for j in tl.static_range(HC): 

568 row_sum += tl.load(comb_mix_ptr + comb_base + i * HC + j) 

569 inv_row_sum = 1.0 / (row_sum + hc_sinkhorn_eps) 

570 for j in tl.static_range(HC): 

571 v = tl.load(comb_mix_ptr + comb_base + i * HC + j) 

572 tl.store(comb_mix_ptr + comb_base + i * HC + j, v * inv_row_sum) 

573 

574 for j in tl.static_range(HC): 

575 col_sum = 0.0 

576 for i in tl.static_range(HC): 

577 col_sum += tl.load(comb_mix_ptr + comb_base + i * HC + j) 

578 inv_col_sum = 1.0 / (col_sum + hc_sinkhorn_eps) 

579 for i in tl.static_range(HC): 

580 v = tl.load(comb_mix_ptr + comb_base + i * HC + j) 

581 tl.store(comb_mix_ptr + comb_base + i * HC + j, v * inv_col_sum) 

582 

583 for h_start in range(0, hidden_size, BLOCK_H): 

584 h_offsets = h_start + tl.arange(0, BLOCK_H) 

585 h_mask = h_offsets < hidden_size 

586 acc = tl.zeros([BLOCK_H], dtype=tl.float32) 

587 

588 for k in tl.static_range(HC): 

589 pre_k = ( 

590 tl.sigmoid( 

591 tl.load(gemm_out_ptr + go_base + k) * rms_inv * scale_0 

592 + tl.load(hc_base_ptr + k) 

593 ) 

594 + hc_pre_eps 

595 ) 

596 rk = tl.load( 

597 residual_ptr + res_base + k * res_stride_i + h_offsets * res_stride_h, 

598 mask=h_mask, 

599 other=0.0, 

600 ).to(tl.float32) 

601 acc += pre_k * rk 

602 

603 tl.store( 

604 layer_input_ptr + pid_n * li_stride_n + h_offsets * li_stride_h, 

605 acc.to(tl.bfloat16), 

606 mask=h_mask, 

607 ) 

608 

609 

610def mhc_pre( 

611 residual: torch.Tensor, 

612 fn: torch.Tensor, 

613 hc_scale: torch.Tensor, 

614 hc_base: torch.Tensor, 

615 rms_eps: float, 

616 hc_pre_eps: float, 

617 hc_sinkhorn_eps: float, 

618 hc_post_mult_value: float, 

619 sinkhorn_repeat: int, 

620 n_splits: int = 1, 

621) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 

622 """ 

623 Optimized mHC pre block. 

624 

625 - hc_mult == 4: specialized fused Triton kernel 

626 - hc_mult != 4: generic Triton kernel aligned to reference math 

627 """ 

628 assert residual.dtype == torch.bfloat16 

629 assert fn.dtype == torch.float32 

630 

631 hc_mult = residual.shape[-2] 

632 hidden_size = residual.shape[-1] 

633 hc_mult3 = hc_mult * 2 + hc_mult * hc_mult 

634 hc_hidden_size = hc_mult * hidden_size 

635 

636 assert fn.shape == (hc_mult3, hc_hidden_size) 

637 

638 outer_shape = residual.shape[:-2] 

639 residual_flat = residual.reshape(-1, hc_mult, hidden_size).contiguous() 

640 num_tokens = residual_flat.shape[0] 

641 device = residual.device 

642 if num_tokens <= 512: 

643 num_tokens_bucket = 1 

644 elif num_tokens <= 1024: 

645 num_tokens_bucket = 2 

646 elif num_tokens <= 2048: 

647 num_tokens_bucket = 3 

648 elif num_tokens <= 4096: 

649 num_tokens_bucket = 4 

650 else: 

651 num_tokens_bucket = 5 

652 

653 # ── Step 1: GEMM via cuBLAS (bf16 tensor cores) ── 

654 x_flat = residual_flat.reshape(num_tokens, hc_hidden_size) 

655 fn_bf16 = _get_fn_bf16_cached(fn) 

656 gemm_out = torch.mm(x_flat, fn_bf16.t()).float() 

657 

658 # ── Step 2: Fused sqrsum + norm + mix + sinkhorn + weighted sum ── 

659 post_mix = torch.empty(num_tokens, hc_mult, dtype=torch.float32, device=device) 

660 comb_mix = torch.empty( 

661 num_tokens, hc_mult * hc_mult, dtype=torch.float32, device=device 

662 ) 

663 layer_input = torch.empty( 

664 num_tokens, hidden_size, dtype=torch.bfloat16, device=device 

665 ) 

666 

667 if hc_mult == 4: 

668 mhc_pre_fused_kernel_hc_mult_4[(num_tokens,)]( 

669 gemm_out, 

670 hc_scale, 

671 hc_base, 

672 residual_flat, 

673 post_mix, 

674 comb_mix, 

675 layer_input, 

676 num_tokens, 

677 num_tokens_bucket, 

678 residual_flat.stride(0), 

679 residual_flat.stride(1), 

680 residual_flat.stride(2), 

681 layer_input.stride(0), 

682 layer_input.stride(1), 

683 hidden_size, 

684 hc_hidden_size, 

685 rms_eps=rms_eps, 

686 hc_pre_eps=hc_pre_eps, 

687 hc_sinkhorn_eps=hc_sinkhorn_eps, 

688 hc_post_mult_value=hc_post_mult_value, 

689 sinkhorn_repeat=sinkhorn_repeat, 

690 HC_MULT3=hc_mult3, 

691 ) 

692 else: 

693 mhc_pre_generic_kernel[(num_tokens,)]( 

694 gemm_out, 

695 hc_scale, 

696 hc_base, 

697 residual_flat, 

698 post_mix, 

699 comb_mix, 

700 layer_input, 

701 num_tokens, 

702 num_tokens_bucket, 

703 residual_flat.stride(0), 

704 residual_flat.stride(1), 

705 residual_flat.stride(2), 

706 layer_input.stride(0), 

707 layer_input.stride(1), 

708 hidden_size, 

709 hc_hidden_size, 

710 rms_eps=rms_eps, 

711 hc_pre_eps=hc_pre_eps, 

712 hc_sinkhorn_eps=hc_sinkhorn_eps, 

713 hc_post_mult_value=hc_post_mult_value, 

714 sinkhorn_repeat=sinkhorn_repeat, 

715 HC=hc_mult, 

716 ) 

717 

718 post_mix = post_mix.view(*outer_shape, hc_mult, 1) 

719 comb_mix = comb_mix.view(*outer_shape, hc_mult, hc_mult) 

720 layer_input = layer_input.view(*outer_shape, hidden_size) 

721 

722 return post_mix, comb_mix, layer_input 

723 

724 

725# ───────────────────────── Reference implementations ───────────────────────── 

726 

727 

728def sinkhorn_normalize_ref(x: torch.Tensor, repeat: int, eps: float) -> torch.Tensor: 

729 x = x.softmax(-1) + eps 

730 x = x / (x.sum(-2, keepdim=True) + eps) 

731 for _ in range(repeat - 1): 

732 x = x / (x.sum(-1, keepdim=True) + eps) 

733 x = x / (x.sum(-2, keepdim=True) + eps) 

734 return x 

735 

736 

737def mhc_pre_ref( 

738 residual: torch.Tensor, 

739 fn: torch.Tensor, 

740 hc_scale: torch.Tensor, 

741 hc_base: torch.Tensor, 

742 rms_eps: float, 

743 hc_pre_eps: float, 

744 hc_sinkhorn_eps: float, 

745 hc_post_mult_value: float, 

746 sinkhorn_repeat: int, 

747) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 

748 """PyTorch reference.""" 

749 hc_mult = residual.shape[-2] 

750 residual_flat = residual.flatten(-2, -1).float() 

751 sqrsum = residual_flat.square().sum(-1) 

752 mixes = ( 

753 residual_flat @ fn.T * (sqrsum.unsqueeze(-1) / fn.shape[-1] + rms_eps).rsqrt() 

754 ) 

755 hc_scale_expanded = torch.cat( 

756 [ 

757 hc_scale[0].expand(hc_mult), 

758 hc_scale[1].expand(hc_mult), 

759 hc_scale[2].expand(hc_mult * hc_mult), 

760 ] 

761 ) 

762 mixes = mixes * hc_scale_expanded + hc_base 

763 pre_mix = mixes[:, :hc_mult].sigmoid().unsqueeze(-1) + hc_pre_eps 

764 post_mix = ( 

765 mixes[:, hc_mult : 2 * hc_mult].sigmoid() * hc_post_mult_value 

766 ).unsqueeze(-1) 

767 res_mix = mixes[:, 2 * hc_mult :].view(-1, hc_mult, hc_mult) 

768 res_mix = sinkhorn_normalize_ref( 

769 res_mix, repeat=sinkhorn_repeat, eps=hc_sinkhorn_eps 

770 ) 

771 layer_input = (residual * pre_mix).sum(-2).bfloat16() 

772 return post_mix, res_mix, layer_input