Coverage for src/flag_gems/ops/ctc_loss.py: 41%
446 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry
10logger = logging.getLogger(__name__)
13_REDUCTION_NONE = 0
14_REDUCTION_MEAN = 1
15_REDUCTION_SUM = 2
16_LENGTH_STATS_CACHE = {}
17_LENGTH_STATS_CACHE_LIMIT = 256
20if hasattr(tl, "debug_barrier"):
21 _debug_barrier = tl.debug_barrier
22else:
24 @triton.jit
25 def _debug_barrier():
26 return
29@triton.jit
30def _logaddexp(a, b):
31 max_ab = tl.maximum(a, b)
32 min_ab = tl.minimum(a, b)
33 return tl.where(
34 max_ab == -float("inf"),
35 -float("inf"),
36 max_ab + tl.log(1.0 + tl.exp(min_ab - max_ab)),
37 )
40@triton.jit
41def _logaddexp3(a, b, c, use_c):
42 c = tl.where(use_c, c, -float("inf"))
43 max_abc = tl.maximum(tl.maximum(a, b), c)
44 safe_max = tl.where(max_abc == -float("inf"), 0.0, max_abc)
45 exp_sum = tl.exp(a - safe_max) + tl.exp(b - safe_max) + tl.exp(c - safe_max)
46 return tl.where(
47 max_abc == -float("inf"),
48 -float("inf"),
49 max_abc + tl.log(exp_sum),
50 )
53@libentry()
54@triton.jit
55def _ctc_loss_forward_kernel(
56 log_probs,
57 targets,
58 input_lengths,
59 target_lengths,
60 target_offsets,
61 neg_log_likelihood,
62 log_alpha,
63 T: tl.constexpr,
64 N: tl.constexpr,
65 C: tl.constexpr,
66 MAX_TARGET: tl.constexpr,
67 STATE_COUNT_MAX: tl.constexpr,
68 BLANK: tl.constexpr,
69 TARGET_1D: tl.constexpr,
70 BLOCK_S: tl.constexpr,
71):
72 batch = tl.program_id(0)
73 states = tl.arange(0, BLOCK_S)
75 input_len = tl.load(input_lengths + batch)
76 target_len = tl.load(target_lengths + batch)
77 state_count = target_len * 2 + 1
78 valid_state = states < state_count
79 stored_state = states < STATE_COUNT_MAX
81 is_blank_state = (states % 2) == 0
82 target_index = (states - 1) // 2
83 target_mask = (target_index >= 0) & (target_index < target_len)
84 target_safe_index = tl.where(target_mask, target_index, 0)
86 if TARGET_1D:
87 target_base = tl.full((), 0, tl.int64)
88 for prev_batch in tl.range(0, N):
89 target_base += tl.load(
90 target_lengths + prev_batch,
91 mask=prev_batch < batch,
92 other=0,
93 )
94 target_origin = target_base
95 target_ptrs = targets + target_origin + target_safe_index
96 else:
97 target_origin = batch * MAX_TARGET
98 target_ptrs = targets + target_origin + target_safe_index
100 target_value = tl.load(target_ptrs, mask=target_mask, other=BLANK)
101 labels = tl.where(is_blank_state, BLANK, target_value)
103 t0_active = input_len > 0
104 init_state = (states == 0) | ((states == 1) & (target_len > 0))
105 init_logp = tl.load(
106 log_probs + batch * C + labels,
107 mask=init_state & stored_state & t0_active,
108 other=0.0,
109 ).to(tl.float32)
110 alpha = tl.where(init_state & valid_state & t0_active, init_logp, -float("inf"))
111 tl.store(
112 log_alpha + batch * T * STATE_COUNT_MAX + states,
113 alpha,
114 mask=stored_state,
115 )
116 _debug_barrier()
118 for t in tl.range(1, T):
119 prev_base = log_alpha + batch * T * STATE_COUNT_MAX + (t - 1) * STATE_COUNT_MAX
120 prev0 = tl.load(prev_base + states, mask=stored_state, other=-float("inf")).to(
121 tl.float32
122 )
123 prev1 = tl.load(
124 prev_base + tl.where(states > 0, states - 1, 0),
125 mask=(states > 0) & stored_state,
126 other=-float("inf"),
127 ).to(tl.float32)
128 prev2 = tl.load(
129 prev_base + tl.where(states > 1, states - 2, 0),
130 mask=(states > 1) & stored_state,
131 other=-float("inf"),
132 ).to(tl.float32)
134 prev_target_index = tl.where(target_index > 0, target_index - 1, 0)
135 prev_target_value = tl.load(
136 targets + target_origin + prev_target_index,
137 mask=target_mask & (target_index > 0),
138 other=BLANK,
139 )
140 skip_allowed = (
141 (~is_blank_state) & (target_index > 0) & (target_value != prev_target_value)
142 )
144 acc = _logaddexp3(prev0, prev1, prev2, skip_allowed)
146 logp = tl.load(
147 log_probs + t * N * C + batch * C + labels,
148 mask=valid_state & (t < input_len),
149 other=0.0,
150 ).to(tl.float32)
151 alpha = tl.where(valid_state & (t < input_len), acc + logp, -float("inf"))
152 tl.store(
153 log_alpha + batch * T * STATE_COUNT_MAX + t * STATE_COUNT_MAX + states,
154 alpha,
155 mask=stored_state,
156 )
157 _debug_barrier()
159 if input_len <= 0:
160 loss = tl.where(target_len == 0, 0.0, float("inf"))
161 else:
162 _debug_barrier()
163 final_base = (
164 log_alpha + batch * T * STATE_COUNT_MAX + (input_len - 1) * STATE_COUNT_MAX
165 )
166 last = tl.load(final_base + state_count - 1).to(tl.float32)
167 prev_last = tl.load(
168 final_base + tl.where(target_len > 0, state_count - 2, 0),
169 mask=target_len > 0,
170 other=-float("inf"),
171 ).to(tl.float32)
172 log_likelihood = _logaddexp(last, prev_last)
173 loss = -log_likelihood
175 tl.store(neg_log_likelihood + batch, loss)
178@libentry()
179@triton.jit
180def _ctc_loss_forward_no_grad_kernel(
181 log_probs,
182 targets,
183 input_lengths,
184 target_lengths,
185 target_offsets,
186 neg_log_likelihood,
187 scratch_alpha,
188 T: tl.constexpr,
189 N: tl.constexpr,
190 C: tl.constexpr,
191 MAX_TARGET: tl.constexpr,
192 STATE_COUNT_MAX: tl.constexpr,
193 BLANK: tl.constexpr,
194 TARGET_1D: tl.constexpr,
195 BLOCK_S: tl.constexpr,
196):
197 batch = tl.program_id(0)
198 states = tl.arange(0, BLOCK_S)
200 target_len = tl.load(target_lengths + batch)
201 state_count = target_len * 2 + 1
202 valid_state = states < state_count
203 stored_state = states < STATE_COUNT_MAX
205 is_blank_state = (states % 2) == 0
206 target_index = (states - 1) // 2
207 target_mask = (target_index >= 0) & (target_index < target_len)
208 target_safe_index = tl.where(target_mask, target_index, 0)
210 if TARGET_1D:
211 target_origin = tl.load(target_offsets + batch)
212 target_ptrs = targets + target_origin + target_safe_index
213 else:
214 target_origin = batch * MAX_TARGET
215 target_ptrs = targets + target_origin + target_safe_index
217 target_value = tl.load(target_ptrs, mask=target_mask, other=BLANK)
218 labels = tl.where(is_blank_state, BLANK, target_value)
220 input_len = tl.load(input_lengths + batch)
221 init_state = (states == 0) | ((states == 1) & (target_len > 0))
222 init_logp = tl.load(
223 log_probs + batch * C + labels,
224 mask=init_state & stored_state & (input_len > 0),
225 other=0.0,
226 ).to(tl.float32)
227 alpha = tl.where(
228 init_state & valid_state & (input_len > 0), init_logp, -float("inf")
229 )
230 scratch_batch = scratch_alpha + batch * 2 * STATE_COUNT_MAX
231 tl.store(scratch_batch + states, alpha, mask=stored_state)
232 _debug_barrier()
234 for t in tl.range(1, T):
235 prev_base = scratch_batch + ((t - 1) % 2) * STATE_COUNT_MAX
236 cur_base = scratch_batch + (t % 2) * STATE_COUNT_MAX
237 prev0 = tl.load(prev_base + states, mask=stored_state, other=-float("inf")).to(
238 tl.float32
239 )
240 prev1 = tl.load(
241 prev_base + tl.where(states > 0, states - 1, 0),
242 mask=(states > 0) & stored_state,
243 other=-float("inf"),
244 ).to(tl.float32)
245 prev2 = tl.load(
246 prev_base + tl.where(states > 1, states - 2, 0),
247 mask=(states > 1) & stored_state,
248 other=-float("inf"),
249 ).to(tl.float32)
251 prev_target_index = tl.where(target_index > 0, target_index - 1, 0)
252 prev_target_value = tl.load(
253 targets + target_origin + prev_target_index,
254 mask=target_mask & (target_index > 0),
255 other=BLANK,
256 )
257 skip_allowed = (
258 (~is_blank_state) & (target_index > 0) & (target_value != prev_target_value)
259 )
261 acc = _logaddexp3(prev0, prev1, prev2, skip_allowed)
262 logp = tl.load(
263 log_probs + t * N * C + batch * C + labels,
264 mask=valid_state & (t < input_len),
265 other=0.0,
266 ).to(tl.float32)
267 alpha = tl.where(valid_state & (t < input_len), acc + logp, -float("inf"))
268 tl.store(cur_base + states, alpha, mask=stored_state & (t < input_len))
269 _debug_barrier()
271 if input_len <= 0:
272 loss = tl.where(target_len == 0, 0.0, float("inf"))
273 else:
274 _debug_barrier()
275 final_base = scratch_batch + ((input_len - 1) % 2) * STATE_COUNT_MAX
276 last = tl.load(final_base + state_count - 1).to(tl.float32)
277 prev_last = tl.load(
278 final_base + tl.where(target_len > 0, state_count - 2, 0),
279 mask=target_len > 0,
280 other=-float("inf"),
281 ).to(tl.float32)
282 loss = -_logaddexp(last, prev_last)
284 tl.store(neg_log_likelihood + batch, loss)
287@libentry()
288@triton.jit
289def _ctc_loss_forward_full_length_reduce_kernel(
290 log_probs,
291 targets,
292 target_lengths,
293 target_offsets,
294 contrib,
295 scratch_alpha,
296 T: tl.constexpr,
297 N: tl.constexpr,
298 C: tl.constexpr,
299 MAX_TARGET: tl.constexpr,
300 STATE_COUNT_MAX: tl.constexpr,
301 BLANK: tl.constexpr,
302 TARGET_1D: tl.constexpr,
303 REDUCTION: tl.constexpr,
304 BLOCK_S: tl.constexpr,
305):
306 batch = tl.program_id(0)
307 states = tl.arange(0, BLOCK_S)
309 target_len = tl.load(target_lengths + batch)
310 state_count = target_len * 2 + 1
311 valid_state = states < state_count
312 stored_state = states < STATE_COUNT_MAX
314 is_blank_state = (states % 2) == 0
315 target_index = (states - 1) // 2
316 target_mask = (target_index >= 0) & (target_index < target_len)
317 target_safe_index = tl.where(target_mask, target_index, 0)
319 if TARGET_1D:
320 target_origin = tl.load(target_offsets + batch)
321 target_ptrs = targets + target_origin + target_safe_index
322 else:
323 target_origin = batch * MAX_TARGET
324 target_ptrs = targets + target_origin + target_safe_index
326 target_value = tl.load(target_ptrs, mask=target_mask, other=BLANK)
327 labels = tl.where(is_blank_state, BLANK, target_value)
329 init_state = (states == 0) | ((states == 1) & (target_len > 0))
330 init_logp = tl.load(
331 log_probs + batch * C + labels,
332 mask=init_state & stored_state,
333 other=0.0,
334 ).to(tl.float32)
335 alpha = tl.where(init_state & valid_state, init_logp, -float("inf"))
336 scratch_batch = scratch_alpha + batch * 2 * STATE_COUNT_MAX
337 tl.store(scratch_batch + states, alpha, mask=stored_state)
338 _debug_barrier()
340 for t in tl.range(1, T):
341 prev_base = scratch_batch + ((t - 1) % 2) * STATE_COUNT_MAX
342 cur_base = scratch_batch + (t % 2) * STATE_COUNT_MAX
343 prev0 = tl.load(prev_base + states, mask=stored_state, other=-float("inf")).to(
344 tl.float32
345 )
346 prev1 = tl.load(
347 prev_base + tl.where(states > 0, states - 1, 0),
348 mask=(states > 0) & stored_state,
349 other=-float("inf"),
350 ).to(tl.float32)
351 prev2 = tl.load(
352 prev_base + tl.where(states > 1, states - 2, 0),
353 mask=(states > 1) & stored_state,
354 other=-float("inf"),
355 ).to(tl.float32)
357 prev_target_index = tl.where(target_index > 0, target_index - 1, 0)
358 prev_target_value = tl.load(
359 targets + target_origin + prev_target_index,
360 mask=target_mask & (target_index > 0),
361 other=BLANK,
362 )
363 skip_allowed = (
364 (~is_blank_state) & (target_index > 0) & (target_value != prev_target_value)
365 )
367 acc = _logaddexp3(prev0, prev1, prev2, skip_allowed)
368 logp = tl.load(
369 log_probs + t * N * C + batch * C + labels,
370 mask=valid_state,
371 other=0.0,
372 ).to(tl.float32)
373 alpha = tl.where(valid_state, acc + logp, -float("inf"))
374 tl.store(cur_base + states, alpha, mask=stored_state)
375 _debug_barrier()
377 if T <= 0:
378 loss = tl.where(target_len == 0, 0.0, float("inf"))
379 else:
380 _debug_barrier()
381 final_base = scratch_batch + ((T - 1) % 2) * STATE_COUNT_MAX
382 last = tl.load(final_base + state_count - 1).to(tl.float32)
383 prev_last = tl.load(
384 final_base + tl.where(target_len > 0, state_count - 2, 0),
385 mask=target_len > 0,
386 other=-float("inf"),
387 ).to(tl.float32)
388 loss = -_logaddexp(last, prev_last)
390 if REDUCTION == 1:
391 loss = loss / tl.maximum(target_len, 1).to(tl.float32) / N
392 tl.store(contrib + batch, loss)
395@libentry()
396@triton.jit
397def _ctc_loss_init_grad_kernel(
398 log_probs,
399 input_lengths,
400 target_lengths,
401 neg_log_likelihood,
402 grad_output,
403 grad_input,
404 total: tl.constexpr,
405 T: tl.constexpr,
406 N: tl.constexpr,
407 C: tl.constexpr,
408 REDUCTION: tl.constexpr,
409 ZERO_INFINITY: tl.constexpr,
410 BLOCK: tl.constexpr,
411):
412 offsets = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
413 mask = offsets < total
414 batch = (offsets // C) % N
415 t = offsets // (N * C)
417 input_len = tl.load(input_lengths + batch, mask=mask, other=0)
418 target_len = tl.load(target_lengths + batch, mask=mask, other=1)
419 nll = tl.load(neg_log_likelihood + batch, mask=mask, other=0.0).to(tl.float32)
421 if REDUCTION == 0:
422 scale = tl.load(grad_output + batch, mask=mask, other=0.0).to(tl.float32)
423 else:
424 scale = tl.load(grad_output).to(tl.float32)
425 if REDUCTION == 1:
426 denom = tl.maximum(target_len, 1).to(tl.float32) * N
427 scale = scale / denom
429 if ZERO_INFINITY:
430 scale = tl.where(nll == float("inf"), 0.0, scale)
432 logp = tl.load(log_probs + offsets, mask=mask, other=-float("inf")).to(tl.float32)
433 grad = tl.where((t < input_len) & mask, tl.exp(logp) * scale, 0.0)
434 nan_grad = float("nan")
435 grad = tl.where(
436 (t < input_len) & mask & (scale != 0.0) & (logp == -float("inf")),
437 nan_grad,
438 grad,
439 )
440 if not ZERO_INFINITY:
441 grad = tl.where((t < input_len) & mask & (nll == float("inf")), nan_grad, grad)
442 tl.store(grad_input + offsets, grad, mask=mask)
445@libentry()
446@triton.jit
447def _ctc_loss_backward_kernel(
448 log_probs,
449 targets,
450 input_lengths,
451 target_lengths,
452 target_offsets,
453 neg_log_likelihood,
454 log_alpha,
455 grad_output,
456 grad_input,
457 scratch_beta,
458 T: tl.constexpr,
459 N: tl.constexpr,
460 C: tl.constexpr,
461 MAX_TARGET: tl.constexpr,
462 STATE_COUNT_MAX: tl.constexpr,
463 BLANK: tl.constexpr,
464 TARGET_1D: tl.constexpr,
465 REDUCTION: tl.constexpr,
466 ZERO_INFINITY: tl.constexpr,
467 BLOCK_S: tl.constexpr,
468):
469 batch = tl.program_id(0)
470 states = tl.arange(0, BLOCK_S)
472 input_len = tl.load(input_lengths + batch)
473 target_len = tl.load(target_lengths + batch)
474 nll = tl.load(neg_log_likelihood + batch).to(tl.float32)
475 state_count = target_len * 2 + 1
476 valid_state = states < state_count
477 stored_state = states < STATE_COUNT_MAX
479 is_blank_state = (states % 2) == 0
480 target_index = (states - 1) // 2
481 target_mask = (target_index >= 0) & (target_index < target_len)
482 target_safe_index = tl.where(target_mask, target_index, 0)
484 if TARGET_1D:
485 target_origin = tl.load(target_offsets + batch)
486 target_ptrs = targets + target_origin + target_safe_index
487 else:
488 target_origin = batch * MAX_TARGET
489 target_ptrs = targets + target_origin + target_safe_index
491 target_value = tl.load(target_ptrs, mask=target_mask, other=BLANK)
492 labels = tl.where(is_blank_state, BLANK, target_value)
494 state1 = states + 1
495 is_blank_state1 = (state1 % 2) == 0
496 target_index1 = (state1 - 1) // 2
497 target_mask1 = (target_index1 >= 0) & (target_index1 < target_len)
498 target_safe_index1 = tl.where(target_mask1, target_index1, 0)
499 target_ptrs1 = targets + target_origin + target_safe_index1
500 target_value1 = tl.load(target_ptrs1, mask=target_mask1, other=BLANK)
501 labels1 = tl.where(is_blank_state1, BLANK, target_value1)
503 state2 = states + 2
504 target_index2 = (state2 - 1) // 2
505 target_mask2 = (target_index2 >= 0) & (target_index2 < target_len)
506 target_safe_index2 = tl.where(target_mask2, target_index2, 0)
507 target_ptrs2 = targets + target_origin + target_safe_index2
508 target_value2 = tl.load(target_ptrs2, mask=target_mask2, other=BLANK)
509 labels2 = target_value2
511 if REDUCTION == 0:
512 scale = tl.load(grad_output + batch).to(tl.float32)
513 else:
514 scale = tl.load(grad_output).to(tl.float32)
515 if REDUCTION == 1:
516 denom = tl.maximum(target_len, 1).to(tl.float32) * N
517 scale = scale / denom
519 if ZERO_INFINITY:
520 scale = tl.where(nll == float("inf"), 0.0, scale)
522 beta_init = tl.where(
523 ((states == state_count - 1) | ((states == state_count - 2) & (target_len > 0)))
524 & valid_state
525 & (input_len > 0),
526 0.0,
527 -float("inf"),
528 )
529 scratch_batch = scratch_beta + batch * 2 * STATE_COUNT_MAX
530 tl.store(scratch_batch + states, beta_init, mask=stored_state)
531 _debug_barrier()
532 log_likelihood = tl.where(scale != 0.0, -nll, 0.0)
534 for step in tl.range(0, T):
535 t = input_len - 1 - step
536 active = t >= 0
537 safe_t = tl.where(active, t, 0)
538 beta_base = scratch_batch + (step % 2) * STATE_COUNT_MAX
539 next_beta_base = scratch_batch + ((step + 1) % 2) * STATE_COUNT_MAX
540 beta = tl.load(beta_base + states, mask=stored_state, other=-float("inf")).to(
541 tl.float32
542 )
544 alpha_t = tl.load(
545 log_alpha + batch * T * STATE_COUNT_MAX + safe_t * STATE_COUNT_MAX + states,
546 mask=active & stored_state,
547 other=-float("inf"),
548 ).to(tl.float32)
549 log_post = alpha_t + beta - log_likelihood
550 posterior = tl.where(
551 active & valid_state & (scale != 0.0),
552 tl.exp(log_post),
553 0.0,
554 )
555 tl.atomic_add(
556 grad_input + safe_t * N * C + batch * C + labels,
557 -scale * posterior,
558 sem="relaxed",
559 mask=active & valid_state & stored_state,
560 )
562 stay = beta + tl.load(
563 log_probs + safe_t * N * C + batch * C + labels,
564 mask=active & valid_state,
565 other=-float("inf"),
566 ).to(tl.float32)
567 next1 = tl.load(
568 beta_base + states + 1,
569 mask=(states + 1 < state_count) & stored_state,
570 other=-float("inf"),
571 ).to(tl.float32) + tl.load(
572 log_probs + safe_t * N * C + batch * C + labels1,
573 mask=active & (states + 1 < state_count) & stored_state,
574 other=-float("inf"),
575 ).to(
576 tl.float32
577 )
578 skip_allowed = (
579 (~is_blank_state)
580 & (states + 2 < state_count)
581 & (target_value != target_value2)
582 )
583 next2 = tl.load(
584 beta_base + states + 2,
585 mask=(states + 2 < state_count) & stored_state,
586 other=-float("inf"),
587 ).to(tl.float32) + tl.load(
588 log_probs + safe_t * N * C + batch * C + labels2,
589 mask=active & skip_allowed & stored_state,
590 other=-float("inf"),
591 ).to(
592 tl.float32
593 )
595 beta_next = _logaddexp3(stay, next1, next2, skip_allowed)
596 tl.store(
597 next_beta_base + states,
598 tl.where(active, beta_next, -float("inf")),
599 mask=stored_state,
600 )
601 _debug_barrier()
604def _reduction_enum(reduction):
605 if isinstance(reduction, str):
606 if reduction == "none":
607 return _REDUCTION_NONE
608 if reduction == "mean":
609 return _REDUCTION_MEAN
610 if reduction == "sum":
611 return _REDUCTION_SUM
612 raise ValueError(
613 "ctc_loss expected reduction to be one of 'none', 'mean', or 'sum', "
614 f"but got {reduction!r}"
615 )
616 return int(reduction)
619_INTEGRAL_DTYPES = {
620 torch.uint8,
621 torch.int8,
622 torch.int16,
623 torch.int32,
624 torch.int64,
625}
628def _is_integral_dtype(dtype):
629 return dtype in _INTEGRAL_DTYPES
632def _lengths_to_tensor(lengths, device, name):
633 if torch.is_tensor(lengths):
634 if not _is_integral_dtype(lengths.dtype):
635 raise RuntimeError(f"{name} must be integral")
636 out = lengths.to(device=device)
637 else:
638 out = torch.tensor(lengths, device=device)
639 if not _is_integral_dtype(out.dtype):
640 raise RuntimeError(f"{name} must be integral")
641 if out.dtype != torch.long:
642 out = out.to(dtype=torch.long)
643 return out.reshape(1) if out.ndim == 0 else out.reshape(-1).contiguous()
646def _length_stats(lengths):
647 key = None
648 if torch.is_tensor(lengths):
649 key = (
650 lengths.device.type,
651 lengths.device.index,
652 lengths.data_ptr(),
653 lengths.numel(),
654 lengths._version,
655 )
656 cached = _LENGTH_STATS_CACHE.get(key)
657 if cached is not None:
658 return cached[1]
660 stats_tensor = torch.stack((lengths.min(), lengths.max(), lengths.sum())).cpu()
661 stats = tuple(int(value) for value in stats_tensor.tolist())
662 if key is not None:
663 if len(_LENGTH_STATS_CACHE) >= _LENGTH_STATS_CACHE_LIMIT:
664 _LENGTH_STATS_CACHE.clear()
665 _LENGTH_STATS_CACHE[key] = (lengths, stats)
666 return stats
669def _compute_dtype(dtype):
670 if dtype in (torch.float16, torch.bfloat16):
671 return torch.float32
672 return dtype
675class CtcLossFunction(torch.autograd.Function):
676 @staticmethod
677 def forward(
678 ctx,
679 log_probs,
680 targets,
681 input_lengths,
682 target_lengths,
683 blank=0,
684 reduction="mean",
685 zero_infinity=False,
686 ):
687 reduction = _reduction_enum(reduction)
688 if reduction not in (_REDUCTION_NONE, _REDUCTION_MEAN, _REDUCTION_SUM):
689 raise ValueError(f"ctc_loss got invalid reduction enum {reduction}")
691 if log_probs.ndim not in (2, 3):
692 raise RuntimeError(
693 "ctc_loss expects log_probs to be a 2D or 3D tensor, "
694 f"but got {log_probs.ndim}D"
695 )
696 if not torch.is_floating_point(log_probs):
697 raise RuntimeError(f'"ctc_loss" not implemented for {log_probs.dtype}')
698 if blank < 0 or blank >= log_probs.shape[-1]:
699 raise RuntimeError("blank must be in label range")
701 original_dtype = log_probs.dtype
702 compute_dtype = _compute_dtype(original_dtype)
703 unbatched = log_probs.ndim == 2
704 batch_size = 1 if unbatched else log_probs.shape[1]
706 work_log_probs = log_probs.unsqueeze(1) if unbatched else log_probs
707 work_log_probs = work_log_probs.contiguous()
708 if work_log_probs.dtype != compute_dtype:
709 work_log_probs = work_log_probs.to(compute_dtype)
711 if torch.is_floating_point(targets):
712 work_targets = targets.to(dtype=torch.long).contiguous()
713 elif _is_integral_dtype(targets.dtype):
714 work_targets = targets.contiguous()
715 else:
716 raise RuntimeError("ctc_loss targets must be integral or floating point")
717 work_input_lengths = _lengths_to_tensor(
718 input_lengths, log_probs.device, "input_lengths"
719 )
720 work_target_lengths = _lengths_to_tensor(
721 target_lengths, log_probs.device, "target_lengths"
722 )
723 if work_input_lengths.numel() != batch_size:
724 raise RuntimeError(
725 f"ctc_loss expected input_lengths to have size {batch_size}, "
726 f"but got {work_input_lengths.numel()}"
727 )
728 if work_target_lengths.numel() != batch_size:
729 raise RuntimeError(
730 f"ctc_loss expected target_lengths to have size {batch_size}, "
731 f"but got {work_target_lengths.numel()}"
732 )
733 min_input_length, max_input_length, _ = _length_stats(work_input_lengths)
734 min_target_length, max_target, total_target_length = _length_stats(
735 work_target_lengths
736 )
737 if min_input_length < 0 or max_input_length > work_log_probs.shape[0]:
738 raise RuntimeError("ctc_loss input_lengths must be in [0, T]")
739 if min_target_length < 0:
740 raise RuntimeError("ctc_loss target_lengths must be non-negative")
742 state_count_max = 2 * max_target + 1
743 target_stride = max_target
744 if work_targets.ndim == 1:
745 target_1d = True
746 if total_target_length != work_targets.numel():
747 raise RuntimeError(
748 "ctc_loss expected concatenated targets length to equal "
749 "sum(target_lengths)"
750 )
751 work_target_offsets = (
752 work_target_lengths.cumsum(0) - work_target_lengths
753 ).contiguous()
754 elif work_targets.ndim == 2:
755 target_1d = False
756 if max_target > work_targets.shape[1]:
757 raise RuntimeError(
758 "ctc_loss target_lengths cannot exceed padded target width"
759 )
760 target_stride = work_targets.shape[1]
761 work_target_offsets = work_target_lengths
762 else:
763 raise RuntimeError(
764 "ctc_loss expects targets to be a 1D concatenated tensor or a "
765 f"2D padded tensor, but got {work_targets.ndim}D"
766 )
768 needs_log_probs_grad = ctx.needs_input_grad[0]
769 block_s = triton.next_power_of_2(state_count_max)
771 if not needs_log_probs_grad:
772 if (
773 not unbatched
774 and not zero_infinity
775 and reduction in (_REDUCTION_MEAN, _REDUCTION_SUM)
776 and min_input_length == work_log_probs.shape[0]
777 and work_log_probs.shape[0] > 0
778 ):
779 contrib = torch.empty(
780 (batch_size,), dtype=torch.float32, device=log_probs.device
781 )
782 scratch_alpha = torch.empty(
783 (batch_size, 2, state_count_max),
784 dtype=torch.float32,
785 device=log_probs.device,
786 )
787 with torch_device_fn.device(log_probs.device):
788 _ctc_loss_forward_full_length_reduce_kernel[(batch_size,)](
789 work_log_probs,
790 work_targets,
791 work_target_lengths,
792 work_target_offsets,
793 contrib,
794 scratch_alpha,
795 work_log_probs.shape[0],
796 batch_size,
797 work_log_probs.shape[2],
798 target_stride,
799 state_count_max,
800 blank,
801 target_1d,
802 reduction,
803 block_s,
804 )
805 output = contrib.sum()
806 if output.dtype != original_dtype:
807 output = output.to(original_dtype)
808 return output
810 raw_neg_log_likelihood = torch.empty(
811 (batch_size,), dtype=torch.float32, device=log_probs.device
812 )
813 scratch_alpha = torch.empty(
814 (batch_size, 2, state_count_max),
815 dtype=torch.float32,
816 device=log_probs.device,
817 )
818 with torch_device_fn.device(log_probs.device):
819 _ctc_loss_forward_no_grad_kernel[(batch_size,)](
820 work_log_probs,
821 work_targets,
822 work_input_lengths,
823 work_target_lengths,
824 work_target_offsets,
825 raw_neg_log_likelihood,
826 scratch_alpha,
827 work_log_probs.shape[0],
828 batch_size,
829 work_log_probs.shape[2],
830 target_stride,
831 state_count_max,
832 blank,
833 target_1d,
834 block_s,
835 )
836 neg_log_likelihood = raw_neg_log_likelihood
837 if zero_infinity:
838 neg_log_likelihood = torch.where(
839 torch.isinf(neg_log_likelihood),
840 torch.zeros(
841 (), dtype=neg_log_likelihood.dtype, device=log_probs.device
842 ),
843 neg_log_likelihood,
844 )
846 if reduction == _REDUCTION_NONE:
847 output = neg_log_likelihood
848 if unbatched:
849 output = output.squeeze(0)
850 elif reduction == _REDUCTION_SUM:
851 output = neg_log_likelihood.sum()
852 else:
853 denom = work_target_lengths.clamp_min(1)
854 output = (neg_log_likelihood / denom).mean()
856 if output.dtype != original_dtype:
857 output = output.to(original_dtype)
858 return output
860 raw_neg_log_likelihood = torch.empty(
861 (batch_size,), dtype=torch.float32, device=log_probs.device
862 )
864 log_alpha = torch.empty(
865 (batch_size, work_log_probs.shape[0], state_count_max),
866 dtype=torch.float32,
867 device=log_probs.device,
868 )
869 with torch_device_fn.device(log_probs.device):
870 _ctc_loss_forward_kernel[(batch_size,)](
871 work_log_probs,
872 work_targets,
873 work_input_lengths,
874 work_target_lengths,
875 work_target_offsets,
876 raw_neg_log_likelihood,
877 log_alpha,
878 work_log_probs.shape[0],
879 batch_size,
880 work_log_probs.shape[2],
881 target_stride,
882 state_count_max,
883 blank,
884 target_1d,
885 block_s,
886 )
887 neg_log_likelihood = raw_neg_log_likelihood
888 if zero_infinity:
889 neg_log_likelihood = torch.where(
890 torch.isinf(neg_log_likelihood),
891 torch.zeros(
892 (), dtype=neg_log_likelihood.dtype, device=log_probs.device
893 ),
894 neg_log_likelihood,
895 )
897 if reduction == _REDUCTION_NONE:
898 output = neg_log_likelihood
899 if unbatched:
900 output = output.squeeze(0)
901 if output.dtype != original_dtype:
902 output = output.to(original_dtype)
903 elif reduction == _REDUCTION_SUM:
904 output = neg_log_likelihood.sum()
905 else:
906 denom = work_target_lengths.clamp_min(1)
907 output = (neg_log_likelihood / denom).mean()
909 if output.dtype != original_dtype:
910 output = output.to(original_dtype)
912 ctx.save_for_backward(
913 work_log_probs,
914 work_targets,
915 work_input_lengths,
916 work_target_lengths,
917 work_target_offsets,
918 raw_neg_log_likelihood,
919 log_alpha,
920 )
921 ctx.blank = blank
922 ctx.reduction = reduction
923 ctx.zero_infinity = zero_infinity
924 ctx.unbatched = unbatched
925 ctx.batch_size = batch_size
926 ctx.original_dtype = original_dtype
927 ctx.max_target = target_stride
928 ctx.state_count_max = state_count_max
929 ctx.target_1d = target_1d
931 return output
933 @staticmethod
934 def backward(ctx, grad_output):
935 (
936 work_log_probs,
937 work_targets,
938 work_input_lengths,
939 work_target_lengths,
940 work_target_offsets,
941 neg_log_likelihood,
942 log_alpha,
943 ) = ctx.saved_tensors
945 grad_output = grad_output.contiguous()
947 grad_log_probs = torch.empty_like(work_log_probs)
948 total = work_log_probs.numel()
949 block = 256
950 with torch_device_fn.device(work_log_probs.device):
951 _ctc_loss_init_grad_kernel[(triton.cdiv(total, block),)](
952 work_log_probs,
953 work_input_lengths,
954 work_target_lengths,
955 neg_log_likelihood,
956 grad_output,
957 grad_log_probs,
958 total,
959 work_log_probs.shape[0],
960 ctx.batch_size,
961 work_log_probs.shape[2],
962 ctx.reduction,
963 ctx.zero_infinity,
964 block,
965 )
967 scratch_beta = torch.empty(
968 (ctx.batch_size, 2, ctx.state_count_max),
969 dtype=torch.float32,
970 device=work_log_probs.device,
971 )
972 block_s = triton.next_power_of_2(ctx.state_count_max)
973 _ctc_loss_backward_kernel[(ctx.batch_size,)](
974 work_log_probs,
975 work_targets,
976 work_input_lengths,
977 work_target_lengths,
978 work_target_offsets,
979 neg_log_likelihood,
980 log_alpha,
981 grad_output,
982 grad_log_probs,
983 scratch_beta,
984 work_log_probs.shape[0],
985 ctx.batch_size,
986 work_log_probs.shape[2],
987 ctx.max_target,
988 ctx.state_count_max,
989 ctx.blank,
990 ctx.target_1d,
991 ctx.reduction,
992 ctx.zero_infinity,
993 block_s,
994 )
996 if ctx.unbatched:
997 grad_log_probs = grad_log_probs.squeeze(1)
998 if grad_log_probs.dtype != ctx.original_dtype:
999 grad_log_probs = grad_log_probs.to(ctx.original_dtype)
1001 return grad_log_probs, None, None, None, None, None, None
1004def ctc_loss(
1005 log_probs,
1006 targets,
1007 input_lengths,
1008 target_lengths,
1009 blank=0,
1010 reduction="mean",
1011 zero_infinity=False,
1012):
1013 logger.debug("GEMS CTC LOSS")
1014 return CtcLossFunction.apply(
1015 log_probs,
1016 targets,
1017 input_lengths,
1018 target_lengths,
1019 blank,
1020 reduction,
1021 zero_infinity,
1022 )