Coverage for src/flag_gems/ops/aminmax.py: 51%

120 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import dim_compress, libentry, libtuner 

11from flag_gems.utils import triton_lang_extension as ext 

12from flag_gems.utils.limits import get_dtype_max, get_dtype_min 

13 

14logger = logging.getLogger(__name__) 

15 

16 

17@libentry() 

18@triton.jit 

19def aminmax_kernel_1( 

20 inp, 

21 min_out, 

22 max_out, 

23 M, 

24 BLOCK_SIZE: tl.constexpr, 

25): 

26 pid = ext.program_id(0) 

27 

28 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

29 inp_ptrs = inp + offset 

30 mask = offset < M 

31 min_fill = get_dtype_max(inp.type.element_ty) 

32 max_fill = get_dtype_min(inp.type.element_ty) 

33 min_val = tl.load(inp_ptrs, mask=mask, other=min_fill) 

34 max_val = tl.load(inp_ptrs, mask=mask, other=max_fill) 

35 

36 min_val = tl.min(min_val) 

37 max_val = tl.max(max_val) 

38 

39 min_ptr = min_out + pid 

40 max_ptr = max_out + pid 

41 tl.store(min_ptr, min_val) 

42 tl.store(max_ptr, max_val) 

43 

44 

45@libentry() 

46@triton.jit 

47def aminmax_kernel_2( 

48 min_inp, max_inp, min_out, max_out, mid_size, BLOCK_MID: tl.constexpr 

49): 

50 offset = tl.arange(0, BLOCK_MID) 

51 min_ptrs = min_inp + offset 

52 max_ptrs = max_inp + offset 

53 mask = offset < mid_size 

54 min_fill = get_dtype_max(min_inp.type.element_ty) 

55 max_fill = get_dtype_min(max_inp.type.element_ty) 

56 min_val = tl.load(min_ptrs, mask=mask, other=min_fill) 

57 max_val = tl.load(max_ptrs, mask=mask, other=max_fill) 

58 

59 min_val = tl.min(min_val) 

60 max_val = tl.max(max_val) 

61 

62 tl.store(min_out, min_val) 

63 tl.store(max_out, max_val) 

64 

65 

66@libentry() 

67@libtuner( 

68 configs=runtime.get_tuned_config("naive_reduction"), 

69 key=["M", "N"], 

70) 

71@triton.jit 

72def aminmax_kernel( 

73 inp, 

74 min_out, 

75 max_out, 

76 M, 

77 N, 

78 BLOCK_M: tl.constexpr, 

79 BLOCK_N: tl.constexpr, 

80): 

81 dtype = inp.type.element_ty 

82 min_value = get_dtype_min(dtype) 

83 max_value = get_dtype_max(dtype) 

84 

85 # Map the program id to the row of inp it should compute. 

86 pid = ext.program_id(0) 

87 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

88 inp = inp + rows * N 

89 min_out = min_out + rows 

90 max_out = max_out + rows 

91 row_mask = rows < M 

92 

93 acc_type = tl.float32 if dtype is tl.bfloat16 else dtype 

94 _min = tl.full([BLOCK_M, BLOCK_N], value=max_value, dtype=acc_type) 

95 _max = tl.full([BLOCK_M, BLOCK_N], value=min_value, dtype=acc_type) 

96 for off in range(0, N, BLOCK_N): 

97 cols = off + tl.arange(0, BLOCK_N)[None, :] 

98 col_mask = cols < N 

99 mask = row_mask & col_mask 

100 a = tl.load(inp + cols, mask=mask, other=min_value) 

101 _min = tl.where(mask, tl.minimum(_min, a), _min) 

102 _max = tl.where(mask, tl.maximum(_max, a), _max) 

103 min_result = tl.min(_min, axis=1)[:, None] 

104 max_result = tl.max(_max, axis=1)[:, None] 

105 tl.store(min_out, min_result, row_mask) 

106 tl.store(max_out, max_result, row_mask) 

107 

108 

109def aminmax(inp, dim=None, keepdim=False, *, out=None): 

110 logger.debug("GEMS AMINMAX") 

111 

112 if dim is None: 

113 M = inp.numel() 

114 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

115 mid_size = triton.cdiv(M, block_size) 

116 block_mid = triton.next_power_of_2(mid_size) 

117 dtype = inp.dtype 

118 min_mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

119 max_mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

120 

121 if out is not None: 

122 min_out = out[0] if isinstance(out, tuple) else out 

123 max_out = out[1] if isinstance(out, tuple) else out 

124 if not keepdim: 

125 min_out = min_out.squeeze() 

126 max_out = max_out.squeeze() 

127 else: 

128 if not keepdim: 

129 min_out = torch.empty([], dtype=dtype, device=inp.device) 

130 max_out = torch.empty([], dtype=dtype, device=inp.device) 

131 else: 

132 shape = [1] * inp.dim() 

133 min_out = torch.empty(shape, dtype=dtype, device=inp.device) 

134 max_out = torch.empty(shape, dtype=dtype, device=inp.device) 

135 

136 with torch_device_fn.device(inp.device): 

137 aminmax_kernel_1[(mid_size, 1)]( 

138 inp, 

139 min_mid, 

140 max_mid, 

141 M, 

142 block_size, 

143 ) 

144 aminmax_kernel_2[(1, 1)]( 

145 min_mid, max_mid, min_out, max_out, mid_size, block_mid 

146 ) 

147 return min_out, max_out 

148 else: 

149 if isinstance(dim, int): 

150 dim = [dim] 

151 assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim" 

152 dtype = inp.dtype 

153 

154 shape = list(inp.shape) 

155 dim = [d % inp.ndim for d in dim] 

156 inp = dim_compress(inp, dim) 

157 N = 1 

158 for i in dim: 

159 N *= shape[i] 

160 shape[i] = 1 

161 M = inp.numel() // N 

162 

163 if out is not None: 

164 min_out = out[0] if isinstance(out, tuple) else out 

165 max_out = out[1] if isinstance(out, tuple) else out 

166 else: 

167 min_out = torch.empty(shape, dtype=dtype, device=inp.device) 

168 max_out = torch.empty(shape, dtype=dtype, device=inp.device) 

169 

170 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

171 with torch_device_fn.device(inp.device): 

172 aminmax_kernel[grid](inp, min_out, max_out, M, N) 

173 if not keepdim: 

174 min_out = min_out.squeeze(dim=dim) 

175 max_out = max_out.squeeze(dim=dim) 

176 return min_out, max_out