Coverage for src/flag_gems/runtime/backend/_ascend/ops/argmax.py: 0%
103 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as ext
12from flag_gems.utils.limits import get_dtype_min
14logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
17@libentry()
18@triton.jit
19def argmax_kernel_1(
20 inp,
21 mid_value,
22 mid_index,
23 M,
24 BLOCK_SIZE: tl.constexpr,
25):
26 pid = ext.program_id(0)
27 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
28 inp_ptrs = inp + offset
29 mask = offset < M
30 min_value = get_dtype_min(inp.type.element_ty)
31 inp_val = tl.load(inp_ptrs, mask=mask, other=min_value)
32 max_val, max_index = tl.max(inp_val, axis=0, return_indices=True)
33 max_index = max_index + pid * BLOCK_SIZE
34 mid_value_ptr = mid_value + pid
35 max_index_ptr = mid_index + pid
36 tl.store(mid_value_ptr, max_val)
37 tl.store(max_index_ptr, max_index)
40@libentry()
41@triton.jit
42def argmax_kernel_2(mid_value, mid_index, out, mid_size, BLOCK_MID: tl.constexpr):
43 offset = tl.arange(0, BLOCK_MID)
44 mid_ptrs = mid_value + offset
45 mask = offset < mid_size
46 min_value = get_dtype_min(mid_value.type.element_ty)
47 mid_val = tl.load(mid_ptrs, mask=mask, other=min_value)
48 index_val = tl.argmax(mid_val, axis=0)
49 mid_index_ptrs = mid_index + index_val
50 out_val = tl.load(mid_index_ptrs)
51 tl.store(out, out_val)
54@libentry()
55@libentry()
56@triton.heuristics(runtime.get_heuristic_config("argmax"))
57@triton.jit
58def argmax_kernel(
59 inp,
60 out_index,
61 M,
62 N,
63 K,
64 BLOCK_M: tl.constexpr,
65 BLOCK_N: tl.constexpr,
66):
67 pid_m = ext.program_id(0)
68 pid_k = ext.program_id(1)
70 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
72 dtype = inp.type.element_ty
73 acc_type = tl.float32 if dtype is tl.bfloat16 else dtype
74 min_value = get_dtype_min(dtype)
75 max_values = tl.full([BLOCK_M], dtype=acc_type, value=min_value)
76 argmax_values = tl.full([BLOCK_M], dtype=tl.int64, value=0)
77 for start_n in range(0, N, BLOCK_N):
78 n_offset = start_n + tl.arange(0, BLOCK_N)
79 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
80 mask = m_offset[:, None] < M and n_offset[None, :] < N
81 inp_ptrs = inp + offset
82 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value)
83 local_max, local_argmax = tl.max(
84 inp_vals, 1, return_indices=True, return_indices_tie_break_left=True
85 )
86 update = local_max > max_values
87 max_values = tl.where(update, local_max, max_values)
88 argmax_values = tl.where(update, start_n + local_argmax, argmax_values)
90 offset_index = m_offset * K + pid_k
91 out_index_ptrs = out_index + offset_index
92 mask1 = m_offset < M
93 tl.store(out_index_ptrs, argmax_values, mask=mask1)
96def argmax(inp, dim=None, keepdim=False, *, dtype=None):
97 logger.debug("GEMS_ASCEND ARGMAX")
98 if dim is None:
99 M = inp.numel()
100 if dtype is None:
101 dtype = inp.dtype
102 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
103 mid_size = triton.cdiv(M, block_size)
104 block_mid = triton.next_power_of_2(mid_size)
106 mid_value = torch.empty((mid_size,), dtype=dtype, device=inp.device)
107 mid_index = torch.empty((mid_size,), dtype=torch.int64, device=inp.device)
108 out = torch.empty([], dtype=torch.int64, device=inp.device)
110 with torch_device_fn.device(inp.device):
111 argmax_kernel_1[(mid_size, 1, 1)](
112 inp,
113 mid_value,
114 mid_index,
115 M,
116 block_size,
117 )
118 argmax_kernel_2[(1, 1, 1)](mid_value, mid_index, out, mid_size, block_mid)
119 return out
120 else:
121 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
122 shape = inp.shape
123 dim = dim % inp.ndim
124 if inp.numel() == 0:
125 out_shape = list(shape)
126 if keepdim:
127 out_shape[dim] = 1
128 else:
129 del out_shape[dim]
130 return torch.zeros(out_shape, dtype=torch.int64, device=inp.device)
131 N = shape[dim]
132 M = math.prod(shape[:dim])
133 K = inp.numel() // M // N
135 inp = inp.contiguous()
137 shape_list = list(shape)
138 shape_list[dim] = 1
139 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device)
140 if not keepdim:
141 out_index = torch.squeeze(out_index, dim)
142 grid = lambda meta: (
143 triton.cdiv(M, meta["BLOCK_M"]),
144 K,
145 )
146 with torch_device_fn.device(inp.device):
147 argmax_kernel[grid](
148 inp,
149 out_index,
150 M,
151 N,
152 K,
153 )
155 return out_index