Coverage for src/flag_gems/fused/mhc/mhc_bwd.py: 29%

249 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1""" 

2Triton implementation of mHC Backward (Sinkhorn implicit CG differentiation). 

3 

4This kernel computes the gradient of the Sinkhorn normalization using 

5implicit differentiation via the conjugate gradient method. 

6 

7Algorithm: 

8Given R = Sinkhorn(M) and upstream gradient dR, we solve for dM using: 

91. Compute b1 = sum(R * dR, dim=-1), b2 = sum(R * dR, dim=-2) 

102. Solve the linear system A*x = b using CG where A is the Sinkhorn Jacobian 

113. Result: dM = (dR - x1 - x2) * R 

12""" 

13 

14import torch 

15import triton 

16import triton.language as tl 

17 

18EPS = 1e-10 

19 

20 

21def _get_autotune_configs(): 

22 """Generate autotune configurations for different tile sizes and warps.""" 

23 configs = [] 

24 for TILE_SIZE in [1, 2, 4, 8, 16, 32]: 

25 for num_warps in [1, 2, 4, 8]: 

26 configs.append(triton.Config({"TILE_SIZE": TILE_SIZE}, num_warps=num_warps)) 

27 return configs 

28 

29 

30@triton.autotune( 

31 configs=_get_autotune_configs(), 

32 key=["seqlen", "n_stream"], 

33) 

34@triton.jit 

35def _mhc_bwd_kernel( 

36 # Pointers to tensors 

37 out_ptr, # (seqlen, n_stream, n_stream), float32 - Sinkhorn output R 

38 dout_ptr, # (seqlen, n_stream, n_stream), float32 - upstream gradient dR 

39 res_ptr, # (seqlen, n_stream, n_stream), float32 - result dM 

40 # Dimensions 

41 seqlen, 

42 n_stream, 

43 # Strides 

44 out_stride_s, 

45 out_stride_i, 

46 out_stride_j, 

47 dout_stride_s, 

48 dout_stride_i, 

49 dout_stride_j, 

50 res_stride_s, 

51 res_stride_i, 

52 res_stride_j, 

53 # Number of CG iterations 

54 cg_iters: tl.constexpr, 

55 # Constants 

56 TILE_SIZE: tl.constexpr, 

57 N_STREAM: tl.constexpr, 

58): 

59 """Sinkhorn backward via implicit CG differentiation - one tile per program.""" 

60 pid = tl.program_id(0) 

61 tile_start = pid * TILE_SIZE 

62 

63 for t in range(TILE_SIZE): 

64 seq_idx = tile_start + t 

65 if seq_idx >= seqlen: 

66 continue 

67 

68 base_out = seq_idx * out_stride_s 

69 base_dout = seq_idx * dout_stride_s 

70 

71 for i in range(N_STREAM): 

72 for j in range(N_STREAM): 

73 r_val = tl.load( 

74 out_ptr + base_out + i * out_stride_i + j * out_stride_j 

75 ) 

76 dr_val = tl.load( 

77 dout_ptr + base_dout + i * dout_stride_i + j * dout_stride_j 

78 ) 

79 

80 base_res = seq_idx * res_stride_s 

81 

82 for i in range(N_STREAM): 

83 for j in range(N_STREAM): 

84 r_val = tl.load( 

85 out_ptr + base_out + i * out_stride_i + j * out_stride_j 

86 ) 

87 dr_val = tl.load( 

88 dout_ptr + base_dout + i * dout_stride_i + j * dout_stride_j 

89 ) 

90 tl.store( 

91 res_ptr + base_res + i * res_stride_i + j * res_stride_j, 

92 dr_val * r_val, 

93 ) 

94 

95 

96@triton.jit 

97def _mhc_bwd_kernel_n4( 

98 # Pointers to tensors 

99 out_ptr, # (seqlen, 4, 4), float32 - Sinkhorn output R 

100 dout_ptr, # (seqlen, 4, 4), float32 - upstream gradient dR 

101 res_ptr, # (seqlen, 4, 4), float32 - result dM 

102 seqlen, 

103 cg_iters: tl.constexpr, 

104 BLOCK_S: tl.constexpr, 

105): 

106 """Sinkhorn backward for n_stream=4, optimized with unrolled CG.""" 

107 pid = tl.program_id(0) 

108 seq_start = pid * BLOCK_S 

109 seq_offsets = seq_start + tl.arange(0, BLOCK_S) 

110 mask = seq_offsets < seqlen 

111 

112 base_out = seq_offsets * 16 # 4*4 = 16 

113 base_dout = seq_offsets * 16 

114 base_res = seq_offsets * 16 

115 

116 R_00 = tl.load(out_ptr + base_out + 0, mask=mask, other=0.0) 

117 R_01 = tl.load(out_ptr + base_out + 1, mask=mask, other=0.0) 

118 R_02 = tl.load(out_ptr + base_out + 2, mask=mask, other=0.0) 

119 R_03 = tl.load(out_ptr + base_out + 3, mask=mask, other=0.0) 

120 R_10 = tl.load(out_ptr + base_out + 4, mask=mask, other=0.0) 

121 R_11 = tl.load(out_ptr + base_out + 5, mask=mask, other=0.0) 

122 R_12 = tl.load(out_ptr + base_out + 6, mask=mask, other=0.0) 

123 R_13 = tl.load(out_ptr + base_out + 7, mask=mask, other=0.0) 

124 R_20 = tl.load(out_ptr + base_out + 8, mask=mask, other=0.0) 

125 R_21 = tl.load(out_ptr + base_out + 9, mask=mask, other=0.0) 

126 R_22 = tl.load(out_ptr + base_out + 10, mask=mask, other=0.0) 

127 R_23 = tl.load(out_ptr + base_out + 11, mask=mask, other=0.0) 

128 R_30 = tl.load(out_ptr + base_out + 12, mask=mask, other=0.0) 

129 R_31 = tl.load(out_ptr + base_out + 13, mask=mask, other=0.0) 

130 R_32 = tl.load(out_ptr + base_out + 14, mask=mask, other=0.0) 

131 R_33 = tl.load(out_ptr + base_out + 15, mask=mask, other=0.0) 

132 

133 # Load dR matrix 

134 dR_00 = tl.load(dout_ptr + base_dout + 0, mask=mask, other=0.0) 

135 dR_01 = tl.load(dout_ptr + base_dout + 1, mask=mask, other=0.0) 

136 dR_02 = tl.load(dout_ptr + base_dout + 2, mask=mask, other=0.0) 

137 dR_03 = tl.load(dout_ptr + base_dout + 3, mask=mask, other=0.0) 

138 dR_10 = tl.load(dout_ptr + base_dout + 4, mask=mask, other=0.0) 

139 dR_11 = tl.load(dout_ptr + base_dout + 5, mask=mask, other=0.0) 

140 dR_12 = tl.load(dout_ptr + base_dout + 6, mask=mask, other=0.0) 

141 dR_13 = tl.load(dout_ptr + base_dout + 7, mask=mask, other=0.0) 

142 dR_20 = tl.load(dout_ptr + base_dout + 8, mask=mask, other=0.0) 

143 dR_21 = tl.load(dout_ptr + base_dout + 9, mask=mask, other=0.0) 

144 dR_22 = tl.load(dout_ptr + base_dout + 10, mask=mask, other=0.0) 

145 dR_23 = tl.load(dout_ptr + base_dout + 11, mask=mask, other=0.0) 

146 dR_30 = tl.load(dout_ptr + base_dout + 12, mask=mask, other=0.0) 

147 dR_31 = tl.load(dout_ptr + base_dout + 13, mask=mask, other=0.0) 

148 dR_32 = tl.load(dout_ptr + base_dout + 14, mask=mask, other=0.0) 

149 dR_33 = tl.load(dout_ptr + base_dout + 15, mask=mask, other=0.0) 

150 

151 # Compute RdR = R * dR (element-wise) 

152 RdR_00 = R_00 * dR_00 

153 RdR_01 = R_01 * dR_01 

154 RdR_02 = R_02 * dR_02 

155 RdR_03 = R_03 * dR_03 

156 RdR_10 = R_10 * dR_10 

157 RdR_11 = R_11 * dR_11 

158 RdR_12 = R_12 * dR_12 

159 RdR_13 = R_13 * dR_13 

160 RdR_20 = R_20 * dR_20 

161 RdR_21 = R_21 * dR_21 

162 RdR_22 = R_22 * dR_22 

163 RdR_23 = R_23 * dR_23 

164 RdR_30 = R_30 * dR_30 

165 RdR_31 = R_31 * dR_31 

166 RdR_32 = R_32 * dR_32 

167 RdR_33 = R_33 * dR_33 

168 

169 # b1 = sum(RdR, dim=-1) -> b1[i] = sum_j(RdR[i,j]) 

170 b1_0 = RdR_00 + RdR_01 + RdR_02 + RdR_03 

171 b1_1 = RdR_10 + RdR_11 + RdR_12 + RdR_13 

172 b1_2 = RdR_20 + RdR_21 + RdR_22 + RdR_23 

173 b1_3 = RdR_30 + RdR_31 + RdR_32 + RdR_33 

174 

175 # b2 = sum(RdR, dim=-2) -> b2[j] = sum_i(RdR[i,j]) 

176 b2_0 = RdR_00 + RdR_10 + RdR_20 + RdR_30 

177 b2_1 = RdR_01 + RdR_11 + RdR_21 + RdR_31 

178 b2_2 = RdR_02 + RdR_12 + RdR_22 + RdR_32 

179 b2_3 = RdR_03 + RdR_13 + RdR_23 + RdR_33 

180 

181 # Initialize CG: x = 0, r = b - A*x = b, p = r 

182 x1_0 = tl.zeros_like(b1_0) 

183 x1_1 = tl.zeros_like(b1_1) 

184 x1_2 = tl.zeros_like(b1_2) 

185 x1_3 = tl.zeros_like(b1_3) 

186 x2_0 = tl.zeros_like(b2_0) 

187 x2_1 = tl.zeros_like(b2_1) 

188 x2_2 = tl.zeros_like(b2_2) 

189 x2_3 = tl.zeros_like(b2_3) 

190 

191 # Compute A*x where x=0 -> r = b 

192 r1_0 = b1_0 

193 r1_1 = b1_1 

194 r1_2 = b1_2 

195 r1_3 = b1_3 

196 r2_0 = b2_0 

197 r2_1 = b2_1 

198 r2_2 = b2_2 

199 r2_3 = b2_3 

200 

201 # p = r 

202 p1_0 = r1_0 

203 p1_1 = r1_1 

204 p1_2 = r1_2 

205 p1_3 = r1_3 

206 p2_0 = r2_0 

207 p2_1 = r2_1 

208 p2_2 = r2_2 

209 p2_3 = r2_3 

210 

211 # r_normsq = dot(r, r) 

212 r_normsq = ( 

213 r1_0 * r1_0 

214 + r1_1 * r1_1 

215 + r1_2 * r1_2 

216 + r1_3 * r1_3 

217 + r2_0 * r2_0 

218 + r2_1 * r2_1 

219 + r2_2 * r2_2 

220 + r2_3 * r2_3 

221 ) 

222 

223 # CG iterations (2 * n_stream = 8 iterations for n_stream=4) 

224 for _ in range(cg_iters): 

225 # y1 = R @ p2 + p1 

226 Ap1_0 = (R_00 * p2_0 + R_01 * p2_1 + R_02 * p2_2 + R_03 * p2_3) + p1_0 

227 Ap1_1 = (R_10 * p2_0 + R_11 * p2_1 + R_12 * p2_2 + R_13 * p2_3) + p1_1 

228 Ap1_2 = (R_20 * p2_0 + R_21 * p2_1 + R_22 * p2_2 + R_23 * p2_3) + p1_2 

229 Ap1_3 = (R_30 * p2_0 + R_31 * p2_1 + R_32 * p2_2 + R_33 * p2_3) + p1_3 

230 

231 # y2 = R.T @ p1 + p2 

232 Ap2_0 = (R_00 * p1_0 + R_10 * p1_1 + R_20 * p1_2 + R_30 * p1_3) + p2_0 

233 Ap2_1 = (R_01 * p1_0 + R_11 * p1_1 + R_21 * p1_2 + R_31 * p1_3) + p2_1 

234 Ap2_2 = (R_02 * p1_0 + R_12 * p1_1 + R_22 * p1_2 + R_32 * p1_3) + p2_2 

235 Ap2_3 = (R_03 * p1_0 + R_13 * p1_1 + R_23 * p1_2 + R_33 * p1_3) + p2_3 

236 

237 # pAp = dot(p, Ap) 

238 pAp = ( 

239 p1_0 * Ap1_0 

240 + p1_1 * Ap1_1 

241 + p1_2 * Ap1_2 

242 + p1_3 * Ap1_3 

243 + p2_0 * Ap2_0 

244 + p2_1 * Ap2_1 

245 + p2_2 * Ap2_2 

246 + p2_3 * Ap2_3 

247 ) 

248 

249 # alpha = r_normsq / (pAp + eps) 

250 alpha = r_normsq / (pAp + 1e-10) 

251 

252 # x = x + alpha * p 

253 x1_0 = x1_0 + alpha * p1_0 

254 x1_1 = x1_1 + alpha * p1_1 

255 x1_2 = x1_2 + alpha * p1_2 

256 x1_3 = x1_3 + alpha * p1_3 

257 x2_0 = x2_0 + alpha * p2_0 

258 x2_1 = x2_1 + alpha * p2_1 

259 x2_2 = x2_2 + alpha * p2_2 

260 x2_3 = x2_3 + alpha * p2_3 

261 

262 # r = r - alpha * Ap 

263 r1_0 = r1_0 - alpha * Ap1_0 

264 r1_1 = r1_1 - alpha * Ap1_1 

265 r1_2 = r1_2 - alpha * Ap1_2 

266 r1_3 = r1_3 - alpha * Ap1_3 

267 r2_0 = r2_0 - alpha * Ap2_0 

268 r2_1 = r2_1 - alpha * Ap2_1 

269 r2_2 = r2_2 - alpha * Ap2_2 

270 r2_3 = r2_3 - alpha * Ap2_3 

271 

272 # r_new_normsq = dot(r, r) 

273 r_new_normsq = ( 

274 r1_0 * r1_0 

275 + r1_1 * r1_1 

276 + r1_2 * r1_2 

277 + r1_3 * r1_3 

278 + r2_0 * r2_0 

279 + r2_1 * r2_1 

280 + r2_2 * r2_2 

281 + r2_3 * r2_3 

282 ) 

283 

284 # beta = r_new_normsq / (r_normsq + eps) 

285 beta = r_new_normsq / (r_normsq + 1e-10) 

286 

287 # p = r + beta * p 

288 p1_0 = r1_0 + beta * p1_0 

289 p1_1 = r1_1 + beta * p1_1 

290 p1_2 = r1_2 + beta * p1_2 

291 p1_3 = r1_3 + beta * p1_3 

292 p2_0 = r2_0 + beta * p2_0 

293 p2_1 = r2_1 + beta * p2_1 

294 p2_2 = r2_2 + beta * p2_2 

295 p2_3 = r2_3 + beta * p2_3 

296 

297 r_normsq = r_new_normsq 

298 

299 # Compute result: res = (dR - x1 - x2) * R 

300 # res[i,j] = (dR[i,j] - x1[i] - x2[j]) * R[i,j] 

301 res_00 = (dR_00 - x1_0 - x2_0) * R_00 

302 res_01 = (dR_01 - x1_0 - x2_1) * R_01 

303 res_02 = (dR_02 - x1_0 - x2_2) * R_02 

304 res_03 = (dR_03 - x1_0 - x2_3) * R_03 

305 res_10 = (dR_10 - x1_1 - x2_0) * R_10 

306 res_11 = (dR_11 - x1_1 - x2_1) * R_11 

307 res_12 = (dR_12 - x1_1 - x2_2) * R_12 

308 res_13 = (dR_13 - x1_1 - x2_3) * R_13 

309 res_20 = (dR_20 - x1_2 - x2_0) * R_20 

310 res_21 = (dR_21 - x1_2 - x2_1) * R_21 

311 res_22 = (dR_22 - x1_2 - x2_2) * R_22 

312 res_23 = (dR_23 - x1_2 - x2_3) * R_23 

313 res_30 = (dR_30 - x1_3 - x2_0) * R_30 

314 res_31 = (dR_31 - x1_3 - x2_1) * R_31 

315 res_32 = (dR_32 - x1_3 - x2_2) * R_32 

316 res_33 = (dR_33 - x1_3 - x2_3) * R_33 

317 

318 # Store results 

319 tl.store(res_ptr + base_res + 0, res_00, mask=mask) 

320 tl.store(res_ptr + base_res + 1, res_01, mask=mask) 

321 tl.store(res_ptr + base_res + 2, res_02, mask=mask) 

322 tl.store(res_ptr + base_res + 3, res_03, mask=mask) 

323 tl.store(res_ptr + base_res + 4, res_10, mask=mask) 

324 tl.store(res_ptr + base_res + 5, res_11, mask=mask) 

325 tl.store(res_ptr + base_res + 6, res_12, mask=mask) 

326 tl.store(res_ptr + base_res + 7, res_13, mask=mask) 

327 tl.store(res_ptr + base_res + 8, res_20, mask=mask) 

328 tl.store(res_ptr + base_res + 9, res_21, mask=mask) 

329 tl.store(res_ptr + base_res + 10, res_22, mask=mask) 

330 tl.store(res_ptr + base_res + 11, res_23, mask=mask) 

331 tl.store(res_ptr + base_res + 12, res_30, mask=mask) 

332 tl.store(res_ptr + base_res + 13, res_31, mask=mask) 

333 tl.store(res_ptr + base_res + 14, res_32, mask=mask) 

334 tl.store(res_ptr + base_res + 15, res_33, mask=mask) 

335 

336 

337def mhc_bwd( 

338 out: torch.Tensor, 

339 dout: torch.Tensor, 

340 cg_iters: int = None, 

341) -> torch.Tensor: 

342 """Compute Sinkhorn backward using implicit CG differentiation. 

343 

344 Args: 

345 out: Sinkhorn output R, shape (seqlen, n_stream, n_stream), float32. 

346 dout: Upstream gradient dR, same shape as out, float32. 

347 cg_iters: Number of CG iterations. Defaults to 2 * n_stream. 

348 

349 Returns: 

350 Gradient w.r.t. pre-Sinkhorn input, same shape as out. 

351 """ 

352 assert out.shape == dout.shape, "out and dout must have same shape" 

353 assert out.ndim == 3, "Expected 3D tensors (seqlen, n_stream, n_stream)" 

354 assert out.shape[1] == out.shape[2], "n_stream dimensions must match" 

355 

356 seqlen, n_stream, _ = out.shape 

357 if cg_iters is None: 

358 cg_iters = 2 * n_stream 

359 

360 # Ensure contiguous and float32 

361 out = out.contiguous().float() 

362 dout = dout.contiguous().float() 

363 

364 # Allocate output 

365 res = torch.empty_like(out) 

366 

367 # For n_stream=4, use optimized kernel 

368 if n_stream == 4: 

369 BLOCK_S = 64 

370 grid = (triton.cdiv(seqlen, BLOCK_S),) 

371 _mhc_bwd_kernel_n4[grid]( 

372 out, 

373 dout, 

374 res, 

375 seqlen, 

376 cg_iters, 

377 BLOCK_S=BLOCK_S, 

378 ) 

379 else: 

380 res = mhc_bwd_ref(out, dout, cg_iters=cg_iters) 

381 

382 return res 

383 

384 

385def mhc_bwd_ref( 

386 out: torch.Tensor, 

387 dout: torch.Tensor, 

388 cg_iters: int = None, 

389) -> torch.Tensor: 

390 """PyTorch reference implementation of Sinkhorn backward via implicit CG. 

391 

392 Args: 

393 out: Sinkhorn output R, shape (seqlen, n_stream, n_stream), float32. 

394 dout: Upstream gradient dR, same shape as out, float32. 

395 cg_iters: Number of CG iterations. Defaults to 2 * n_stream. 

396 

397 Returns: 

398 Gradient w.r.t. pre-Sinkhorn input, same shape as out. 

399 """ 

400 seqlen, n_stream, _ = out.shape 

401 if cg_iters is None: 

402 cg_iters = 2 * n_stream 

403 

404 R = out.float() 

405 dR = dout.float() 

406 

407 # RdR = R * dR 

408 RdR = R * dR 

409 

410 # b1 = sum(RdR, dim=-1), b2 = sum(RdR, dim=-2) 

411 b1 = RdR.sum(dim=-1) # (seqlen, n_stream) 

412 b2 = RdR.sum(dim=-2) # (seqlen, n_stream) 

413 

414 # Initialize CG 

415 x1 = torch.zeros_like(b1) 

416 x2 = torch.zeros_like(b2) 

417 

418 def matvec(r, x1_in, x2_in): 

419 # y1[i] = sum_j(R[i,j] * x2[j]) + x1[i] 

420 y1 = (r * x2_in.unsqueeze(-2)).sum(dim=-1) + x1_in 

421 # y2[j] = sum_i(R[i,j] * x1[i]) + x2[j] 

422 y2 = (r * x1_in.unsqueeze(-1)).sum(dim=-2) + x2_in 

423 return y1, y2 

424 

425 # r = b - A*x (with x=0, r = b) 

426 r1, r2 = b1.clone(), b2.clone() 

427 p1, p2 = r1.clone(), r2.clone() 

428 r_normsq = (r1 * r1 + r2 * r2).sum(dim=-1) # (seqlen,) 

429 

430 for _ in range(cg_iters): 

431 # Ap = A * p 

432 Ap1, Ap2 = matvec(R, p1, p2) 

433 

434 # pAp = dot(p, Ap) 

435 pAp = (p1 * Ap1 + p2 * Ap2).sum(dim=-1) # (seqlen,) 

436 

437 # alpha = r_normsq / (pAp + eps) 

438 alpha = r_normsq / (pAp + EPS) 

439 alpha = alpha.unsqueeze(-1) # (seqlen, 1) 

440 

441 # x = x + alpha * p 

442 x1 = x1 + alpha * p1 

443 x2 = x2 + alpha * p2 

444 

445 # r = r - alpha * Ap 

446 r1 = r1 - alpha * Ap1 

447 r2 = r2 - alpha * Ap2 

448 

449 # r_new_normsq = dot(r, r) 

450 r_new_normsq = (r1 * r1 + r2 * r2).sum(dim=-1) 

451 

452 # beta = r_new_normsq / (r_normsq + eps) 

453 beta = r_new_normsq / (r_normsq + EPS) 

454 beta = beta.unsqueeze(-1) 

455 

456 # p = r + beta * p 

457 p1 = r1 + beta * p1 

458 p2 = r2 + beta * p2 

459 

460 r_normsq = r_new_normsq 

461 

462 # res = (dR - x1 - x2) * R 

463 res = (dR - x1.unsqueeze(-1) - x2.unsqueeze(-2)) * R 

464 return res 

465 

466 

467def sinkhorn_forward( 

468 M: torch.Tensor, iters: int = 20 

469) -> tuple[torch.Tensor, torch.Tensor]: 

470 """Sinkhorn normalization forward pass. 

471 

472 Args: 

473 M: Input logits, shape (..., n, n). 

474 iters: Number of Sinkhorn iterations. 

475 

476 Returns: 

477 (R, P) where P = exp(M) and R is the doubly-stochastic matrix. 

478 """ 

479 P = torch.exp(M) 

480 R = P.clone() 

481 for _ in range(iters): 

482 R = R / R.sum(-2, keepdim=True) 

483 R = R / R.sum(-1, keepdim=True) 

484 return R, P