Coverage for src/flag_gems/runtime/backend/_arm/ops/argmax.py: 0%

103 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9 

10# from ..runtime import torch_device_fn 

11# from ..utils import libentry 

12from flag_gems.utils import triton_lang_extension as tle 

13 

14 

15# @libentry() 

16@triton.jit 

17def argmax_kernel_1( 

18 inp, 

19 mid_value, 

20 mid_index, 

21 M, 

22 BLOCK_SIZE: tl.constexpr, 

23): 

24 pid = tle.program_id(0) 

25 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

26 inp_ptrs = inp + offset 

27 mask = offset < M 

28 inp_val = tl.load(inp_ptrs, mask=mask, other=-float("inf")) 

29 max_val, max_index = tl.max(inp_val, axis=0, return_indices=True) 

30 max_index = max_index + pid * BLOCK_SIZE 

31 mid_value_ptr = mid_value + pid 

32 max_index_ptr = mid_index + pid 

33 tl.store(mid_value_ptr, max_val) 

34 tl.store(max_index_ptr, max_index) 

35 

36 

37# @libentry() 

38@triton.jit 

39def argmax_kernel_2(mid_value, mid_index, out, mid_size, BLOCK_MID: tl.constexpr): 

40 offset = tl.arange(0, BLOCK_MID) 

41 mid_ptrs = mid_value + offset 

42 mask = offset < mid_size 

43 mid_val = tl.load(mid_ptrs, mask=mask, other=-float("inf")) 

44 index_val = tl.argmax(mid_val, axis=0) 

45 mid_index_ptrs = mid_index + index_val 

46 out_val = tl.load(mid_index_ptrs) 

47 tl.store(out, out_val) 

48 

49 

50# @libentry() 

51@triton.heuristics(runtime.get_heuristic_config("argmax")) 

52@triton.jit 

53def argmax_kernel( 

54 inp, 

55 out_index, 

56 M, 

57 N, 

58 K, 

59 BLOCK_M: tl.constexpr, 

60 BLOCK_N: tl.constexpr, 

61): 

62 # set offset 

63 pid_m = tle.program_id(0) 

64 pid_k = tle.program_id(1) 

65 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

66 

67 max_values = tl.full([BLOCK_M], dtype=tl.float32, value=float("-inf")) 

68 argmax_values = tl.full([BLOCK_M], dtype=tl.int64, value=0) 

69 for start_n in range(0, N, BLOCK_N): 

70 n_offset = start_n + tl.arange(0, BLOCK_N) 

71 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

72 mask = m_offset[:, None] < M and n_offset[None, :] < N 

73 inp_ptrs = inp + offset 

74 inp_vals = tl.load(inp_ptrs, mask=mask, other=-float("inf")).to(tl.float32) 

75 local_max, local_argmax = tl.max( 

76 inp_vals, 1, return_indices=True, return_indices_tie_break_left=True 

77 ) 

78 # if return indices is not supported, call a tl.argmax in addition 

79 # local_argmax = tl.argmax(inp_vals, 1) 

80 update = local_max > max_values 

81 max_values = tl.where(update, local_max, max_values) 

82 argmax_values = tl.where(update, start_n + local_argmax, argmax_values) 

83 

84 offset_index = m_offset * K + pid_k 

85 out_index_ptrs = out_index + offset_index 

86 mask1 = m_offset < M 

87 tl.store(out_index_ptrs, argmax_values, mask=mask1) 

88 

89 

90def argmax(inp, dim=None, keepdim=False, *, dtype=None): 

91 logging.debug("GEMS ARGMAX") 

92 if dim is None: 

93 M = inp.numel() 

94 if dtype is None: 

95 dtype = inp.dtype 

96 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

97 mid_size = triton.cdiv(M, block_size) 

98 block_mid = triton.next_power_of_2(mid_size) 

99 

100 mid_value = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

101 mid_index = torch.empty((mid_size,), dtype=torch.int64, device=inp.device) 

102 if keepdim: 

103 shape = list(inp.shape) 

104 for i in range(0, inp.dim()): 

105 shape[i] = 1 

106 out = torch.empty(shape, dtype=torch.int64, device=inp.device) 

107 else: 

108 out = torch.empty([], dtype=torch.int64, device=inp.device) 

109 

110 # with torch_device_fn.device(inp.device): 

111 argmax_kernel_1[(mid_size, 1, 1)]( 

112 inp, 

113 mid_value, 

114 mid_index, 

115 M, 

116 block_size, 

117 ) 

118 argmax_kernel_2[(1, 1, 1)](mid_value, mid_index, out, mid_size, block_mid) 

119 return out 

120 else: 

121 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

122 shape = inp.shape 

123 dim = dim % inp.ndim 

124 if inp.numel() == 0: 

125 out_shape = list(shape) 

126 if keepdim: 

127 out_shape[dim] = 1 

128 else: 

129 del out_shape[dim] 

130 return torch.zeros(out_shape, dtype=torch.int64, device=inp.device) 

131 N = shape[dim] 

132 M = math.prod(shape[:dim]) 

133 K = inp.numel() // M // N 

134 

135 inp = inp.contiguous() 

136 

137 shape_list = list(shape) 

138 shape_list[dim] = 1 

139 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device) 

140 if not keepdim: 

141 out_index = torch.squeeze(out_index, dim) 

142 

143 # Decode-heavy path frequently reduces a single row over vocab; use 

144 # a two-stage reduction to parallelize across N and reduce launch cost. 

145 if M == 1 and K == 1: 

146 block_size = triton.next_power_of_2(math.ceil(math.sqrt(N))) 

147 mid_size = triton.cdiv(N, block_size) 

148 block_mid = triton.next_power_of_2(mid_size) 

149 mid_value = torch.empty((mid_size,), dtype=inp.dtype, device=inp.device) 

150 mid_index = torch.empty((mid_size,), dtype=torch.int64, device=inp.device) 

151 flat_out = out_index.reshape(-1) 

152 argmax_kernel_1[(mid_size, 1, 1)]( 

153 inp.reshape(-1), 

154 mid_value, 

155 mid_index, 

156 N, 

157 block_size, 

158 ) 

159 argmax_kernel_2[(1, 1, 1)]( 

160 mid_value, 

161 mid_index, 

162 flat_out, 

163 mid_size, 

164 block_mid, 

165 ) 

166 return out_index 

167 

168 grid = lambda meta: ( 

169 triton.cdiv(M, meta["BLOCK_M"]), 

170 K, 

171 ) 

172 # with torch_device_fn.device(inp.device): 

173 argmax_kernel[grid]( 

174 inp, 

175 out_index, 

176 M, 

177 N, 

178 K, 

179 ) 

180 

181 return out_index