Coverage for src/flag_gems/runtime/backend/_sunrise/ops/ctc_loss.py: 0%

446 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13_REDUCTION_NONE = 0 

14_REDUCTION_MEAN = 1 

15_REDUCTION_SUM = 2 

16_LENGTH_STATS_CACHE = {} 

17_LENGTH_STATS_CACHE_LIMIT = 256 

18 

19 

20if hasattr(tl, "debug_barrier"): 

21 _debug_barrier = tl.debug_barrier 

22else: 

23 

24 @triton.jit 

25 def _debug_barrier(): 

26 return 

27 

28 

29@triton.jit 

30def _logaddexp(a, b): 

31 max_ab = tl.maximum(a, b) 

32 min_ab = tl.minimum(a, b) 

33 return tl.where( 

34 max_ab == -float("inf"), 

35 -float("inf"), 

36 max_ab + tl.log(1.0 + tl.exp(min_ab - max_ab)), 

37 ) 

38 

39 

40@triton.jit 

41def _logaddexp3(a, b, c, use_c): 

42 c = tl.where(use_c, c, -float("inf")) 

43 max_abc = tl.maximum(tl.maximum(a, b), c) 

44 safe_max = tl.where(max_abc == -float("inf"), 0.0, max_abc) 

45 exp_sum = tl.exp(a - safe_max) + tl.exp(b - safe_max) + tl.exp(c - safe_max) 

46 return tl.where( 

47 max_abc == -float("inf"), 

48 -float("inf"), 

49 max_abc + tl.log(exp_sum), 

50 ) 

51 

52 

53@libentry() 

54@triton.jit 

55def _ctc_loss_forward_kernel( 

56 log_probs, 

57 targets, 

58 input_lengths, 

59 target_lengths, 

60 target_offsets, 

61 neg_log_likelihood, 

62 log_alpha, 

63 T: tl.constexpr, 

64 N: tl.constexpr, 

65 C: tl.constexpr, 

66 MAX_TARGET: tl.constexpr, 

67 STATE_COUNT_MAX: tl.constexpr, 

68 BLANK: tl.constexpr, 

69 TARGET_1D: tl.constexpr, 

70 BLOCK_S: tl.constexpr, 

71): 

72 batch = tl.program_id(0) 

73 states = tl.arange(0, BLOCK_S) 

74 

75 input_len = tl.load(input_lengths + batch) 

76 target_len = tl.load(target_lengths + batch) 

77 state_count = target_len * 2 + 1 

78 valid_state = states < state_count 

79 stored_state = states < STATE_COUNT_MAX 

80 

81 is_blank_state = (states % 2) == 0 

82 target_index = (states - 1) // 2 

83 target_mask = (target_index >= 0) & (target_index < target_len) 

84 target_safe_index = tl.where(target_mask, target_index, 0) 

85 

86 if TARGET_1D: 

87 target_base = tl.full((), 0, tl.int64) 

88 for prev_batch in tl.range(0, N): 

89 target_base += tl.load( 

90 target_lengths + prev_batch, 

91 mask=prev_batch < batch, 

92 other=0, 

93 ) 

94 target_origin = target_base 

95 target_ptrs = targets + target_origin + target_safe_index 

96 else: 

97 target_origin = batch * MAX_TARGET 

98 target_ptrs = targets + target_origin + target_safe_index 

99 

100 target_value = tl.load(target_ptrs, mask=target_mask, other=BLANK) 

101 labels = tl.where(is_blank_state, BLANK, target_value) 

102 

103 t0_active = input_len > 0 

104 init_state = (states == 0) | ((states == 1) & (target_len > 0)) 

105 init_logp = tl.load( 

106 log_probs + batch * C + labels, 

107 mask=init_state & stored_state & t0_active, 

108 other=0.0, 

109 ).to(tl.float32) 

110 alpha = tl.where(init_state & valid_state & t0_active, init_logp, -float("inf")) 

111 tl.store( 

112 log_alpha + batch * T * STATE_COUNT_MAX + states, 

113 alpha, 

114 mask=stored_state, 

115 ) 

116 _debug_barrier() 

117 

118 for t in tl.range(1, T): 

119 prev_base = log_alpha + batch * T * STATE_COUNT_MAX + (t - 1) * STATE_COUNT_MAX 

120 prev0 = tl.load(prev_base + states, mask=stored_state, other=-float("inf")).to( 

121 tl.float32 

122 ) 

123 prev1 = tl.load( 

124 prev_base + tl.where(states > 0, states - 1, 0), 

125 mask=(states > 0) & stored_state, 

126 other=-float("inf"), 

127 ).to(tl.float32) 

128 prev2 = tl.load( 

129 prev_base + tl.where(states > 1, states - 2, 0), 

130 mask=(states > 1) & stored_state, 

131 other=-float("inf"), 

132 ).to(tl.float32) 

133 

134 prev_target_index = tl.where(target_index > 0, target_index - 1, 0) 

135 prev_target_value = tl.load( 

136 targets + target_origin + prev_target_index, 

137 mask=target_mask & (target_index > 0), 

138 other=BLANK, 

139 ) 

140 skip_allowed = ( 

141 (~is_blank_state) & (target_index > 0) & (target_value != prev_target_value) 

142 ) 

143 

144 acc = _logaddexp3(prev0, prev1, prev2, skip_allowed) 

145 

146 logp = tl.load( 

147 log_probs + t * N * C + batch * C + labels, 

148 mask=valid_state & (t < input_len), 

149 other=0.0, 

150 ).to(tl.float32) 

151 alpha = tl.where(valid_state & (t < input_len), acc + logp, -float("inf")) 

152 tl.store( 

153 log_alpha + batch * T * STATE_COUNT_MAX + t * STATE_COUNT_MAX + states, 

154 alpha, 

155 mask=stored_state, 

156 ) 

157 _debug_barrier() 

158 

159 if input_len <= 0: 

160 loss = tl.where(target_len == 0, 0.0, float("inf")) 

161 else: 

162 _debug_barrier() 

163 final_base = ( 

164 log_alpha + batch * T * STATE_COUNT_MAX + (input_len - 1) * STATE_COUNT_MAX 

165 ) 

166 last = tl.load(final_base + state_count - 1).to(tl.float32) 

167 prev_last = tl.load( 

168 final_base + tl.where(target_len > 0, state_count - 2, 0), 

169 mask=target_len > 0, 

170 other=-float("inf"), 

171 ).to(tl.float32) 

172 log_likelihood = _logaddexp(last, prev_last) 

173 loss = -log_likelihood 

174 

175 tl.store(neg_log_likelihood + batch, loss) 

176 

177 

178@libentry() 

179@triton.jit 

180def _ctc_loss_forward_no_grad_kernel( 

181 log_probs, 

182 targets, 

183 input_lengths, 

184 target_lengths, 

185 target_offsets, 

186 neg_log_likelihood, 

187 scratch_alpha, 

188 T: tl.constexpr, 

189 N: tl.constexpr, 

190 C: tl.constexpr, 

191 MAX_TARGET: tl.constexpr, 

192 STATE_COUNT_MAX: tl.constexpr, 

193 BLANK: tl.constexpr, 

194 TARGET_1D: tl.constexpr, 

195 BLOCK_S: tl.constexpr, 

196): 

197 batch = tl.program_id(0) 

198 states = tl.arange(0, BLOCK_S) 

199 

200 target_len = tl.load(target_lengths + batch) 

201 state_count = target_len * 2 + 1 

202 valid_state = states < state_count 

203 stored_state = states < STATE_COUNT_MAX 

204 

205 is_blank_state = (states % 2) == 0 

206 target_index = (states - 1) // 2 

207 target_mask = (target_index >= 0) & (target_index < target_len) 

208 target_safe_index = tl.where(target_mask, target_index, 0) 

209 

210 if TARGET_1D: 

211 target_origin = tl.load(target_offsets + batch) 

212 target_ptrs = targets + target_origin + target_safe_index 

213 else: 

214 target_origin = batch * MAX_TARGET 

215 target_ptrs = targets + target_origin + target_safe_index 

216 

217 target_value = tl.load(target_ptrs, mask=target_mask, other=BLANK) 

218 labels = tl.where(is_blank_state, BLANK, target_value) 

219 

220 input_len = tl.load(input_lengths + batch) 

221 init_state = (states == 0) | ((states == 1) & (target_len > 0)) 

222 init_logp = tl.load( 

223 log_probs + batch * C + labels, 

224 mask=init_state & stored_state & (input_len > 0), 

225 other=0.0, 

226 ).to(tl.float32) 

227 alpha = tl.where( 

228 init_state & valid_state & (input_len > 0), init_logp, -float("inf") 

229 ) 

230 scratch_batch = scratch_alpha + batch * 2 * STATE_COUNT_MAX 

231 tl.store(scratch_batch + states, alpha, mask=stored_state) 

232 _debug_barrier() 

233 

234 for t in tl.range(1, T): 

235 prev_base = scratch_batch + ((t - 1) % 2) * STATE_COUNT_MAX 

236 cur_base = scratch_batch + (t % 2) * STATE_COUNT_MAX 

237 prev0 = tl.load(prev_base + states, mask=stored_state, other=-float("inf")).to( 

238 tl.float32 

239 ) 

240 prev1 = tl.load( 

241 prev_base + tl.where(states > 0, states - 1, 0), 

242 mask=(states > 0) & stored_state, 

243 other=-float("inf"), 

244 ).to(tl.float32) 

245 prev2 = tl.load( 

246 prev_base + tl.where(states > 1, states - 2, 0), 

247 mask=(states > 1) & stored_state, 

248 other=-float("inf"), 

249 ).to(tl.float32) 

250 

251 prev_target_index = tl.where(target_index > 0, target_index - 1, 0) 

252 prev_target_value = tl.load( 

253 targets + target_origin + prev_target_index, 

254 mask=target_mask & (target_index > 0), 

255 other=BLANK, 

256 ) 

257 skip_allowed = ( 

258 (~is_blank_state) & (target_index > 0) & (target_value != prev_target_value) 

259 ) 

260 

261 acc = _logaddexp3(prev0, prev1, prev2, skip_allowed) 

262 logp = tl.load( 

263 log_probs + t * N * C + batch * C + labels, 

264 mask=valid_state & (t < input_len), 

265 other=0.0, 

266 ).to(tl.float32) 

267 alpha = tl.where(valid_state & (t < input_len), acc + logp, -float("inf")) 

268 tl.store(cur_base + states, alpha, mask=stored_state & (t < input_len)) 

269 _debug_barrier() 

270 

271 if input_len <= 0: 

272 loss = tl.where(target_len == 0, 0.0, float("inf")) 

273 else: 

274 _debug_barrier() 

275 final_base = scratch_batch + ((input_len - 1) % 2) * STATE_COUNT_MAX 

276 last = tl.load(final_base + state_count - 1).to(tl.float32) 

277 prev_last = tl.load( 

278 final_base + tl.where(target_len > 0, state_count - 2, 0), 

279 mask=target_len > 0, 

280 other=-float("inf"), 

281 ).to(tl.float32) 

282 loss = -_logaddexp(last, prev_last) 

283 

284 tl.store(neg_log_likelihood + batch, loss) 

285 

286 

287@libentry() 

288@triton.jit 

289def _ctc_loss_forward_full_length_reduce_kernel( 

290 log_probs, 

291 targets, 

292 target_lengths, 

293 target_offsets, 

294 contrib, 

295 scratch_alpha, 

296 T: tl.constexpr, 

297 N: tl.constexpr, 

298 C: tl.constexpr, 

299 MAX_TARGET: tl.constexpr, 

300 STATE_COUNT_MAX: tl.constexpr, 

301 BLANK: tl.constexpr, 

302 TARGET_1D: tl.constexpr, 

303 REDUCTION: tl.constexpr, 

304 BLOCK_S: tl.constexpr, 

305): 

306 batch = tl.program_id(0) 

307 states = tl.arange(0, BLOCK_S) 

308 

309 target_len = tl.load(target_lengths + batch) 

310 state_count = target_len * 2 + 1 

311 valid_state = states < state_count 

312 stored_state = states < STATE_COUNT_MAX 

313 

314 is_blank_state = (states % 2) == 0 

315 target_index = (states - 1) // 2 

316 target_mask = (target_index >= 0) & (target_index < target_len) 

317 target_safe_index = tl.where(target_mask, target_index, 0) 

318 

319 if TARGET_1D: 

320 target_origin = tl.load(target_offsets + batch) 

321 target_ptrs = targets + target_origin + target_safe_index 

322 else: 

323 target_origin = batch * MAX_TARGET 

324 target_ptrs = targets + target_origin + target_safe_index 

325 

326 target_value = tl.load(target_ptrs, mask=target_mask, other=BLANK) 

327 labels = tl.where(is_blank_state, BLANK, target_value) 

328 

329 init_state = (states == 0) | ((states == 1) & (target_len > 0)) 

330 init_logp = tl.load( 

331 log_probs + batch * C + labels, 

332 mask=init_state & stored_state, 

333 other=0.0, 

334 ).to(tl.float32) 

335 alpha = tl.where(init_state & valid_state, init_logp, -float("inf")) 

336 scratch_batch = scratch_alpha + batch * 2 * STATE_COUNT_MAX 

337 tl.store(scratch_batch + states, alpha, mask=stored_state) 

338 _debug_barrier() 

339 

340 for t in tl.range(1, T): 

341 prev_base = scratch_batch + ((t - 1) % 2) * STATE_COUNT_MAX 

342 cur_base = scratch_batch + (t % 2) * STATE_COUNT_MAX 

343 prev0 = tl.load(prev_base + states, mask=stored_state, other=-float("inf")).to( 

344 tl.float32 

345 ) 

346 prev1 = tl.load( 

347 prev_base + tl.where(states > 0, states - 1, 0), 

348 mask=(states > 0) & stored_state, 

349 other=-float("inf"), 

350 ).to(tl.float32) 

351 prev2 = tl.load( 

352 prev_base + tl.where(states > 1, states - 2, 0), 

353 mask=(states > 1) & stored_state, 

354 other=-float("inf"), 

355 ).to(tl.float32) 

356 

357 prev_target_index = tl.where(target_index > 0, target_index - 1, 0) 

358 prev_target_value = tl.load( 

359 targets + target_origin + prev_target_index, 

360 mask=target_mask & (target_index > 0), 

361 other=BLANK, 

362 ) 

363 skip_allowed = ( 

364 (~is_blank_state) & (target_index > 0) & (target_value != prev_target_value) 

365 ) 

366 

367 acc = _logaddexp3(prev0, prev1, prev2, skip_allowed) 

368 logp = tl.load( 

369 log_probs + t * N * C + batch * C + labels, 

370 mask=valid_state, 

371 other=0.0, 

372 ).to(tl.float32) 

373 alpha = tl.where(valid_state, acc + logp, -float("inf")) 

374 tl.store(cur_base + states, alpha, mask=stored_state) 

375 _debug_barrier() 

376 

377 if T <= 0: 

378 loss = tl.where(target_len == 0, 0.0, float("inf")) 

379 else: 

380 _debug_barrier() 

381 final_base = scratch_batch + ((T - 1) % 2) * STATE_COUNT_MAX 

382 last = tl.load(final_base + state_count - 1).to(tl.float32) 

383 prev_last = tl.load( 

384 final_base + tl.where(target_len > 0, state_count - 2, 0), 

385 mask=target_len > 0, 

386 other=-float("inf"), 

387 ).to(tl.float32) 

388 loss = -_logaddexp(last, prev_last) 

389 

390 if REDUCTION == 1: 

391 loss = loss / tl.maximum(target_len, 1).to(tl.float32) / N 

392 tl.store(contrib + batch, loss) 

393 

394 

395@libentry() 

396@triton.jit 

397def _ctc_loss_init_grad_kernel( 

398 log_probs, 

399 input_lengths, 

400 target_lengths, 

401 neg_log_likelihood, 

402 grad_output, 

403 grad_input, 

404 total: tl.constexpr, 

405 T: tl.constexpr, 

406 N: tl.constexpr, 

407 C: tl.constexpr, 

408 REDUCTION: tl.constexpr, 

409 ZERO_INFINITY: tl.constexpr, 

410 BLOCK: tl.constexpr, 

411): 

412 offsets = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) 

413 mask = offsets < total 

414 batch = (offsets // C) % N 

415 t = offsets // (N * C) 

416 

417 input_len = tl.load(input_lengths + batch, mask=mask, other=0) 

418 target_len = tl.load(target_lengths + batch, mask=mask, other=1) 

419 nll = tl.load(neg_log_likelihood + batch, mask=mask, other=0.0).to(tl.float32) 

420 

421 if REDUCTION == 0: 

422 scale = tl.load(grad_output + batch, mask=mask, other=0.0).to(tl.float32) 

423 else: 

424 scale = tl.load(grad_output).to(tl.float32) 

425 if REDUCTION == 1: 

426 denom = tl.maximum(target_len, 1).to(tl.float32) * N 

427 scale = scale / denom 

428 

429 if ZERO_INFINITY: 

430 scale = tl.where(nll == float("inf"), 0.0, scale) 

431 

432 logp = tl.load(log_probs + offsets, mask=mask, other=-float("inf")).to(tl.float32) 

433 grad = tl.where((t < input_len) & mask, tl.exp(logp) * scale, 0.0) 

434 nan_grad = float("nan") 

435 grad = tl.where( 

436 (t < input_len) & mask & (scale != 0.0) & (logp == -float("inf")), 

437 nan_grad, 

438 grad, 

439 ) 

440 if not ZERO_INFINITY: 

441 grad = tl.where((t < input_len) & mask & (nll == float("inf")), nan_grad, grad) 

442 tl.store(grad_input + offsets, grad, mask=mask) 

443 

444 

445@libentry() 

446@triton.jit 

447def _ctc_loss_backward_kernel( 

448 log_probs, 

449 targets, 

450 input_lengths, 

451 target_lengths, 

452 target_offsets, 

453 neg_log_likelihood, 

454 log_alpha, 

455 grad_output, 

456 grad_input, 

457 scratch_beta, 

458 T: tl.constexpr, 

459 N: tl.constexpr, 

460 C: tl.constexpr, 

461 MAX_TARGET: tl.constexpr, 

462 STATE_COUNT_MAX: tl.constexpr, 

463 BLANK: tl.constexpr, 

464 TARGET_1D: tl.constexpr, 

465 REDUCTION: tl.constexpr, 

466 ZERO_INFINITY: tl.constexpr, 

467 BLOCK_S: tl.constexpr, 

468): 

469 batch = tl.program_id(0) 

470 states = tl.arange(0, BLOCK_S) 

471 

472 input_len = tl.load(input_lengths + batch) 

473 target_len = tl.load(target_lengths + batch) 

474 nll = tl.load(neg_log_likelihood + batch).to(tl.float32) 

475 state_count = target_len * 2 + 1 

476 valid_state = states < state_count 

477 stored_state = states < STATE_COUNT_MAX 

478 

479 is_blank_state = (states % 2) == 0 

480 target_index = (states - 1) // 2 

481 target_mask = (target_index >= 0) & (target_index < target_len) 

482 target_safe_index = tl.where(target_mask, target_index, 0) 

483 

484 if TARGET_1D: 

485 target_origin = tl.load(target_offsets + batch) 

486 target_ptrs = targets + target_origin + target_safe_index 

487 else: 

488 target_origin = batch * MAX_TARGET 

489 target_ptrs = targets + target_origin + target_safe_index 

490 

491 target_value = tl.load(target_ptrs, mask=target_mask, other=BLANK) 

492 labels = tl.where(is_blank_state, BLANK, target_value) 

493 

494 state1 = states + 1 

495 is_blank_state1 = (state1 % 2) == 0 

496 target_index1 = (state1 - 1) // 2 

497 target_mask1 = (target_index1 >= 0) & (target_index1 < target_len) 

498 target_safe_index1 = tl.where(target_mask1, target_index1, 0) 

499 target_ptrs1 = targets + target_origin + target_safe_index1 

500 target_value1 = tl.load(target_ptrs1, mask=target_mask1, other=BLANK) 

501 labels1 = tl.where(is_blank_state1, BLANK, target_value1) 

502 

503 state2 = states + 2 

504 target_index2 = (state2 - 1) // 2 

505 target_mask2 = (target_index2 >= 0) & (target_index2 < target_len) 

506 target_safe_index2 = tl.where(target_mask2, target_index2, 0) 

507 target_ptrs2 = targets + target_origin + target_safe_index2 

508 target_value2 = tl.load(target_ptrs2, mask=target_mask2, other=BLANK) 

509 labels2 = target_value2 

510 

511 if REDUCTION == 0: 

512 scale = tl.load(grad_output + batch).to(tl.float32) 

513 else: 

514 scale = tl.load(grad_output).to(tl.float32) 

515 if REDUCTION == 1: 

516 denom = tl.maximum(target_len, 1).to(tl.float32) * N 

517 scale = scale / denom 

518 

519 if ZERO_INFINITY: 

520 scale = tl.where(nll == float("inf"), 0.0, scale) 

521 

522 beta_init = tl.where( 

523 ((states == state_count - 1) | ((states == state_count - 2) & (target_len > 0))) 

524 & valid_state 

525 & (input_len > 0), 

526 0.0, 

527 -float("inf"), 

528 ) 

529 scratch_batch = scratch_beta + batch * 2 * STATE_COUNT_MAX 

530 tl.store(scratch_batch + states, beta_init, mask=stored_state) 

531 _debug_barrier() 

532 log_likelihood = tl.where(scale != 0.0, -nll, 0.0) 

533 

534 for step in tl.range(0, T): 

535 t = input_len - 1 - step 

536 active = t >= 0 

537 safe_t = tl.where(active, t, 0) 

538 beta_base = scratch_batch + (step % 2) * STATE_COUNT_MAX 

539 next_beta_base = scratch_batch + ((step + 1) % 2) * STATE_COUNT_MAX 

540 beta = tl.load(beta_base + states, mask=stored_state, other=-float("inf")).to( 

541 tl.float32 

542 ) 

543 

544 alpha_t = tl.load( 

545 log_alpha + batch * T * STATE_COUNT_MAX + safe_t * STATE_COUNT_MAX + states, 

546 mask=active & stored_state, 

547 other=-float("inf"), 

548 ).to(tl.float32) 

549 log_post = alpha_t + beta - log_likelihood 

550 posterior = tl.where( 

551 active & valid_state & (scale != 0.0), 

552 tl.exp(log_post), 

553 0.0, 

554 ) 

555 tl.atomic_add( 

556 grad_input + safe_t * N * C + batch * C + labels, 

557 -scale * posterior, 

558 sem="relaxed", 

559 mask=active & valid_state & stored_state, 

560 ) 

561 

562 stay = beta + tl.load( 

563 log_probs + safe_t * N * C + batch * C + labels, 

564 mask=active & valid_state, 

565 other=-float("inf"), 

566 ).to(tl.float32) 

567 next1 = tl.load( 

568 beta_base + states + 1, 

569 mask=(states + 1 < state_count) & stored_state, 

570 other=-float("inf"), 

571 ).to(tl.float32) + tl.load( 

572 log_probs + safe_t * N * C + batch * C + labels1, 

573 mask=active & (states + 1 < state_count) & stored_state, 

574 other=-float("inf"), 

575 ).to( 

576 tl.float32 

577 ) 

578 skip_allowed = ( 

579 (~is_blank_state) 

580 & (states + 2 < state_count) 

581 & (target_value != target_value2) 

582 ) 

583 next2 = tl.load( 

584 beta_base + states + 2, 

585 mask=(states + 2 < state_count) & stored_state, 

586 other=-float("inf"), 

587 ).to(tl.float32) + tl.load( 

588 log_probs + safe_t * N * C + batch * C + labels2, 

589 mask=active & skip_allowed & stored_state, 

590 other=-float("inf"), 

591 ).to( 

592 tl.float32 

593 ) 

594 

595 beta_next = _logaddexp3(stay, next1, next2, skip_allowed) 

596 tl.store( 

597 next_beta_base + states, 

598 tl.where(active, beta_next, -float("inf")), 

599 mask=stored_state, 

600 ) 

601 _debug_barrier() 

602 

603 

604def _reduction_enum(reduction): 

605 if isinstance(reduction, str): 

606 if reduction == "none": 

607 return _REDUCTION_NONE 

608 if reduction == "mean": 

609 return _REDUCTION_MEAN 

610 if reduction == "sum": 

611 return _REDUCTION_SUM 

612 raise ValueError( 

613 "ctc_loss expected reduction to be one of 'none', 'mean', or 'sum', " 

614 f"but got {reduction!r}" 

615 ) 

616 return int(reduction) 

617 

618 

619_INTEGRAL_DTYPES = { 

620 torch.uint8, 

621 torch.int8, 

622 torch.int16, 

623 torch.int32, 

624 torch.int64, 

625} 

626 

627 

628def _is_integral_dtype(dtype): 

629 return dtype in _INTEGRAL_DTYPES 

630 

631 

632def _lengths_to_tensor(lengths, device, name): 

633 if torch.is_tensor(lengths): 

634 if not _is_integral_dtype(lengths.dtype): 

635 raise RuntimeError(f"{name} must be integral") 

636 out = lengths.to(device=device) 

637 else: 

638 out = torch.tensor(lengths, device=device) 

639 if not _is_integral_dtype(out.dtype): 

640 raise RuntimeError(f"{name} must be integral") 

641 if out.dtype != torch.long: 

642 out = out.to(dtype=torch.long) 

643 return out.reshape(1) if out.ndim == 0 else out.reshape(-1).contiguous() 

644 

645 

646def _length_stats(lengths): 

647 key = None 

648 if torch.is_tensor(lengths): 

649 key = ( 

650 lengths.device.type, 

651 lengths.device.index, 

652 lengths.data_ptr(), 

653 lengths.numel(), 

654 lengths._version, 

655 ) 

656 cached = _LENGTH_STATS_CACHE.get(key) 

657 if cached is not None: 

658 return cached[1] 

659 

660 # [sunrise fix] Could not run 'aten::min' with arguments from the 'ptpu' backend. 

661 stats_tensor = torch.stack( 

662 (lengths.cpu().min(), lengths.cpu().max(), lengths.cpu().sum()) 

663 ) 

664 stats = tuple(int(value) for value in stats_tensor.tolist()) 

665 if key is not None: 

666 if len(_LENGTH_STATS_CACHE) >= _LENGTH_STATS_CACHE_LIMIT: 

667 _LENGTH_STATS_CACHE.clear() 

668 _LENGTH_STATS_CACHE[key] = (lengths, stats) 

669 return stats 

670 

671 

672def _compute_dtype(dtype): 

673 if dtype in (torch.float16, torch.bfloat16): 

674 return torch.float32 

675 return dtype 

676 

677 

678class CtcLossFunction(torch.autograd.Function): 

679 @staticmethod 

680 def forward( 

681 ctx, 

682 log_probs, 

683 targets, 

684 input_lengths, 

685 target_lengths, 

686 blank=0, 

687 reduction="mean", 

688 zero_infinity=False, 

689 ): 

690 reduction = _reduction_enum(reduction) 

691 if reduction not in (_REDUCTION_NONE, _REDUCTION_MEAN, _REDUCTION_SUM): 

692 raise ValueError(f"ctc_loss got invalid reduction enum {reduction}") 

693 

694 if log_probs.ndim not in (2, 3): 

695 raise RuntimeError( 

696 "ctc_loss expects log_probs to be a 2D or 3D tensor, " 

697 f"but got {log_probs.ndim}D" 

698 ) 

699 if not torch.is_floating_point(log_probs): 

700 raise RuntimeError(f'"ctc_loss" not implemented for {log_probs.dtype}') 

701 if blank < 0 or blank >= log_probs.shape[-1]: 

702 raise RuntimeError("blank must be in label range") 

703 

704 original_dtype = log_probs.dtype 

705 compute_dtype = _compute_dtype(original_dtype) 

706 unbatched = log_probs.ndim == 2 

707 batch_size = 1 if unbatched else log_probs.shape[1] 

708 

709 work_log_probs = log_probs.unsqueeze(1) if unbatched else log_probs 

710 work_log_probs = work_log_probs.contiguous() 

711 if work_log_probs.dtype != compute_dtype: 

712 work_log_probs = work_log_probs.to(compute_dtype) 

713 

714 if torch.is_floating_point(targets): 

715 work_targets = targets.to(dtype=torch.long).contiguous() 

716 elif _is_integral_dtype(targets.dtype): 

717 work_targets = targets.contiguous() 

718 else: 

719 raise RuntimeError("ctc_loss targets must be integral or floating point") 

720 work_input_lengths = _lengths_to_tensor( 

721 input_lengths, log_probs.device, "input_lengths" 

722 ) 

723 work_target_lengths = _lengths_to_tensor( 

724 target_lengths, log_probs.device, "target_lengths" 

725 ) 

726 if work_input_lengths.numel() != batch_size: 

727 raise RuntimeError( 

728 f"ctc_loss expected input_lengths to have size {batch_size}, " 

729 f"but got {work_input_lengths.numel()}" 

730 ) 

731 if work_target_lengths.numel() != batch_size: 

732 raise RuntimeError( 

733 f"ctc_loss expected target_lengths to have size {batch_size}, " 

734 f"but got {work_target_lengths.numel()}" 

735 ) 

736 min_input_length, max_input_length, _ = _length_stats(work_input_lengths) 

737 min_target_length, max_target, total_target_length = _length_stats( 

738 work_target_lengths 

739 ) 

740 if min_input_length < 0 or max_input_length > work_log_probs.shape[0]: 

741 raise RuntimeError("ctc_loss input_lengths must be in [0, T]") 

742 if min_target_length < 0: 

743 raise RuntimeError("ctc_loss target_lengths must be non-negative") 

744 

745 state_count_max = 2 * max_target + 1 

746 target_stride = max_target 

747 if work_targets.ndim == 1: 

748 target_1d = True 

749 if total_target_length != work_targets.numel(): 

750 raise RuntimeError( 

751 "ctc_loss expected concatenated targets length to equal " 

752 "sum(target_lengths)" 

753 ) 

754 work_target_offsets = ( 

755 work_target_lengths.cumsum(0) - work_target_lengths 

756 ).contiguous() 

757 elif work_targets.ndim == 2: 

758 target_1d = False 

759 if max_target > work_targets.shape[1]: 

760 raise RuntimeError( 

761 "ctc_loss target_lengths cannot exceed padded target width" 

762 ) 

763 target_stride = work_targets.shape[1] 

764 work_target_offsets = work_target_lengths 

765 else: 

766 raise RuntimeError( 

767 "ctc_loss expects targets to be a 1D concatenated tensor or a " 

768 f"2D padded tensor, but got {work_targets.ndim}D" 

769 ) 

770 

771 needs_log_probs_grad = ctx.needs_input_grad[0] 

772 block_s = triton.next_power_of_2(state_count_max) 

773 

774 if not needs_log_probs_grad: 

775 if ( 

776 not unbatched 

777 and not zero_infinity 

778 and reduction in (_REDUCTION_MEAN, _REDUCTION_SUM) 

779 and min_input_length == work_log_probs.shape[0] 

780 and work_log_probs.shape[0] > 0 

781 ): 

782 contrib = torch.empty( 

783 (batch_size,), dtype=torch.float32, device=log_probs.device 

784 ) 

785 scratch_alpha = torch.empty( 

786 (batch_size, 2, state_count_max), 

787 dtype=torch.float32, 

788 device=log_probs.device, 

789 ) 

790 with torch_device_fn.device(log_probs.device): 

791 _ctc_loss_forward_full_length_reduce_kernel[(batch_size,)]( 

792 work_log_probs, 

793 work_targets, 

794 work_target_lengths, 

795 work_target_offsets, 

796 contrib, 

797 scratch_alpha, 

798 work_log_probs.shape[0], 

799 batch_size, 

800 work_log_probs.shape[2], 

801 target_stride, 

802 state_count_max, 

803 blank, 

804 target_1d, 

805 reduction, 

806 block_s, 

807 ) 

808 output = contrib.sum() 

809 if output.dtype != original_dtype: 

810 output = output.to(original_dtype) 

811 return output 

812 

813 raw_neg_log_likelihood = torch.empty( 

814 (batch_size,), dtype=torch.float32, device=log_probs.device 

815 ) 

816 scratch_alpha = torch.empty( 

817 (batch_size, 2, state_count_max), 

818 dtype=torch.float32, 

819 device=log_probs.device, 

820 ) 

821 with torch_device_fn.device(log_probs.device): 

822 _ctc_loss_forward_no_grad_kernel[(batch_size,)]( 

823 work_log_probs, 

824 work_targets, 

825 work_input_lengths, 

826 work_target_lengths, 

827 work_target_offsets, 

828 raw_neg_log_likelihood, 

829 scratch_alpha, 

830 work_log_probs.shape[0], 

831 batch_size, 

832 work_log_probs.shape[2], 

833 target_stride, 

834 state_count_max, 

835 blank, 

836 target_1d, 

837 block_s, 

838 ) 

839 neg_log_likelihood = raw_neg_log_likelihood 

840 if zero_infinity: 

841 neg_log_likelihood = torch.where( 

842 torch.isinf(neg_log_likelihood), 

843 torch.zeros( 

844 (), dtype=neg_log_likelihood.dtype, device=log_probs.device 

845 ), 

846 neg_log_likelihood, 

847 ) 

848 

849 if reduction == _REDUCTION_NONE: 

850 output = neg_log_likelihood 

851 if unbatched: 

852 output = output.squeeze(0) 

853 elif reduction == _REDUCTION_SUM: 

854 output = neg_log_likelihood.sum() 

855 else: 

856 # denom = work_target_lengths.clamp_min(1) 

857 # output = (neg_log_likelihood / denom).mean() 

858 # [sunrise fix] Could not run 'aten::min' & 'aten::mean' with arguments from the 'ptpu' backend. 

859 denom = ( 

860 work_target_lengths.cpu() 

861 .clamp_min(1) 

862 .to(work_target_lengths.device) 

863 ) 

864 output = ( 

865 (neg_log_likelihood / denom) 

866 .cpu() 

867 .mean() 

868 .to(neg_log_likelihood.device) 

869 ) 

870 

871 if output.dtype != original_dtype: 

872 output = output.to(original_dtype) 

873 return output 

874 

875 raw_neg_log_likelihood = torch.empty( 

876 (batch_size,), dtype=torch.float32, device=log_probs.device 

877 ) 

878 

879 log_alpha = torch.empty( 

880 (batch_size, work_log_probs.shape[0], state_count_max), 

881 dtype=torch.float32, 

882 device=log_probs.device, 

883 ) 

884 with torch_device_fn.device(log_probs.device): 

885 _ctc_loss_forward_kernel[(batch_size,)]( 

886 work_log_probs, 

887 work_targets, 

888 work_input_lengths, 

889 work_target_lengths, 

890 work_target_offsets, 

891 raw_neg_log_likelihood, 

892 log_alpha, 

893 work_log_probs.shape[0], 

894 batch_size, 

895 work_log_probs.shape[2], 

896 target_stride, 

897 state_count_max, 

898 blank, 

899 target_1d, 

900 block_s, 

901 ) 

902 neg_log_likelihood = raw_neg_log_likelihood 

903 if zero_infinity: 

904 neg_log_likelihood = torch.where( 

905 torch.isinf(neg_log_likelihood), 

906 torch.zeros( 

907 (), dtype=neg_log_likelihood.dtype, device=log_probs.device 

908 ), 

909 neg_log_likelihood, 

910 ) 

911 

912 if reduction == _REDUCTION_NONE: 

913 output = neg_log_likelihood 

914 if unbatched: 

915 output = output.squeeze(0) 

916 if output.dtype != original_dtype: 

917 output = output.to(original_dtype) 

918 elif reduction == _REDUCTION_SUM: 

919 output = neg_log_likelihood.sum() 

920 else: 

921 # denom = work_target_lengths.clamp_min(1) 

922 # output = (neg_log_likelihood / denom).mean() 

923 # [sunrise fix] Could not run 'aten::min' & 'aten::mean' with arguments from the 'ptpu' backend. 

924 denom = ( 

925 work_target_lengths.cpu().clamp_min(1).to(work_target_lengths.device) 

926 ) 

927 output = ( 

928 (neg_log_likelihood / denom).cpu().mean().to(neg_log_likelihood.device) 

929 ) 

930 

931 if output.dtype != original_dtype: 

932 output = output.to(original_dtype) 

933 

934 ctx.save_for_backward( 

935 work_log_probs, 

936 work_targets, 

937 work_input_lengths, 

938 work_target_lengths, 

939 work_target_offsets, 

940 raw_neg_log_likelihood, 

941 log_alpha, 

942 ) 

943 ctx.blank = blank 

944 ctx.reduction = reduction 

945 ctx.zero_infinity = zero_infinity 

946 ctx.unbatched = unbatched 

947 ctx.batch_size = batch_size 

948 ctx.original_dtype = original_dtype 

949 ctx.max_target = target_stride 

950 ctx.state_count_max = state_count_max 

951 ctx.target_1d = target_1d 

952 

953 return output 

954 

955 @staticmethod 

956 def backward(ctx, grad_output): 

957 ( 

958 work_log_probs, 

959 work_targets, 

960 work_input_lengths, 

961 work_target_lengths, 

962 work_target_offsets, 

963 neg_log_likelihood, 

964 log_alpha, 

965 ) = ctx.saved_tensors 

966 

967 grad_output = grad_output.contiguous() 

968 

969 grad_log_probs = torch.empty_like(work_log_probs) 

970 total = work_log_probs.numel() 

971 block = 256 

972 with torch_device_fn.device(work_log_probs.device): 

973 _ctc_loss_init_grad_kernel[(triton.cdiv(total, block),)]( 

974 work_log_probs, 

975 work_input_lengths, 

976 work_target_lengths, 

977 neg_log_likelihood, 

978 grad_output, 

979 grad_log_probs, 

980 total, 

981 work_log_probs.shape[0], 

982 ctx.batch_size, 

983 work_log_probs.shape[2], 

984 ctx.reduction, 

985 ctx.zero_infinity, 

986 block, 

987 ) 

988 

989 scratch_beta = torch.empty( 

990 (ctx.batch_size, 2, ctx.state_count_max), 

991 dtype=torch.float32, 

992 device=work_log_probs.device, 

993 ) 

994 block_s = triton.next_power_of_2(ctx.state_count_max) 

995 _ctc_loss_backward_kernel[(ctx.batch_size,)]( 

996 work_log_probs, 

997 work_targets, 

998 work_input_lengths, 

999 work_target_lengths, 

1000 work_target_offsets, 

1001 neg_log_likelihood, 

1002 log_alpha, 

1003 grad_output, 

1004 grad_log_probs, 

1005 scratch_beta, 

1006 work_log_probs.shape[0], 

1007 ctx.batch_size, 

1008 work_log_probs.shape[2], 

1009 ctx.max_target, 

1010 ctx.state_count_max, 

1011 ctx.blank, 

1012 ctx.target_1d, 

1013 ctx.reduction, 

1014 ctx.zero_infinity, 

1015 block_s, 

1016 ) 

1017 

1018 if ctx.unbatched: 

1019 grad_log_probs = grad_log_probs.squeeze(1) 

1020 if grad_log_probs.dtype != ctx.original_dtype: 

1021 grad_log_probs = grad_log_probs.to(ctx.original_dtype) 

1022 

1023 return grad_log_probs, None, None, None, None, None, None 

1024 

1025 

1026def ctc_loss( 

1027 log_probs, 

1028 targets, 

1029 input_lengths, 

1030 target_lengths, 

1031 blank=0, 

1032 reduction="mean", 

1033 zero_infinity=False, 

1034): 

1035 logger.debug("GEMS CTC LOSS") 

1036 return CtcLossFunction.apply( 

1037 log_probs, 

1038 targets, 

1039 input_lengths, 

1040 target_lengths, 

1041 blank, 

1042 reduction, 

1043 zero_infinity, 

1044 )