Coverage for src/flag_gems/runtime/backend/_sunrise/ops/upsample_nearest2d.py: 0%

52 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2from typing import Optional, Tuple 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import device 

9 

10device = device.name 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15def configs(): 

16 block = [128, 256, 512, 1024] 

17 warps = [4, 8, 16, 32] 

18 return [ 

19 triton.Config({"BLOCK_SIZE": bs}, num_warps=wp) for bs in block for wp in warps 

20 ] 

21 

22 

23@triton.autotune(configs=configs(), key=["N", "C", "OH", "OW"]) 

24@triton.heuristics( 

25 { 

26 "SAME_H": lambda args: args["OH"] == args["IH"], 

27 "SAME_W": lambda args: args["OW"] == args["IW"], 

28 } 

29) 

30@triton.jit 

31def upsample_nearest2d_kernel( 

32 ptr_o, 

33 ptr_i, 

34 sno, 

35 sco, 

36 sho, 

37 swo, 

38 sni, 

39 sci, 

40 shi, 

41 swi, 

42 N, 

43 C, 

44 OH, 

45 OW, 

46 IH, 

47 IW, 

48 reciprocal_scale_h, 

49 reciprocal_scale_w, 

50 BLOCK_SIZE: tl.constexpr, 

51 SAME_H: tl.constexpr, 

52 SAME_W: tl.constexpr, 

53): 

54 pid = tl.program_id(axis=0) 

55 idx = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

56 ow = idx % OW 

57 oh = idx // OW % OH 

58 c = idx // OW // OH % C 

59 n = idx // OW // OH // C % N 

60 if SAME_H: 

61 ih = oh 

62 else: 

63 # tl.floor() cannot be found in 2.3.1, using int trunc 

64 ih = tl.minimum((oh * reciprocal_scale_h).to(tl.int32), IH - 1) 

65 if SAME_W: 

66 iw = ow 

67 else: 

68 iw = tl.minimum((ow * reciprocal_scale_w).to(tl.int32), IW - 1) 

69 offset_o = n * sno + c * sco + oh * sho + ow * swo 

70 offset_i = n * sni + c * sci + ih * shi + iw * swi 

71 data = tl.load(ptr_i + offset_i) 

72 tl.store(ptr_o + offset_o, data) 

73 

74 

75def upsample_nearest2d( 

76 input: torch.Tensor, 

77 output_size: Tuple[int], 

78 scales_h: Optional[float] = None, 

79 scales_w: Optional[float] = None, 

80) -> torch.Tensor: 

81 logging.debug("GEMS UPSAMPLE NEAREST2D") 

82 assert input.device.type == device 

83 assert input.ndim == 4, "The ndim of input must be 4" 

84 assert len(output_size) == 2, "The len of output_size must be 2" 

85 OH, OW = output_size 

86 N, C, IH, IW = input.shape 

87 if scales_h is not None: 

88 reciprocal_scale_h = 1 / scales_h 

89 else: 

90 reciprocal_scale_h = IH / OH 

91 if scales_w is not None: 

92 reciprocal_scale_w = 1 / scales_w 

93 else: 

94 reciprocal_scale_w = IW / OW 

95 # allocate output 

96 output = torch.empty((N, C, OH, OW), device=input.device, dtype=input.dtype) 

97 total_threads = N * C * OH * OW 

98 sno, sco, sho, swo = output.stride() 

99 sni, sci, shi, swi = input.stride() 

100 grid = lambda META: (triton.cdiv(total_threads, META["BLOCK_SIZE"]),) 

101 upsample_nearest2d_kernel[grid]( 

102 output, 

103 input, 

104 sno, 

105 sco, 

106 sho, 

107 swo, 

108 sni, 

109 sci, 

110 shi, 

111 swi, 

112 N, 

113 C, 

114 OH, 

115 OW, 

116 IH, 

117 IW, 

118 reciprocal_scale_h, 

119 reciprocal_scale_w, 

120 ) 

121 return output