Coverage for src/flag_gems/runtime/backend/_ascend/ops/argmax.py: 0%

103 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as ext 

12from flag_gems.utils.limits import get_dtype_min 

13 

14logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

15 

16 

17@libentry() 

18@triton.jit 

19def argmax_kernel_1( 

20 inp, 

21 mid_value, 

22 mid_index, 

23 M, 

24 BLOCK_SIZE: tl.constexpr, 

25): 

26 pid = ext.program_id(0) 

27 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

28 inp_ptrs = inp + offset 

29 mask = offset < M 

30 min_value = get_dtype_min(inp.type.element_ty) 

31 inp_val = tl.load(inp_ptrs, mask=mask, other=min_value) 

32 max_val, max_index = tl.max(inp_val, axis=0, return_indices=True) 

33 max_index = max_index + pid * BLOCK_SIZE 

34 mid_value_ptr = mid_value + pid 

35 max_index_ptr = mid_index + pid 

36 tl.store(mid_value_ptr, max_val) 

37 tl.store(max_index_ptr, max_index) 

38 

39 

40@libentry() 

41@triton.jit 

42def argmax_kernel_2(mid_value, mid_index, out, mid_size, BLOCK_MID: tl.constexpr): 

43 offset = tl.arange(0, BLOCK_MID) 

44 mid_ptrs = mid_value + offset 

45 mask = offset < mid_size 

46 min_value = get_dtype_min(mid_value.type.element_ty) 

47 mid_val = tl.load(mid_ptrs, mask=mask, other=min_value) 

48 index_val = tl.argmax(mid_val, axis=0) 

49 mid_index_ptrs = mid_index + index_val 

50 out_val = tl.load(mid_index_ptrs) 

51 tl.store(out, out_val) 

52 

53 

54@libentry() 

55@libentry() 

56@triton.heuristics(runtime.get_heuristic_config("argmax")) 

57@triton.jit 

58def argmax_kernel( 

59 inp, 

60 out_index, 

61 M, 

62 N, 

63 K, 

64 BLOCK_M: tl.constexpr, 

65 BLOCK_N: tl.constexpr, 

66): 

67 pid_m = ext.program_id(0) 

68 pid_k = ext.program_id(1) 

69 

70 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

71 

72 dtype = inp.type.element_ty 

73 acc_type = tl.float32 if dtype is tl.bfloat16 else dtype 

74 min_value = get_dtype_min(dtype) 

75 max_values = tl.full([BLOCK_M], dtype=acc_type, value=min_value) 

76 argmax_values = tl.full([BLOCK_M], dtype=tl.int64, value=0) 

77 for start_n in range(0, N, BLOCK_N): 

78 n_offset = start_n + tl.arange(0, BLOCK_N) 

79 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k 

80 mask = m_offset[:, None] < M and n_offset[None, :] < N 

81 inp_ptrs = inp + offset 

82 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value) 

83 local_max, local_argmax = tl.max( 

84 inp_vals, 1, return_indices=True, return_indices_tie_break_left=True 

85 ) 

86 update = local_max > max_values 

87 max_values = tl.where(update, local_max, max_values) 

88 argmax_values = tl.where(update, start_n + local_argmax, argmax_values) 

89 

90 offset_index = m_offset * K + pid_k 

91 out_index_ptrs = out_index + offset_index 

92 mask1 = m_offset < M 

93 tl.store(out_index_ptrs, argmax_values, mask=mask1) 

94 

95 

96def argmax(inp, dim=None, keepdim=False, *, dtype=None): 

97 logger.debug("GEMS_ASCEND ARGMAX") 

98 if dim is None: 

99 M = inp.numel() 

100 if dtype is None: 

101 dtype = inp.dtype 

102 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

103 mid_size = triton.cdiv(M, block_size) 

104 block_mid = triton.next_power_of_2(mid_size) 

105 

106 mid_value = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

107 mid_index = torch.empty((mid_size,), dtype=torch.int64, device=inp.device) 

108 out = torch.empty([], dtype=torch.int64, device=inp.device) 

109 

110 with torch_device_fn.device(inp.device): 

111 argmax_kernel_1[(mid_size, 1, 1)]( 

112 inp, 

113 mid_value, 

114 mid_index, 

115 M, 

116 block_size, 

117 ) 

118 argmax_kernel_2[(1, 1, 1)](mid_value, mid_index, out, mid_size, block_mid) 

119 return out 

120 else: 

121 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

122 shape = inp.shape 

123 dim = dim % inp.ndim 

124 if inp.numel() == 0: 

125 out_shape = list(shape) 

126 if keepdim: 

127 out_shape[dim] = 1 

128 else: 

129 del out_shape[dim] 

130 return torch.zeros(out_shape, dtype=torch.int64, device=inp.device) 

131 N = shape[dim] 

132 M = math.prod(shape[:dim]) 

133 K = inp.numel() // M // N 

134 

135 inp = inp.contiguous() 

136 

137 shape_list = list(shape) 

138 shape_list[dim] = 1 

139 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device) 

140 if not keepdim: 

141 out_index = torch.squeeze(out_index, dim) 

142 grid = lambda meta: ( 

143 triton.cdiv(M, meta["BLOCK_M"]), 

144 K, 

145 ) 

146 with torch_device_fn.device(inp.device): 

147 argmax_kernel[grid]( 

148 inp, 

149 out_index, 

150 M, 

151 N, 

152 K, 

153 ) 

154 

155 return out_index