Coverage for src/flag_gems/fused/mhc/hc_split_sinkhorn.py: 17%

312 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def mhc_split_sinkhorn_kernel_hcmult_4( 

8 mixes_ptr, 

9 hc_scale_ptr, 

10 hc_base_ptr, 

11 pre_ptr, 

12 post_ptr, 

13 comb_ptr, 

14 num_tokens, 

15 BLOCK_N: tl.constexpr, 

16 SINKHORN_ITERS: tl.constexpr, 

17): 

18 """Vectorized kernel for HC_MULT=4.""" 

19 pid = tl.program_id(0) 

20 offs = pid * BLOCK_N + tl.arange(0, BLOCK_N) 

21 mask = offs < num_tokens 

22 base = offs * 24 

23 

24 scale_0 = tl.load(hc_scale_ptr + 0) 

25 scale_1 = tl.load(hc_scale_ptr + 1) 

26 scale_2 = tl.load(hc_scale_ptr + 2) 

27 

28 m0 = tl.load(mixes_ptr + base + 0, mask=mask) 

29 m1 = tl.load(mixes_ptr + base + 1, mask=mask) 

30 m2 = tl.load(mixes_ptr + base + 2, mask=mask) 

31 m3 = tl.load(mixes_ptr + base + 3, mask=mask) 

32 m4 = tl.load(mixes_ptr + base + 4, mask=mask) 

33 m5 = tl.load(mixes_ptr + base + 5, mask=mask) 

34 m6 = tl.load(mixes_ptr + base + 6, mask=mask) 

35 m7 = tl.load(mixes_ptr + base + 7, mask=mask) 

36 

37 b0 = tl.load(hc_base_ptr + 0) 

38 b1 = tl.load(hc_base_ptr + 1) 

39 b2 = tl.load(hc_base_ptr + 2) 

40 b3 = tl.load(hc_base_ptr + 3) 

41 b4 = tl.load(hc_base_ptr + 4) 

42 b5 = tl.load(hc_base_ptr + 5) 

43 b6 = tl.load(hc_base_ptr + 6) 

44 b7 = tl.load(hc_base_ptr + 7) 

45 

46 pre_base = offs * 4 

47 tl.store(pre_ptr + pre_base + 0, tl.sigmoid(m0 * scale_0 + b0) + 1e-6, mask=mask) 

48 tl.store(pre_ptr + pre_base + 1, tl.sigmoid(m1 * scale_0 + b1) + 1e-6, mask=mask) 

49 tl.store(pre_ptr + pre_base + 2, tl.sigmoid(m2 * scale_0 + b2) + 1e-6, mask=mask) 

50 tl.store(pre_ptr + pre_base + 3, tl.sigmoid(m3 * scale_0 + b3) + 1e-6, mask=mask) 

51 

52 post_base = offs * 4 

53 tl.store(post_ptr + post_base + 0, 2.0 * tl.sigmoid(m4 * scale_1 + b4), mask=mask) 

54 tl.store(post_ptr + post_base + 1, 2.0 * tl.sigmoid(m5 * scale_1 + b5), mask=mask) 

55 tl.store(post_ptr + post_base + 2, 2.0 * tl.sigmoid(m6 * scale_1 + b6), mask=mask) 

56 tl.store(post_ptr + post_base + 3, 2.0 * tl.sigmoid(m7 * scale_1 + b7), mask=mask) 

57 

58 cb = 8 

59 b8 = tl.load(hc_base_ptr + cb + 0) 

60 b9 = tl.load(hc_base_ptr + cb + 1) 

61 b10 = tl.load(hc_base_ptr + cb + 2) 

62 b11 = tl.load(hc_base_ptr + cb + 3) 

63 b12 = tl.load(hc_base_ptr + cb + 4) 

64 b13 = tl.load(hc_base_ptr + cb + 5) 

65 b14 = tl.load(hc_base_ptr + cb + 6) 

66 b15 = tl.load(hc_base_ptr + cb + 7) 

67 b16 = tl.load(hc_base_ptr + cb + 8) 

68 b17 = tl.load(hc_base_ptr + cb + 9) 

69 b18 = tl.load(hc_base_ptr + cb + 10) 

70 b19 = tl.load(hc_base_ptr + cb + 11) 

71 b20 = tl.load(hc_base_ptr + cb + 12) 

72 b21 = tl.load(hc_base_ptr + cb + 13) 

73 b22 = tl.load(hc_base_ptr + cb + 14) 

74 b23 = tl.load(hc_base_ptr + cb + 15) 

75 

76 cm_00 = tl.load(mixes_ptr + base + cb + 0, mask=mask) * scale_2 + b8 

77 cm_01 = tl.load(mixes_ptr + base + cb + 1, mask=mask) * scale_2 + b9 

78 cm_02 = tl.load(mixes_ptr + base + cb + 2, mask=mask) * scale_2 + b10 

79 cm_03 = tl.load(mixes_ptr + base + cb + 3, mask=mask) * scale_2 + b11 

80 cm_10 = tl.load(mixes_ptr + base + cb + 4, mask=mask) * scale_2 + b12 

81 cm_11 = tl.load(mixes_ptr + base + cb + 5, mask=mask) * scale_2 + b13 

82 cm_12 = tl.load(mixes_ptr + base + cb + 6, mask=mask) * scale_2 + b14 

83 cm_13 = tl.load(mixes_ptr + base + cb + 7, mask=mask) * scale_2 + b15 

84 cm_20 = tl.load(mixes_ptr + base + cb + 8, mask=mask) * scale_2 + b16 

85 cm_21 = tl.load(mixes_ptr + base + cb + 9, mask=mask) * scale_2 + b17 

86 cm_22 = tl.load(mixes_ptr + base + cb + 10, mask=mask) * scale_2 + b18 

87 cm_23 = tl.load(mixes_ptr + base + cb + 11, mask=mask) * scale_2 + b19 

88 cm_30 = tl.load(mixes_ptr + base + cb + 12, mask=mask) * scale_2 + b20 

89 cm_31 = tl.load(mixes_ptr + base + cb + 13, mask=mask) * scale_2 + b21 

90 cm_32 = tl.load(mixes_ptr + base + cb + 14, mask=mask) * scale_2 + b22 

91 cm_33 = tl.load(mixes_ptr + base + cb + 15, mask=mask) * scale_2 + b23 

92 

93 rm = tl.maximum(tl.maximum(cm_00, cm_01), tl.maximum(cm_02, cm_03)) 

94 cm_00 = tl.exp(cm_00 - rm) 

95 cm_01 = tl.exp(cm_01 - rm) 

96 cm_02 = tl.exp(cm_02 - rm) 

97 cm_03 = tl.exp(cm_03 - rm) 

98 inv_rs = 1.0 / (cm_00 + cm_01 + cm_02 + cm_03) 

99 cm_00 = cm_00 * inv_rs + 1e-6 

100 cm_01 = cm_01 * inv_rs + 1e-6 

101 cm_02 = cm_02 * inv_rs + 1e-6 

102 cm_03 = cm_03 * inv_rs + 1e-6 

103 

104 rm = tl.maximum(tl.maximum(cm_10, cm_11), tl.maximum(cm_12, cm_13)) 

105 cm_10 = tl.exp(cm_10 - rm) 

106 cm_11 = tl.exp(cm_11 - rm) 

107 cm_12 = tl.exp(cm_12 - rm) 

108 cm_13 = tl.exp(cm_13 - rm) 

109 inv_rs = 1.0 / (cm_10 + cm_11 + cm_12 + cm_13) 

110 cm_10 = cm_10 * inv_rs + 1e-6 

111 cm_11 = cm_11 * inv_rs + 1e-6 

112 cm_12 = cm_12 * inv_rs + 1e-6 

113 cm_13 = cm_13 * inv_rs + 1e-6 

114 

115 rm = tl.maximum(tl.maximum(cm_20, cm_21), tl.maximum(cm_22, cm_23)) 

116 cm_20 = tl.exp(cm_20 - rm) 

117 cm_21 = tl.exp(cm_21 - rm) 

118 cm_22 = tl.exp(cm_22 - rm) 

119 cm_23 = tl.exp(cm_23 - rm) 

120 inv_rs = 1.0 / (cm_20 + cm_21 + cm_22 + cm_23) 

121 cm_20 = cm_20 * inv_rs + 1e-6 

122 cm_21 = cm_21 * inv_rs + 1e-6 

123 cm_22 = cm_22 * inv_rs + 1e-6 

124 cm_23 = cm_23 * inv_rs + 1e-6 

125 

126 rm = tl.maximum(tl.maximum(cm_30, cm_31), tl.maximum(cm_32, cm_33)) 

127 cm_30 = tl.exp(cm_30 - rm) 

128 cm_31 = tl.exp(cm_31 - rm) 

129 cm_32 = tl.exp(cm_32 - rm) 

130 cm_33 = tl.exp(cm_33 - rm) 

131 inv_rs = 1.0 / (cm_30 + cm_31 + cm_32 + cm_33) 

132 cm_30 = cm_30 * inv_rs + 1e-6 

133 cm_31 = cm_31 * inv_rs + 1e-6 

134 cm_32 = cm_32 * inv_rs + 1e-6 

135 cm_33 = cm_33 * inv_rs + 1e-6 

136 

137 inv_cs0 = 1.0 / (cm_00 + cm_10 + cm_20 + cm_30 + 1e-6) 

138 inv_cs1 = 1.0 / (cm_01 + cm_11 + cm_21 + cm_31 + 1e-6) 

139 inv_cs2 = 1.0 / (cm_02 + cm_12 + cm_22 + cm_32 + 1e-6) 

140 inv_cs3 = 1.0 / (cm_03 + cm_13 + cm_23 + cm_33 + 1e-6) 

141 cm_00 *= inv_cs0 

142 cm_10 *= inv_cs0 

143 cm_20 *= inv_cs0 

144 cm_30 *= inv_cs0 

145 cm_01 *= inv_cs1 

146 cm_11 *= inv_cs1 

147 cm_21 *= inv_cs1 

148 cm_31 *= inv_cs1 

149 cm_02 *= inv_cs2 

150 cm_12 *= inv_cs2 

151 cm_22 *= inv_cs2 

152 cm_32 *= inv_cs2 

153 cm_03 *= inv_cs3 

154 cm_13 *= inv_cs3 

155 cm_23 *= inv_cs3 

156 cm_33 *= inv_cs3 

157 

158 for _ in range(SINKHORN_ITERS - 1): 

159 inv_rs0 = 1.0 / (cm_00 + cm_01 + cm_02 + cm_03 + 1e-6) 

160 inv_rs1 = 1.0 / (cm_10 + cm_11 + cm_12 + cm_13 + 1e-6) 

161 inv_rs2 = 1.0 / (cm_20 + cm_21 + cm_22 + cm_23 + 1e-6) 

162 inv_rs3 = 1.0 / (cm_30 + cm_31 + cm_32 + cm_33 + 1e-6) 

163 cm_00 *= inv_rs0 

164 cm_01 *= inv_rs0 

165 cm_02 *= inv_rs0 

166 cm_03 *= inv_rs0 

167 cm_10 *= inv_rs1 

168 cm_11 *= inv_rs1 

169 cm_12 *= inv_rs1 

170 cm_13 *= inv_rs1 

171 cm_20 *= inv_rs2 

172 cm_21 *= inv_rs2 

173 cm_22 *= inv_rs2 

174 cm_23 *= inv_rs2 

175 cm_30 *= inv_rs3 

176 cm_31 *= inv_rs3 

177 cm_32 *= inv_rs3 

178 cm_33 *= inv_rs3 

179 

180 inv_cs0 = 1.0 / (cm_00 + cm_10 + cm_20 + cm_30 + 1e-6) 

181 inv_cs1 = 1.0 / (cm_01 + cm_11 + cm_21 + cm_31 + 1e-6) 

182 inv_cs2 = 1.0 / (cm_02 + cm_12 + cm_22 + cm_32 + 1e-6) 

183 inv_cs3 = 1.0 / (cm_03 + cm_13 + cm_23 + cm_33 + 1e-6) 

184 cm_00 *= inv_cs0 

185 cm_01 *= inv_cs1 

186 cm_02 *= inv_cs2 

187 cm_03 *= inv_cs3 

188 cm_10 *= inv_cs0 

189 cm_11 *= inv_cs1 

190 cm_12 *= inv_cs2 

191 cm_13 *= inv_cs3 

192 cm_20 *= inv_cs0 

193 cm_21 *= inv_cs1 

194 cm_22 *= inv_cs2 

195 cm_23 *= inv_cs3 

196 cm_30 *= inv_cs0 

197 cm_31 *= inv_cs1 

198 cm_32 *= inv_cs2 

199 cm_33 *= inv_cs3 

200 

201 co = offs * 16 

202 tl.store(comb_ptr + co + 0, cm_00, mask=mask) 

203 tl.store(comb_ptr + co + 1, cm_01, mask=mask) 

204 tl.store(comb_ptr + co + 2, cm_02, mask=mask) 

205 tl.store(comb_ptr + co + 3, cm_03, mask=mask) 

206 tl.store(comb_ptr + co + 4, cm_10, mask=mask) 

207 tl.store(comb_ptr + co + 5, cm_11, mask=mask) 

208 tl.store(comb_ptr + co + 6, cm_12, mask=mask) 

209 tl.store(comb_ptr + co + 7, cm_13, mask=mask) 

210 tl.store(comb_ptr + co + 8, cm_20, mask=mask) 

211 tl.store(comb_ptr + co + 9, cm_21, mask=mask) 

212 tl.store(comb_ptr + co + 10, cm_22, mask=mask) 

213 tl.store(comb_ptr + co + 11, cm_23, mask=mask) 

214 tl.store(comb_ptr + co + 12, cm_30, mask=mask) 

215 tl.store(comb_ptr + co + 13, cm_31, mask=mask) 

216 tl.store(comb_ptr + co + 14, cm_32, mask=mask) 

217 tl.store(comb_ptr + co + 15, cm_33, mask=mask) 

218 

219 

220@triton.jit 

221def mhc_split_sinkhorn_kernel_generic( 

222 mixes_ptr, 

223 hc_scale_ptr, 

224 hc_base_ptr, 

225 pre_ptr, 

226 post_ptr, 

227 comb_ptr, 

228 num_tokens, 

229 SINKHORN_ITERS: tl.constexpr, 

230 HC_MULT: tl.constexpr, 

231 MIX_HC: tl.constexpr, 

232): 

233 """Generic split+sinkhorn kernel for arbitrary HC_MULT (one token per program).""" 

234 pid_n = tl.program_id(0) 

235 if pid_n >= num_tokens: 

236 return 

237 

238 base = pid_n * MIX_HC 

239 pre_base = pid_n * HC_MULT 

240 post_base = pid_n * HC_MULT 

241 comb_base = pid_n * (HC_MULT * HC_MULT) 

242 

243 scale_0 = tl.load(hc_scale_ptr + 0) 

244 scale_1 = tl.load(hc_scale_ptr + 1) 

245 scale_2 = tl.load(hc_scale_ptr + 2) 

246 

247 for j in tl.static_range(HC_MULT): 

248 pre_idx = j 

249 post_idx = HC_MULT + j 

250 pre_m = tl.load(mixes_ptr + base + pre_idx) 

251 post_m = tl.load(mixes_ptr + base + post_idx) 

252 pre_b = tl.load(hc_base_ptr + pre_idx) 

253 post_b = tl.load(hc_base_ptr + post_idx) 

254 tl.store(pre_ptr + pre_base + j, tl.sigmoid(pre_m * scale_0 + pre_b) + 1e-6) 

255 tl.store(post_ptr + post_base + j, 2.0 * tl.sigmoid(post_m * scale_1 + post_b)) 

256 

257 comb_offset = 2 * HC_MULT 

258 

259 for row in tl.static_range(HC_MULT): 

260 for col in tl.static_range(HC_MULT): 

261 idx = comb_offset + row * HC_MULT + col 

262 out_idx = row * HC_MULT + col 

263 m = tl.load(mixes_ptr + base + idx) 

264 b = tl.load(hc_base_ptr + idx) 

265 tl.store(comb_ptr + comb_base + out_idx, m * scale_2 + b) 

266 

267 for row in tl.static_range(HC_MULT): 

268 row_ptr0 = comb_ptr + comb_base + row * HC_MULT 

269 row_max = tl.load(row_ptr0) 

270 for col in tl.static_range(HC_MULT): 

271 row_ptr = comb_ptr + comb_base + row * HC_MULT + col 

272 row_max = tl.maximum(row_max, tl.load(row_ptr)) 

273 row_sum = 0.0 

274 for col in tl.static_range(HC_MULT): 

275 row_ptr = comb_ptr + comb_base + row * HC_MULT + col 

276 v = tl.exp(tl.load(row_ptr) - row_max) 

277 row_sum += v 

278 tl.store(row_ptr, v) 

279 inv_row_sum = 1.0 / row_sum 

280 for col in tl.static_range(HC_MULT): 

281 row_ptr = comb_ptr + comb_base + row * HC_MULT + col 

282 v = tl.load(row_ptr) * inv_row_sum + 1e-6 

283 tl.store(row_ptr, v) 

284 

285 for col in tl.static_range(HC_MULT): 

286 col_sum = 0.0 

287 for row in tl.static_range(HC_MULT): 

288 ptr = comb_ptr + comb_base + row * HC_MULT + col 

289 col_sum += tl.load(ptr) 

290 inv_col_sum = 1.0 / (col_sum + 1e-6) 

291 for row in tl.static_range(HC_MULT): 

292 ptr = comb_ptr + comb_base + row * HC_MULT + col 

293 tl.store(ptr, tl.load(ptr) * inv_col_sum) 

294 

295 for _ in range(SINKHORN_ITERS - 1): 

296 for row in tl.static_range(HC_MULT): 

297 row_sum = 0.0 

298 for col in tl.static_range(HC_MULT): 

299 ptr = comb_ptr + comb_base + row * HC_MULT + col 

300 row_sum += tl.load(ptr) 

301 inv_row_sum = 1.0 / (row_sum + 1e-6) 

302 for col in tl.static_range(HC_MULT): 

303 ptr = comb_ptr + comb_base + row * HC_MULT + col 

304 tl.store(ptr, tl.load(ptr) * inv_row_sum) 

305 

306 for col in tl.static_range(HC_MULT): 

307 col_sum = 0.0 

308 for row in tl.static_range(HC_MULT): 

309 ptr = comb_ptr + comb_base + row * HC_MULT + col 

310 col_sum += tl.load(ptr) 

311 inv_col_sum = 1.0 / (col_sum + 1e-6) 

312 for row in tl.static_range(HC_MULT): 

313 ptr = comb_ptr + comb_base + row * HC_MULT + col 

314 tl.store(ptr, tl.load(ptr) * inv_col_sum) 

315 

316 

317def hc_split_sinkhorn( 

318 mixes: torch.Tensor, 

319 hc_scale: torch.Tensor, 

320 hc_base: torch.Tensor, 

321 hc_mult: int = 4, 

322 sinkhorn_iters: int = 20, 

323 eps: float = 1e-6, 

324) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 

325 mix_hc = (2 + hc_mult) * hc_mult 

326 assert mixes.shape[-1] == mix_hc 

327 assert hc_scale.shape == (3,) 

328 assert hc_base.shape == (mix_hc,) 

329 

330 if mixes.device.type == "cuda" and eps == 1e-6 and hc_mult >= 1: 

331 outer_shape = mixes.shape[:-1] 

332 mixes_flat = mixes.reshape(-1, mix_hc).contiguous() 

333 num_tokens = mixes_flat.shape[0] 

334 

335 pre = torch.empty(num_tokens, hc_mult, dtype=torch.float32, device=mixes.device) 

336 post = torch.empty( 

337 num_tokens, hc_mult, dtype=torch.float32, device=mixes.device 

338 ) 

339 comb = torch.empty( 

340 num_tokens, hc_mult * hc_mult, dtype=torch.float32, device=mixes.device 

341 ) 

342 

343 if num_tokens <= 256: 

344 block_n = 16 

345 num_warps = 1 

346 elif num_tokens <= 2048: 

347 block_n = 32 

348 num_warps = 1 

349 elif num_tokens <= 16384: 

350 block_n = 128 

351 num_warps = 4 

352 else: 

353 block_n = 256 

354 num_warps = 8 

355 grid = (num_tokens + block_n - 1) // block_n 

356 

357 if hc_mult == 4: 

358 mhc_split_sinkhorn_kernel_hcmult_4[(grid,)]( 

359 mixes_flat, 

360 hc_scale, 

361 hc_base, 

362 pre, 

363 post, 

364 comb, 

365 num_tokens, 

366 BLOCK_N=block_n, 

367 SINKHORN_ITERS=sinkhorn_iters, 

368 num_warps=num_warps, 

369 num_stages=1, 

370 ) 

371 else: 

372 if hc_mult <= 4: 

373 num_warps = 1 

374 elif hc_mult <= 8: 

375 num_warps = 2 

376 else: 

377 num_warps = 4 

378 

379 mhc_split_sinkhorn_kernel_generic[(num_tokens,)]( 

380 mixes_flat, 

381 hc_scale, 

382 hc_base, 

383 pre, 

384 post, 

385 comb, 

386 num_tokens, 

387 SINKHORN_ITERS=sinkhorn_iters, 

388 HC_MULT=hc_mult, 

389 MIX_HC=mix_hc, 

390 num_warps=num_warps, 

391 num_stages=1, 

392 ) 

393 else: 

394 return mhc_split_sinkhorn_torch_ref( 

395 mixes, 

396 hc_scale, 

397 hc_base, 

398 hc_mult=hc_mult, 

399 sinkhorn_iters=sinkhorn_iters, 

400 eps=eps, 

401 ) 

402 

403 return ( 

404 pre.view(*outer_shape, hc_mult), 

405 post.view(*outer_shape, hc_mult), 

406 comb.view(*outer_shape, hc_mult, hc_mult), 

407 ) 

408 

409 

410def mhc_split_sinkhorn_torch_ref( 

411 mixes: torch.Tensor, 

412 hc_scale: torch.Tensor, 

413 hc_base: torch.Tensor, 

414 hc_mult: int = 4, 

415 sinkhorn_iters: int = 20, 

416 eps: float = 1e-6, 

417) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 

418 outer_shape = mixes.shape[:-1] 

419 mix_hc = (2 + hc_mult) * hc_mult 

420 assert mixes.shape[-1] == mix_hc 

421 

422 pre = torch.sigmoid(mixes[..., :hc_mult] * hc_scale[0] + hc_base[:hc_mult]) + eps 

423 post = 2 * torch.sigmoid( 

424 mixes[..., hc_mult : 2 * hc_mult] * hc_scale[1] + hc_base[hc_mult : 2 * hc_mult] 

425 ) 

426 comb = mixes[..., 2 * hc_mult :].view(*outer_shape, hc_mult, hc_mult) * hc_scale[ 

427 2 

428 ] + hc_base[2 * hc_mult :].view(hc_mult, hc_mult) 

429 

430 row_max = comb.max(dim=-1, keepdim=True).values 

431 comb = (comb - row_max).exp() 

432 comb = comb / comb.sum(dim=-1, keepdim=True) + eps 

433 comb = comb / (comb.sum(dim=-2, keepdim=True) + eps) 

434 for _ in range(sinkhorn_iters - 1): 

435 comb = comb / (comb.sum(dim=-1, keepdim=True) + eps) 

436 comb = comb / (comb.sum(dim=-2, keepdim=True) + eps) 

437 return pre, post, comb