Coverage for src/flag_gems/ops/upsample_linear1d_backward.py: 31%
71 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7logger = logging.getLogger(__name__)
10@triton.jit
11def upsample_linear1d_backward_kernel(
12 grad_out_ptr,
13 grad_in_ptr,
14 n,
15 c,
16 in_w,
17 out_w,
18 go_stride_n,
19 go_stride_c,
20 go_stride_w,
21 gi_stride_n,
22 gi_stride_c,
23 gi_stride_w,
24 align_corners: tl.constexpr,
25 BLOCK: tl.constexpr,
26):
27 pid = tl.program_id(0)
28 offs = pid * BLOCK + tl.arange(0, BLOCK)
30 total = n * c * in_w
31 mask = offs < total
33 x_in = offs % in_w
34 tmp = offs // in_w
35 c_idx = tmp % c
36 n_idx = tmp // c
38 x_in_f = x_in.to(tl.float32)
39 in_w_f = tl.cast(in_w, tl.float32)
40 out_w_f = tl.cast(out_w, tl.float32)
42 if align_corners:
43 if in_w > 1:
44 center = x_in_f * (out_w_f - 1.0) / (in_w_f - 1.0)
45 else:
46 center = tl.zeros([BLOCK], dtype=tl.float32)
47 else:
48 center = (x_in_f + 0.5) * out_w_f / in_w_f - 0.5
50 base = tl.floor(center).to(tl.int32)
52 go_base = grad_out_ptr + n_idx * go_stride_n + c_idx * go_stride_c
54 acc = tl.zeros([BLOCK], dtype=tl.float32)
56 for i in range(-2, 3):
57 x_out = base + i
58 valid = (x_out >= 0) & (x_out < out_w)
59 x_out_f = x_out.to(tl.float32)
61 if align_corners:
62 if out_w > 1:
63 x_real = x_out_f * (in_w_f - 1.0) / (out_w_f - 1.0)
64 else:
65 x_real = tl.zeros([BLOCK], dtype=tl.float32)
66 else:
67 x_real = (x_out_f + 0.5) * in_w_f / out_w_f - 0.5
69 x0_f = tl.floor(x_real)
70 w1 = x_real - x0_f
71 w0 = 1.0 - w1
73 x0_i = tl.maximum(x0_f, 0.0).to(tl.int32)
74 x1_i = tl.minimum(x0_f + 1.0, in_w_f - 1.0).to(tl.int32)
76 g = tl.load(
77 go_base + x_out * go_stride_w,
78 mask=mask & valid,
79 other=0.0,
80 ).to(tl.float32)
82 same = x0_i == x1_i
83 is_x0 = x_in.to(tl.int32) == x0_i
84 is_x1 = x_in.to(tl.int32) == x1_i
86 acc += tl.where(same & is_x0, g * (w0 + w1), 0.0)
87 acc += tl.where(~same & is_x0, g * w0, 0.0)
88 acc += tl.where(~same & is_x1, g * w1, 0.0)
90 gi_ptr = (
91 grad_in_ptr + n_idx * gi_stride_n + c_idx * gi_stride_c + x_in * gi_stride_w
92 )
93 tl.store(gi_ptr, acc, mask=mask)
96def upsample_linear1d_backward(
97 grad_output: torch.Tensor,
98 output_size,
99 input_size,
100 align_corners: bool,
101 scale_factors=None,
102) -> torch.Tensor:
103 logger.debug("GEMS UPSAMPLE_LINEAR1D_BACKWARD")
105 if len(input_size) == 3:
106 n, c, in_w = input_size
107 elif len(input_size) == 2:
108 n, c, in_w = input_size[0], 1, input_size[1]
109 elif len(input_size) == 1:
110 n, c, in_w = 1, 1, input_size[0]
111 else:
112 raise ValueError
114 if output_size is not None:
115 out_w = output_size[0]
116 else:
117 assert scale_factors is not None
118 out_w = int(in_w * scale_factors[0])
120 assert grad_output.shape[-1] == out_w
122 grad_out_3d = grad_output.contiguous().view(n, c, out_w)
124 grad_in = torch.zeros(
125 (n, c, in_w),
126 device=grad_output.device,
127 dtype=grad_output.dtype,
128 )
130 go_stride_n, go_stride_c, go_stride_w = grad_out_3d.stride()
131 gi_stride_n, gi_stride_c, gi_stride_w = grad_in.stride()
133 BLOCK = 512
134 grid = (triton.cdiv(n * c * in_w, BLOCK),)
136 upsample_linear1d_backward_kernel[grid](
137 grad_out_3d,
138 grad_in,
139 n,
140 c,
141 in_w,
142 out_w,
143 go_stride_n,
144 go_stride_c,
145 go_stride_w,
146 gi_stride_n,
147 gi_stride_c,
148 gi_stride_w,
149 align_corners,
150 BLOCK=BLOCK,
151 )
153 return grad_in