Coverage for src/flag_gems/runtime/backend/_arm/ops/argmax.py: 0%
103 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
10# from ..runtime import torch_device_fn
11# from ..utils import libentry
12from flag_gems.utils import triton_lang_extension as tle
15# @libentry()
16@triton.jit
17def argmax_kernel_1(
18 inp,
19 mid_value,
20 mid_index,
21 M,
22 BLOCK_SIZE: tl.constexpr,
23):
24 pid = tle.program_id(0)
25 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
26 inp_ptrs = inp + offset
27 mask = offset < M
28 inp_val = tl.load(inp_ptrs, mask=mask, other=-float("inf"))
29 max_val, max_index = tl.max(inp_val, axis=0, return_indices=True)
30 max_index = max_index + pid * BLOCK_SIZE
31 mid_value_ptr = mid_value + pid
32 max_index_ptr = mid_index + pid
33 tl.store(mid_value_ptr, max_val)
34 tl.store(max_index_ptr, max_index)
37# @libentry()
38@triton.jit
39def argmax_kernel_2(mid_value, mid_index, out, mid_size, BLOCK_MID: tl.constexpr):
40 offset = tl.arange(0, BLOCK_MID)
41 mid_ptrs = mid_value + offset
42 mask = offset < mid_size
43 mid_val = tl.load(mid_ptrs, mask=mask, other=-float("inf"))
44 index_val = tl.argmax(mid_val, axis=0)
45 mid_index_ptrs = mid_index + index_val
46 out_val = tl.load(mid_index_ptrs)
47 tl.store(out, out_val)
50# @libentry()
51@triton.heuristics(runtime.get_heuristic_config("argmax"))
52@triton.jit
53def argmax_kernel(
54 inp,
55 out_index,
56 M,
57 N,
58 K,
59 BLOCK_M: tl.constexpr,
60 BLOCK_N: tl.constexpr,
61):
62 # set offset
63 pid_m = tle.program_id(0)
64 pid_k = tle.program_id(1)
65 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
67 max_values = tl.full([BLOCK_M], dtype=tl.float32, value=float("-inf"))
68 argmax_values = tl.full([BLOCK_M], dtype=tl.int64, value=0)
69 for start_n in range(0, N, BLOCK_N):
70 n_offset = start_n + tl.arange(0, BLOCK_N)
71 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
72 mask = m_offset[:, None] < M and n_offset[None, :] < N
73 inp_ptrs = inp + offset
74 inp_vals = tl.load(inp_ptrs, mask=mask, other=-float("inf")).to(tl.float32)
75 local_max, local_argmax = tl.max(
76 inp_vals, 1, return_indices=True, return_indices_tie_break_left=True
77 )
78 # if return indices is not supported, call a tl.argmax in addition
79 # local_argmax = tl.argmax(inp_vals, 1)
80 update = local_max > max_values
81 max_values = tl.where(update, local_max, max_values)
82 argmax_values = tl.where(update, start_n + local_argmax, argmax_values)
84 offset_index = m_offset * K + pid_k
85 out_index_ptrs = out_index + offset_index
86 mask1 = m_offset < M
87 tl.store(out_index_ptrs, argmax_values, mask=mask1)
90def argmax(inp, dim=None, keepdim=False, *, dtype=None):
91 logging.debug("GEMS ARGMAX")
92 if dim is None:
93 M = inp.numel()
94 if dtype is None:
95 dtype = inp.dtype
96 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
97 mid_size = triton.cdiv(M, block_size)
98 block_mid = triton.next_power_of_2(mid_size)
100 mid_value = torch.empty((mid_size,), dtype=dtype, device=inp.device)
101 mid_index = torch.empty((mid_size,), dtype=torch.int64, device=inp.device)
102 if keepdim:
103 shape = list(inp.shape)
104 for i in range(0, inp.dim()):
105 shape[i] = 1
106 out = torch.empty(shape, dtype=torch.int64, device=inp.device)
107 else:
108 out = torch.empty([], dtype=torch.int64, device=inp.device)
110 # with torch_device_fn.device(inp.device):
111 argmax_kernel_1[(mid_size, 1, 1)](
112 inp,
113 mid_value,
114 mid_index,
115 M,
116 block_size,
117 )
118 argmax_kernel_2[(1, 1, 1)](mid_value, mid_index, out, mid_size, block_mid)
119 return out
120 else:
121 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
122 shape = inp.shape
123 dim = dim % inp.ndim
124 if inp.numel() == 0:
125 out_shape = list(shape)
126 if keepdim:
127 out_shape[dim] = 1
128 else:
129 del out_shape[dim]
130 return torch.zeros(out_shape, dtype=torch.int64, device=inp.device)
131 N = shape[dim]
132 M = math.prod(shape[:dim])
133 K = inp.numel() // M // N
135 inp = inp.contiguous()
137 shape_list = list(shape)
138 shape_list[dim] = 1
139 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device)
140 if not keepdim:
141 out_index = torch.squeeze(out_index, dim)
143 # Decode-heavy path frequently reduces a single row over vocab; use
144 # a two-stage reduction to parallelize across N and reduce launch cost.
145 if M == 1 and K == 1:
146 block_size = triton.next_power_of_2(math.ceil(math.sqrt(N)))
147 mid_size = triton.cdiv(N, block_size)
148 block_mid = triton.next_power_of_2(mid_size)
149 mid_value = torch.empty((mid_size,), dtype=inp.dtype, device=inp.device)
150 mid_index = torch.empty((mid_size,), dtype=torch.int64, device=inp.device)
151 flat_out = out_index.reshape(-1)
152 argmax_kernel_1[(mid_size, 1, 1)](
153 inp.reshape(-1),
154 mid_value,
155 mid_index,
156 N,
157 block_size,
158 )
159 argmax_kernel_2[(1, 1, 1)](
160 mid_value,
161 mid_index,
162 flat_out,
163 mid_size,
164 block_mid,
165 )
166 return out_index
168 grid = lambda meta: (
169 triton.cdiv(M, meta["BLOCK_M"]),
170 K,
171 )
172 # with torch_device_fn.device(inp.device):
173 argmax_kernel[grid](
174 inp,
175 out_index,
176 M,
177 N,
178 K,
179 )
181 return out_index