Coverage for src/flag_gems/fused/mhc/mhc_bwd.py: 29%
249 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1"""
2Triton implementation of mHC Backward (Sinkhorn implicit CG differentiation).
4This kernel computes the gradient of the Sinkhorn normalization using
5implicit differentiation via the conjugate gradient method.
7Algorithm:
8Given R = Sinkhorn(M) and upstream gradient dR, we solve for dM using:
91. Compute b1 = sum(R * dR, dim=-1), b2 = sum(R * dR, dim=-2)
102. Solve the linear system A*x = b using CG where A is the Sinkhorn Jacobian
113. Result: dM = (dR - x1 - x2) * R
12"""
14import torch
15import triton
16import triton.language as tl
18EPS = 1e-10
21def _get_autotune_configs():
22 """Generate autotune configurations for different tile sizes and warps."""
23 configs = []
24 for TILE_SIZE in [1, 2, 4, 8, 16, 32]:
25 for num_warps in [1, 2, 4, 8]:
26 configs.append(triton.Config({"TILE_SIZE": TILE_SIZE}, num_warps=num_warps))
27 return configs
30@triton.autotune(
31 configs=_get_autotune_configs(),
32 key=["seqlen", "n_stream"],
33)
34@triton.jit
35def _mhc_bwd_kernel(
36 # Pointers to tensors
37 out_ptr, # (seqlen, n_stream, n_stream), float32 - Sinkhorn output R
38 dout_ptr, # (seqlen, n_stream, n_stream), float32 - upstream gradient dR
39 res_ptr, # (seqlen, n_stream, n_stream), float32 - result dM
40 # Dimensions
41 seqlen,
42 n_stream,
43 # Strides
44 out_stride_s,
45 out_stride_i,
46 out_stride_j,
47 dout_stride_s,
48 dout_stride_i,
49 dout_stride_j,
50 res_stride_s,
51 res_stride_i,
52 res_stride_j,
53 # Number of CG iterations
54 cg_iters: tl.constexpr,
55 # Constants
56 TILE_SIZE: tl.constexpr,
57 N_STREAM: tl.constexpr,
58):
59 """Sinkhorn backward via implicit CG differentiation - one tile per program."""
60 pid = tl.program_id(0)
61 tile_start = pid * TILE_SIZE
63 for t in range(TILE_SIZE):
64 seq_idx = tile_start + t
65 if seq_idx >= seqlen:
66 continue
68 base_out = seq_idx * out_stride_s
69 base_dout = seq_idx * dout_stride_s
71 for i in range(N_STREAM):
72 for j in range(N_STREAM):
73 r_val = tl.load(
74 out_ptr + base_out + i * out_stride_i + j * out_stride_j
75 )
76 dr_val = tl.load(
77 dout_ptr + base_dout + i * dout_stride_i + j * dout_stride_j
78 )
80 base_res = seq_idx * res_stride_s
82 for i in range(N_STREAM):
83 for j in range(N_STREAM):
84 r_val = tl.load(
85 out_ptr + base_out + i * out_stride_i + j * out_stride_j
86 )
87 dr_val = tl.load(
88 dout_ptr + base_dout + i * dout_stride_i + j * dout_stride_j
89 )
90 tl.store(
91 res_ptr + base_res + i * res_stride_i + j * res_stride_j,
92 dr_val * r_val,
93 )
96@triton.jit
97def _mhc_bwd_kernel_n4(
98 # Pointers to tensors
99 out_ptr, # (seqlen, 4, 4), float32 - Sinkhorn output R
100 dout_ptr, # (seqlen, 4, 4), float32 - upstream gradient dR
101 res_ptr, # (seqlen, 4, 4), float32 - result dM
102 seqlen,
103 cg_iters: tl.constexpr,
104 BLOCK_S: tl.constexpr,
105):
106 """Sinkhorn backward for n_stream=4, optimized with unrolled CG."""
107 pid = tl.program_id(0)
108 seq_start = pid * BLOCK_S
109 seq_offsets = seq_start + tl.arange(0, BLOCK_S)
110 mask = seq_offsets < seqlen
112 base_out = seq_offsets * 16 # 4*4 = 16
113 base_dout = seq_offsets * 16
114 base_res = seq_offsets * 16
116 R_00 = tl.load(out_ptr + base_out + 0, mask=mask, other=0.0)
117 R_01 = tl.load(out_ptr + base_out + 1, mask=mask, other=0.0)
118 R_02 = tl.load(out_ptr + base_out + 2, mask=mask, other=0.0)
119 R_03 = tl.load(out_ptr + base_out + 3, mask=mask, other=0.0)
120 R_10 = tl.load(out_ptr + base_out + 4, mask=mask, other=0.0)
121 R_11 = tl.load(out_ptr + base_out + 5, mask=mask, other=0.0)
122 R_12 = tl.load(out_ptr + base_out + 6, mask=mask, other=0.0)
123 R_13 = tl.load(out_ptr + base_out + 7, mask=mask, other=0.0)
124 R_20 = tl.load(out_ptr + base_out + 8, mask=mask, other=0.0)
125 R_21 = tl.load(out_ptr + base_out + 9, mask=mask, other=0.0)
126 R_22 = tl.load(out_ptr + base_out + 10, mask=mask, other=0.0)
127 R_23 = tl.load(out_ptr + base_out + 11, mask=mask, other=0.0)
128 R_30 = tl.load(out_ptr + base_out + 12, mask=mask, other=0.0)
129 R_31 = tl.load(out_ptr + base_out + 13, mask=mask, other=0.0)
130 R_32 = tl.load(out_ptr + base_out + 14, mask=mask, other=0.0)
131 R_33 = tl.load(out_ptr + base_out + 15, mask=mask, other=0.0)
133 # Load dR matrix
134 dR_00 = tl.load(dout_ptr + base_dout + 0, mask=mask, other=0.0)
135 dR_01 = tl.load(dout_ptr + base_dout + 1, mask=mask, other=0.0)
136 dR_02 = tl.load(dout_ptr + base_dout + 2, mask=mask, other=0.0)
137 dR_03 = tl.load(dout_ptr + base_dout + 3, mask=mask, other=0.0)
138 dR_10 = tl.load(dout_ptr + base_dout + 4, mask=mask, other=0.0)
139 dR_11 = tl.load(dout_ptr + base_dout + 5, mask=mask, other=0.0)
140 dR_12 = tl.load(dout_ptr + base_dout + 6, mask=mask, other=0.0)
141 dR_13 = tl.load(dout_ptr + base_dout + 7, mask=mask, other=0.0)
142 dR_20 = tl.load(dout_ptr + base_dout + 8, mask=mask, other=0.0)
143 dR_21 = tl.load(dout_ptr + base_dout + 9, mask=mask, other=0.0)
144 dR_22 = tl.load(dout_ptr + base_dout + 10, mask=mask, other=0.0)
145 dR_23 = tl.load(dout_ptr + base_dout + 11, mask=mask, other=0.0)
146 dR_30 = tl.load(dout_ptr + base_dout + 12, mask=mask, other=0.0)
147 dR_31 = tl.load(dout_ptr + base_dout + 13, mask=mask, other=0.0)
148 dR_32 = tl.load(dout_ptr + base_dout + 14, mask=mask, other=0.0)
149 dR_33 = tl.load(dout_ptr + base_dout + 15, mask=mask, other=0.0)
151 # Compute RdR = R * dR (element-wise)
152 RdR_00 = R_00 * dR_00
153 RdR_01 = R_01 * dR_01
154 RdR_02 = R_02 * dR_02
155 RdR_03 = R_03 * dR_03
156 RdR_10 = R_10 * dR_10
157 RdR_11 = R_11 * dR_11
158 RdR_12 = R_12 * dR_12
159 RdR_13 = R_13 * dR_13
160 RdR_20 = R_20 * dR_20
161 RdR_21 = R_21 * dR_21
162 RdR_22 = R_22 * dR_22
163 RdR_23 = R_23 * dR_23
164 RdR_30 = R_30 * dR_30
165 RdR_31 = R_31 * dR_31
166 RdR_32 = R_32 * dR_32
167 RdR_33 = R_33 * dR_33
169 # b1 = sum(RdR, dim=-1) -> b1[i] = sum_j(RdR[i,j])
170 b1_0 = RdR_00 + RdR_01 + RdR_02 + RdR_03
171 b1_1 = RdR_10 + RdR_11 + RdR_12 + RdR_13
172 b1_2 = RdR_20 + RdR_21 + RdR_22 + RdR_23
173 b1_3 = RdR_30 + RdR_31 + RdR_32 + RdR_33
175 # b2 = sum(RdR, dim=-2) -> b2[j] = sum_i(RdR[i,j])
176 b2_0 = RdR_00 + RdR_10 + RdR_20 + RdR_30
177 b2_1 = RdR_01 + RdR_11 + RdR_21 + RdR_31
178 b2_2 = RdR_02 + RdR_12 + RdR_22 + RdR_32
179 b2_3 = RdR_03 + RdR_13 + RdR_23 + RdR_33
181 # Initialize CG: x = 0, r = b - A*x = b, p = r
182 x1_0 = tl.zeros_like(b1_0)
183 x1_1 = tl.zeros_like(b1_1)
184 x1_2 = tl.zeros_like(b1_2)
185 x1_3 = tl.zeros_like(b1_3)
186 x2_0 = tl.zeros_like(b2_0)
187 x2_1 = tl.zeros_like(b2_1)
188 x2_2 = tl.zeros_like(b2_2)
189 x2_3 = tl.zeros_like(b2_3)
191 # Compute A*x where x=0 -> r = b
192 r1_0 = b1_0
193 r1_1 = b1_1
194 r1_2 = b1_2
195 r1_3 = b1_3
196 r2_0 = b2_0
197 r2_1 = b2_1
198 r2_2 = b2_2
199 r2_3 = b2_3
201 # p = r
202 p1_0 = r1_0
203 p1_1 = r1_1
204 p1_2 = r1_2
205 p1_3 = r1_3
206 p2_0 = r2_0
207 p2_1 = r2_1
208 p2_2 = r2_2
209 p2_3 = r2_3
211 # r_normsq = dot(r, r)
212 r_normsq = (
213 r1_0 * r1_0
214 + r1_1 * r1_1
215 + r1_2 * r1_2
216 + r1_3 * r1_3
217 + r2_0 * r2_0
218 + r2_1 * r2_1
219 + r2_2 * r2_2
220 + r2_3 * r2_3
221 )
223 # CG iterations (2 * n_stream = 8 iterations for n_stream=4)
224 for _ in range(cg_iters):
225 # y1 = R @ p2 + p1
226 Ap1_0 = (R_00 * p2_0 + R_01 * p2_1 + R_02 * p2_2 + R_03 * p2_3) + p1_0
227 Ap1_1 = (R_10 * p2_0 + R_11 * p2_1 + R_12 * p2_2 + R_13 * p2_3) + p1_1
228 Ap1_2 = (R_20 * p2_0 + R_21 * p2_1 + R_22 * p2_2 + R_23 * p2_3) + p1_2
229 Ap1_3 = (R_30 * p2_0 + R_31 * p2_1 + R_32 * p2_2 + R_33 * p2_3) + p1_3
231 # y2 = R.T @ p1 + p2
232 Ap2_0 = (R_00 * p1_0 + R_10 * p1_1 + R_20 * p1_2 + R_30 * p1_3) + p2_0
233 Ap2_1 = (R_01 * p1_0 + R_11 * p1_1 + R_21 * p1_2 + R_31 * p1_3) + p2_1
234 Ap2_2 = (R_02 * p1_0 + R_12 * p1_1 + R_22 * p1_2 + R_32 * p1_3) + p2_2
235 Ap2_3 = (R_03 * p1_0 + R_13 * p1_1 + R_23 * p1_2 + R_33 * p1_3) + p2_3
237 # pAp = dot(p, Ap)
238 pAp = (
239 p1_0 * Ap1_0
240 + p1_1 * Ap1_1
241 + p1_2 * Ap1_2
242 + p1_3 * Ap1_3
243 + p2_0 * Ap2_0
244 + p2_1 * Ap2_1
245 + p2_2 * Ap2_2
246 + p2_3 * Ap2_3
247 )
249 # alpha = r_normsq / (pAp + eps)
250 alpha = r_normsq / (pAp + 1e-10)
252 # x = x + alpha * p
253 x1_0 = x1_0 + alpha * p1_0
254 x1_1 = x1_1 + alpha * p1_1
255 x1_2 = x1_2 + alpha * p1_2
256 x1_3 = x1_3 + alpha * p1_3
257 x2_0 = x2_0 + alpha * p2_0
258 x2_1 = x2_1 + alpha * p2_1
259 x2_2 = x2_2 + alpha * p2_2
260 x2_3 = x2_3 + alpha * p2_3
262 # r = r - alpha * Ap
263 r1_0 = r1_0 - alpha * Ap1_0
264 r1_1 = r1_1 - alpha * Ap1_1
265 r1_2 = r1_2 - alpha * Ap1_2
266 r1_3 = r1_3 - alpha * Ap1_3
267 r2_0 = r2_0 - alpha * Ap2_0
268 r2_1 = r2_1 - alpha * Ap2_1
269 r2_2 = r2_2 - alpha * Ap2_2
270 r2_3 = r2_3 - alpha * Ap2_3
272 # r_new_normsq = dot(r, r)
273 r_new_normsq = (
274 r1_0 * r1_0
275 + r1_1 * r1_1
276 + r1_2 * r1_2
277 + r1_3 * r1_3
278 + r2_0 * r2_0
279 + r2_1 * r2_1
280 + r2_2 * r2_2
281 + r2_3 * r2_3
282 )
284 # beta = r_new_normsq / (r_normsq + eps)
285 beta = r_new_normsq / (r_normsq + 1e-10)
287 # p = r + beta * p
288 p1_0 = r1_0 + beta * p1_0
289 p1_1 = r1_1 + beta * p1_1
290 p1_2 = r1_2 + beta * p1_2
291 p1_3 = r1_3 + beta * p1_3
292 p2_0 = r2_0 + beta * p2_0
293 p2_1 = r2_1 + beta * p2_1
294 p2_2 = r2_2 + beta * p2_2
295 p2_3 = r2_3 + beta * p2_3
297 r_normsq = r_new_normsq
299 # Compute result: res = (dR - x1 - x2) * R
300 # res[i,j] = (dR[i,j] - x1[i] - x2[j]) * R[i,j]
301 res_00 = (dR_00 - x1_0 - x2_0) * R_00
302 res_01 = (dR_01 - x1_0 - x2_1) * R_01
303 res_02 = (dR_02 - x1_0 - x2_2) * R_02
304 res_03 = (dR_03 - x1_0 - x2_3) * R_03
305 res_10 = (dR_10 - x1_1 - x2_0) * R_10
306 res_11 = (dR_11 - x1_1 - x2_1) * R_11
307 res_12 = (dR_12 - x1_1 - x2_2) * R_12
308 res_13 = (dR_13 - x1_1 - x2_3) * R_13
309 res_20 = (dR_20 - x1_2 - x2_0) * R_20
310 res_21 = (dR_21 - x1_2 - x2_1) * R_21
311 res_22 = (dR_22 - x1_2 - x2_2) * R_22
312 res_23 = (dR_23 - x1_2 - x2_3) * R_23
313 res_30 = (dR_30 - x1_3 - x2_0) * R_30
314 res_31 = (dR_31 - x1_3 - x2_1) * R_31
315 res_32 = (dR_32 - x1_3 - x2_2) * R_32
316 res_33 = (dR_33 - x1_3 - x2_3) * R_33
318 # Store results
319 tl.store(res_ptr + base_res + 0, res_00, mask=mask)
320 tl.store(res_ptr + base_res + 1, res_01, mask=mask)
321 tl.store(res_ptr + base_res + 2, res_02, mask=mask)
322 tl.store(res_ptr + base_res + 3, res_03, mask=mask)
323 tl.store(res_ptr + base_res + 4, res_10, mask=mask)
324 tl.store(res_ptr + base_res + 5, res_11, mask=mask)
325 tl.store(res_ptr + base_res + 6, res_12, mask=mask)
326 tl.store(res_ptr + base_res + 7, res_13, mask=mask)
327 tl.store(res_ptr + base_res + 8, res_20, mask=mask)
328 tl.store(res_ptr + base_res + 9, res_21, mask=mask)
329 tl.store(res_ptr + base_res + 10, res_22, mask=mask)
330 tl.store(res_ptr + base_res + 11, res_23, mask=mask)
331 tl.store(res_ptr + base_res + 12, res_30, mask=mask)
332 tl.store(res_ptr + base_res + 13, res_31, mask=mask)
333 tl.store(res_ptr + base_res + 14, res_32, mask=mask)
334 tl.store(res_ptr + base_res + 15, res_33, mask=mask)
337def mhc_bwd(
338 out: torch.Tensor,
339 dout: torch.Tensor,
340 cg_iters: int = None,
341) -> torch.Tensor:
342 """Compute Sinkhorn backward using implicit CG differentiation.
344 Args:
345 out: Sinkhorn output R, shape (seqlen, n_stream, n_stream), float32.
346 dout: Upstream gradient dR, same shape as out, float32.
347 cg_iters: Number of CG iterations. Defaults to 2 * n_stream.
349 Returns:
350 Gradient w.r.t. pre-Sinkhorn input, same shape as out.
351 """
352 assert out.shape == dout.shape, "out and dout must have same shape"
353 assert out.ndim == 3, "Expected 3D tensors (seqlen, n_stream, n_stream)"
354 assert out.shape[1] == out.shape[2], "n_stream dimensions must match"
356 seqlen, n_stream, _ = out.shape
357 if cg_iters is None:
358 cg_iters = 2 * n_stream
360 # Ensure contiguous and float32
361 out = out.contiguous().float()
362 dout = dout.contiguous().float()
364 # Allocate output
365 res = torch.empty_like(out)
367 # For n_stream=4, use optimized kernel
368 if n_stream == 4:
369 BLOCK_S = 64
370 grid = (triton.cdiv(seqlen, BLOCK_S),)
371 _mhc_bwd_kernel_n4[grid](
372 out,
373 dout,
374 res,
375 seqlen,
376 cg_iters,
377 BLOCK_S=BLOCK_S,
378 )
379 else:
380 res = mhc_bwd_ref(out, dout, cg_iters=cg_iters)
382 return res
385def mhc_bwd_ref(
386 out: torch.Tensor,
387 dout: torch.Tensor,
388 cg_iters: int = None,
389) -> torch.Tensor:
390 """PyTorch reference implementation of Sinkhorn backward via implicit CG.
392 Args:
393 out: Sinkhorn output R, shape (seqlen, n_stream, n_stream), float32.
394 dout: Upstream gradient dR, same shape as out, float32.
395 cg_iters: Number of CG iterations. Defaults to 2 * n_stream.
397 Returns:
398 Gradient w.r.t. pre-Sinkhorn input, same shape as out.
399 """
400 seqlen, n_stream, _ = out.shape
401 if cg_iters is None:
402 cg_iters = 2 * n_stream
404 R = out.float()
405 dR = dout.float()
407 # RdR = R * dR
408 RdR = R * dR
410 # b1 = sum(RdR, dim=-1), b2 = sum(RdR, dim=-2)
411 b1 = RdR.sum(dim=-1) # (seqlen, n_stream)
412 b2 = RdR.sum(dim=-2) # (seqlen, n_stream)
414 # Initialize CG
415 x1 = torch.zeros_like(b1)
416 x2 = torch.zeros_like(b2)
418 def matvec(r, x1_in, x2_in):
419 # y1[i] = sum_j(R[i,j] * x2[j]) + x1[i]
420 y1 = (r * x2_in.unsqueeze(-2)).sum(dim=-1) + x1_in
421 # y2[j] = sum_i(R[i,j] * x1[i]) + x2[j]
422 y2 = (r * x1_in.unsqueeze(-1)).sum(dim=-2) + x2_in
423 return y1, y2
425 # r = b - A*x (with x=0, r = b)
426 r1, r2 = b1.clone(), b2.clone()
427 p1, p2 = r1.clone(), r2.clone()
428 r_normsq = (r1 * r1 + r2 * r2).sum(dim=-1) # (seqlen,)
430 for _ in range(cg_iters):
431 # Ap = A * p
432 Ap1, Ap2 = matvec(R, p1, p2)
434 # pAp = dot(p, Ap)
435 pAp = (p1 * Ap1 + p2 * Ap2).sum(dim=-1) # (seqlen,)
437 # alpha = r_normsq / (pAp + eps)
438 alpha = r_normsq / (pAp + EPS)
439 alpha = alpha.unsqueeze(-1) # (seqlen, 1)
441 # x = x + alpha * p
442 x1 = x1 + alpha * p1
443 x2 = x2 + alpha * p2
445 # r = r - alpha * Ap
446 r1 = r1 - alpha * Ap1
447 r2 = r2 - alpha * Ap2
449 # r_new_normsq = dot(r, r)
450 r_new_normsq = (r1 * r1 + r2 * r2).sum(dim=-1)
452 # beta = r_new_normsq / (r_normsq + eps)
453 beta = r_new_normsq / (r_normsq + EPS)
454 beta = beta.unsqueeze(-1)
456 # p = r + beta * p
457 p1 = r1 + beta * p1
458 p2 = r2 + beta * p2
460 r_normsq = r_new_normsq
462 # res = (dR - x1 - x2) * R
463 res = (dR - x1.unsqueeze(-1) - x2.unsqueeze(-2)) * R
464 return res
467def sinkhorn_forward(
468 M: torch.Tensor, iters: int = 20
469) -> tuple[torch.Tensor, torch.Tensor]:
470 """Sinkhorn normalization forward pass.
472 Args:
473 M: Input logits, shape (..., n, n).
474 iters: Number of Sinkhorn iterations.
476 Returns:
477 (R, P) where P = exp(M) and R is the doubly-stochastic matrix.
478 """
479 P = torch.exp(M)
480 R = P.clone()
481 for _ in range(iters):
482 R = R / R.sum(-2, keepdim=True)
483 R = R / R.sum(-1, keepdim=True)
484 return R, P