Coverage for src/flag_gems/runtime/backend/_sunrise/ops/upsample_linear1d.py: 0%
55 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8logger = logging.getLogger(__name__)
11@triton.jit
12def upsample_linear1d_kernel(
13 input_ptr,
14 output_ptr,
15 NC,
16 W_in,
17 W_out,
18 scale,
19 bias,
20 BLOCK_SIZE: tl.constexpr,
21):
22 pid_nc = tl.program_id(0)
23 pid_w = tl.program_id(1)
25 base_in = pid_nc * W_in
26 base_out = pid_nc * W_out
28 offs_w = pid_w * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
29 mask = (pid_nc < NC) & (offs_w < W_out)
31 offs_w_f = offs_w.to(tl.float32)
33 src = offs_w_f * scale + bias
35 src = tl.maximum(0.0, tl.minimum(src, W_in - 1.0))
37 lower = tl.floor(src).to(tl.int32)
38 upper = tl.minimum(lower + 1, W_in - 1)
40 t = src - lower.to(tl.float32)
41 w0 = 1.0 - t
42 w1 = t
44 x0 = tl.load(input_ptr + base_in + lower, mask=mask)
45 x1 = tl.load(input_ptr + base_in + upper, mask=mask)
47 x0_f = x0.to(tl.float32)
48 x1_f = x1.to(tl.float32)
50 out = w0 * x0_f + w1 * x1_f
52 out = out.to(x0.dtype)
53 tl.store(output_ptr + base_out + offs_w, out, mask=mask)
56def upsample_linear1d(
57 self: torch.Tensor,
58 output_size,
59 align_corners: bool,
60 scales: float = None,
61):
62 logger.debug("GEMS UPSAMPLE LINEAR1D OPTIMIZED")
63 assert self.ndim == 3, "Input must be [N, C, W]"
64 assert self.is_ptpu
66 N, C, W_in = self.shape
67 NC = N * C
69 if output_size is not None:
70 W_out = int(
71 output_size[0] if isinstance(output_size, (list, tuple)) else output_size
72 )
73 else:
74 assert scales is not None
75 W_out = int(math.floor(W_in * scales))
77 inp = self.contiguous().view(NC, W_in)
78 out = torch.empty((NC, W_out), device=self.device, dtype=self.dtype)
80 if align_corners:
81 if W_out > 1:
82 scale_val = (W_in - 1.0) / (W_out - 1.0)
83 else:
84 scale_val = 0.0
85 bias_val = 0.0
86 else:
87 if scales is not None:
88 real_scale = 1.0 / scales
89 else:
90 real_scale = W_in / W_out
92 scale_val = real_scale
93 bias_val = 0.5 * real_scale - 0.5
95 BLOCK_SIZE = 256
96 grid = (NC, triton.cdiv(W_out, BLOCK_SIZE))
98 upsample_linear1d_kernel[grid](
99 inp,
100 out,
101 NC,
102 W_in,
103 W_out,
104 scale_val,
105 bias_val,
106 BLOCK_SIZE=BLOCK_SIZE,
107 )
109 return out.view(N, C, W_out)