Coverage for src/flag_gems/runtime/backend/_sunrise/ops/upsample_nearest2d.py: 0%
52 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
2from typing import Optional, Tuple
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import device
10device = device.name
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15def configs():
16 block = [128, 256, 512, 1024]
17 warps = [4, 8, 16, 32]
18 return [
19 triton.Config({"BLOCK_SIZE": bs}, num_warps=wp) for bs in block for wp in warps
20 ]
23@triton.autotune(configs=configs(), key=["N", "C", "OH", "OW"])
24@triton.heuristics(
25 {
26 "SAME_H": lambda args: args["OH"] == args["IH"],
27 "SAME_W": lambda args: args["OW"] == args["IW"],
28 }
29)
30@triton.jit
31def upsample_nearest2d_kernel(
32 ptr_o,
33 ptr_i,
34 sno,
35 sco,
36 sho,
37 swo,
38 sni,
39 sci,
40 shi,
41 swi,
42 N,
43 C,
44 OH,
45 OW,
46 IH,
47 IW,
48 reciprocal_scale_h,
49 reciprocal_scale_w,
50 BLOCK_SIZE: tl.constexpr,
51 SAME_H: tl.constexpr,
52 SAME_W: tl.constexpr,
53):
54 pid = tl.program_id(axis=0)
55 idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
56 ow = idx % OW
57 oh = idx // OW % OH
58 c = idx // OW // OH % C
59 n = idx // OW // OH // C % N
60 if SAME_H:
61 ih = oh
62 else:
63 # tl.floor() cannot be found in 2.3.1, using int trunc
64 ih = tl.minimum((oh * reciprocal_scale_h).to(tl.int32), IH - 1)
65 if SAME_W:
66 iw = ow
67 else:
68 iw = tl.minimum((ow * reciprocal_scale_w).to(tl.int32), IW - 1)
69 offset_o = n * sno + c * sco + oh * sho + ow * swo
70 offset_i = n * sni + c * sci + ih * shi + iw * swi
71 data = tl.load(ptr_i + offset_i)
72 tl.store(ptr_o + offset_o, data)
75def upsample_nearest2d(
76 input: torch.Tensor,
77 output_size: Tuple[int],
78 scales_h: Optional[float] = None,
79 scales_w: Optional[float] = None,
80) -> torch.Tensor:
81 logging.debug("GEMS UPSAMPLE NEAREST2D")
82 assert input.device.type == device
83 assert input.ndim == 4, "The ndim of input must be 4"
84 assert len(output_size) == 2, "The len of output_size must be 2"
85 OH, OW = output_size
86 N, C, IH, IW = input.shape
87 if scales_h is not None:
88 reciprocal_scale_h = 1 / scales_h
89 else:
90 reciprocal_scale_h = IH / OH
91 if scales_w is not None:
92 reciprocal_scale_w = 1 / scales_w
93 else:
94 reciprocal_scale_w = IW / OW
95 # allocate output
96 output = torch.empty((N, C, OH, OW), device=input.device, dtype=input.dtype)
97 total_threads = N * C * OH * OW
98 sno, sco, sho, swo = output.stride()
99 sni, sci, shi, swi = input.stride()
100 grid = lambda META: (triton.cdiv(total_threads, META["BLOCK_SIZE"]),)
101 upsample_nearest2d_kernel[grid](
102 output,
103 input,
104 sno,
105 sco,
106 sho,
107 swo,
108 sni,
109 sci,
110 shi,
111 swi,
112 N,
113 C,
114 OH,
115 OW,
116 IH,
117 IW,
118 reciprocal_scale_h,
119 reciprocal_scale_w,
120 )
121 return output