Coverage for src/flag_gems/runtime/backend/_sunrise/ops/upsample_linear1d.py: 0%
56 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8import flag_gems
10logger = logging.getLogger(__name__)
13@triton.jit
14def upsample_linear1d_kernel(
15 input_ptr,
16 output_ptr,
17 NC,
18 W_in,
19 W_out,
20 scale,
21 bias,
22 BLOCK_SIZE: tl.constexpr,
23):
24 pid_nc = tl.program_id(0)
25 pid_w = tl.program_id(1)
27 base_in = pid_nc * W_in
28 base_out = pid_nc * W_out
30 offs_w = pid_w * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
31 mask = (pid_nc < NC) & (offs_w < W_out)
33 offs_w_f = offs_w.to(tl.float32)
35 src = offs_w_f * scale + bias
37 src = tl.maximum(0.0, tl.minimum(src, W_in - 1.0))
39 lower = tl.floor(src).to(tl.int32)
40 upper = tl.minimum(lower + 1, W_in - 1)
42 t = src - lower.to(tl.float32)
43 w0 = 1.0 - t
44 w1 = t
46 x0 = tl.load(input_ptr + base_in + lower, mask=mask)
47 x1 = tl.load(input_ptr + base_in + upper, mask=mask)
49 x0_f = x0.to(tl.float32)
50 x1_f = x1.to(tl.float32)
52 out = w0 * x0_f + w1 * x1_f
54 out = out.to(x0.dtype)
55 tl.store(output_ptr + base_out + offs_w, out, mask=mask)
58def upsample_linear1d(
59 self: torch.Tensor,
60 output_size,
61 align_corners: bool,
62 scales: float = None,
63):
64 logger.debug("GEMS UPSAMPLE LINEAR1D OPTIMIZED")
65 assert self.ndim == 3, "Input must be [N, C, W]"
66 assert self.device.type == flag_gems.device
68 N, C, W_in = self.shape
69 NC = N * C
71 if output_size is not None:
72 W_out = int(
73 output_size[0] if isinstance(output_size, (list, tuple)) else output_size
74 )
75 else:
76 assert scales is not None
77 W_out = int(math.floor(W_in * scales))
79 inp = self.contiguous().view(NC, W_in)
80 out = torch.empty((NC, W_out), device=self.device, dtype=self.dtype)
82 if align_corners:
83 if W_out > 1:
84 scale_val = (W_in - 1.0) / (W_out - 1.0)
85 else:
86 scale_val = 0.0
87 bias_val = 0.0
88 else:
89 if scales is not None:
90 real_scale = 1.0 / scales
91 else:
92 real_scale = W_in / W_out
94 scale_val = real_scale
95 bias_val = 0.5 * real_scale - 0.5
97 BLOCK_SIZE = 256
98 grid = (NC, triton.cdiv(W_out, BLOCK_SIZE))
100 upsample_linear1d_kernel[grid](
101 inp,
102 out,
103 NC,
104 W_in,
105 W_out,
106 scale_val,
107 bias_val,
108 BLOCK_SIZE=BLOCK_SIZE,
109 )
111 return out.view(N, C, W_out)