Coverage for src/flag_gems/runtime/backend/_sunrise/ops/triu.py: 0%
108 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as ext
12logger = logging.getLogger(__name__)
15@libentry()
16@triton.autotune(configs=runtime.get_tuned_config("triu"), key=["M", "N"])
17@triton.jit(do_not_specialize=["diagonal"])
18def triu_kernel(
19 X,
20 Y,
21 M,
22 N,
23 diagonal,
24 M_BLOCK_SIZE: tl.constexpr,
25 N_BLOCK_SIZE: tl.constexpr,
26):
27 pid = ext.program_id(0)
28 row = pid * M_BLOCK_SIZE + tl.arange(0, M_BLOCK_SIZE)[:, None]
29 m_mask = row < M
30 X += row * N
31 Y += row * N
33 for n_offset in range(0, N, N_BLOCK_SIZE):
34 cols = n_offset + tl.arange(0, N_BLOCK_SIZE)[None, :]
35 n_mask = cols < N
36 mask = m_mask and n_mask
38 x = tl.load(X + cols, mask, other=0.0)
39 y = tl.where(row + diagonal <= cols, x, 0.0)
40 tl.store(Y + cols, y, mask=mask)
43@libentry()
44@triton.autotune(
45 configs=runtime.get_tuned_config("triu_batch"),
46 key=["batch", "MN", "N", "diagonal"],
47)
48@triton.jit(do_not_specialize=["diagonal"])
49def triu_batch_kernel(
50 X,
51 Y,
52 batch,
53 MN,
54 N,
55 diagonal,
56 BATCH_BLOCK_SIZE: tl.constexpr,
57 MN_BLOCK_SIZE: tl.constexpr,
58):
59 batch_id = ext.program_id(0)
60 mn_id = ext.program_id(1)
61 row = batch_id * BATCH_BLOCK_SIZE + tl.arange(0, BATCH_BLOCK_SIZE)[:, None]
62 batch_mask = row < batch
63 X += row * MN
64 Y += row * MN
66 cols = mn_id * MN_BLOCK_SIZE + tl.arange(0, MN_BLOCK_SIZE)[None, :]
67 mn_mask = cols < MN
68 mask = batch_mask and mn_mask
69 x = tl.load(X + cols, mask, other=0.0)
70 m = cols // N
71 n = cols % N
72 y = tl.where(m + diagonal <= n, x, 0.0)
73 tl.store(Y + cols, y, mask=mask)
76def _check_batch_contiguous(tensor, allow_zero_stride=True):
77 if tensor.is_contiguous():
78 return True, tensor
80 dims = tensor.dim()
82 if dims >= 2:
83 n = tensor.size(-1)
84 stride_row, stride_col = tensor.stride(-2), tensor.stride(-1)
86 if not (stride_col == 1 and stride_row == n):
87 return False, tensor.contiguous()
89 if allow_zero_stride and dims <= 3:
90 return True, tensor
92 expected_stride = tensor.size(-1) * tensor.size(-2)
93 for i in range(dims - 3, -1, -1):
94 if (
95 allow_zero_stride
96 and i == 0
97 and (tensor.stride(i) == 0 or tensor.size(i) == 1)
98 ):
99 continue
101 if tensor.stride(i) != expected_stride:
102 return False, tensor.contiguous()
104 expected_stride *= tensor.size(i)
106 return True, tensor
109def triu(A, diagonal=0):
110 logger.debug("GEMS TRIU")
111 ori_type = A.dtype
112 out = torch.empty(A.shape, device="ptpu").as_strided(A.shape, A.stride())
113 assert len(A.shape) > 1, "Input tensor must have at least 2 dimensions"
115 can_use_directly, A_input = _check_batch_contiguous(A, allow_zero_stride=False)
117 out = torch.empty(
118 A.shape, dtype=A.dtype, device=A.device, memory_format=torch.contiguous_format
119 )
121 M, N = A_input.shape[-2:]
123 with torch_device_fn.device(A_input.device):
124 if len(A_input.shape) == 2:
125 grid = lambda meta: (triton.cdiv(M, meta["M_BLOCK_SIZE"]),)
126 triu_kernel[grid](A_input, out, M, N, diagonal)
127 else:
128 batch = int(torch.numel(A_input) / M / N)
129 B = A_input.view(batch, -1)
130 grid = lambda meta: (
131 triton.cdiv(batch, meta["BATCH_BLOCK_SIZE"]),
132 triton.cdiv(M * N, meta["MN_BLOCK_SIZE"]),
133 )
134 triu_batch_kernel[grid](B, out, batch, M * N, N, diagonal)
135 out = out.view(A.shape)
137 return out.to(ori_type)
140def triu_(A, diagonal=0):
141 logger.debug("GEMS TRIU_ (inplace)")
143 assert len(A.shape) > 1, "Input tensor must have at least 2 dimensions"
144 diagonal = int(diagonal)
145 M, N = A.shape[-2:]
147 can_use_directly, A_to_use = _check_batch_contiguous(A, allow_zero_stride=True)
149 if not can_use_directly:
150 logger.debug(
151 "Input tensor does not satisfy contiguity requirements, "
152 "using temporary tensor for computation"
153 )
155 result_temp = torch.empty_like(A_to_use, memory_format=torch.contiguous_format)
157 with torch_device_fn.device(A.device):
158 if len(A.shape) == 2:
159 grid = lambda meta: (triton.cdiv(M, meta["M_BLOCK_SIZE"]),)
160 triu_kernel[grid](A_to_use, result_temp, M, N, diagonal)
161 else:
162 batch = int(torch.numel(A) / M / N)
163 B = A_to_use.view(batch, -1)
164 result_temp_flat = result_temp.view(batch, -1)
165 grid = lambda meta: (
166 triton.cdiv(batch, meta["BATCH_BLOCK_SIZE"]),
167 triton.cdiv(M * N, meta["MN_BLOCK_SIZE"]),
168 )
169 triu_batch_kernel[grid](B, result_temp_flat, batch, M * N, N, diagonal)
171 A.copy_(result_temp)
172 else:
173 with torch_device_fn.device(A.device):
174 if len(A.shape) == 2:
175 grid = lambda meta: (triton.cdiv(M, meta["M_BLOCK_SIZE"]),)
176 triu_kernel[grid](A, A, M, N, diagonal)
177 else:
178 batch = int(torch.numel(A) / M / N)
179 B = A.view(batch, -1)
180 grid = lambda meta: (
181 triton.cdiv(batch, meta["BATCH_BLOCK_SIZE"]),
182 triton.cdiv(M * N, meta["MN_BLOCK_SIZE"]),
183 )
184 triu_batch_kernel[grid](B, B, batch, M * N, N, diagonal)
186 return A