Coverage for src/flag_gems/ops/rot90.py: 37%

98 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@triton.autotune(configs=runtime.get_tuned_config("rot90"), key=["n_elements"]) 

15@triton.jit 

16def rot90_kernel_2d( 

17 in_ptr, 

18 out_ptr, 

19 n_elements, 

20 M, 

21 N, 

22 k_norm, 

23 BLOCK_SIZE: tl.constexpr, 

24): 

25 """ 

26 rot90 kernel for rotating a tensor by 90 degrees in the plane [0, 1]. 

27 

28 Input shape: [M, N, D2, D3, ...] 

29 Output shape for k=1,3: [N, M, D2, D3, ...] 

30 Output shape for k=0,2: [M, N, D2, D3, ...] 

31 

32 Formulas (verified): 

33 - k=0 (identity): out[i,j] = in[i,j] -> in_dim0=out_dim0, in_dim1=out_dim1 

34 - k=1 (90° clockwise): out[i,j] = in[j, N-1-i] 

35 -> in_dim0=out_dim1, in_dim1=N-1-out_dim0 

36 - k=2 (180°): out[i,j] = in[M-1-i, N-1-j] 

37 -> in_dim0=M-1-out_dim0, in_dim1=N-1-out_dim1 

38 - k=3 (270° clockwise / 90° CCW): out[i,j] = in[M-1-j, i] 

39 -> in_dim0=M-1-out_dim1, in_dim1=out_dim0 

40 """ 

41 pid = tl.program_id(axis=0) 

42 block_start = pid * BLOCK_SIZE 

43 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

44 mask = offsets < n_elements 

45 

46 m_minus_1 = M - 1 

47 n_minus_1 = N - 1 

48 

49 if k_norm == 0: 

50 # Identity case - output same shape as input [M, N, ...] 

51 stride_0 = n_elements // M 

52 out_dim0 = offsets // stride_0 

53 remainder = offsets % stride_0 

54 out_dim1 = remainder % N 

55 

56 in_dim0 = out_dim0 

57 in_dim1 = out_dim1 

58 

59 stride_0_in = n_elements // M 

60 in_offset = in_dim0 * stride_0_in + in_dim1 * (stride_0_in // N) 

61 

62 elif k_norm == 1: 

63 # 90° clockwise - output shape [N, M, ...] 

64 stride_0 = n_elements // N 

65 out_dim0 = offsets // stride_0 

66 remainder = offsets % stride_0 

67 out_dim1 = remainder % M 

68 

69 # out[i,j] = in[j, N-1-i] where i=out_dim0, j=out_dim1 

70 in_dim0 = out_dim1 

71 in_dim1 = n_minus_1 - out_dim0 

72 

73 stride_0_in = n_elements // M 

74 in_offset = in_dim0 * stride_0_in + in_dim1 * (stride_0_in // N) 

75 

76 elif k_norm == 2: 

77 # 180° - output same shape as input [M, N, ...] 

78 stride_0 = n_elements // M 

79 out_dim0 = offsets // stride_0 

80 remainder = offsets % stride_0 

81 out_dim1 = remainder % N 

82 

83 # out[i,j] = in[M-1-i, N-1-j] 

84 in_dim0 = m_minus_1 - out_dim0 

85 in_dim1 = n_minus_1 - out_dim1 

86 

87 stride_0_in = n_elements // M 

88 in_offset = in_dim0 * stride_0_in + in_dim1 * (stride_0_in // N) 

89 

90 else: # k_norm == 3 

91 # 270° clockwise - output shape [N, M, ...] 

92 stride_0 = n_elements // N 

93 out_dim0 = offsets // stride_0 

94 remainder = offsets % stride_0 

95 out_dim1 = remainder % M 

96 

97 # out[i,j] = in[M-1-j, i] 

98 in_dim0 = m_minus_1 - out_dim1 

99 in_dim1 = out_dim0 

100 

101 stride_0_in = n_elements // M 

102 in_offset = in_dim0 * stride_0_in + in_dim1 * (stride_0_in // N) 

103 

104 x = tl.load(in_ptr + in_offset, mask=mask) 

105 tl.store(out_ptr + offsets, x, mask=mask) 

106 

107 

108def rot90_2d(inp, k, dims, out): 

109 """Handle the case when dims = [0, 1] using optimized Triton kernel.""" 

110 M = inp.shape[dims[0]] 

111 N = inp.shape[dims[1]] 

112 n_elements = out.numel() 

113 if n_elements == 0: 

114 return 

115 

116 # Normalize k to 0, 1, 2, 3 

117 k_norm = ((k % 4) + 4) % 4 

118 

119 grid = lambda META: (triton.cdiv(n_elements, META["BLOCK_SIZE"]),) 

120 with torch_device_fn.device(inp.device): 

121 rot90_kernel_2d[grid]( 

122 inp, 

123 out, 

124 n_elements, 

125 M, 

126 N, 

127 k_norm, 

128 ) 

129 

130 

131def rot90(input, k=1, dims=[0, 1]): 

132 """ 

133 Rotate an n-D tensor by 90 degrees in the plane specified by dims. 

134 

135 Args: 

136 input: the input tensor 

137 k: number of times to rotate (default: 1) 

138 dims: axis to rotate (default: [0, 1]) 

139 

140 Returns: 

141 Rotated tensor 

142 """ 

143 logger.debug("GEMS ROT90") 

144 x = input 

145 if not x.is_contiguous(): 

146 x = x.contiguous() 

147 

148 dim0, dim1 = dims[0], dims[1] 

149 M = x.shape[dim0] 

150 N = x.shape[dim1] 

151 

152 # Normalize k to 0, 1, 2, 3 

153 k_norm = ((k % 4) + 4) % 4 

154 

155 # For k=0 or k=2, output shape is same as input 

156 # For k=1 or k=3, output dims are swapped 

157 if k_norm == 0 or k_norm == 2: 

158 out_shape = list(x.shape) 

159 else: 

160 out_shape = list(x.shape) 

161 out_shape[dim0] = N 

162 out_shape[dim1] = M 

163 

164 out = torch.empty(out_shape, device=x.device, dtype=x.dtype) 

165 

166 if dim0 == 0 and dim1 == 1: 

167 # Direct path for dims = [0, 1] 

168 rot90_2d(x, k, dims, out) 

169 else: 

170 # General case: transpose to bring dims to [0, 1], rotate, transpose back 

171 ndim = x.ndim 

172 

173 # Build permutation to move dims[0] and dims[1] to front 

174 perm = [dim0, dim1] 

175 for i in range(ndim): 

176 if i != dim0 and i != dim1: 

177 perm.append(i) 

178 

179 # Inverse permutation to restore original order 

180 inverse_perm = [0] * ndim 

181 inverse_perm[dim0] = 0 

182 inverse_perm[dim1] = 1 

183 idx = 2 

184 for i in range(ndim): 

185 if i != dim0 and i != dim1: 

186 inverse_perm[i] = idx 

187 idx += 1 

188 

189 # Transpose, rotate 2D plane, transpose back 

190 x_transposed = x.permute(perm) 

191 out_transposed = torch.empty(out_shape, device=x.device, dtype=x.dtype) 

192 rot90_2d(x_transposed, k, [0, 1], out_transposed) 

193 out.copy_(out_transposed.permute(inverse_perm)) 

194 

195 return out