Coverage for src/flag_gems/ops/ctc_loss.py: 41%

446 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13_REDUCTION_NONE = 0 

14_REDUCTION_MEAN = 1 

15_REDUCTION_SUM = 2 

16_LENGTH_STATS_CACHE = {} 

17_LENGTH_STATS_CACHE_LIMIT = 256 

18 

19 

20if hasattr(tl, "debug_barrier"): 

21 _debug_barrier = tl.debug_barrier 

22else: 

23 

24 @triton.jit 

25 def _debug_barrier(): 

26 return 

27 

28 

29@triton.jit 

30def _logaddexp(a, b): 

31 max_ab = tl.maximum(a, b) 

32 min_ab = tl.minimum(a, b) 

33 return tl.where( 

34 max_ab == -float("inf"), 

35 -float("inf"), 

36 max_ab + tl.log(1.0 + tl.exp(min_ab - max_ab)), 

37 ) 

38 

39 

40@triton.jit 

41def _logaddexp3(a, b, c, use_c): 

42 c = tl.where(use_c, c, -float("inf")) 

43 max_abc = tl.maximum(tl.maximum(a, b), c) 

44 safe_max = tl.where(max_abc == -float("inf"), 0.0, max_abc) 

45 exp_sum = tl.exp(a - safe_max) + tl.exp(b - safe_max) + tl.exp(c - safe_max) 

46 return tl.where( 

47 max_abc == -float("inf"), 

48 -float("inf"), 

49 max_abc + tl.log(exp_sum), 

50 ) 

51 

52 

53@libentry() 

54@triton.jit 

55def _ctc_loss_forward_kernel( 

56 log_probs, 

57 targets, 

58 input_lengths, 

59 target_lengths, 

60 target_offsets, 

61 neg_log_likelihood, 

62 log_alpha, 

63 T: tl.constexpr, 

64 N: tl.constexpr, 

65 C: tl.constexpr, 

66 MAX_TARGET: tl.constexpr, 

67 STATE_COUNT_MAX: tl.constexpr, 

68 BLANK: tl.constexpr, 

69 TARGET_1D: tl.constexpr, 

70 BLOCK_S: tl.constexpr, 

71): 

72 batch = tl.program_id(0) 

73 states = tl.arange(0, BLOCK_S) 

74 

75 input_len = tl.load(input_lengths + batch) 

76 target_len = tl.load(target_lengths + batch) 

77 state_count = target_len * 2 + 1 

78 valid_state = states < state_count 

79 stored_state = states < STATE_COUNT_MAX 

80 

81 is_blank_state = (states % 2) == 0 

82 target_index = (states - 1) // 2 

83 target_mask = (target_index >= 0) & (target_index < target_len) 

84 target_safe_index = tl.where(target_mask, target_index, 0) 

85 

86 if TARGET_1D: 

87 target_base = tl.full((), 0, tl.int64) 

88 for prev_batch in tl.range(0, N): 

89 target_base += tl.load( 

90 target_lengths + prev_batch, 

91 mask=prev_batch < batch, 

92 other=0, 

93 ) 

94 target_origin = target_base 

95 target_ptrs = targets + target_origin + target_safe_index 

96 else: 

97 target_origin = batch * MAX_TARGET 

98 target_ptrs = targets + target_origin + target_safe_index 

99 

100 target_value = tl.load(target_ptrs, mask=target_mask, other=BLANK) 

101 labels = tl.where(is_blank_state, BLANK, target_value) 

102 

103 t0_active = input_len > 0 

104 init_state = (states == 0) | ((states == 1) & (target_len > 0)) 

105 init_logp = tl.load( 

106 log_probs + batch * C + labels, 

107 mask=init_state & stored_state & t0_active, 

108 other=0.0, 

109 ).to(tl.float32) 

110 alpha = tl.where(init_state & valid_state & t0_active, init_logp, -float("inf")) 

111 tl.store( 

112 log_alpha + batch * T * STATE_COUNT_MAX + states, 

113 alpha, 

114 mask=stored_state, 

115 ) 

116 _debug_barrier() 

117 

118 for t in tl.range(1, T): 

119 prev_base = log_alpha + batch * T * STATE_COUNT_MAX + (t - 1) * STATE_COUNT_MAX 

120 prev0 = tl.load(prev_base + states, mask=stored_state, other=-float("inf")).to( 

121 tl.float32 

122 ) 

123 prev1 = tl.load( 

124 prev_base + tl.where(states > 0, states - 1, 0), 

125 mask=(states > 0) & stored_state, 

126 other=-float("inf"), 

127 ).to(tl.float32) 

128 prev2 = tl.load( 

129 prev_base + tl.where(states > 1, states - 2, 0), 

130 mask=(states > 1) & stored_state, 

131 other=-float("inf"), 

132 ).to(tl.float32) 

133 

134 prev_target_index = tl.where(target_index > 0, target_index - 1, 0) 

135 prev_target_value = tl.load( 

136 targets + target_origin + prev_target_index, 

137 mask=target_mask & (target_index > 0), 

138 other=BLANK, 

139 ) 

140 skip_allowed = ( 

141 (~is_blank_state) & (target_index > 0) & (target_value != prev_target_value) 

142 ) 

143 

144 acc = _logaddexp3(prev0, prev1, prev2, skip_allowed) 

145 

146 logp = tl.load( 

147 log_probs + t * N * C + batch * C + labels, 

148 mask=valid_state & (t < input_len), 

149 other=0.0, 

150 ).to(tl.float32) 

151 alpha = tl.where(valid_state & (t < input_len), acc + logp, -float("inf")) 

152 tl.store( 

153 log_alpha + batch * T * STATE_COUNT_MAX + t * STATE_COUNT_MAX + states, 

154 alpha, 

155 mask=stored_state, 

156 ) 

157 _debug_barrier() 

158 

159 if input_len <= 0: 

160 loss = tl.where(target_len == 0, 0.0, float("inf")) 

161 else: 

162 _debug_barrier() 

163 final_base = ( 

164 log_alpha + batch * T * STATE_COUNT_MAX + (input_len - 1) * STATE_COUNT_MAX 

165 ) 

166 last = tl.load(final_base + state_count - 1).to(tl.float32) 

167 prev_last = tl.load( 

168 final_base + tl.where(target_len > 0, state_count - 2, 0), 

169 mask=target_len > 0, 

170 other=-float("inf"), 

171 ).to(tl.float32) 

172 log_likelihood = _logaddexp(last, prev_last) 

173 loss = -log_likelihood 

174 

175 tl.store(neg_log_likelihood + batch, loss) 

176 

177 

178@libentry() 

179@triton.jit 

180def _ctc_loss_forward_no_grad_kernel( 

181 log_probs, 

182 targets, 

183 input_lengths, 

184 target_lengths, 

185 target_offsets, 

186 neg_log_likelihood, 

187 scratch_alpha, 

188 T: tl.constexpr, 

189 N: tl.constexpr, 

190 C: tl.constexpr, 

191 MAX_TARGET: tl.constexpr, 

192 STATE_COUNT_MAX: tl.constexpr, 

193 BLANK: tl.constexpr, 

194 TARGET_1D: tl.constexpr, 

195 BLOCK_S: tl.constexpr, 

196): 

197 batch = tl.program_id(0) 

198 states = tl.arange(0, BLOCK_S) 

199 

200 target_len = tl.load(target_lengths + batch) 

201 state_count = target_len * 2 + 1 

202 valid_state = states < state_count 

203 stored_state = states < STATE_COUNT_MAX 

204 

205 is_blank_state = (states % 2) == 0 

206 target_index = (states - 1) // 2 

207 target_mask = (target_index >= 0) & (target_index < target_len) 

208 target_safe_index = tl.where(target_mask, target_index, 0) 

209 

210 if TARGET_1D: 

211 target_origin = tl.load(target_offsets + batch) 

212 target_ptrs = targets + target_origin + target_safe_index 

213 else: 

214 target_origin = batch * MAX_TARGET 

215 target_ptrs = targets + target_origin + target_safe_index 

216 

217 target_value = tl.load(target_ptrs, mask=target_mask, other=BLANK) 

218 labels = tl.where(is_blank_state, BLANK, target_value) 

219 

220 input_len = tl.load(input_lengths + batch) 

221 init_state = (states == 0) | ((states == 1) & (target_len > 0)) 

222 init_logp = tl.load( 

223 log_probs + batch * C + labels, 

224 mask=init_state & stored_state & (input_len > 0), 

225 other=0.0, 

226 ).to(tl.float32) 

227 alpha = tl.where( 

228 init_state & valid_state & (input_len > 0), init_logp, -float("inf") 

229 ) 

230 scratch_batch = scratch_alpha + batch * 2 * STATE_COUNT_MAX 

231 tl.store(scratch_batch + states, alpha, mask=stored_state) 

232 _debug_barrier() 

233 

234 for t in tl.range(1, T): 

235 prev_base = scratch_batch + ((t - 1) % 2) * STATE_COUNT_MAX 

236 cur_base = scratch_batch + (t % 2) * STATE_COUNT_MAX 

237 prev0 = tl.load(prev_base + states, mask=stored_state, other=-float("inf")).to( 

238 tl.float32 

239 ) 

240 prev1 = tl.load( 

241 prev_base + tl.where(states > 0, states - 1, 0), 

242 mask=(states > 0) & stored_state, 

243 other=-float("inf"), 

244 ).to(tl.float32) 

245 prev2 = tl.load( 

246 prev_base + tl.where(states > 1, states - 2, 0), 

247 mask=(states > 1) & stored_state, 

248 other=-float("inf"), 

249 ).to(tl.float32) 

250 

251 prev_target_index = tl.where(target_index > 0, target_index - 1, 0) 

252 prev_target_value = tl.load( 

253 targets + target_origin + prev_target_index, 

254 mask=target_mask & (target_index > 0), 

255 other=BLANK, 

256 ) 

257 skip_allowed = ( 

258 (~is_blank_state) & (target_index > 0) & (target_value != prev_target_value) 

259 ) 

260 

261 acc = _logaddexp3(prev0, prev1, prev2, skip_allowed) 

262 logp = tl.load( 

263 log_probs + t * N * C + batch * C + labels, 

264 mask=valid_state & (t < input_len), 

265 other=0.0, 

266 ).to(tl.float32) 

267 alpha = tl.where(valid_state & (t < input_len), acc + logp, -float("inf")) 

268 tl.store(cur_base + states, alpha, mask=stored_state & (t < input_len)) 

269 _debug_barrier() 

270 

271 if input_len <= 0: 

272 loss = tl.where(target_len == 0, 0.0, float("inf")) 

273 else: 

274 _debug_barrier() 

275 final_base = scratch_batch + ((input_len - 1) % 2) * STATE_COUNT_MAX 

276 last = tl.load(final_base + state_count - 1).to(tl.float32) 

277 prev_last = tl.load( 

278 final_base + tl.where(target_len > 0, state_count - 2, 0), 

279 mask=target_len > 0, 

280 other=-float("inf"), 

281 ).to(tl.float32) 

282 loss = -_logaddexp(last, prev_last) 

283 

284 tl.store(neg_log_likelihood + batch, loss) 

285 

286 

287@libentry() 

288@triton.jit 

289def _ctc_loss_forward_full_length_reduce_kernel( 

290 log_probs, 

291 targets, 

292 target_lengths, 

293 target_offsets, 

294 contrib, 

295 scratch_alpha, 

296 T: tl.constexpr, 

297 N: tl.constexpr, 

298 C: tl.constexpr, 

299 MAX_TARGET: tl.constexpr, 

300 STATE_COUNT_MAX: tl.constexpr, 

301 BLANK: tl.constexpr, 

302 TARGET_1D: tl.constexpr, 

303 REDUCTION: tl.constexpr, 

304 BLOCK_S: tl.constexpr, 

305): 

306 batch = tl.program_id(0) 

307 states = tl.arange(0, BLOCK_S) 

308 

309 target_len = tl.load(target_lengths + batch) 

310 state_count = target_len * 2 + 1 

311 valid_state = states < state_count 

312 stored_state = states < STATE_COUNT_MAX 

313 

314 is_blank_state = (states % 2) == 0 

315 target_index = (states - 1) // 2 

316 target_mask = (target_index >= 0) & (target_index < target_len) 

317 target_safe_index = tl.where(target_mask, target_index, 0) 

318 

319 if TARGET_1D: 

320 target_origin = tl.load(target_offsets + batch) 

321 target_ptrs = targets + target_origin + target_safe_index 

322 else: 

323 target_origin = batch * MAX_TARGET 

324 target_ptrs = targets + target_origin + target_safe_index 

325 

326 target_value = tl.load(target_ptrs, mask=target_mask, other=BLANK) 

327 labels = tl.where(is_blank_state, BLANK, target_value) 

328 

329 init_state = (states == 0) | ((states == 1) & (target_len > 0)) 

330 init_logp = tl.load( 

331 log_probs + batch * C + labels, 

332 mask=init_state & stored_state, 

333 other=0.0, 

334 ).to(tl.float32) 

335 alpha = tl.where(init_state & valid_state, init_logp, -float("inf")) 

336 scratch_batch = scratch_alpha + batch * 2 * STATE_COUNT_MAX 

337 tl.store(scratch_batch + states, alpha, mask=stored_state) 

338 _debug_barrier() 

339 

340 for t in tl.range(1, T): 

341 prev_base = scratch_batch + ((t - 1) % 2) * STATE_COUNT_MAX 

342 cur_base = scratch_batch + (t % 2) * STATE_COUNT_MAX 

343 prev0 = tl.load(prev_base + states, mask=stored_state, other=-float("inf")).to( 

344 tl.float32 

345 ) 

346 prev1 = tl.load( 

347 prev_base + tl.where(states > 0, states - 1, 0), 

348 mask=(states > 0) & stored_state, 

349 other=-float("inf"), 

350 ).to(tl.float32) 

351 prev2 = tl.load( 

352 prev_base + tl.where(states > 1, states - 2, 0), 

353 mask=(states > 1) & stored_state, 

354 other=-float("inf"), 

355 ).to(tl.float32) 

356 

357 prev_target_index = tl.where(target_index > 0, target_index - 1, 0) 

358 prev_target_value = tl.load( 

359 targets + target_origin + prev_target_index, 

360 mask=target_mask & (target_index > 0), 

361 other=BLANK, 

362 ) 

363 skip_allowed = ( 

364 (~is_blank_state) & (target_index > 0) & (target_value != prev_target_value) 

365 ) 

366 

367 acc = _logaddexp3(prev0, prev1, prev2, skip_allowed) 

368 logp = tl.load( 

369 log_probs + t * N * C + batch * C + labels, 

370 mask=valid_state, 

371 other=0.0, 

372 ).to(tl.float32) 

373 alpha = tl.where(valid_state, acc + logp, -float("inf")) 

374 tl.store(cur_base + states, alpha, mask=stored_state) 

375 _debug_barrier() 

376 

377 if T <= 0: 

378 loss = tl.where(target_len == 0, 0.0, float("inf")) 

379 else: 

380 _debug_barrier() 

381 final_base = scratch_batch + ((T - 1) % 2) * STATE_COUNT_MAX 

382 last = tl.load(final_base + state_count - 1).to(tl.float32) 

383 prev_last = tl.load( 

384 final_base + tl.where(target_len > 0, state_count - 2, 0), 

385 mask=target_len > 0, 

386 other=-float("inf"), 

387 ).to(tl.float32) 

388 loss = -_logaddexp(last, prev_last) 

389 

390 if REDUCTION == 1: 

391 loss = loss / tl.maximum(target_len, 1).to(tl.float32) / N 

392 tl.store(contrib + batch, loss) 

393 

394 

395@libentry() 

396@triton.jit 

397def _ctc_loss_init_grad_kernel( 

398 log_probs, 

399 input_lengths, 

400 target_lengths, 

401 neg_log_likelihood, 

402 grad_output, 

403 grad_input, 

404 total: tl.constexpr, 

405 T: tl.constexpr, 

406 N: tl.constexpr, 

407 C: tl.constexpr, 

408 REDUCTION: tl.constexpr, 

409 ZERO_INFINITY: tl.constexpr, 

410 BLOCK: tl.constexpr, 

411): 

412 offsets = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) 

413 mask = offsets < total 

414 batch = (offsets // C) % N 

415 t = offsets // (N * C) 

416 

417 input_len = tl.load(input_lengths + batch, mask=mask, other=0) 

418 target_len = tl.load(target_lengths + batch, mask=mask, other=1) 

419 nll = tl.load(neg_log_likelihood + batch, mask=mask, other=0.0).to(tl.float32) 

420 

421 if REDUCTION == 0: 

422 scale = tl.load(grad_output + batch, mask=mask, other=0.0).to(tl.float32) 

423 else: 

424 scale = tl.load(grad_output).to(tl.float32) 

425 if REDUCTION == 1: 

426 denom = tl.maximum(target_len, 1).to(tl.float32) * N 

427 scale = scale / denom 

428 

429 if ZERO_INFINITY: 

430 scale = tl.where(nll == float("inf"), 0.0, scale) 

431 

432 logp = tl.load(log_probs + offsets, mask=mask, other=-float("inf")).to(tl.float32) 

433 grad = tl.where((t < input_len) & mask, tl.exp(logp) * scale, 0.0) 

434 nan_grad = float("nan") 

435 grad = tl.where( 

436 (t < input_len) & mask & (scale != 0.0) & (logp == -float("inf")), 

437 nan_grad, 

438 grad, 

439 ) 

440 if not ZERO_INFINITY: 

441 grad = tl.where((t < input_len) & mask & (nll == float("inf")), nan_grad, grad) 

442 tl.store(grad_input + offsets, grad, mask=mask) 

443 

444 

445@libentry() 

446@triton.jit 

447def _ctc_loss_backward_kernel( 

448 log_probs, 

449 targets, 

450 input_lengths, 

451 target_lengths, 

452 target_offsets, 

453 neg_log_likelihood, 

454 log_alpha, 

455 grad_output, 

456 grad_input, 

457 scratch_beta, 

458 T: tl.constexpr, 

459 N: tl.constexpr, 

460 C: tl.constexpr, 

461 MAX_TARGET: tl.constexpr, 

462 STATE_COUNT_MAX: tl.constexpr, 

463 BLANK: tl.constexpr, 

464 TARGET_1D: tl.constexpr, 

465 REDUCTION: tl.constexpr, 

466 ZERO_INFINITY: tl.constexpr, 

467 BLOCK_S: tl.constexpr, 

468): 

469 batch = tl.program_id(0) 

470 states = tl.arange(0, BLOCK_S) 

471 

472 input_len = tl.load(input_lengths + batch) 

473 target_len = tl.load(target_lengths + batch) 

474 nll = tl.load(neg_log_likelihood + batch).to(tl.float32) 

475 state_count = target_len * 2 + 1 

476 valid_state = states < state_count 

477 stored_state = states < STATE_COUNT_MAX 

478 

479 is_blank_state = (states % 2) == 0 

480 target_index = (states - 1) // 2 

481 target_mask = (target_index >= 0) & (target_index < target_len) 

482 target_safe_index = tl.where(target_mask, target_index, 0) 

483 

484 if TARGET_1D: 

485 target_origin = tl.load(target_offsets + batch) 

486 target_ptrs = targets + target_origin + target_safe_index 

487 else: 

488 target_origin = batch * MAX_TARGET 

489 target_ptrs = targets + target_origin + target_safe_index 

490 

491 target_value = tl.load(target_ptrs, mask=target_mask, other=BLANK) 

492 labels = tl.where(is_blank_state, BLANK, target_value) 

493 

494 state1 = states + 1 

495 is_blank_state1 = (state1 % 2) == 0 

496 target_index1 = (state1 - 1) // 2 

497 target_mask1 = (target_index1 >= 0) & (target_index1 < target_len) 

498 target_safe_index1 = tl.where(target_mask1, target_index1, 0) 

499 target_ptrs1 = targets + target_origin + target_safe_index1 

500 target_value1 = tl.load(target_ptrs1, mask=target_mask1, other=BLANK) 

501 labels1 = tl.where(is_blank_state1, BLANK, target_value1) 

502 

503 state2 = states + 2 

504 target_index2 = (state2 - 1) // 2 

505 target_mask2 = (target_index2 >= 0) & (target_index2 < target_len) 

506 target_safe_index2 = tl.where(target_mask2, target_index2, 0) 

507 target_ptrs2 = targets + target_origin + target_safe_index2 

508 target_value2 = tl.load(target_ptrs2, mask=target_mask2, other=BLANK) 

509 labels2 = target_value2 

510 

511 if REDUCTION == 0: 

512 scale = tl.load(grad_output + batch).to(tl.float32) 

513 else: 

514 scale = tl.load(grad_output).to(tl.float32) 

515 if REDUCTION == 1: 

516 denom = tl.maximum(target_len, 1).to(tl.float32) * N 

517 scale = scale / denom 

518 

519 if ZERO_INFINITY: 

520 scale = tl.where(nll == float("inf"), 0.0, scale) 

521 

522 beta_init = tl.where( 

523 ((states == state_count - 1) | ((states == state_count - 2) & (target_len > 0))) 

524 & valid_state 

525 & (input_len > 0), 

526 0.0, 

527 -float("inf"), 

528 ) 

529 scratch_batch = scratch_beta + batch * 2 * STATE_COUNT_MAX 

530 tl.store(scratch_batch + states, beta_init, mask=stored_state) 

531 _debug_barrier() 

532 log_likelihood = tl.where(scale != 0.0, -nll, 0.0) 

533 

534 for step in tl.range(0, T): 

535 t = input_len - 1 - step 

536 active = t >= 0 

537 safe_t = tl.where(active, t, 0) 

538 beta_base = scratch_batch + (step % 2) * STATE_COUNT_MAX 

539 next_beta_base = scratch_batch + ((step + 1) % 2) * STATE_COUNT_MAX 

540 beta = tl.load(beta_base + states, mask=stored_state, other=-float("inf")).to( 

541 tl.float32 

542 ) 

543 

544 alpha_t = tl.load( 

545 log_alpha + batch * T * STATE_COUNT_MAX + safe_t * STATE_COUNT_MAX + states, 

546 mask=active & stored_state, 

547 other=-float("inf"), 

548 ).to(tl.float32) 

549 log_post = alpha_t + beta - log_likelihood 

550 posterior = tl.where( 

551 active & valid_state & (scale != 0.0), 

552 tl.exp(log_post), 

553 0.0, 

554 ) 

555 tl.atomic_add( 

556 grad_input + safe_t * N * C + batch * C + labels, 

557 -scale * posterior, 

558 sem="relaxed", 

559 mask=active & valid_state & stored_state, 

560 ) 

561 

562 stay = beta + tl.load( 

563 log_probs + safe_t * N * C + batch * C + labels, 

564 mask=active & valid_state, 

565 other=-float("inf"), 

566 ).to(tl.float32) 

567 next1 = tl.load( 

568 beta_base + states + 1, 

569 mask=(states + 1 < state_count) & stored_state, 

570 other=-float("inf"), 

571 ).to(tl.float32) + tl.load( 

572 log_probs + safe_t * N * C + batch * C + labels1, 

573 mask=active & (states + 1 < state_count) & stored_state, 

574 other=-float("inf"), 

575 ).to( 

576 tl.float32 

577 ) 

578 skip_allowed = ( 

579 (~is_blank_state) 

580 & (states + 2 < state_count) 

581 & (target_value != target_value2) 

582 ) 

583 next2 = tl.load( 

584 beta_base + states + 2, 

585 mask=(states + 2 < state_count) & stored_state, 

586 other=-float("inf"), 

587 ).to(tl.float32) + tl.load( 

588 log_probs + safe_t * N * C + batch * C + labels2, 

589 mask=active & skip_allowed & stored_state, 

590 other=-float("inf"), 

591 ).to( 

592 tl.float32 

593 ) 

594 

595 beta_next = _logaddexp3(stay, next1, next2, skip_allowed) 

596 tl.store( 

597 next_beta_base + states, 

598 tl.where(active, beta_next, -float("inf")), 

599 mask=stored_state, 

600 ) 

601 _debug_barrier() 

602 

603 

604def _reduction_enum(reduction): 

605 if isinstance(reduction, str): 

606 if reduction == "none": 

607 return _REDUCTION_NONE 

608 if reduction == "mean": 

609 return _REDUCTION_MEAN 

610 if reduction == "sum": 

611 return _REDUCTION_SUM 

612 raise ValueError( 

613 "ctc_loss expected reduction to be one of 'none', 'mean', or 'sum', " 

614 f"but got {reduction!r}" 

615 ) 

616 return int(reduction) 

617 

618 

619_INTEGRAL_DTYPES = { 

620 torch.uint8, 

621 torch.int8, 

622 torch.int16, 

623 torch.int32, 

624 torch.int64, 

625} 

626 

627 

628def _is_integral_dtype(dtype): 

629 return dtype in _INTEGRAL_DTYPES 

630 

631 

632def _lengths_to_tensor(lengths, device, name): 

633 if torch.is_tensor(lengths): 

634 if not _is_integral_dtype(lengths.dtype): 

635 raise RuntimeError(f"{name} must be integral") 

636 out = lengths.to(device=device) 

637 else: 

638 out = torch.tensor(lengths, device=device) 

639 if not _is_integral_dtype(out.dtype): 

640 raise RuntimeError(f"{name} must be integral") 

641 if out.dtype != torch.long: 

642 out = out.to(dtype=torch.long) 

643 return out.reshape(1) if out.ndim == 0 else out.reshape(-1).contiguous() 

644 

645 

646def _length_stats(lengths): 

647 key = None 

648 if torch.is_tensor(lengths): 

649 key = ( 

650 lengths.device.type, 

651 lengths.device.index, 

652 lengths.data_ptr(), 

653 lengths.numel(), 

654 lengths._version, 

655 ) 

656 cached = _LENGTH_STATS_CACHE.get(key) 

657 if cached is not None: 

658 return cached[1] 

659 

660 stats_tensor = torch.stack((lengths.min(), lengths.max(), lengths.sum())).cpu() 

661 stats = tuple(int(value) for value in stats_tensor.tolist()) 

662 if key is not None: 

663 if len(_LENGTH_STATS_CACHE) >= _LENGTH_STATS_CACHE_LIMIT: 

664 _LENGTH_STATS_CACHE.clear() 

665 _LENGTH_STATS_CACHE[key] = (lengths, stats) 

666 return stats 

667 

668 

669def _compute_dtype(dtype): 

670 if dtype in (torch.float16, torch.bfloat16): 

671 return torch.float32 

672 return dtype 

673 

674 

675class CtcLossFunction(torch.autograd.Function): 

676 @staticmethod 

677 def forward( 

678 ctx, 

679 log_probs, 

680 targets, 

681 input_lengths, 

682 target_lengths, 

683 blank=0, 

684 reduction="mean", 

685 zero_infinity=False, 

686 ): 

687 reduction = _reduction_enum(reduction) 

688 if reduction not in (_REDUCTION_NONE, _REDUCTION_MEAN, _REDUCTION_SUM): 

689 raise ValueError(f"ctc_loss got invalid reduction enum {reduction}") 

690 

691 if log_probs.ndim not in (2, 3): 

692 raise RuntimeError( 

693 "ctc_loss expects log_probs to be a 2D or 3D tensor, " 

694 f"but got {log_probs.ndim}D" 

695 ) 

696 if not torch.is_floating_point(log_probs): 

697 raise RuntimeError(f'"ctc_loss" not implemented for {log_probs.dtype}') 

698 if blank < 0 or blank >= log_probs.shape[-1]: 

699 raise RuntimeError("blank must be in label range") 

700 

701 original_dtype = log_probs.dtype 

702 compute_dtype = _compute_dtype(original_dtype) 

703 unbatched = log_probs.ndim == 2 

704 batch_size = 1 if unbatched else log_probs.shape[1] 

705 

706 work_log_probs = log_probs.unsqueeze(1) if unbatched else log_probs 

707 work_log_probs = work_log_probs.contiguous() 

708 if work_log_probs.dtype != compute_dtype: 

709 work_log_probs = work_log_probs.to(compute_dtype) 

710 

711 if torch.is_floating_point(targets): 

712 work_targets = targets.to(dtype=torch.long).contiguous() 

713 elif _is_integral_dtype(targets.dtype): 

714 work_targets = targets.contiguous() 

715 else: 

716 raise RuntimeError("ctc_loss targets must be integral or floating point") 

717 work_input_lengths = _lengths_to_tensor( 

718 input_lengths, log_probs.device, "input_lengths" 

719 ) 

720 work_target_lengths = _lengths_to_tensor( 

721 target_lengths, log_probs.device, "target_lengths" 

722 ) 

723 if work_input_lengths.numel() != batch_size: 

724 raise RuntimeError( 

725 f"ctc_loss expected input_lengths to have size {batch_size}, " 

726 f"but got {work_input_lengths.numel()}" 

727 ) 

728 if work_target_lengths.numel() != batch_size: 

729 raise RuntimeError( 

730 f"ctc_loss expected target_lengths to have size {batch_size}, " 

731 f"but got {work_target_lengths.numel()}" 

732 ) 

733 min_input_length, max_input_length, _ = _length_stats(work_input_lengths) 

734 min_target_length, max_target, total_target_length = _length_stats( 

735 work_target_lengths 

736 ) 

737 if min_input_length < 0 or max_input_length > work_log_probs.shape[0]: 

738 raise RuntimeError("ctc_loss input_lengths must be in [0, T]") 

739 if min_target_length < 0: 

740 raise RuntimeError("ctc_loss target_lengths must be non-negative") 

741 

742 state_count_max = 2 * max_target + 1 

743 target_stride = max_target 

744 if work_targets.ndim == 1: 

745 target_1d = True 

746 if total_target_length != work_targets.numel(): 

747 raise RuntimeError( 

748 "ctc_loss expected concatenated targets length to equal " 

749 "sum(target_lengths)" 

750 ) 

751 work_target_offsets = ( 

752 work_target_lengths.cumsum(0) - work_target_lengths 

753 ).contiguous() 

754 elif work_targets.ndim == 2: 

755 target_1d = False 

756 if max_target > work_targets.shape[1]: 

757 raise RuntimeError( 

758 "ctc_loss target_lengths cannot exceed padded target width" 

759 ) 

760 target_stride = work_targets.shape[1] 

761 work_target_offsets = work_target_lengths 

762 else: 

763 raise RuntimeError( 

764 "ctc_loss expects targets to be a 1D concatenated tensor or a " 

765 f"2D padded tensor, but got {work_targets.ndim}D" 

766 ) 

767 

768 needs_log_probs_grad = ctx.needs_input_grad[0] 

769 block_s = triton.next_power_of_2(state_count_max) 

770 

771 if not needs_log_probs_grad: 

772 if ( 

773 not unbatched 

774 and not zero_infinity 

775 and reduction in (_REDUCTION_MEAN, _REDUCTION_SUM) 

776 and min_input_length == work_log_probs.shape[0] 

777 and work_log_probs.shape[0] > 0 

778 ): 

779 contrib = torch.empty( 

780 (batch_size,), dtype=torch.float32, device=log_probs.device 

781 ) 

782 scratch_alpha = torch.empty( 

783 (batch_size, 2, state_count_max), 

784 dtype=torch.float32, 

785 device=log_probs.device, 

786 ) 

787 with torch_device_fn.device(log_probs.device): 

788 _ctc_loss_forward_full_length_reduce_kernel[(batch_size,)]( 

789 work_log_probs, 

790 work_targets, 

791 work_target_lengths, 

792 work_target_offsets, 

793 contrib, 

794 scratch_alpha, 

795 work_log_probs.shape[0], 

796 batch_size, 

797 work_log_probs.shape[2], 

798 target_stride, 

799 state_count_max, 

800 blank, 

801 target_1d, 

802 reduction, 

803 block_s, 

804 ) 

805 output = contrib.sum() 

806 if output.dtype != original_dtype: 

807 output = output.to(original_dtype) 

808 return output 

809 

810 raw_neg_log_likelihood = torch.empty( 

811 (batch_size,), dtype=torch.float32, device=log_probs.device 

812 ) 

813 scratch_alpha = torch.empty( 

814 (batch_size, 2, state_count_max), 

815 dtype=torch.float32, 

816 device=log_probs.device, 

817 ) 

818 with torch_device_fn.device(log_probs.device): 

819 _ctc_loss_forward_no_grad_kernel[(batch_size,)]( 

820 work_log_probs, 

821 work_targets, 

822 work_input_lengths, 

823 work_target_lengths, 

824 work_target_offsets, 

825 raw_neg_log_likelihood, 

826 scratch_alpha, 

827 work_log_probs.shape[0], 

828 batch_size, 

829 work_log_probs.shape[2], 

830 target_stride, 

831 state_count_max, 

832 blank, 

833 target_1d, 

834 block_s, 

835 ) 

836 neg_log_likelihood = raw_neg_log_likelihood 

837 if zero_infinity: 

838 neg_log_likelihood = torch.where( 

839 torch.isinf(neg_log_likelihood), 

840 torch.zeros( 

841 (), dtype=neg_log_likelihood.dtype, device=log_probs.device 

842 ), 

843 neg_log_likelihood, 

844 ) 

845 

846 if reduction == _REDUCTION_NONE: 

847 output = neg_log_likelihood 

848 if unbatched: 

849 output = output.squeeze(0) 

850 elif reduction == _REDUCTION_SUM: 

851 output = neg_log_likelihood.sum() 

852 else: 

853 denom = work_target_lengths.clamp_min(1) 

854 output = (neg_log_likelihood / denom).mean() 

855 

856 if output.dtype != original_dtype: 

857 output = output.to(original_dtype) 

858 return output 

859 

860 raw_neg_log_likelihood = torch.empty( 

861 (batch_size,), dtype=torch.float32, device=log_probs.device 

862 ) 

863 

864 log_alpha = torch.empty( 

865 (batch_size, work_log_probs.shape[0], state_count_max), 

866 dtype=torch.float32, 

867 device=log_probs.device, 

868 ) 

869 with torch_device_fn.device(log_probs.device): 

870 _ctc_loss_forward_kernel[(batch_size,)]( 

871 work_log_probs, 

872 work_targets, 

873 work_input_lengths, 

874 work_target_lengths, 

875 work_target_offsets, 

876 raw_neg_log_likelihood, 

877 log_alpha, 

878 work_log_probs.shape[0], 

879 batch_size, 

880 work_log_probs.shape[2], 

881 target_stride, 

882 state_count_max, 

883 blank, 

884 target_1d, 

885 block_s, 

886 ) 

887 neg_log_likelihood = raw_neg_log_likelihood 

888 if zero_infinity: 

889 neg_log_likelihood = torch.where( 

890 torch.isinf(neg_log_likelihood), 

891 torch.zeros( 

892 (), dtype=neg_log_likelihood.dtype, device=log_probs.device 

893 ), 

894 neg_log_likelihood, 

895 ) 

896 

897 if reduction == _REDUCTION_NONE: 

898 output = neg_log_likelihood 

899 if unbatched: 

900 output = output.squeeze(0) 

901 if output.dtype != original_dtype: 

902 output = output.to(original_dtype) 

903 elif reduction == _REDUCTION_SUM: 

904 output = neg_log_likelihood.sum() 

905 else: 

906 denom = work_target_lengths.clamp_min(1) 

907 output = (neg_log_likelihood / denom).mean() 

908 

909 if output.dtype != original_dtype: 

910 output = output.to(original_dtype) 

911 

912 ctx.save_for_backward( 

913 work_log_probs, 

914 work_targets, 

915 work_input_lengths, 

916 work_target_lengths, 

917 work_target_offsets, 

918 raw_neg_log_likelihood, 

919 log_alpha, 

920 ) 

921 ctx.blank = blank 

922 ctx.reduction = reduction 

923 ctx.zero_infinity = zero_infinity 

924 ctx.unbatched = unbatched 

925 ctx.batch_size = batch_size 

926 ctx.original_dtype = original_dtype 

927 ctx.max_target = target_stride 

928 ctx.state_count_max = state_count_max 

929 ctx.target_1d = target_1d 

930 

931 return output 

932 

933 @staticmethod 

934 def backward(ctx, grad_output): 

935 ( 

936 work_log_probs, 

937 work_targets, 

938 work_input_lengths, 

939 work_target_lengths, 

940 work_target_offsets, 

941 neg_log_likelihood, 

942 log_alpha, 

943 ) = ctx.saved_tensors 

944 

945 grad_output = grad_output.contiguous() 

946 

947 grad_log_probs = torch.empty_like(work_log_probs) 

948 total = work_log_probs.numel() 

949 block = 256 

950 with torch_device_fn.device(work_log_probs.device): 

951 _ctc_loss_init_grad_kernel[(triton.cdiv(total, block),)]( 

952 work_log_probs, 

953 work_input_lengths, 

954 work_target_lengths, 

955 neg_log_likelihood, 

956 grad_output, 

957 grad_log_probs, 

958 total, 

959 work_log_probs.shape[0], 

960 ctx.batch_size, 

961 work_log_probs.shape[2], 

962 ctx.reduction, 

963 ctx.zero_infinity, 

964 block, 

965 ) 

966 

967 scratch_beta = torch.empty( 

968 (ctx.batch_size, 2, ctx.state_count_max), 

969 dtype=torch.float32, 

970 device=work_log_probs.device, 

971 ) 

972 block_s = triton.next_power_of_2(ctx.state_count_max) 

973 _ctc_loss_backward_kernel[(ctx.batch_size,)]( 

974 work_log_probs, 

975 work_targets, 

976 work_input_lengths, 

977 work_target_lengths, 

978 work_target_offsets, 

979 neg_log_likelihood, 

980 log_alpha, 

981 grad_output, 

982 grad_log_probs, 

983 scratch_beta, 

984 work_log_probs.shape[0], 

985 ctx.batch_size, 

986 work_log_probs.shape[2], 

987 ctx.max_target, 

988 ctx.state_count_max, 

989 ctx.blank, 

990 ctx.target_1d, 

991 ctx.reduction, 

992 ctx.zero_infinity, 

993 block_s, 

994 ) 

995 

996 if ctx.unbatched: 

997 grad_log_probs = grad_log_probs.squeeze(1) 

998 if grad_log_probs.dtype != ctx.original_dtype: 

999 grad_log_probs = grad_log_probs.to(ctx.original_dtype) 

1000 

1001 return grad_log_probs, None, None, None, None, None, None 

1002 

1003 

1004def ctc_loss( 

1005 log_probs, 

1006 targets, 

1007 input_lengths, 

1008 target_lengths, 

1009 blank=0, 

1010 reduction="mean", 

1011 zero_infinity=False, 

1012): 

1013 logger.debug("GEMS CTC LOSS") 

1014 return CtcLossFunction.apply( 

1015 log_probs, 

1016 targets, 

1017 input_lengths, 

1018 target_lengths, 

1019 blank, 

1020 reduction, 

1021 zero_infinity, 

1022 )