Coverage for src/flag_gems/runtime/backend/_ascend/ops/argmin.py: 0%

98 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as ext 

11from flag_gems.utils.limits import get_dtype_max 

12 

13logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

14 

15 

16@libentry() 

17@triton.jit 

18def argmin_kernel_1( 

19 inp, 

20 mid_value, 

21 mid_index, 

22 M, 

23 BLOCK_SIZE: tl.constexpr, 

24): 

25 pid = ext.program_id(0) 

26 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

27 inp_ptrs = inp + offset 

28 mask = offset < M 

29 

30 max_value = get_dtype_max(inp.type.element_ty) 

31 inp_val = tl.load(inp_ptrs, mask=mask, other=max_value) 

32 min_val, min_index = tl.min(inp_val, axis=0, return_indices=True) 

33 min_index = min_index + pid * BLOCK_SIZE 

34 mid_value_ptr = mid_value + pid 

35 min_index_ptr = mid_index + pid 

36 tl.store(mid_value_ptr, min_val) 

37 tl.store(min_index_ptr, min_index) 

38 

39 

40@libentry() 

41@triton.jit 

42def argmin_kernel_2( 

43 mid_value, 

44 mid_index, 

45 out, 

46 mid_size, 

47 BLOCK_MID: tl.constexpr, 

48): 

49 offset = tl.arange(0, BLOCK_MID) 

50 mid_ptrs = mid_value + offset 

51 mask = offset < mid_size 

52 max_value = get_dtype_max(mid_value.type.element_ty) 

53 mid_val = tl.load(mid_ptrs, mask=mask, other=max_value) 

54 index_val = tl.argmin(mid_val, axis=0) 

55 mid_index_ptrs = mid_index + index_val 

56 out_val = tl.load(mid_index_ptrs) 

57 tl.store(out, out_val) 

58 

59 

60@libentry() 

61@triton.autotune( 

62 configs=[ 

63 triton.Config({"BLOCK_M": 1, "BLOCK_N": 512}), 

64 triton.Config({"BLOCK_M": 4, "BLOCK_N": 256}), 

65 triton.Config({"BLOCK_M": 8, "BLOCK_N": 128}), 

66 ], 

67 key=["M", "N", "K"], 

68) 

69@triton.jit 

70def argmin_kernel( 

71 inp, 

72 out_index, 

73 M, 

74 N, 

75 K, 

76 BLOCK_M: tl.constexpr, 

77 BLOCK_N: tl.constexpr, 

78): 

79 pid_m = ext.program_id(0) 

80 pid_k = ext.program_id(1) 

81 

82 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

83 

84 dtype = inp.type.element_ty 

85 acc_type = tl.float32 if dtype is tl.bfloat16 else dtype 

86 max_value = get_dtype_max(dtype) 

87 min_values = tl.full([BLOCK_M], dtype=acc_type, value=max_value) 

88 argmin_values = tl.full([BLOCK_M], dtype=tl.int64, value=0) 

89 for start_n in range(0, N, BLOCK_N): 

90 n_offset = start_n + tl.arange(0, BLOCK_N) 

91 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

92 mask = m_offset[:, None] < M and n_offset[None, :] < N 

93 inp_ptrs = inp + offset 

94 inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value) 

95 local_min, local_argmin = tl.min( 

96 inp_vals, 1, return_indices=True, return_indices_tie_break_left=True 

97 ) 

98 update = local_min < min_values 

99 min_values = tl.where(update, local_min, min_values) 

100 argmin_values = tl.where(update, start_n + local_argmin, argmin_values) 

101 

102 offset_index = m_offset * K + pid_k 

103 out_index_ptrs = out_index + offset_index 

104 mask1 = m_offset < M 

105 tl.store(out_index_ptrs, argmin_values, mask=mask1) 

106 

107 

108def argmin(inp, dim=None, keepdim=False, *, dtype=None): 

109 logger.debug("GEMS_ASCEND ARGMIN") 

110 if inp.dtype == torch.bfloat16: 

111 result = argmin(inp.to(torch.float32), dim=dim, keepdim=keepdim, dtype=dtype) 

112 return result 

113 if dim is None: 

114 M = inp.numel() 

115 if dtype is None: 

116 dtype = inp.dtype 

117 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

118 mid_size = triton.cdiv(M, block_size) 

119 block_mid = triton.next_power_of_2(mid_size) 

120 

121 mid_value = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

122 mid_index = torch.empty((mid_size,), dtype=torch.int64, device=inp.device) 

123 out = torch.empty([], dtype=torch.int64, device=inp.device) 

124 

125 with torch_device_fn.device(inp.device): 

126 argmin_kernel_1[(mid_size, 1, 1)]( 

127 inp, 

128 mid_value, 

129 mid_index, 

130 M, 

131 block_size, 

132 ) 

133 argmin_kernel_2[(1, 1, 1)]( 

134 mid_value, 

135 mid_index, 

136 out, 

137 mid_size, 

138 block_mid, 

139 ) 

140 return out 

141 else: 

142 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

143 shape = inp.shape 

144 dim = dim % inp.ndim 

145 N = shape[dim] 

146 M = math.prod(shape[:dim]) 

147 K = inp.numel() // M // N 

148 

149 inp = inp.contiguous() 

150 

151 shape_list = list(shape) 

152 shape_list[dim] = 1 

153 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device) 

154 if not keepdim: 

155 out_index = torch.squeeze(out_index, dim) 

156 

157 grid = lambda meta: ( 

158 triton.cdiv(M, meta["BLOCK_M"]), 

159 K, 

160 ) 

161 with torch_device_fn.device(inp.device): 

162 argmin_kernel[grid]( 

163 inp, 

164 out_index, 

165 M, 

166 N, 

167 K, 

168 ) 

169 

170 return out_index