Coverage for src/flag_gems/runtime/backend/_spacemit/ops/argmin.py: 0%

103 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12from flag_gems.utils.limits import get_dtype_max 

13 

14logger = logging.getLogger(__name__) 

15 

16 

17@libentry() 

18@triton.jit 

19def argmin_kernel_1( 

20 inp, 

21 mid_value, 

22 mid_index, 

23 M, 

24 BLOCK_SIZE: tl.constexpr, 

25): 

26 pid = tle.program_id(0) 

27 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

28 inp_ptrs = inp + offset 

29 mask = offset < M 

30 max_value = get_dtype_max(inp.type.element_ty) 

31 inp_val = tl.load(inp_ptrs, mask=mask, other=max_value) 

32 min_val, min_index = tl.min(inp_val, axis=0, return_indices=True) 

33 min_index = min_index + pid * BLOCK_SIZE 

34 tl.store(mid_value + pid, min_val) 

35 tl.store(mid_index + pid, min_index) 

36 

37 

38@libentry() 

39@triton.jit 

40def argmin_kernel_2(mid_value, mid_index, out, mid_size, BLOCK_MID: tl.constexpr): 

41 offset = tl.arange(0, BLOCK_MID) 

42 mid_ptrs = mid_value + offset 

43 mask = offset < mid_size 

44 max_value = get_dtype_max(mid_value.type.element_ty) 

45 mid_val = tl.load(mid_ptrs, mask=mask, other=max_value) 

46 index_val = tl.argmin(mid_val, axis=0) 

47 out_val = tl.load(mid_index + index_val) 

48 tl.store(out, out_val) 

49 

50 

51@libentry() 

52@triton.heuristics(runtime.get_heuristic_config("argmin")) 

53@triton.jit 

54def argmin_kernel( 

55 inp_ptr, 

56 out_ptr, 

57 M, 

58 N, 

59 K, 

60 BLOCK_M: tl.constexpr, 

61 BLOCK_N: tl.constexpr, 

62): 

63 pid_m = tl.program_id(0) 

64 pid_k = tle.program_id(1) 

65 start_row = pid_m * BLOCK_M 

66 row_offsets = start_row + tl.arange(0, BLOCK_M) 

67 row_mask = row_offsets < M 

68 

69 dtype = inp_ptr.dtype.element_ty 

70 acc_type = ( 

71 tl.float32 

72 if (dtype is tl.bfloat16 or dtype is tl.float16) 

73 else ( 

74 tl.int32 

75 if (dtype is tl.int16 or dtype is tl.int8 or dtype is tl.uint8) 

76 else dtype 

77 ) 

78 ) 

79 max_value = get_dtype_max(dtype) 

80 max_value_acc = get_dtype_max(acc_type) 

81 row_min = tl.full((BLOCK_M,), max_value_acc, dtype=acc_type) 

82 row_argmin = tl.full((BLOCK_M,), -1, dtype=tl.int32) 

83 

84 for block_start in range(0, N, BLOCK_N): 

85 col_offsets = block_start + tl.arange(0, BLOCK_N) 

86 col_mask = col_offsets < N 

87 mask = row_mask[:, None] & col_mask[None, :] 

88 input_ptrs = ( 

89 inp_ptr + row_offsets[:, None] * N * K + col_offsets[None, :] * K + pid_k 

90 ) 

91 current_block = tl.load(input_ptrs, mask=mask, other=max_value).to(acc_type) 

92 

93 block_min = tl.min(current_block, axis=1) 

94 block_argmin = tl.argmin(current_block, axis=1).to(tl.int32) + block_start 

95 

96 update_mask = block_min < row_min 

97 tie_mask = (block_min == row_min) & ( 

98 (row_argmin < 0) | (block_argmin < row_argmin) 

99 ) 

100 choose_new = update_mask | tie_mask 

101 

102 row_argmin = tl.where(choose_new, block_argmin, row_argmin) 

103 row_min = tl.where(update_mask, block_min, row_min) 

104 

105 out_offsets = row_offsets * K + pid_k 

106 tl.store( 

107 out_ptr + out_offsets, row_argmin.to(out_ptr.dtype.element_ty), mask=row_mask 

108 ) 

109 

110 

111def argmin(inp, dim=None, keepdim=False, *, dtype=None): 

112 logger.debug("GEMS_SPACEMIT ARGMIN") 

113 if dim is None: 

114 M = inp.numel() 

115 if dtype is None: 

116 dtype = inp.dtype 

117 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

118 mid_size = triton.cdiv(M, block_size) 

119 block_mid = triton.next_power_of_2(mid_size) 

120 

121 mid_value = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

122 mid_index = torch.empty((mid_size,), dtype=torch.int64, device=inp.device) 

123 if keepdim: 

124 shape = list(inp.shape) 

125 for i in range(0, inp.dim()): 

126 shape[i] = 1 

127 out = torch.empty(shape, dtype=torch.int64, device=inp.device) 

128 else: 

129 out = torch.empty([], dtype=torch.int64, device=inp.device) 

130 

131 with torch_device_fn.device(inp.device): 

132 argmin_kernel_1[(mid_size, 1, 1)]( 

133 inp, 

134 mid_value, 

135 mid_index, 

136 M, 

137 block_size, 

138 ) 

139 argmin_kernel_2[(1, 1, 1)](mid_value, mid_index, out, mid_size, block_mid) 

140 return out 

141 

142 if dim < -inp.ndim or dim >= inp.ndim: 

143 raise IndexError( 

144 f"Dimension out of range (expected to be in range of [{-inp.ndim}, {inp.ndim - 1}], but got {dim})" 

145 ) 

146 

147 shape = inp.shape 

148 dim = dim % inp.ndim 

149 N = shape[dim] 

150 M = math.prod(shape[:dim]) 

151 K = inp.numel() // M // N 

152 

153 inp = inp.contiguous() 

154 

155 shape_list = list(shape) 

156 shape_list[dim] = 1 

157 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device) 

158 if not keepdim: 

159 out_index = torch.squeeze(out_index, dim) 

160 

161 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]), K) 

162 with torch_device_fn.device(inp.device): 

163 argmin_kernel[grid]( 

164 inp, 

165 out_index, 

166 M, 

167 N, 

168 K, 

169 ) 

170 

171 return out_index