Coverage for src/flag_gems/ops/pixel_shuffle.py: 61%

46 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13# Pixel Shuffle: (N, C*r^2, H, W) -> (N, C, H*r, W*r) 

14# Direct index mapping kernel - each output element reads from the correct 

15# input position without intermediate tensors. 

16@libentry() 

17@triton.autotune( 

18 configs=[ 

19 triton.Config({"BLOCK_SIZE": 256}), 

20 triton.Config({"BLOCK_SIZE": 512}), 

21 triton.Config({"BLOCK_SIZE": 1024}), 

22 triton.Config({"BLOCK_SIZE": 2048}), 

23 ], 

24 key=["n_elements"], 

25) 

26@triton.jit 

27def pixel_shuffle_kernel( 

28 in_ptr, 

29 out_ptr, 

30 n_elements, 

31 C, 

32 H, 

33 W, 

34 R, 

35 C_out, 

36 H_out, 

37 W_out, 

38 BLOCK_SIZE: tl.constexpr, 

39): 

40 pid = tl.program_id(0) 

41 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

42 mask = offsets < n_elements 

43 

44 # Output layout: (N, C_out, H_out, W_out) 

45 ow = offsets % W_out 

46 tmp = offsets // W_out 

47 oh = tmp % H_out 

48 tmp2 = tmp // H_out 

49 c_out = tmp2 % C_out 

50 n = tmp2 // C_out 

51 

52 # Map to input: h_in = oh // R, w_in = ow // R 

53 h_in = oh // R 

54 dh = oh % R 

55 w_in = ow // R 

56 dw = ow % R 

57 

58 # Input channel: c_in = c_out * R * R + dh * R + dw 

59 c_in = c_out * R * R + dh * R + dw 

60 

61 # Input linear index 

62 in_idx = n * (C * H * W) + c_in * (H * W) + h_in * W + w_in 

63 

64 val = tl.load(in_ptr + in_idx, mask=mask) 

65 tl.store(out_ptr + offsets, val, mask=mask) 

66 

67 

68def pixel_shuffle(input, upscale_factor): 

69 logger.debug("GEMS PIXEL_SHUFFLE") 

70 r = int(upscale_factor) 

71 assert input.ndim == 4 

72 N, C, H, W = input.shape 

73 assert C % (r * r) == 0 

74 

75 C_out = C // (r * r) 

76 H_out = H * r 

77 W_out = W * r 

78 

79 input = input.contiguous() 

80 output = torch.empty( 

81 (N, C_out, H_out, W_out), device=input.device, dtype=input.dtype 

82 ) 

83 

84 n_elements = output.numel() 

85 if n_elements == 0: 

86 return output 

87 

88 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

89 with torch_device_fn.device(input.device): 

90 pixel_shuffle_kernel[grid]( 

91 input, 

92 output, 

93 n_elements, 

94 C, 

95 H, 

96 W, 

97 r, 

98 C_out, 

99 H_out, 

100 W_out, 

101 ) 

102 return output