Coverage for src/flag_gems/ops/upsample_trilinear3d.py: 36%

107 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3from typing import Optional, Tuple 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems.runtime import device as runtime_device 

10from flag_gems.runtime import torch_device_fn 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@triton.jit 

16def upsample_trilinear3d_kernel( 

17 output_ptr, 

18 input_ptr, 

19 NC, 

20 OD, 

21 OH, 

22 OW, 

23 ID, 

24 IH, 

25 IW, 

26 scale_d, 

27 scale_h, 

28 scale_w, 

29 bias_d, 

30 bias_h, 

31 bias_w, 

32 BLOCK_SIZE: tl.constexpr, 

33): 

34 pid_nc = tl.program_id(0) 

35 pid_spatial = tl.program_id(1) 

36 

37 idx = pid_spatial * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

38 total_spatial = OD * OH * OW 

39 mask = idx < total_spatial 

40 

41 # Compute od, oh, ow from flat index 

42 ow = idx % OW 

43 oh = (idx // OW) % OH 

44 od = idx // (OW * OH) 

45 

46 # Compute source coordinates 

47 src_d = od.to(tl.float32) * scale_d + bias_d 

48 src_h = oh.to(tl.float32) * scale_h + bias_h 

49 src_w = ow.to(tl.float32) * scale_w + bias_w 

50 

51 # Clamp to valid range 

52 src_d = tl.maximum(0.0, tl.minimum(src_d, ID - 1.0)) 

53 src_h = tl.maximum(0.0, tl.minimum(src_h, IH - 1.0)) 

54 src_w = tl.maximum(0.0, tl.minimum(src_w, IW - 1.0)) 

55 

56 # Compute lower and upper indices for trilinear interpolation 

57 id0 = tl.floor(src_d).to(tl.int32) 

58 ih0 = tl.floor(src_h).to(tl.int32) 

59 iw0 = tl.floor(src_w).to(tl.int32) 

60 

61 id1 = tl.minimum(id0 + 1, ID - 1) 

62 ih1 = tl.minimum(ih0 + 1, IH - 1) 

63 iw1 = tl.minimum(iw0 + 1, IW - 1) 

64 

65 # Interpolation weights 

66 td = src_d - id0.to(tl.float32) 

67 th = src_h - ih0.to(tl.float32) 

68 tw = src_w - iw0.to(tl.float32) 

69 

70 wd0 = 1.0 - td 

71 wd1 = td 

72 wh0 = 1.0 - th 

73 wh1 = th 

74 ww0 = 1.0 - tw 

75 ww1 = tw 

76 

77 # Compute input strides for flattened (NC, ID, IH, IW) layout 

78 d_stride_in = IH * IW 

79 h_stride_in = IW 

80 

81 # Base offset for the batch and channel 

82 in_offset_base = pid_nc * ID * IH * IW 

83 

84 # Load 8 corners of the cube 

85 offset_000 = in_offset_base + id0 * d_stride_in + ih0 * h_stride_in + iw0 

86 offset_001 = in_offset_base + id0 * d_stride_in + ih0 * h_stride_in + iw1 

87 offset_010 = in_offset_base + id0 * d_stride_in + ih1 * h_stride_in + iw0 

88 offset_011 = in_offset_base + id0 * d_stride_in + ih1 * h_stride_in + iw1 

89 offset_100 = in_offset_base + id1 * d_stride_in + ih0 * h_stride_in + iw0 

90 offset_101 = in_offset_base + id1 * d_stride_in + ih0 * h_stride_in + iw1 

91 offset_110 = in_offset_base + id1 * d_stride_in + ih1 * h_stride_in + iw0 

92 offset_111 = in_offset_base + id1 * d_stride_in + ih1 * h_stride_in + iw1 

93 

94 x000 = tl.load(input_ptr + offset_000, mask=mask) 

95 x001 = tl.load(input_ptr + offset_001, mask=mask) 

96 x010 = tl.load(input_ptr + offset_010, mask=mask) 

97 x011 = tl.load(input_ptr + offset_011, mask=mask) 

98 x100 = tl.load(input_ptr + offset_100, mask=mask) 

99 x101 = tl.load(input_ptr + offset_101, mask=mask) 

100 x110 = tl.load(input_ptr + offset_110, mask=mask) 

101 x111 = tl.load(input_ptr + offset_111, mask=mask) 

102 

103 # Convert to float32 for interpolation 

104 x000 = x000.to(tl.float32) 

105 x001 = x001.to(tl.float32) 

106 x010 = x010.to(tl.float32) 

107 x011 = x011.to(tl.float32) 

108 x100 = x100.to(tl.float32) 

109 x101 = x101.to(tl.float32) 

110 x110 = x110.to(tl.float32) 

111 x111 = x111.to(tl.float32) 

112 

113 # Trilinear interpolation 

114 # First interpolate along depth 

115 x00 = wd0 * x000 + wd1 * x100 

116 x01 = wd0 * x001 + wd1 * x101 

117 x10 = wd0 * x010 + wd1 * x110 

118 x11 = wd0 * x011 + wd1 * x111 

119 

120 # Then interpolate along height 

121 x0 = wh0 * x00 + wh1 * x10 

122 x1 = wh0 * x01 + wh1 * x11 

123 

124 # Finally interpolate along width 

125 out = ww0 * x0 + ww1 * x1 

126 

127 out = out.to(x000.dtype) 

128 

129 out_offset = pid_nc * total_spatial + idx 

130 tl.store(output_ptr + out_offset, out, mask=mask) 

131 

132 

133def upsample_trilinear3d( 

134 self: torch.Tensor, 

135 output_size: Tuple[int, int, int], 

136 align_corners: bool, 

137 scales_d: Optional[float] = None, 

138 scales_h: Optional[float] = None, 

139 scales_w: Optional[float] = None, 

140) -> torch.Tensor: 

141 logger.debug("GEMS UPSAMPLE_TRILINEAR3D") 

142 assert ( 

143 self.device.type == runtime_device.name 

144 ), f"Expected device {runtime_device.name}, got {self.device.type}" 

145 assert self.ndim == 5, f"Input must be 5D (NCDHW), got {self.ndim}D" 

146 

147 N, C, ID, IH, IW = self.shape 

148 OD, OH, OW = output_size 

149 NC = N * C 

150 

151 def calculate_scale_and_bias(in_sz, out_sz, scale): 

152 if align_corners: 

153 if out_sz > 1: 

154 scale_val = (in_sz - 1.0) / (out_sz - 1.0) 

155 else: 

156 scale_val = 0.0 

157 bias_val = 0.0 

158 else: 

159 if scale is not None: 

160 real_scale = 1.0 / scale 

161 else: 

162 real_scale = in_sz / out_sz 

163 

164 scale_val = real_scale 

165 bias_val = 0.5 * real_scale - 0.5 

166 

167 return scale_val, bias_val 

168 

169 scale_d, bias_d = calculate_scale_and_bias(ID, OD, scales_d) 

170 scale_h, bias_h = calculate_scale_and_bias(IH, OH, scales_h) 

171 scale_w, bias_w = calculate_scale_and_bias(IW, OW, scales_w) 

172 

173 # Reshape input to (NC, ID, IH, IW) and output to (NC, OD, OH, OW) 

174 inp = self.reshape(NC, ID, IH, IW).contiguous() 

175 out = torch.empty((NC, OD, OH, OW), device=self.device, dtype=self.dtype) 

176 

177 if out.numel() == 0: 

178 return out.view(N, C, OD, OH, OW) 

179 

180 total_spatial = OD * OH * OW 

181 grid = (NC, triton.cdiv(total_spatial, 256)) 

182 

183 with torch_device_fn.device(self.device): 

184 upsample_trilinear3d_kernel[grid]( 

185 out, 

186 inp, 

187 NC, 

188 OD, 

189 OH, 

190 OW, 

191 ID, 

192 IH, 

193 IW, 

194 scale_d, 

195 scale_h, 

196 scale_w, 

197 bias_d, 

198 bias_h, 

199 bias_w, 

200 # 256 threads per block balances occupancy for typical 3D upsampling sizes 

201 BLOCK_SIZE=256, 

202 ) 

203 

204 return out.view(N, C, OD, OH, OW)