Coverage for src/flag_gems/ops/affine_grid_generator.py: 40%

52 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.utils import libentry 

9from flag_gems.utils import triton_lang_extension as tle 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@libentry() 

15@triton.jit 

16def affine_grid_generator_kernel( 

17 output_ptr, 

18 theta_ptr, 

19 N, 

20 H, 

21 W, 

22 align_corners, 

23 OUTPUT_STRIDE0, 

24 OUTPUT_STRIDE1, 

25 OUTPUT_STRIDE2, 

26 OUTPUT_STRIDE3, 

27 THETA_STRIDE0, 

28 THETA_STRIDE1, 

29 THETA_STRIDE2, 

30 BLOCK_SIZE: tl.constexpr, 

31): 

32 # output has shape [N, H, W, 2] 

33 # theta has shape [N, 2, 3] 

34 pid = tle.program_id(0) 

35 num_tasks = N * H * W * 2 

36 

37 if pid * BLOCK_SIZE >= num_tasks: 

38 return 

39 

40 # Compute 4D index: (n, h, w, c) 

41 idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

42 mask = idx < num_tasks 

43 

44 # Unflatten to 4D 

45 tmp = idx // 2 

46 c = idx % 2 # 0 for x, 1 for y 

47 

48 w = tmp % W 

49 tmp = tmp // W 

50 h = tmp % H 

51 n = tmp // H 

52 

53 # Load theta for batch n 

54 # theta[n, 0, 0], theta[n, 0, 1], theta[n, 0, 2] 

55 # theta[n, 1, 0], theta[n, 1, 1], theta[n, 1, 2] 

56 theta_base = n * THETA_STRIDE0 

57 

58 theta_00 = tl.load( 

59 theta_ptr + theta_base + 0 * THETA_STRIDE1 + 0 * THETA_STRIDE2 

60 ).to(tl.float32) 

61 theta_01 = tl.load( 

62 theta_ptr + theta_base + 0 * THETA_STRIDE1 + 1 * THETA_STRIDE2 

63 ).to(tl.float32) 

64 theta_02 = tl.load( 

65 theta_ptr + theta_base + 0 * THETA_STRIDE1 + 2 * THETA_STRIDE2 

66 ).to(tl.float32) 

67 theta_10 = tl.load( 

68 theta_ptr + theta_base + 1 * THETA_STRIDE1 + 0 * THETA_STRIDE2 

69 ).to(tl.float32) 

70 theta_11 = tl.load( 

71 theta_ptr + theta_base + 1 * THETA_STRIDE1 + 1 * THETA_STRIDE2 

72 ).to(tl.float32) 

73 theta_12 = tl.load( 

74 theta_ptr + theta_base + 1 * THETA_STRIDE1 + 2 * THETA_STRIDE2 

75 ).to(tl.float32) 

76 

77 # Compute normalized coordinates 

78 # align_corners=True: normalized = 2.0 * coord / (size - 1) - 1.0 

79 # align_corners=False: normalized = (2.0 * coord + 1.0) / size - 1.0 

80 h_float = h.to(tl.float32) 

81 w_float = w.to(tl.float32) 

82 H_float = H.to(tl.float32) 

83 W_float = W.to(tl.float32) 

84 

85 if align_corners: 

86 norm_x = 2.0 * w_float / (W_float - 1.0) - 1.0 

87 norm_y = 2.0 * h_float / (H_float - 1.0) - 1.0 

88 else: 

89 norm_x = (2.0 * w_float + 1.0) / W_float - 1.0 

90 norm_y = (2.0 * h_float + 1.0) / H_float - 1.0 

91 

92 # Apply affine transformation 

93 # grid[n, h, w, 0] = theta[0,0] * norm_x + theta[0,1] * norm_y + theta[0,2] 

94 # grid[n, h, w, 1] = theta[1,0] * norm_x + theta[1,1] * norm_y + theta[1,2] 

95 result = tl.where( 

96 c == 0, 

97 theta_00 * norm_x + theta_01 * norm_y + theta_02, 

98 theta_10 * norm_x + theta_11 * norm_y + theta_12, 

99 ) 

100 

101 # Store result 

102 output_offset = ( 

103 n * OUTPUT_STRIDE0 

104 + h * OUTPUT_STRIDE1 

105 + w * OUTPUT_STRIDE2 

106 + c * OUTPUT_STRIDE3 

107 ) 

108 tl.store(output_ptr + output_offset, result, mask=mask) 

109 

110 

111def affine_grid_generator( 

112 theta: torch.Tensor, size: torch.Size, align_corners: bool 

113) -> torch.Tensor: 

114 logger.debug("GEMS AFFINE_GRID_GENERATOR") 

115 

116 assert len(size) == 4, f"size must be 4D [N, C, H, W], got {len(size)} dims" 

117 N, C, H, W = size 

118 assert theta.shape == (N, 2, 3), f"theta must be shape (N, 2, 3), got {theta.shape}" 

119 

120 # Output shape is [N, H, W, 2] 

121 output = torch.empty((N, H, W, 2), dtype=theta.dtype, device=theta.device) 

122 

123 BLOCK_SIZE = 128 

124 num_tasks = N * H * W * 2 

125 grid = (triton.cdiv(num_tasks, BLOCK_SIZE),) 

126 

127 affine_grid_generator_kernel[grid]( 

128 output, 

129 theta, 

130 N, 

131 H, 

132 W, 

133 align_corners, 

134 output.stride(0), 

135 output.stride(1), 

136 output.stride(2), 

137 output.stride(3), 

138 theta.stride(0), 

139 theta.stride(1), 

140 theta.stride(2), 

141 BLOCK_SIZE=BLOCK_SIZE, 

142 ) 

143 

144 return output