Coverage for src/flag_gems/ops/tril.py: 49%
424 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
9logger = logging.getLogger(__name__)
12@triton.jit
13def _tril_tile_kernel(
14 in_ptr,
15 out_ptr,
16 diag: tl.constexpr,
17 M: tl.constexpr,
18 N: tl.constexpr,
19 BLOCK_M: tl.constexpr,
20 BLOCK_N: tl.constexpr,
21):
22 pid_m = tl.program_id(0)
23 pid_n = tl.program_id(1)
24 pid_b = tl.program_id(2)
26 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
27 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :]
28 mask = (offs_m < M) & (offs_n < N)
30 base = pid_b * (M * N)
31 idxs = base + offs_m * N + offs_n
33 row_start = pid_m * BLOCK_M
34 row_end = row_start + BLOCK_M - 1
35 col_start = pid_n * BLOCK_N
36 col_end = col_start + BLOCK_N - 1
38 if col_start > row_end + diag:
39 tl.store(out_ptr + idxs, 0.0, mask=mask)
40 return
42 if col_end <= row_start + diag:
43 x = tl.load(in_ptr + idxs, mask=mask, other=0.0)
44 tl.store(out_ptr + idxs, x, mask=mask)
45 return
47 keep = offs_n <= (offs_m + diag)
48 x = tl.load(in_ptr + idxs, mask=mask & keep, other=0.0)
49 tl.store(out_ptr + idxs, x, mask=mask)
52@triton.jit
53def _tril_rows_kernel(
54 in_ptr,
55 out_ptr,
56 diag,
57 M: tl.constexpr,
58 N: tl.constexpr,
59 BLOCK_M: tl.constexpr,
60 BLOCK_N: tl.constexpr,
61):
62 pid_m = tl.program_id(0)
63 pid_b = tl.program_id(1)
65 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
66 row_mask = offs_m < M
67 base = pid_b * (M * N)
68 row_base = base + offs_m * N
69 row_start = pid_m * BLOCK_M
70 row_end = row_start + BLOCK_M - 1
72 for col_start in range(0, N, BLOCK_N):
73 offs_n = col_start + tl.arange(0, BLOCK_N)[None, :]
74 mask = row_mask & (offs_n < N)
75 idxs = row_base + offs_n
77 col_end = col_start + BLOCK_N - 1
78 if col_start > row_end + diag:
79 tl.store(out_ptr + idxs, 0.0, mask=mask)
80 elif col_end <= row_start + diag:
81 x = tl.load(in_ptr + idxs, mask=mask, other=0.0)
82 tl.store(out_ptr + idxs, x, mask=mask)
83 else:
84 keep = offs_n <= (offs_m + diag)
85 x = tl.load(in_ptr + idxs, mask=mask & keep, other=0.0)
86 tl.store(out_ptr + idxs, x, mask=mask)
89@triton.jit
90def _tril_flat_kernel(
91 in_ptr,
92 out_ptr,
93 total,
94 diag,
95 M: tl.constexpr,
96 N: tl.constexpr,
97 BLOCK_SIZE: tl.constexpr,
98):
99 pid = tl.program_id(0)
100 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
101 mask = offsets < total
103 matrix_offsets = offsets % (M * N)
104 rows = matrix_offsets // N
105 cols = matrix_offsets - rows * N
106 keep = cols <= rows + diag
108 x = tl.load(in_ptr + offsets, mask=mask & keep, other=0.0)
109 tl.store(out_ptr + offsets, x, mask=mask)
112@triton.jit
113def _tril_exact_row_kernel(
114 in_ptr,
115 out_ptr,
116 diag,
117 M: tl.constexpr,
118 N: tl.constexpr,
119 BLOCK_N: tl.constexpr,
120):
121 pid_m = tl.program_id(0)
122 pid_b = tl.program_id(1)
124 offs_n = tl.arange(0, BLOCK_N)
125 idxs = pid_b * (M * N) + pid_m * N + offs_n
126 keep = offs_n <= pid_m + diag
127 x = tl.load(in_ptr + idxs, mask=keep, other=0.0)
128 tl.store(out_ptr + idxs, x)
131@triton.jit
132def _tril_exact_diag0_tile_kernel(
133 in_ptr,
134 out_ptr,
135 M: tl.constexpr,
136 N: tl.constexpr,
137 BLOCK_M: tl.constexpr,
138 BLOCK_N: tl.constexpr,
139):
140 pid_m = tl.program_id(0)
141 pid_n = tl.program_id(1)
142 pid_b = tl.program_id(2)
144 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
145 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :]
146 idxs = pid_b * (M * N) + offs_m * N + offs_n
148 row_start = pid_m * BLOCK_M
149 row_end = row_start + BLOCK_M - 1
150 col_start = pid_n * BLOCK_N
151 col_end = col_start + BLOCK_N - 1
153 if col_start > row_end:
154 tl.store(out_ptr + idxs, 0.0)
155 return
157 if col_end <= row_start:
158 x = tl.load(in_ptr + idxs)
159 tl.store(out_ptr + idxs, x)
160 return
162 keep = offs_n <= offs_m
163 x = tl.load(in_ptr + idxs, mask=keep, other=0.0)
164 tl.store(out_ptr + idxs, x)
167@triton.jit
168def _tril_inplace_zero_tile_kernel(
169 ptr,
170 diag: tl.constexpr,
171 M: tl.constexpr,
172 N: tl.constexpr,
173 BLOCK_M: tl.constexpr,
174 BLOCK_N: tl.constexpr,
175):
176 pid_m = tl.program_id(0)
177 pid_n = tl.program_id(1)
178 pid_b = tl.program_id(2)
180 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
181 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :]
182 mask = (offs_m < M) & (offs_n < N)
183 idxs = pid_b * (M * N) + offs_m * N + offs_n
185 row_start = pid_m * BLOCK_M
186 col_end = pid_n * BLOCK_N + BLOCK_N - 1
187 if col_end <= row_start + diag:
188 return
190 row_end = row_start + BLOCK_M - 1
191 col_start = pid_n * BLOCK_N
192 if col_start > row_end + diag:
193 tl.store(ptr + idxs, 0.0, mask=mask)
194 return
196 zero = offs_n > offs_m + diag
197 tl.store(ptr + idxs, 0.0, mask=mask & zero)
200@triton.jit
201def _tril_inplace_zero_strided_tile_kernel(
202 ptr,
203 diag: tl.constexpr,
204 M: tl.constexpr,
205 N: tl.constexpr,
206 B0: tl.constexpr,
207 B1: tl.constexpr,
208 B2: tl.constexpr,
209 B3: tl.constexpr,
210 B4: tl.constexpr,
211 B5: tl.constexpr,
212 S0: tl.constexpr,
213 S1: tl.constexpr,
214 S2: tl.constexpr,
215 S3: tl.constexpr,
216 S4: tl.constexpr,
217 S5: tl.constexpr,
218 STRIDE_M: tl.constexpr,
219 STRIDE_N: tl.constexpr,
220 BLOCK_M: tl.constexpr,
221 BLOCK_N: tl.constexpr,
222):
223 pid_m = tl.program_id(0)
224 pid_n = tl.program_id(1)
225 pid_b = tl.program_id(2)
227 b = pid_b
228 i5 = b % B5
229 b = b // B5
230 i4 = b % B4
231 b = b // B4
232 i3 = b % B3
233 b = b // B3
234 i2 = b % B2
235 b = b // B2
236 i1 = b % B1
237 i0 = b // B1
238 batch_offset = i0 * S0 + i1 * S1 + i2 * S2 + i3 * S3 + i4 * S4 + i5 * S5
240 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
241 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :]
242 mask = (offs_m < M) & (offs_n < N)
243 idxs = batch_offset + offs_m * STRIDE_M + offs_n * STRIDE_N
245 row_start = pid_m * BLOCK_M
246 col_end = pid_n * BLOCK_N + BLOCK_N - 1
247 if col_end <= row_start + diag:
248 return
250 row_end = row_start + BLOCK_M - 1
251 col_start = pid_n * BLOCK_N
252 if col_start > row_end + diag:
253 tl.store(ptr + idxs, 0.0, mask=mask)
254 return
256 zero = offs_n > offs_m + diag
257 tl.store(ptr + idxs, 0.0, mask=mask & zero)
260@triton.jit
261def _tril_strided_out_tile_kernel(
262 in_ptr,
263 out_ptr,
264 diag,
265 M: tl.constexpr,
266 N: tl.constexpr,
267 B0: tl.constexpr,
268 B1: tl.constexpr,
269 B2: tl.constexpr,
270 B3: tl.constexpr,
271 B4: tl.constexpr,
272 B5: tl.constexpr,
273 S0: tl.constexpr,
274 S1: tl.constexpr,
275 S2: tl.constexpr,
276 S3: tl.constexpr,
277 S4: tl.constexpr,
278 S5: tl.constexpr,
279 STRIDE_M: tl.constexpr,
280 STRIDE_N: tl.constexpr,
281 BLOCK_M: tl.constexpr,
282 BLOCK_N: tl.constexpr,
283):
284 pid_m = tl.program_id(0)
285 pid_n = tl.program_id(1)
286 pid_b = tl.program_id(2)
288 b = pid_b
289 i5 = b % B5
290 b = b // B5
291 i4 = b % B4
292 b = b // B4
293 i3 = b % B3
294 b = b // B3
295 i2 = b % B2
296 b = b // B2
297 i1 = b % B1
298 i0 = b // B1
299 out_batch_offset = i0 * S0 + i1 * S1 + i2 * S2 + i3 * S3 + i4 * S4 + i5 * S5
301 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
302 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :]
303 mask = (offs_m < M) & (offs_n < N)
304 in_idxs = pid_b * (M * N) + offs_m * N + offs_n
305 out_idxs = out_batch_offset + offs_m * STRIDE_M + offs_n * STRIDE_N
307 row_start = pid_m * BLOCK_M
308 row_end = row_start + BLOCK_M - 1
309 col_start = pid_n * BLOCK_N
310 col_end = col_start + BLOCK_N - 1
312 if col_start > row_end + diag:
313 tl.store(out_ptr + out_idxs, 0.0, mask=mask)
314 return
316 if col_end <= row_start + diag:
317 x = tl.load(in_ptr + in_idxs, mask=mask, other=0.0)
318 tl.store(out_ptr + out_idxs, x, mask=mask)
319 return
321 keep = offs_n <= (offs_m + diag)
322 x = tl.load(in_ptr + in_idxs, mask=mask & keep, other=0.0)
323 tl.store(out_ptr + out_idxs, x, mask=mask)
326def _check_input(input: torch.Tensor):
327 if input.dim() < 2:
328 raise RuntimeError("tril: input tensor must have at least 2 dimensions")
331def _empty_contiguous_like(input: torch.Tensor):
332 if input.is_contiguous():
333 return torch.empty_like(input)
334 return torch.empty_like(input, memory_format=torch.contiguous_format)
337def _zero_out(out: torch.Tensor):
338 if out.numel() == 0:
339 return out
340 if out.is_contiguous():
341 return out.zero_()
342 return out.fill_(0)
345def _is_power_of_2(value: int):
346 return value > 0 and (value & (value - 1)) == 0
349def _has_internal_overlap_from_strides(tensor: torch.Tensor):
350 span = 1
351 strides_and_sizes = sorted(
352 (stride, size)
353 for size, stride in zip(tensor.shape, tensor.stride())
354 if size > 1
355 )
356 for stride, size in strides_and_sizes:
357 if stride < span:
358 return True
359 span += stride * (size - 1)
360 return False
363def _tensors_overlap(left: torch.Tensor, right: torch.Tensor):
364 try:
365 return torch._C._overlaps(left, right)
366 except AttributeError:
367 return True
370def _can_use_strided_out_kernel(input: torch.Tensor, out: torch.Tensor):
371 if out.is_contiguous() or out.numel() == 0:
372 return False
373 if out.dim() - 2 > 6:
374 return False
375 if _has_internal_overlap_from_strides(out):
376 return False
377 if input.is_contiguous() and _tensors_overlap(input, out):
378 return False
379 return True
382_WIDE_EXACT_ROW_MIN_N = 2048
383_WIDE_EXACT_ROW_MAX_N = 8192
384_WIDE_EXACT_ROW_MIN_ROWS = 256
385_WIDE_EXACT_ROW_ALWAYS_ROW_M = 512
386_TINY_BATCHED_TILE_MIN_BATCH = 128
389def _use_wide_exact_row(M: int, N: int, batch: int):
390 # One exact-row program covers one matrix row with BLOCK_N == N. Use it for
391 # wide power-of-two rows where it avoids the flat kernel's div/mod indexing,
392 # but require enough row programs to keep occupancy reasonable.
393 if N < _WIDE_EXACT_ROW_MIN_N or N > _WIDE_EXACT_ROW_MAX_N or not _is_power_of_2(N):
394 return False
396 rows = M * batch
397 if M >= _WIDE_EXACT_ROW_ALWAYS_ROW_M:
398 return True
399 return N <= 4096 and rows >= _WIDE_EXACT_ROW_MIN_ROWS
402def _use_tiny_batched_tile(M: int, N: int, batch: int):
403 return batch >= _TINY_BATCHED_TILE_MIN_BATCH and M <= 32 and N <= 32
406def _wide_exact_row_warps(N: int):
407 if N <= 4096:
408 return 2
409 return 4
412def _launch_tile(
413 input: torch.Tensor,
414 out: torch.Tensor,
415 diagonal: int,
416 block_m: int = 32,
417 block_n: int = 32,
418 num_warps: int = 4,
419 num_stages: int = 2,
420):
421 M, N = input.shape[-2:]
422 total = input.numel()
423 if total == 0:
424 return out
426 batch = total // (M * N)
427 grid = (triton.cdiv(M, block_m), triton.cdiv(N, block_n), batch)
428 with torch_device_fn.device(input.device):
429 _tril_tile_kernel[grid](
430 input,
431 out,
432 int(diagonal),
433 M,
434 N,
435 BLOCK_M=block_m,
436 BLOCK_N=block_n,
437 num_warps=num_warps,
438 num_stages=num_stages,
439 )
440 return out
443def _launch_flat(
444 input: torch.Tensor,
445 out: torch.Tensor,
446 diagonal: int,
447 block_size: int = 1024,
448 num_warps: int = 4,
449 num_stages: int = 2,
450):
451 M, N = input.shape[-2:]
452 total = input.numel()
453 if total == 0:
454 return out
456 grid = (triton.cdiv(total, block_size),)
457 with torch_device_fn.device(input.device):
458 _tril_flat_kernel[grid](
459 input,
460 out,
461 total,
462 int(diagonal),
463 M,
464 N,
465 BLOCK_SIZE=block_size,
466 num_warps=num_warps,
467 num_stages=num_stages,
468 )
469 return out
472def _launch_rows(
473 input: torch.Tensor,
474 out: torch.Tensor,
475 diagonal: int,
476 block_m: int = 32,
477 block_n: int = 64,
478 num_warps: int = 4,
479 num_stages: int = 2,
480):
481 M, N = input.shape[-2:]
482 total = input.numel()
483 if total == 0:
484 return out
486 batch = total // (M * N)
487 grid = (triton.cdiv(M, block_m), batch)
488 with torch_device_fn.device(input.device):
489 _tril_rows_kernel[grid](
490 input,
491 out,
492 int(diagonal),
493 M,
494 N,
495 BLOCK_M=block_m,
496 BLOCK_N=block_n,
497 num_warps=num_warps,
498 num_stages=num_stages,
499 )
500 return out
503def _launch_exact_row(
504 input: torch.Tensor,
505 out: torch.Tensor,
506 diagonal: int,
507 num_warps: int = 4,
508 num_stages: int = 2,
509):
510 M, N = input.shape[-2:]
511 total = input.numel()
512 if total == 0:
513 return out
515 batch = total // (M * N)
516 grid = (M, batch)
517 with torch_device_fn.device(input.device):
518 _tril_exact_row_kernel[grid](
519 input,
520 out,
521 int(diagonal),
522 M,
523 N,
524 BLOCK_N=N,
525 num_warps=num_warps,
526 num_stages=num_stages,
527 )
528 return out
531def _launch_exact_diag0_tile(
532 input: torch.Tensor,
533 out: torch.Tensor,
534 block_m: int,
535 block_n: int,
536 num_warps: int = 4,
537 num_stages: int = 2,
538):
539 M, N = input.shape[-2:]
540 total = input.numel()
541 if total == 0:
542 return out
544 batch = total // (M * N)
545 grid = (triton.cdiv(M, block_m), triton.cdiv(N, block_n), batch)
546 with torch_device_fn.device(input.device):
547 _tril_exact_diag0_tile_kernel[grid](
548 input,
549 out,
550 M,
551 N,
552 BLOCK_M=block_m,
553 BLOCK_N=block_n,
554 num_warps=num_warps,
555 num_stages=num_stages,
556 )
557 return out
560def _launch_tril_inplace_contiguous(
561 input: torch.Tensor,
562 diagonal: int,
563 block_m: int = 16,
564 block_n: int = 64,
565 num_warps: int = 4,
566 num_stages: int = 2,
567):
568 M, N = input.shape[-2:]
569 if input.numel() == 0:
570 return input
572 active_rows = min(M, max(0, N - 1 - diagonal))
573 if active_rows == 0:
574 return input
576 batch = input.numel() // (M * N)
577 grid = (triton.cdiv(active_rows, block_m), triton.cdiv(N, block_n), batch)
578 with torch_device_fn.device(input.device):
579 _tril_inplace_zero_tile_kernel[grid](
580 input,
581 int(diagonal),
582 M,
583 N,
584 BLOCK_M=block_m,
585 BLOCK_N=block_n,
586 num_warps=num_warps,
587 num_stages=num_stages,
588 )
589 return input
592def _launch_tril_inplace_strided(
593 input: torch.Tensor,
594 diagonal: int,
595 block_m: int = 16,
596 block_n: int = 64,
597 num_warps: int = 4,
598 num_stages: int = 2,
599):
600 M, N = input.shape[-2:]
601 if input.numel() == 0:
602 return input
604 active_rows = min(M, max(0, N - 1 - diagonal))
605 if active_rows == 0:
606 return input
608 batch_shape = list(input.shape[:-2])
609 batch_strides = list(input.stride()[:-2])
610 batch = 1
611 for size in batch_shape:
612 batch *= size
614 if len(batch_shape) > 6:
615 tmp = _empty_contiguous_like(input)
616 _launch_tril(input, tmp, diagonal)
617 input.copy_(tmp)
618 return input
620 batch_shape.extend([1] * (6 - len(batch_shape)))
621 batch_strides.extend([0] * (6 - len(batch_strides)))
622 stride_m, stride_n = input.stride()[-2:]
624 grid = (triton.cdiv(active_rows, block_m), triton.cdiv(N, block_n), batch)
625 with torch_device_fn.device(input.device):
626 _tril_inplace_zero_strided_tile_kernel[grid](
627 input,
628 int(diagonal),
629 M,
630 N,
631 B0=batch_shape[0],
632 B1=batch_shape[1],
633 B2=batch_shape[2],
634 B3=batch_shape[3],
635 B4=batch_shape[4],
636 B5=batch_shape[5],
637 S0=batch_strides[0],
638 S1=batch_strides[1],
639 S2=batch_strides[2],
640 S3=batch_strides[3],
641 S4=batch_strides[4],
642 S5=batch_strides[5],
643 STRIDE_M=stride_m,
644 STRIDE_N=stride_n,
645 BLOCK_M=block_m,
646 BLOCK_N=block_n,
647 num_warps=num_warps,
648 num_stages=num_stages,
649 )
650 return input
653def _launch_tril_strided_out(
654 input: torch.Tensor,
655 out: torch.Tensor,
656 diagonal: int,
657 block_m: int = 32,
658 block_n: int = 64,
659 num_warps: int = 4,
660 num_stages: int = 2,
661):
662 M, N = input.shape[-2:]
663 if input.numel() == 0:
664 return out
666 input_to_use = input if input.is_contiguous() else input.contiguous()
667 batch_shape = list(out.shape[:-2])
668 batch_strides = list(out.stride()[:-2])
669 batch = 1
670 for size in batch_shape:
671 batch *= size
673 batch_shape.extend([1] * (6 - len(batch_shape)))
674 batch_strides.extend([0] * (6 - len(batch_strides)))
675 stride_m, stride_n = out.stride()[-2:]
677 grid = (triton.cdiv(M, block_m), triton.cdiv(N, block_n), batch)
678 with torch_device_fn.device(input.device):
679 _tril_strided_out_tile_kernel[grid](
680 input_to_use,
681 out,
682 int(diagonal),
683 M,
684 N,
685 B0=batch_shape[0],
686 B1=batch_shape[1],
687 B2=batch_shape[2],
688 B3=batch_shape[3],
689 B4=batch_shape[4],
690 B5=batch_shape[5],
691 S0=batch_strides[0],
692 S1=batch_strides[1],
693 S2=batch_strides[2],
694 S3=batch_strides[3],
695 S4=batch_strides[4],
696 S5=batch_strides[5],
697 STRIDE_M=stride_m,
698 STRIDE_N=stride_n,
699 BLOCK_M=block_m,
700 BLOCK_N=block_n,
701 num_warps=num_warps,
702 num_stages=num_stages,
703 )
704 return out
707def _launch_tril(input: torch.Tensor, out: torch.Tensor, diagonal: int):
708 M, N = input.shape[-2:]
709 if input.numel() == 0:
710 return out
712 if diagonal <= -M:
713 return _zero_out(out)
714 if diagonal >= N - 1:
715 out.copy_(input)
716 return out
718 input_to_use = input if input.is_contiguous() else input.contiguous()
719 batch = input_to_use.numel() // (M * N)
720 if _use_wide_exact_row(M, N, batch):
721 return _launch_exact_row(
722 input_to_use,
723 out,
724 diagonal,
725 num_warps=_wide_exact_row_warps(N),
726 )
727 if batch == 1 and M == 1024 and N == 1024 and diagonal == 0:
728 return _launch_exact_diag0_tile(
729 input_to_use,
730 out,
731 block_m=32,
732 block_n=64,
733 num_warps=4,
734 )
735 if batch >= 1 and M == 512 and N == 512 and diagonal == 0:
736 return _launch_exact_diag0_tile(
737 input_to_use,
738 out,
739 block_m=16,
740 block_n=128,
741 num_warps=4,
742 )
743 if _use_tiny_batched_tile(M, N, batch):
744 return _launch_tile(
745 input_to_use,
746 out,
747 diagonal,
748 block_m=16,
749 block_n=64,
750 num_warps=2,
751 )
752 if M <= 64 and N <= 64:
753 return _launch_rows(
754 input_to_use,
755 out,
756 diagonal,
757 block_m=2,
758 block_n=64,
759 num_warps=1,
760 )
761 if N >= 2048:
762 return _launch_flat(
763 input_to_use,
764 out,
765 diagonal,
766 block_size=4096,
767 num_warps=4,
768 )
769 if batch > 1:
770 if M >= 256 and N >= 256:
771 return _launch_tile(
772 input_to_use,
773 out,
774 diagonal,
775 block_m=16,
776 block_n=64,
777 num_warps=4,
778 )
779 return _launch_rows(
780 input_to_use,
781 out,
782 diagonal,
783 block_m=8,
784 block_n=512,
785 num_warps=1,
786 )
787 if N >= 512:
788 return _launch_tile(
789 input_to_use,
790 out,
791 diagonal,
792 block_m=64,
793 block_n=64,
794 num_warps=4,
795 )
796 if M == 256 and N == 256:
797 return _launch_rows(
798 input_to_use,
799 out,
800 diagonal,
801 block_m=8,
802 block_n=256,
803 num_warps=2,
804 )
805 return _launch_rows(
806 input_to_use,
807 out,
808 diagonal,
809 block_m=8,
810 block_n=512,
811 num_warps=1,
812 )
815def tril(input: torch.Tensor, diagonal: int = 0):
816 logger.debug("GEMS TRIL")
817 _check_input(input)
819 out = _empty_contiguous_like(input)
820 return _launch_tril(input, out, int(diagonal))
823def tril_(input: torch.Tensor, diagonal: int = 0):
824 logger.debug("GEMS TRIL_")
825 _check_input(input)
827 diagonal = int(diagonal)
828 if input.numel() == 0:
829 return input
831 M, N = input.shape[-2:]
832 if diagonal >= N - 1:
833 return input
834 if diagonal <= -M:
835 return _zero_out(input)
837 if input.is_contiguous():
838 return _launch_tril_inplace_contiguous(input, diagonal)
840 return _launch_tril_inplace_strided(input, diagonal)
843def tril_out(input: torch.Tensor, diagonal: int = 0, *, out: torch.Tensor = None):
844 logger.debug("GEMS TRIL.OUT")
846 if out is None:
847 return tril(input, diagonal)
849 _check_input(input)
850 if out.dtype != input.dtype:
851 raise RuntimeError(
852 f"Expected out tensor to have dtype {input.dtype}, but got {out.dtype} instead"
853 )
854 if out.device != input.device:
855 raise RuntimeError(
856 f"Expected out tensor to be on device {input.device}, but got {out.device} instead"
857 )
858 if out.shape != input.shape:
859 out.resize_(input.shape)
861 if out.is_contiguous():
862 return _launch_tril(input, out, int(diagonal))
864 if input.numel() == 0:
865 return out
866 M, N = input.shape[-2:]
867 if diagonal <= -M:
868 return _zero_out(out)
869 if diagonal >= N - 1:
870 out.copy_(input)
871 return out
873 if _can_use_strided_out_kernel(input, out):
874 batch = input.numel() // (M * N)
875 if M <= 64 and N <= 64:
876 return _launch_tril_strided_out(
877 input,
878 out,
879 int(diagonal),
880 block_m=16,
881 block_n=64,
882 num_warps=2,
883 )
884 if batch > 1 and M >= 256 and N >= 256:
885 return _launch_tril_strided_out(
886 input,
887 out,
888 int(diagonal),
889 block_m=16,
890 block_n=64,
891 num_warps=4,
892 )
893 return _launch_tril_strided_out(
894 input,
895 out,
896 int(diagonal),
897 block_m=32,
898 block_n=64,
899 num_warps=4,
900 )
902 tmp = _empty_contiguous_like(input)
903 _launch_tril(input, tmp, int(diagonal))
904 out.copy_(tmp)
905 return out