Coverage for src/flag_gems/fused/mhc/hc_split_sinkhorn.py: 17%
312 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def mhc_split_sinkhorn_kernel_hcmult_4(
8 mixes_ptr,
9 hc_scale_ptr,
10 hc_base_ptr,
11 pre_ptr,
12 post_ptr,
13 comb_ptr,
14 num_tokens,
15 BLOCK_N: tl.constexpr,
16 SINKHORN_ITERS: tl.constexpr,
17):
18 """Vectorized kernel for HC_MULT=4."""
19 pid = tl.program_id(0)
20 offs = pid * BLOCK_N + tl.arange(0, BLOCK_N)
21 mask = offs < num_tokens
22 base = offs * 24
24 scale_0 = tl.load(hc_scale_ptr + 0)
25 scale_1 = tl.load(hc_scale_ptr + 1)
26 scale_2 = tl.load(hc_scale_ptr + 2)
28 m0 = tl.load(mixes_ptr + base + 0, mask=mask)
29 m1 = tl.load(mixes_ptr + base + 1, mask=mask)
30 m2 = tl.load(mixes_ptr + base + 2, mask=mask)
31 m3 = tl.load(mixes_ptr + base + 3, mask=mask)
32 m4 = tl.load(mixes_ptr + base + 4, mask=mask)
33 m5 = tl.load(mixes_ptr + base + 5, mask=mask)
34 m6 = tl.load(mixes_ptr + base + 6, mask=mask)
35 m7 = tl.load(mixes_ptr + base + 7, mask=mask)
37 b0 = tl.load(hc_base_ptr + 0)
38 b1 = tl.load(hc_base_ptr + 1)
39 b2 = tl.load(hc_base_ptr + 2)
40 b3 = tl.load(hc_base_ptr + 3)
41 b4 = tl.load(hc_base_ptr + 4)
42 b5 = tl.load(hc_base_ptr + 5)
43 b6 = tl.load(hc_base_ptr + 6)
44 b7 = tl.load(hc_base_ptr + 7)
46 pre_base = offs * 4
47 tl.store(pre_ptr + pre_base + 0, tl.sigmoid(m0 * scale_0 + b0) + 1e-6, mask=mask)
48 tl.store(pre_ptr + pre_base + 1, tl.sigmoid(m1 * scale_0 + b1) + 1e-6, mask=mask)
49 tl.store(pre_ptr + pre_base + 2, tl.sigmoid(m2 * scale_0 + b2) + 1e-6, mask=mask)
50 tl.store(pre_ptr + pre_base + 3, tl.sigmoid(m3 * scale_0 + b3) + 1e-6, mask=mask)
52 post_base = offs * 4
53 tl.store(post_ptr + post_base + 0, 2.0 * tl.sigmoid(m4 * scale_1 + b4), mask=mask)
54 tl.store(post_ptr + post_base + 1, 2.0 * tl.sigmoid(m5 * scale_1 + b5), mask=mask)
55 tl.store(post_ptr + post_base + 2, 2.0 * tl.sigmoid(m6 * scale_1 + b6), mask=mask)
56 tl.store(post_ptr + post_base + 3, 2.0 * tl.sigmoid(m7 * scale_1 + b7), mask=mask)
58 cb = 8
59 b8 = tl.load(hc_base_ptr + cb + 0)
60 b9 = tl.load(hc_base_ptr + cb + 1)
61 b10 = tl.load(hc_base_ptr + cb + 2)
62 b11 = tl.load(hc_base_ptr + cb + 3)
63 b12 = tl.load(hc_base_ptr + cb + 4)
64 b13 = tl.load(hc_base_ptr + cb + 5)
65 b14 = tl.load(hc_base_ptr + cb + 6)
66 b15 = tl.load(hc_base_ptr + cb + 7)
67 b16 = tl.load(hc_base_ptr + cb + 8)
68 b17 = tl.load(hc_base_ptr + cb + 9)
69 b18 = tl.load(hc_base_ptr + cb + 10)
70 b19 = tl.load(hc_base_ptr + cb + 11)
71 b20 = tl.load(hc_base_ptr + cb + 12)
72 b21 = tl.load(hc_base_ptr + cb + 13)
73 b22 = tl.load(hc_base_ptr + cb + 14)
74 b23 = tl.load(hc_base_ptr + cb + 15)
76 cm_00 = tl.load(mixes_ptr + base + cb + 0, mask=mask) * scale_2 + b8
77 cm_01 = tl.load(mixes_ptr + base + cb + 1, mask=mask) * scale_2 + b9
78 cm_02 = tl.load(mixes_ptr + base + cb + 2, mask=mask) * scale_2 + b10
79 cm_03 = tl.load(mixes_ptr + base + cb + 3, mask=mask) * scale_2 + b11
80 cm_10 = tl.load(mixes_ptr + base + cb + 4, mask=mask) * scale_2 + b12
81 cm_11 = tl.load(mixes_ptr + base + cb + 5, mask=mask) * scale_2 + b13
82 cm_12 = tl.load(mixes_ptr + base + cb + 6, mask=mask) * scale_2 + b14
83 cm_13 = tl.load(mixes_ptr + base + cb + 7, mask=mask) * scale_2 + b15
84 cm_20 = tl.load(mixes_ptr + base + cb + 8, mask=mask) * scale_2 + b16
85 cm_21 = tl.load(mixes_ptr + base + cb + 9, mask=mask) * scale_2 + b17
86 cm_22 = tl.load(mixes_ptr + base + cb + 10, mask=mask) * scale_2 + b18
87 cm_23 = tl.load(mixes_ptr + base + cb + 11, mask=mask) * scale_2 + b19
88 cm_30 = tl.load(mixes_ptr + base + cb + 12, mask=mask) * scale_2 + b20
89 cm_31 = tl.load(mixes_ptr + base + cb + 13, mask=mask) * scale_2 + b21
90 cm_32 = tl.load(mixes_ptr + base + cb + 14, mask=mask) * scale_2 + b22
91 cm_33 = tl.load(mixes_ptr + base + cb + 15, mask=mask) * scale_2 + b23
93 rm = tl.maximum(tl.maximum(cm_00, cm_01), tl.maximum(cm_02, cm_03))
94 cm_00 = tl.exp(cm_00 - rm)
95 cm_01 = tl.exp(cm_01 - rm)
96 cm_02 = tl.exp(cm_02 - rm)
97 cm_03 = tl.exp(cm_03 - rm)
98 inv_rs = 1.0 / (cm_00 + cm_01 + cm_02 + cm_03)
99 cm_00 = cm_00 * inv_rs + 1e-6
100 cm_01 = cm_01 * inv_rs + 1e-6
101 cm_02 = cm_02 * inv_rs + 1e-6
102 cm_03 = cm_03 * inv_rs + 1e-6
104 rm = tl.maximum(tl.maximum(cm_10, cm_11), tl.maximum(cm_12, cm_13))
105 cm_10 = tl.exp(cm_10 - rm)
106 cm_11 = tl.exp(cm_11 - rm)
107 cm_12 = tl.exp(cm_12 - rm)
108 cm_13 = tl.exp(cm_13 - rm)
109 inv_rs = 1.0 / (cm_10 + cm_11 + cm_12 + cm_13)
110 cm_10 = cm_10 * inv_rs + 1e-6
111 cm_11 = cm_11 * inv_rs + 1e-6
112 cm_12 = cm_12 * inv_rs + 1e-6
113 cm_13 = cm_13 * inv_rs + 1e-6
115 rm = tl.maximum(tl.maximum(cm_20, cm_21), tl.maximum(cm_22, cm_23))
116 cm_20 = tl.exp(cm_20 - rm)
117 cm_21 = tl.exp(cm_21 - rm)
118 cm_22 = tl.exp(cm_22 - rm)
119 cm_23 = tl.exp(cm_23 - rm)
120 inv_rs = 1.0 / (cm_20 + cm_21 + cm_22 + cm_23)
121 cm_20 = cm_20 * inv_rs + 1e-6
122 cm_21 = cm_21 * inv_rs + 1e-6
123 cm_22 = cm_22 * inv_rs + 1e-6
124 cm_23 = cm_23 * inv_rs + 1e-6
126 rm = tl.maximum(tl.maximum(cm_30, cm_31), tl.maximum(cm_32, cm_33))
127 cm_30 = tl.exp(cm_30 - rm)
128 cm_31 = tl.exp(cm_31 - rm)
129 cm_32 = tl.exp(cm_32 - rm)
130 cm_33 = tl.exp(cm_33 - rm)
131 inv_rs = 1.0 / (cm_30 + cm_31 + cm_32 + cm_33)
132 cm_30 = cm_30 * inv_rs + 1e-6
133 cm_31 = cm_31 * inv_rs + 1e-6
134 cm_32 = cm_32 * inv_rs + 1e-6
135 cm_33 = cm_33 * inv_rs + 1e-6
137 inv_cs0 = 1.0 / (cm_00 + cm_10 + cm_20 + cm_30 + 1e-6)
138 inv_cs1 = 1.0 / (cm_01 + cm_11 + cm_21 + cm_31 + 1e-6)
139 inv_cs2 = 1.0 / (cm_02 + cm_12 + cm_22 + cm_32 + 1e-6)
140 inv_cs3 = 1.0 / (cm_03 + cm_13 + cm_23 + cm_33 + 1e-6)
141 cm_00 *= inv_cs0
142 cm_10 *= inv_cs0
143 cm_20 *= inv_cs0
144 cm_30 *= inv_cs0
145 cm_01 *= inv_cs1
146 cm_11 *= inv_cs1
147 cm_21 *= inv_cs1
148 cm_31 *= inv_cs1
149 cm_02 *= inv_cs2
150 cm_12 *= inv_cs2
151 cm_22 *= inv_cs2
152 cm_32 *= inv_cs2
153 cm_03 *= inv_cs3
154 cm_13 *= inv_cs3
155 cm_23 *= inv_cs3
156 cm_33 *= inv_cs3
158 for _ in range(SINKHORN_ITERS - 1):
159 inv_rs0 = 1.0 / (cm_00 + cm_01 + cm_02 + cm_03 + 1e-6)
160 inv_rs1 = 1.0 / (cm_10 + cm_11 + cm_12 + cm_13 + 1e-6)
161 inv_rs2 = 1.0 / (cm_20 + cm_21 + cm_22 + cm_23 + 1e-6)
162 inv_rs3 = 1.0 / (cm_30 + cm_31 + cm_32 + cm_33 + 1e-6)
163 cm_00 *= inv_rs0
164 cm_01 *= inv_rs0
165 cm_02 *= inv_rs0
166 cm_03 *= inv_rs0
167 cm_10 *= inv_rs1
168 cm_11 *= inv_rs1
169 cm_12 *= inv_rs1
170 cm_13 *= inv_rs1
171 cm_20 *= inv_rs2
172 cm_21 *= inv_rs2
173 cm_22 *= inv_rs2
174 cm_23 *= inv_rs2
175 cm_30 *= inv_rs3
176 cm_31 *= inv_rs3
177 cm_32 *= inv_rs3
178 cm_33 *= inv_rs3
180 inv_cs0 = 1.0 / (cm_00 + cm_10 + cm_20 + cm_30 + 1e-6)
181 inv_cs1 = 1.0 / (cm_01 + cm_11 + cm_21 + cm_31 + 1e-6)
182 inv_cs2 = 1.0 / (cm_02 + cm_12 + cm_22 + cm_32 + 1e-6)
183 inv_cs3 = 1.0 / (cm_03 + cm_13 + cm_23 + cm_33 + 1e-6)
184 cm_00 *= inv_cs0
185 cm_01 *= inv_cs1
186 cm_02 *= inv_cs2
187 cm_03 *= inv_cs3
188 cm_10 *= inv_cs0
189 cm_11 *= inv_cs1
190 cm_12 *= inv_cs2
191 cm_13 *= inv_cs3
192 cm_20 *= inv_cs0
193 cm_21 *= inv_cs1
194 cm_22 *= inv_cs2
195 cm_23 *= inv_cs3
196 cm_30 *= inv_cs0
197 cm_31 *= inv_cs1
198 cm_32 *= inv_cs2
199 cm_33 *= inv_cs3
201 co = offs * 16
202 tl.store(comb_ptr + co + 0, cm_00, mask=mask)
203 tl.store(comb_ptr + co + 1, cm_01, mask=mask)
204 tl.store(comb_ptr + co + 2, cm_02, mask=mask)
205 tl.store(comb_ptr + co + 3, cm_03, mask=mask)
206 tl.store(comb_ptr + co + 4, cm_10, mask=mask)
207 tl.store(comb_ptr + co + 5, cm_11, mask=mask)
208 tl.store(comb_ptr + co + 6, cm_12, mask=mask)
209 tl.store(comb_ptr + co + 7, cm_13, mask=mask)
210 tl.store(comb_ptr + co + 8, cm_20, mask=mask)
211 tl.store(comb_ptr + co + 9, cm_21, mask=mask)
212 tl.store(comb_ptr + co + 10, cm_22, mask=mask)
213 tl.store(comb_ptr + co + 11, cm_23, mask=mask)
214 tl.store(comb_ptr + co + 12, cm_30, mask=mask)
215 tl.store(comb_ptr + co + 13, cm_31, mask=mask)
216 tl.store(comb_ptr + co + 14, cm_32, mask=mask)
217 tl.store(comb_ptr + co + 15, cm_33, mask=mask)
220@triton.jit
221def mhc_split_sinkhorn_kernel_generic(
222 mixes_ptr,
223 hc_scale_ptr,
224 hc_base_ptr,
225 pre_ptr,
226 post_ptr,
227 comb_ptr,
228 num_tokens,
229 SINKHORN_ITERS: tl.constexpr,
230 HC_MULT: tl.constexpr,
231 MIX_HC: tl.constexpr,
232):
233 """Generic split+sinkhorn kernel for arbitrary HC_MULT (one token per program)."""
234 pid_n = tl.program_id(0)
235 if pid_n >= num_tokens:
236 return
238 base = pid_n * MIX_HC
239 pre_base = pid_n * HC_MULT
240 post_base = pid_n * HC_MULT
241 comb_base = pid_n * (HC_MULT * HC_MULT)
243 scale_0 = tl.load(hc_scale_ptr + 0)
244 scale_1 = tl.load(hc_scale_ptr + 1)
245 scale_2 = tl.load(hc_scale_ptr + 2)
247 for j in tl.static_range(HC_MULT):
248 pre_idx = j
249 post_idx = HC_MULT + j
250 pre_m = tl.load(mixes_ptr + base + pre_idx)
251 post_m = tl.load(mixes_ptr + base + post_idx)
252 pre_b = tl.load(hc_base_ptr + pre_idx)
253 post_b = tl.load(hc_base_ptr + post_idx)
254 tl.store(pre_ptr + pre_base + j, tl.sigmoid(pre_m * scale_0 + pre_b) + 1e-6)
255 tl.store(post_ptr + post_base + j, 2.0 * tl.sigmoid(post_m * scale_1 + post_b))
257 comb_offset = 2 * HC_MULT
259 for row in tl.static_range(HC_MULT):
260 for col in tl.static_range(HC_MULT):
261 idx = comb_offset + row * HC_MULT + col
262 out_idx = row * HC_MULT + col
263 m = tl.load(mixes_ptr + base + idx)
264 b = tl.load(hc_base_ptr + idx)
265 tl.store(comb_ptr + comb_base + out_idx, m * scale_2 + b)
267 for row in tl.static_range(HC_MULT):
268 row_ptr0 = comb_ptr + comb_base + row * HC_MULT
269 row_max = tl.load(row_ptr0)
270 for col in tl.static_range(HC_MULT):
271 row_ptr = comb_ptr + comb_base + row * HC_MULT + col
272 row_max = tl.maximum(row_max, tl.load(row_ptr))
273 row_sum = 0.0
274 for col in tl.static_range(HC_MULT):
275 row_ptr = comb_ptr + comb_base + row * HC_MULT + col
276 v = tl.exp(tl.load(row_ptr) - row_max)
277 row_sum += v
278 tl.store(row_ptr, v)
279 inv_row_sum = 1.0 / row_sum
280 for col in tl.static_range(HC_MULT):
281 row_ptr = comb_ptr + comb_base + row * HC_MULT + col
282 v = tl.load(row_ptr) * inv_row_sum + 1e-6
283 tl.store(row_ptr, v)
285 for col in tl.static_range(HC_MULT):
286 col_sum = 0.0
287 for row in tl.static_range(HC_MULT):
288 ptr = comb_ptr + comb_base + row * HC_MULT + col
289 col_sum += tl.load(ptr)
290 inv_col_sum = 1.0 / (col_sum + 1e-6)
291 for row in tl.static_range(HC_MULT):
292 ptr = comb_ptr + comb_base + row * HC_MULT + col
293 tl.store(ptr, tl.load(ptr) * inv_col_sum)
295 for _ in range(SINKHORN_ITERS - 1):
296 for row in tl.static_range(HC_MULT):
297 row_sum = 0.0
298 for col in tl.static_range(HC_MULT):
299 ptr = comb_ptr + comb_base + row * HC_MULT + col
300 row_sum += tl.load(ptr)
301 inv_row_sum = 1.0 / (row_sum + 1e-6)
302 for col in tl.static_range(HC_MULT):
303 ptr = comb_ptr + comb_base + row * HC_MULT + col
304 tl.store(ptr, tl.load(ptr) * inv_row_sum)
306 for col in tl.static_range(HC_MULT):
307 col_sum = 0.0
308 for row in tl.static_range(HC_MULT):
309 ptr = comb_ptr + comb_base + row * HC_MULT + col
310 col_sum += tl.load(ptr)
311 inv_col_sum = 1.0 / (col_sum + 1e-6)
312 for row in tl.static_range(HC_MULT):
313 ptr = comb_ptr + comb_base + row * HC_MULT + col
314 tl.store(ptr, tl.load(ptr) * inv_col_sum)
317def hc_split_sinkhorn(
318 mixes: torch.Tensor,
319 hc_scale: torch.Tensor,
320 hc_base: torch.Tensor,
321 hc_mult: int = 4,
322 sinkhorn_iters: int = 20,
323 eps: float = 1e-6,
324) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
325 mix_hc = (2 + hc_mult) * hc_mult
326 assert mixes.shape[-1] == mix_hc
327 assert hc_scale.shape == (3,)
328 assert hc_base.shape == (mix_hc,)
330 if mixes.device.type == "cuda" and eps == 1e-6 and hc_mult >= 1:
331 outer_shape = mixes.shape[:-1]
332 mixes_flat = mixes.reshape(-1, mix_hc).contiguous()
333 num_tokens = mixes_flat.shape[0]
335 pre = torch.empty(num_tokens, hc_mult, dtype=torch.float32, device=mixes.device)
336 post = torch.empty(
337 num_tokens, hc_mult, dtype=torch.float32, device=mixes.device
338 )
339 comb = torch.empty(
340 num_tokens, hc_mult * hc_mult, dtype=torch.float32, device=mixes.device
341 )
343 if num_tokens <= 256:
344 block_n = 16
345 num_warps = 1
346 elif num_tokens <= 2048:
347 block_n = 32
348 num_warps = 1
349 elif num_tokens <= 16384:
350 block_n = 128
351 num_warps = 4
352 else:
353 block_n = 256
354 num_warps = 8
355 grid = (num_tokens + block_n - 1) // block_n
357 if hc_mult == 4:
358 mhc_split_sinkhorn_kernel_hcmult_4[(grid,)](
359 mixes_flat,
360 hc_scale,
361 hc_base,
362 pre,
363 post,
364 comb,
365 num_tokens,
366 BLOCK_N=block_n,
367 SINKHORN_ITERS=sinkhorn_iters,
368 num_warps=num_warps,
369 num_stages=1,
370 )
371 else:
372 if hc_mult <= 4:
373 num_warps = 1
374 elif hc_mult <= 8:
375 num_warps = 2
376 else:
377 num_warps = 4
379 mhc_split_sinkhorn_kernel_generic[(num_tokens,)](
380 mixes_flat,
381 hc_scale,
382 hc_base,
383 pre,
384 post,
385 comb,
386 num_tokens,
387 SINKHORN_ITERS=sinkhorn_iters,
388 HC_MULT=hc_mult,
389 MIX_HC=mix_hc,
390 num_warps=num_warps,
391 num_stages=1,
392 )
393 else:
394 return mhc_split_sinkhorn_torch_ref(
395 mixes,
396 hc_scale,
397 hc_base,
398 hc_mult=hc_mult,
399 sinkhorn_iters=sinkhorn_iters,
400 eps=eps,
401 )
403 return (
404 pre.view(*outer_shape, hc_mult),
405 post.view(*outer_shape, hc_mult),
406 comb.view(*outer_shape, hc_mult, hc_mult),
407 )
410def mhc_split_sinkhorn_torch_ref(
411 mixes: torch.Tensor,
412 hc_scale: torch.Tensor,
413 hc_base: torch.Tensor,
414 hc_mult: int = 4,
415 sinkhorn_iters: int = 20,
416 eps: float = 1e-6,
417) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
418 outer_shape = mixes.shape[:-1]
419 mix_hc = (2 + hc_mult) * hc_mult
420 assert mixes.shape[-1] == mix_hc
422 pre = torch.sigmoid(mixes[..., :hc_mult] * hc_scale[0] + hc_base[:hc_mult]) + eps
423 post = 2 * torch.sigmoid(
424 mixes[..., hc_mult : 2 * hc_mult] * hc_scale[1] + hc_base[hc_mult : 2 * hc_mult]
425 )
426 comb = mixes[..., 2 * hc_mult :].view(*outer_shape, hc_mult, hc_mult) * hc_scale[
427 2
428 ] + hc_base[2 * hc_mult :].view(hc_mult, hc_mult)
430 row_max = comb.max(dim=-1, keepdim=True).values
431 comb = (comb - row_max).exp()
432 comb = comb / comb.sum(dim=-1, keepdim=True) + eps
433 comb = comb / (comb.sum(dim=-2, keepdim=True) + eps)
434 for _ in range(sinkhorn_iters - 1):
435 comb = comb / (comb.sum(dim=-1, keepdim=True) + eps)
436 comb = comb / (comb.sum(dim=-2, keepdim=True) + eps)
437 return pre, post, comb