Coverage for src/flag_gems/ops/pixel_shuffle.py: 61%
46 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry
10logger = logging.getLogger(__name__)
13# Pixel Shuffle: (N, C*r^2, H, W) -> (N, C, H*r, W*r)
14# Direct index mapping kernel - each output element reads from the correct
15# input position without intermediate tensors.
16@libentry()
17@triton.autotune(
18 configs=[
19 triton.Config({"BLOCK_SIZE": 256}),
20 triton.Config({"BLOCK_SIZE": 512}),
21 triton.Config({"BLOCK_SIZE": 1024}),
22 triton.Config({"BLOCK_SIZE": 2048}),
23 ],
24 key=["n_elements"],
25)
26@triton.jit
27def pixel_shuffle_kernel(
28 in_ptr,
29 out_ptr,
30 n_elements,
31 C,
32 H,
33 W,
34 R,
35 C_out,
36 H_out,
37 W_out,
38 BLOCK_SIZE: tl.constexpr,
39):
40 pid = tl.program_id(0)
41 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
42 mask = offsets < n_elements
44 # Output layout: (N, C_out, H_out, W_out)
45 ow = offsets % W_out
46 tmp = offsets // W_out
47 oh = tmp % H_out
48 tmp2 = tmp // H_out
49 c_out = tmp2 % C_out
50 n = tmp2 // C_out
52 # Map to input: h_in = oh // R, w_in = ow // R
53 h_in = oh // R
54 dh = oh % R
55 w_in = ow // R
56 dw = ow % R
58 # Input channel: c_in = c_out * R * R + dh * R + dw
59 c_in = c_out * R * R + dh * R + dw
61 # Input linear index
62 in_idx = n * (C * H * W) + c_in * (H * W) + h_in * W + w_in
64 val = tl.load(in_ptr + in_idx, mask=mask)
65 tl.store(out_ptr + offsets, val, mask=mask)
68def pixel_shuffle(input, upscale_factor):
69 logger.debug("GEMS PIXEL_SHUFFLE")
70 r = int(upscale_factor)
71 assert input.ndim == 4
72 N, C, H, W = input.shape
73 assert C % (r * r) == 0
75 C_out = C // (r * r)
76 H_out = H * r
77 W_out = W * r
79 input = input.contiguous()
80 output = torch.empty(
81 (N, C_out, H_out, W_out), device=input.device, dtype=input.dtype
82 )
84 n_elements = output.numel()
85 if n_elements == 0:
86 return output
88 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
89 with torch_device_fn.device(input.device):
90 pixel_shuffle_kernel[grid](
91 input,
92 output,
93 n_elements,
94 C,
95 H,
96 W,
97 r,
98 C_out,
99 H_out,
100 W_out,
101 )
102 return output