Coverage for src/flag_gems/ops/rot90.py: 37%
98 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
11logger = logging.getLogger(__name__)
14@triton.autotune(configs=runtime.get_tuned_config("rot90"), key=["n_elements"])
15@triton.jit
16def rot90_kernel_2d(
17 in_ptr,
18 out_ptr,
19 n_elements,
20 M,
21 N,
22 k_norm,
23 BLOCK_SIZE: tl.constexpr,
24):
25 """
26 rot90 kernel for rotating a tensor by 90 degrees in the plane [0, 1].
28 Input shape: [M, N, D2, D3, ...]
29 Output shape for k=1,3: [N, M, D2, D3, ...]
30 Output shape for k=0,2: [M, N, D2, D3, ...]
32 Formulas (verified):
33 - k=0 (identity): out[i,j] = in[i,j] -> in_dim0=out_dim0, in_dim1=out_dim1
34 - k=1 (90° clockwise): out[i,j] = in[j, N-1-i]
35 -> in_dim0=out_dim1, in_dim1=N-1-out_dim0
36 - k=2 (180°): out[i,j] = in[M-1-i, N-1-j]
37 -> in_dim0=M-1-out_dim0, in_dim1=N-1-out_dim1
38 - k=3 (270° clockwise / 90° CCW): out[i,j] = in[M-1-j, i]
39 -> in_dim0=M-1-out_dim1, in_dim1=out_dim0
40 """
41 pid = tl.program_id(axis=0)
42 block_start = pid * BLOCK_SIZE
43 offsets = block_start + tl.arange(0, BLOCK_SIZE)
44 mask = offsets < n_elements
46 m_minus_1 = M - 1
47 n_minus_1 = N - 1
49 if k_norm == 0:
50 # Identity case - output same shape as input [M, N, ...]
51 stride_0 = n_elements // M
52 out_dim0 = offsets // stride_0
53 remainder = offsets % stride_0
54 out_dim1 = remainder % N
56 in_dim0 = out_dim0
57 in_dim1 = out_dim1
59 stride_0_in = n_elements // M
60 in_offset = in_dim0 * stride_0_in + in_dim1 * (stride_0_in // N)
62 elif k_norm == 1:
63 # 90° clockwise - output shape [N, M, ...]
64 stride_0 = n_elements // N
65 out_dim0 = offsets // stride_0
66 remainder = offsets % stride_0
67 out_dim1 = remainder % M
69 # out[i,j] = in[j, N-1-i] where i=out_dim0, j=out_dim1
70 in_dim0 = out_dim1
71 in_dim1 = n_minus_1 - out_dim0
73 stride_0_in = n_elements // M
74 in_offset = in_dim0 * stride_0_in + in_dim1 * (stride_0_in // N)
76 elif k_norm == 2:
77 # 180° - output same shape as input [M, N, ...]
78 stride_0 = n_elements // M
79 out_dim0 = offsets // stride_0
80 remainder = offsets % stride_0
81 out_dim1 = remainder % N
83 # out[i,j] = in[M-1-i, N-1-j]
84 in_dim0 = m_minus_1 - out_dim0
85 in_dim1 = n_minus_1 - out_dim1
87 stride_0_in = n_elements // M
88 in_offset = in_dim0 * stride_0_in + in_dim1 * (stride_0_in // N)
90 else: # k_norm == 3
91 # 270° clockwise - output shape [N, M, ...]
92 stride_0 = n_elements // N
93 out_dim0 = offsets // stride_0
94 remainder = offsets % stride_0
95 out_dim1 = remainder % M
97 # out[i,j] = in[M-1-j, i]
98 in_dim0 = m_minus_1 - out_dim1
99 in_dim1 = out_dim0
101 stride_0_in = n_elements // M
102 in_offset = in_dim0 * stride_0_in + in_dim1 * (stride_0_in // N)
104 x = tl.load(in_ptr + in_offset, mask=mask)
105 tl.store(out_ptr + offsets, x, mask=mask)
108def rot90_2d(inp, k, dims, out):
109 """Handle the case when dims = [0, 1] using optimized Triton kernel."""
110 M = inp.shape[dims[0]]
111 N = inp.shape[dims[1]]
112 n_elements = out.numel()
113 if n_elements == 0:
114 return
116 # Normalize k to 0, 1, 2, 3
117 k_norm = ((k % 4) + 4) % 4
119 grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"]),)
120 with torch_device_fn.device(inp.device):
121 rot90_kernel_2d[grid](
122 inp,
123 out,
124 n_elements,
125 M,
126 N,
127 k_norm,
128 )
131def rot90(input, k=1, dims=[0, 1]):
132 """
133 Rotate an n-D tensor by 90 degrees in the plane specified by dims.
135 Args:
136 input: the input tensor
137 k: number of times to rotate (default: 1)
138 dims: axis to rotate (default: [0, 1])
140 Returns:
141 Rotated tensor
142 """
143 logger.debug("GEMS ROT90")
144 x = input
145 if not x.is_contiguous():
146 x = x.contiguous()
148 dim0, dim1 = dims[0], dims[1]
149 M = x.shape[dim0]
150 N = x.shape[dim1]
152 # Normalize k to 0, 1, 2, 3
153 k_norm = ((k % 4) + 4) % 4
155 # For k=0 or k=2, output shape is same as input
156 # For k=1 or k=3, output dims are swapped
157 if k_norm == 0 or k_norm == 2:
158 out_shape = list(x.shape)
159 else:
160 out_shape = list(x.shape)
161 out_shape[dim0] = N
162 out_shape[dim1] = M
164 out = torch.empty(out_shape, device=x.device, dtype=x.dtype)
166 if dim0 == 0 and dim1 == 1:
167 # Direct path for dims = [0, 1]
168 rot90_2d(x, k, dims, out)
169 else:
170 # General case: transpose to bring dims to [0, 1], rotate, transpose back
171 ndim = x.ndim
173 # Build permutation to move dims[0] and dims[1] to front
174 perm = [dim0, dim1]
175 for i in range(ndim):
176 if i != dim0 and i != dim1:
177 perm.append(i)
179 # Inverse permutation to restore original order
180 inverse_perm = [0] * ndim
181 inverse_perm[dim0] = 0
182 inverse_perm[dim1] = 1
183 idx = 2
184 for i in range(ndim):
185 if i != dim0 and i != dim1:
186 inverse_perm[i] = idx
187 idx += 1
189 # Transpose, rotate 2D plane, transpose back
190 x_transposed = x.permute(perm)
191 out_transposed = torch.empty(out_shape, device=x.device, dtype=x.dtype)
192 rot90_2d(x_transposed, k, [0, 1], out_transposed)
193 out.copy_(out_transposed.permute(inverse_perm))
195 return out