Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/upsample_linear1d.py: 0%
55 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
10logger = logging.getLogger(__name__)
13@triton.jit
14def upsample_linear1d_kernel(
15 input_ptr,
16 output_ptr,
17 NC,
18 W_in,
19 W_out,
20 scale,
21 bias,
22 BLOCK_SIZE: tl.constexpr,
23):
24 pid_nc = tl.program_id(0)
25 pid_w = tl.program_id(1)
27 base_in = pid_nc * W_in
28 base_out = pid_nc * W_out
30 # Use modulo wrap to keep all indices in [0, W_out).
31 # On KunlunXin, masked tl.store does not suppress writes for masked-out
32 # threads without TRITONXPU_STORE_MASK_SIM=1, causing corruption of
33 # adjacent channel data. The modulo wrap means tail-block threads simply
34 # re-write already-computed values to valid positions — harmless.
35 offs_w = (pid_w * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)) % W_out
37 offs_w_f = offs_w.to(tl.float32)
39 src = offs_w_f * scale + bias
41 # Clamp source position to [0, W_in - 1]
42 src = tl.maximum(0.0, tl.minimum(src, W_in - 1.0))
44 # For non-negative src, int truncation equals floor
45 lower = src.to(tl.int32)
46 upper = tl.minimum(lower + 1, W_in - 1)
48 t = src - lower.to(tl.float32)
49 w0 = 1.0 - t
50 w1 = t
52 # No mask needed: all offsets are within [0, W_in - 1] and [0, W_out - 1]
53 x0 = tl.load(input_ptr + base_in + lower)
54 x1 = tl.load(input_ptr + base_in + upper)
56 x0_f = x0.to(tl.float32)
57 x1_f = x1.to(tl.float32)
59 out = w0 * x0_f + w1 * x1_f
61 out = out.to(x0.dtype)
62 tl.store(output_ptr + base_out + offs_w, out)
65def upsample_linear1d(
66 self: torch.Tensor,
67 output_size,
68 align_corners: bool,
69 scales: float = None,
70):
71 logger.debug("GEMS_KUNLUNXIN UPSAMPL_LINEAR1D")
72 assert self.ndim == 3, "Input must be [N, C, W]"
74 N, C, W_in = self.shape
75 NC = N * C
77 if output_size is not None:
78 W_out = int(
79 output_size[0] if isinstance(output_size, (list, tuple)) else output_size
80 )
81 else:
82 assert (
83 scales is not None
84 ), "scales must be specified if output_size is not provided."
85 W_out = int(math.floor(W_in * scales))
87 inp = self.contiguous().view(NC, W_in)
88 out = torch.empty((NC, W_out), device=self.device, dtype=self.dtype)
90 if align_corners:
91 if W_out > 1:
92 scale_val = (W_in - 1.0) / (W_out - 1.0)
93 else:
94 scale_val = 0.0
95 bias_val = 0.0
96 else:
97 if scales is not None:
98 real_scale = 1.0 / scales
99 else:
100 real_scale = W_in / W_out
102 scale_val = real_scale
103 bias_val = 0.5 * real_scale - 0.5
105 BLOCK_SIZE = 256
106 grid = (NC, triton.cdiv(W_out, BLOCK_SIZE))
108 with torch_device_fn.device(self.device):
109 upsample_linear1d_kernel[grid](
110 inp,
111 out,
112 NC,
113 W_in,
114 W_out,
115 scale_val,
116 bias_val,
117 BLOCK_SIZE=BLOCK_SIZE,
118 )
120 return out.view(N, C, W_out)