Coverage for src/flag_gems/ops/upsample_trilinear3d.py: 36%
107 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
3from typing import Optional, Tuple
5import torch
6import triton
7import triton.language as tl
9from flag_gems.runtime import device as runtime_device
10from flag_gems.runtime import torch_device_fn
12logger = logging.getLogger(__name__)
15@triton.jit
16def upsample_trilinear3d_kernel(
17 output_ptr,
18 input_ptr,
19 NC,
20 OD,
21 OH,
22 OW,
23 ID,
24 IH,
25 IW,
26 scale_d,
27 scale_h,
28 scale_w,
29 bias_d,
30 bias_h,
31 bias_w,
32 BLOCK_SIZE: tl.constexpr,
33):
34 pid_nc = tl.program_id(0)
35 pid_spatial = tl.program_id(1)
37 idx = pid_spatial * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
38 total_spatial = OD * OH * OW
39 mask = idx < total_spatial
41 # Compute od, oh, ow from flat index
42 ow = idx % OW
43 oh = (idx // OW) % OH
44 od = idx // (OW * OH)
46 # Compute source coordinates
47 src_d = od.to(tl.float32) * scale_d + bias_d
48 src_h = oh.to(tl.float32) * scale_h + bias_h
49 src_w = ow.to(tl.float32) * scale_w + bias_w
51 # Clamp to valid range
52 src_d = tl.maximum(0.0, tl.minimum(src_d, ID - 1.0))
53 src_h = tl.maximum(0.0, tl.minimum(src_h, IH - 1.0))
54 src_w = tl.maximum(0.0, tl.minimum(src_w, IW - 1.0))
56 # Compute lower and upper indices for trilinear interpolation
57 id0 = tl.floor(src_d).to(tl.int32)
58 ih0 = tl.floor(src_h).to(tl.int32)
59 iw0 = tl.floor(src_w).to(tl.int32)
61 id1 = tl.minimum(id0 + 1, ID - 1)
62 ih1 = tl.minimum(ih0 + 1, IH - 1)
63 iw1 = tl.minimum(iw0 + 1, IW - 1)
65 # Interpolation weights
66 td = src_d - id0.to(tl.float32)
67 th = src_h - ih0.to(tl.float32)
68 tw = src_w - iw0.to(tl.float32)
70 wd0 = 1.0 - td
71 wd1 = td
72 wh0 = 1.0 - th
73 wh1 = th
74 ww0 = 1.0 - tw
75 ww1 = tw
77 # Compute input strides for flattened (NC, ID, IH, IW) layout
78 d_stride_in = IH * IW
79 h_stride_in = IW
81 # Base offset for the batch and channel
82 in_offset_base = pid_nc * ID * IH * IW
84 # Load 8 corners of the cube
85 offset_000 = in_offset_base + id0 * d_stride_in + ih0 * h_stride_in + iw0
86 offset_001 = in_offset_base + id0 * d_stride_in + ih0 * h_stride_in + iw1
87 offset_010 = in_offset_base + id0 * d_stride_in + ih1 * h_stride_in + iw0
88 offset_011 = in_offset_base + id0 * d_stride_in + ih1 * h_stride_in + iw1
89 offset_100 = in_offset_base + id1 * d_stride_in + ih0 * h_stride_in + iw0
90 offset_101 = in_offset_base + id1 * d_stride_in + ih0 * h_stride_in + iw1
91 offset_110 = in_offset_base + id1 * d_stride_in + ih1 * h_stride_in + iw0
92 offset_111 = in_offset_base + id1 * d_stride_in + ih1 * h_stride_in + iw1
94 x000 = tl.load(input_ptr + offset_000, mask=mask)
95 x001 = tl.load(input_ptr + offset_001, mask=mask)
96 x010 = tl.load(input_ptr + offset_010, mask=mask)
97 x011 = tl.load(input_ptr + offset_011, mask=mask)
98 x100 = tl.load(input_ptr + offset_100, mask=mask)
99 x101 = tl.load(input_ptr + offset_101, mask=mask)
100 x110 = tl.load(input_ptr + offset_110, mask=mask)
101 x111 = tl.load(input_ptr + offset_111, mask=mask)
103 # Convert to float32 for interpolation
104 x000 = x000.to(tl.float32)
105 x001 = x001.to(tl.float32)
106 x010 = x010.to(tl.float32)
107 x011 = x011.to(tl.float32)
108 x100 = x100.to(tl.float32)
109 x101 = x101.to(tl.float32)
110 x110 = x110.to(tl.float32)
111 x111 = x111.to(tl.float32)
113 # Trilinear interpolation
114 # First interpolate along depth
115 x00 = wd0 * x000 + wd1 * x100
116 x01 = wd0 * x001 + wd1 * x101
117 x10 = wd0 * x010 + wd1 * x110
118 x11 = wd0 * x011 + wd1 * x111
120 # Then interpolate along height
121 x0 = wh0 * x00 + wh1 * x10
122 x1 = wh0 * x01 + wh1 * x11
124 # Finally interpolate along width
125 out = ww0 * x0 + ww1 * x1
127 out = out.to(x000.dtype)
129 out_offset = pid_nc * total_spatial + idx
130 tl.store(output_ptr + out_offset, out, mask=mask)
133def upsample_trilinear3d(
134 self: torch.Tensor,
135 output_size: Tuple[int, int, int],
136 align_corners: bool,
137 scales_d: Optional[float] = None,
138 scales_h: Optional[float] = None,
139 scales_w: Optional[float] = None,
140) -> torch.Tensor:
141 logger.debug("GEMS UPSAMPLE_TRILINEAR3D")
142 assert (
143 self.device.type == runtime_device.name
144 ), f"Expected device {runtime_device.name}, got {self.device.type}"
145 assert self.ndim == 5, f"Input must be 5D (NCDHW), got {self.ndim}D"
147 N, C, ID, IH, IW = self.shape
148 OD, OH, OW = output_size
149 NC = N * C
151 def calculate_scale_and_bias(in_sz, out_sz, scale):
152 if align_corners:
153 if out_sz > 1:
154 scale_val = (in_sz - 1.0) / (out_sz - 1.0)
155 else:
156 scale_val = 0.0
157 bias_val = 0.0
158 else:
159 if scale is not None:
160 real_scale = 1.0 / scale
161 else:
162 real_scale = in_sz / out_sz
164 scale_val = real_scale
165 bias_val = 0.5 * real_scale - 0.5
167 return scale_val, bias_val
169 scale_d, bias_d = calculate_scale_and_bias(ID, OD, scales_d)
170 scale_h, bias_h = calculate_scale_and_bias(IH, OH, scales_h)
171 scale_w, bias_w = calculate_scale_and_bias(IW, OW, scales_w)
173 # Reshape input to (NC, ID, IH, IW) and output to (NC, OD, OH, OW)
174 inp = self.reshape(NC, ID, IH, IW).contiguous()
175 out = torch.empty((NC, OD, OH, OW), device=self.device, dtype=self.dtype)
177 if out.numel() == 0:
178 return out.view(N, C, OD, OH, OW)
180 total_spatial = OD * OH * OW
181 grid = (NC, triton.cdiv(total_spatial, 256))
183 with torch_device_fn.device(self.device):
184 upsample_trilinear3d_kernel[grid](
185 out,
186 inp,
187 NC,
188 OD,
189 OH,
190 OW,
191 ID,
192 IH,
193 IW,
194 scale_d,
195 scale_h,
196 scale_w,
197 bias_d,
198 bias_h,
199 bias_w,
200 # 256 threads per block balances occupancy for typical 3D upsampling sizes
201 BLOCK_SIZE=256,
202 )
204 return out.view(N, C, OD, OH, OW)