Coverage for src/flag_gems/ops/affine_grid_generator.py: 40%
52 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8from flag_gems.utils import libentry
9from flag_gems.utils import triton_lang_extension as tle
11logger = logging.getLogger(__name__)
14@libentry()
15@triton.jit
16def affine_grid_generator_kernel(
17 output_ptr,
18 theta_ptr,
19 N,
20 H,
21 W,
22 align_corners,
23 OUTPUT_STRIDE0,
24 OUTPUT_STRIDE1,
25 OUTPUT_STRIDE2,
26 OUTPUT_STRIDE3,
27 THETA_STRIDE0,
28 THETA_STRIDE1,
29 THETA_STRIDE2,
30 BLOCK_SIZE: tl.constexpr,
31):
32 # output has shape [N, H, W, 2]
33 # theta has shape [N, 2, 3]
34 pid = tle.program_id(0)
35 num_tasks = N * H * W * 2
37 if pid * BLOCK_SIZE >= num_tasks:
38 return
40 # Compute 4D index: (n, h, w, c)
41 idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
42 mask = idx < num_tasks
44 # Unflatten to 4D
45 tmp = idx // 2
46 c = idx % 2 # 0 for x, 1 for y
48 w = tmp % W
49 tmp = tmp // W
50 h = tmp % H
51 n = tmp // H
53 # Load theta for batch n
54 # theta[n, 0, 0], theta[n, 0, 1], theta[n, 0, 2]
55 # theta[n, 1, 0], theta[n, 1, 1], theta[n, 1, 2]
56 theta_base = n * THETA_STRIDE0
58 theta_00 = tl.load(
59 theta_ptr + theta_base + 0 * THETA_STRIDE1 + 0 * THETA_STRIDE2
60 ).to(tl.float32)
61 theta_01 = tl.load(
62 theta_ptr + theta_base + 0 * THETA_STRIDE1 + 1 * THETA_STRIDE2
63 ).to(tl.float32)
64 theta_02 = tl.load(
65 theta_ptr + theta_base + 0 * THETA_STRIDE1 + 2 * THETA_STRIDE2
66 ).to(tl.float32)
67 theta_10 = tl.load(
68 theta_ptr + theta_base + 1 * THETA_STRIDE1 + 0 * THETA_STRIDE2
69 ).to(tl.float32)
70 theta_11 = tl.load(
71 theta_ptr + theta_base + 1 * THETA_STRIDE1 + 1 * THETA_STRIDE2
72 ).to(tl.float32)
73 theta_12 = tl.load(
74 theta_ptr + theta_base + 1 * THETA_STRIDE1 + 2 * THETA_STRIDE2
75 ).to(tl.float32)
77 # Compute normalized coordinates
78 # align_corners=True: normalized = 2.0 * coord / (size - 1) - 1.0
79 # align_corners=False: normalized = (2.0 * coord + 1.0) / size - 1.0
80 h_float = h.to(tl.float32)
81 w_float = w.to(tl.float32)
82 H_float = H.to(tl.float32)
83 W_float = W.to(tl.float32)
85 if align_corners:
86 norm_x = 2.0 * w_float / (W_float - 1.0) - 1.0
87 norm_y = 2.0 * h_float / (H_float - 1.0) - 1.0
88 else:
89 norm_x = (2.0 * w_float + 1.0) / W_float - 1.0
90 norm_y = (2.0 * h_float + 1.0) / H_float - 1.0
92 # Apply affine transformation
93 # grid[n, h, w, 0] = theta[0,0] * norm_x + theta[0,1] * norm_y + theta[0,2]
94 # grid[n, h, w, 1] = theta[1,0] * norm_x + theta[1,1] * norm_y + theta[1,2]
95 result = tl.where(
96 c == 0,
97 theta_00 * norm_x + theta_01 * norm_y + theta_02,
98 theta_10 * norm_x + theta_11 * norm_y + theta_12,
99 )
101 # Store result
102 output_offset = (
103 n * OUTPUT_STRIDE0
104 + h * OUTPUT_STRIDE1
105 + w * OUTPUT_STRIDE2
106 + c * OUTPUT_STRIDE3
107 )
108 tl.store(output_ptr + output_offset, result, mask=mask)
111def affine_grid_generator(
112 theta: torch.Tensor, size: torch.Size, align_corners: bool
113) -> torch.Tensor:
114 logger.debug("GEMS AFFINE_GRID_GENERATOR")
116 assert len(size) == 4, f"size must be 4D [N, C, H, W], got {len(size)} dims"
117 N, C, H, W = size
118 assert theta.shape == (N, 2, 3), f"theta must be shape (N, 2, 3), got {theta.shape}"
120 # Output shape is [N, H, W, 2]
121 output = torch.empty((N, H, W, 2), dtype=theta.dtype, device=theta.device)
123 BLOCK_SIZE = 128
124 num_tasks = N * H * W * 2
125 grid = (triton.cdiv(num_tasks, BLOCK_SIZE),)
127 affine_grid_generator_kernel[grid](
128 output,
129 theta,
130 N,
131 H,
132 W,
133 align_corners,
134 output.stride(0),
135 output.stride(1),
136 output.stride(2),
137 output.stride(3),
138 theta.stride(0),
139 theta.stride(1),
140 theta.stride(2),
141 BLOCK_SIZE=BLOCK_SIZE,
142 )
144 return output