Coverage for src/flag_gems/runtime/backend/_spacemit/ops/argmax.py: 0%

105 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12from flag_gems.utils.limits import get_dtype_min 

13 

14logger = logging.getLogger(__name__) 

15 

16 

17@libentry() 

18@triton.jit 

19def argmax_kernel_1( 

20 inp, 

21 mid_value, 

22 mid_index, 

23 M, 

24 BLOCK_SIZE: tl.constexpr, 

25): 

26 pid = tle.program_id(0) 

27 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

28 inp_ptrs = inp + offset 

29 mask = offset < M 

30 min_value = get_dtype_min(inp.type.element_ty) 

31 inp_val = tl.load(inp_ptrs, mask=mask, other=min_value) 

32 max_val, max_index = tl.max(inp_val, axis=0, return_indices=True) 

33 max_index = max_index + pid * BLOCK_SIZE 

34 mid_value_ptr = mid_value + pid 

35 max_index_ptr = mid_index + pid 

36 tl.store(mid_value_ptr, max_val) 

37 tl.store(max_index_ptr, max_index) 

38 

39 

40@libentry() 

41@triton.jit 

42def argmax_kernel_2(mid_value, mid_index, out, mid_size, BLOCK_MID: tl.constexpr): 

43 offset = tl.arange(0, BLOCK_MID) 

44 mid_ptrs = mid_value + offset 

45 mask = offset < mid_size 

46 min_value = get_dtype_min(mid_value.type.element_ty) 

47 mid_val = tl.load(mid_ptrs, mask=mask, other=min_value) 

48 index_val = tl.argmax(mid_val, axis=0) 

49 mid_index_ptrs = mid_index + index_val 

50 out_val = tl.load(mid_index_ptrs) 

51 tl.store(out, out_val) 

52 

53 

54@libentry() 

55@triton.heuristics(runtime.get_heuristic_config("argmax")) 

56@triton.jit 

57def argmax_kernel( 

58 inp_ptr, 

59 out_ptr, 

60 M, 

61 N, 

62 K, 

63 BLOCK_M: tl.constexpr, 

64 BLOCK_N: tl.constexpr, 

65): 

66 pid_m = tl.program_id(0) 

67 pid_k = tle.program_id(1) 

68 start_row = pid_m * BLOCK_M 

69 row_offsets = start_row + tl.arange(0, BLOCK_M) 

70 row_mask = row_offsets < M 

71 dtype = inp_ptr.dtype.element_ty 

72 min_value = get_dtype_min(dtype) 

73 row_max = tl.full((BLOCK_M,), min_value, dtype=dtype) 

74 row_argmax = tl.full((BLOCK_M,), -1, dtype=tl.int32) 

75 

76 for block_start in range(0, N, BLOCK_N): 

77 col_offsets = block_start + tl.arange(0, BLOCK_N) 

78 col_mask = col_offsets < N 

79 mask = row_mask[:, None] & col_mask[None, :] 

80 input_ptrs = ( 

81 inp_ptr + row_offsets[:, None] * N * K + col_offsets[None, :] * K + pid_k 

82 ) 

83 current_block = tl.load(input_ptrs, mask=mask, other=min_value) 

84 

85 block_max = tl.max(current_block, axis=1) 

86 block_argmax = tl.argmax(current_block, axis=1).to(tl.int32) + block_start 

87 

88 update_mask = block_max > row_max 

89 tie_mask = (block_max == row_max) & ( 

90 (row_argmax < 0) | (block_argmax < row_argmax) 

91 ) 

92 choose_new = update_mask | tie_mask 

93 

94 row_argmax = tl.where(choose_new, block_argmax, row_argmax) 

95 row_max = tl.where(update_mask, block_max, row_max) 

96 

97 out_offsets = row_offsets * K + pid_k 

98 out_ptrs = out_ptr + out_offsets 

99 tl.store(out_ptrs, row_argmax.to(out_ptr.dtype.element_ty), mask=row_mask) 

100 

101 

102def argmax(inp, dim=None, keepdim=False, *, dtype=None): 

103 logger.debug("GEMS_SPACEMIT ARGMAX") 

104 if dim is None: 

105 M = inp.numel() 

106 if dtype is None: 

107 dtype = inp.dtype 

108 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

109 mid_size = triton.cdiv(M, block_size) 

110 block_mid = triton.next_power_of_2(mid_size) 

111 

112 mid_value = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

113 mid_index = torch.empty((mid_size,), dtype=torch.int64, device=inp.device) 

114 if keepdim: 

115 shape = list(inp.shape) 

116 for i in range(0, inp.dim()): 

117 shape[i] = 1 

118 out = torch.empty(shape, dtype=torch.int64, device=inp.device) 

119 else: 

120 out = torch.empty([], dtype=torch.int64, device=inp.device) 

121 

122 with torch_device_fn.device(inp.device): 

123 argmax_kernel_1[(mid_size, 1, 1)]( 

124 inp, 

125 mid_value, 

126 mid_index, 

127 M, 

128 block_size, 

129 ) 

130 argmax_kernel_2[(1, 1, 1)](mid_value, mid_index, out, mid_size, block_mid) 

131 return out 

132 else: 

133 if dim < -inp.ndim or dim >= inp.ndim: 

134 raise IndexError( 

135 f"Dimension out of range (expected to be in range of [{-inp.ndim}, {inp.ndim - 1}], but got {dim})" 

136 ) 

137 shape = inp.shape 

138 dim = dim % inp.ndim 

139 N = shape[dim] 

140 M = math.prod(shape[:dim]) 

141 K = inp.numel() // M // N 

142 

143 inp = inp.contiguous() 

144 

145 shape_list = list(shape) 

146 shape_list[dim] = 1 

147 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device) 

148 if not keepdim: 

149 out_index = torch.squeeze(out_index, dim) 

150 

151 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]), K) 

152 with torch_device_fn.device(inp.device): 

153 argmax_kernel[grid]( 

154 inp, 

155 out_index, 

156 M, 

157 N, 

158 K, 

159 ) 

160 

161 return out_index