Coverage for src/flag_gems/runtime/backend/_sunrise/ops/upsample_nearest2d.py: 0%
52 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
2from typing import Optional, Tuple
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import device
10device = device.name
11logger = logging.getLogger(__name__)
14def configs():
15 block = [128, 256, 512, 1024]
16 warps = [4, 8, 16, 32]
17 return [
18 triton.Config({"BLOCK_SIZE": bs}, num_warps=wp) for bs in block for wp in warps
19 ]
22@triton.autotune(configs=configs(), key=["N", "C", "OH", "OW"])
23@triton.heuristics(
24 {
25 "SAME_H": lambda args: args["OH"] == args["IH"],
26 "SAME_W": lambda args: args["OW"] == args["IW"],
27 }
28)
29@triton.jit
30def upsample_nearest2d_kernel(
31 ptr_o,
32 ptr_i,
33 sno,
34 sco,
35 sho,
36 swo,
37 sni,
38 sci,
39 shi,
40 swi,
41 N,
42 C,
43 OH,
44 OW,
45 IH,
46 IW,
47 reciprocal_scale_h,
48 reciprocal_scale_w,
49 BLOCK_SIZE: tl.constexpr,
50 SAME_H: tl.constexpr,
51 SAME_W: tl.constexpr,
52):
53 pid = tl.program_id(axis=0)
54 idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
55 ow = idx % OW
56 oh = idx // OW % OH
57 c = idx // OW // OH % C
58 n = idx // OW // OH // C % N
59 if SAME_H:
60 ih = oh
61 else:
62 # tl.floor() cannot be found in 2.3.1, using int trunc
63 ih = tl.minimum((oh * reciprocal_scale_h).to(tl.int32), IH - 1)
64 if SAME_W:
65 iw = ow
66 else:
67 iw = tl.minimum((ow * reciprocal_scale_w).to(tl.int32), IW - 1)
68 offset_o = n * sno + c * sco + oh * sho + ow * swo
69 offset_i = n * sni + c * sci + ih * shi + iw * swi
70 data = tl.load(ptr_i + offset_i)
71 tl.store(ptr_o + offset_o, data)
74def upsample_nearest2d(
75 input: torch.Tensor,
76 output_size: Tuple[int],
77 scales_h: Optional[float] = None,
78 scales_w: Optional[float] = None,
79) -> torch.Tensor:
80 logging.debug("GEMS UPSAMPLE NEAREST2D")
81 assert input.device.type == device
82 assert input.ndim == 4, "The ndim of input must be 4"
83 assert len(output_size) == 2, "The len of output_size must be 2"
84 OH, OW = output_size
85 N, C, IH, IW = input.shape
86 if scales_h is not None:
87 reciprocal_scale_h = 1 / scales_h
88 else:
89 reciprocal_scale_h = IH / OH
90 if scales_w is not None:
91 reciprocal_scale_w = 1 / scales_w
92 else:
93 reciprocal_scale_w = IW / OW
94 # allocate output
95 output = torch.empty((N, C, OH, OW), device=input.device, dtype=input.dtype)
96 total_threads = N * C * OH * OW
97 sno, sco, sho, swo = output.stride()
98 sni, sci, shi, swi = input.stride()
99 grid = lambda META: (triton.cdiv(total_threads, META["BLOCK_SIZE"]),)
100 upsample_nearest2d_kernel[grid](
101 output,
102 input,
103 sno,
104 sco,
105 sho,
106 swo,
107 sni,
108 sci,
109 shi,
110 swi,
111 N,
112 C,
113 OH,
114 OW,
115 IH,
116 IW,
117 reciprocal_scale_h,
118 reciprocal_scale_w,
119 )
120 return output