Coverage for src/flag_gems/runtime/backend/_spacemit/ops/argmax.py: 0%
105 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as tle
12from flag_gems.utils.limits import get_dtype_min
14logger = logging.getLogger(__name__)
17@libentry()
18@triton.jit
19def argmax_kernel_1(
20 inp,
21 mid_value,
22 mid_index,
23 M,
24 BLOCK_SIZE: tl.constexpr,
25):
26 pid = tle.program_id(0)
27 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
28 inp_ptrs = inp + offset
29 mask = offset < M
30 min_value = get_dtype_min(inp.type.element_ty)
31 inp_val = tl.load(inp_ptrs, mask=mask, other=min_value)
32 max_val, max_index = tl.max(inp_val, axis=0, return_indices=True)
33 max_index = max_index + pid * BLOCK_SIZE
34 mid_value_ptr = mid_value + pid
35 max_index_ptr = mid_index + pid
36 tl.store(mid_value_ptr, max_val)
37 tl.store(max_index_ptr, max_index)
40@libentry()
41@triton.jit
42def argmax_kernel_2(mid_value, mid_index, out, mid_size, BLOCK_MID: tl.constexpr):
43 offset = tl.arange(0, BLOCK_MID)
44 mid_ptrs = mid_value + offset
45 mask = offset < mid_size
46 min_value = get_dtype_min(mid_value.type.element_ty)
47 mid_val = tl.load(mid_ptrs, mask=mask, other=min_value)
48 index_val = tl.argmax(mid_val, axis=0)
49 mid_index_ptrs = mid_index + index_val
50 out_val = tl.load(mid_index_ptrs)
51 tl.store(out, out_val)
54@libentry()
55@triton.heuristics(runtime.get_heuristic_config("argmax"))
56@triton.jit
57def argmax_kernel(
58 inp_ptr,
59 out_ptr,
60 M,
61 N,
62 K,
63 BLOCK_M: tl.constexpr,
64 BLOCK_N: tl.constexpr,
65):
66 pid_m = tl.program_id(0)
67 pid_k = tle.program_id(1)
68 start_row = pid_m * BLOCK_M
69 row_offsets = start_row + tl.arange(0, BLOCK_M)
70 row_mask = row_offsets < M
71 dtype = inp_ptr.dtype.element_ty
72 min_value = get_dtype_min(dtype)
73 row_max = tl.full((BLOCK_M,), min_value, dtype=dtype)
74 row_argmax = tl.full((BLOCK_M,), -1, dtype=tl.int32)
76 for block_start in range(0, N, BLOCK_N):
77 col_offsets = block_start + tl.arange(0, BLOCK_N)
78 col_mask = col_offsets < N
79 mask = row_mask[:, None] & col_mask[None, :]
80 input_ptrs = (
81 inp_ptr + row_offsets[:, None] * N * K + col_offsets[None, :] * K + pid_k
82 )
83 current_block = tl.load(input_ptrs, mask=mask, other=min_value)
85 block_max = tl.max(current_block, axis=1)
86 block_argmax = tl.argmax(current_block, axis=1).to(tl.int32) + block_start
88 update_mask = block_max > row_max
89 tie_mask = (block_max == row_max) & (
90 (row_argmax < 0) | (block_argmax < row_argmax)
91 )
92 choose_new = update_mask | tie_mask
94 row_argmax = tl.where(choose_new, block_argmax, row_argmax)
95 row_max = tl.where(update_mask, block_max, row_max)
97 out_offsets = row_offsets * K + pid_k
98 out_ptrs = out_ptr + out_offsets
99 tl.store(out_ptrs, row_argmax.to(out_ptr.dtype.element_ty), mask=row_mask)
102def argmax(inp, dim=None, keepdim=False, *, dtype=None):
103 logger.debug("GEMS_SPACEMIT ARGMAX")
104 if dim is None:
105 M = inp.numel()
106 if dtype is None:
107 dtype = inp.dtype
108 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
109 mid_size = triton.cdiv(M, block_size)
110 block_mid = triton.next_power_of_2(mid_size)
112 mid_value = torch.empty((mid_size,), dtype=dtype, device=inp.device)
113 mid_index = torch.empty((mid_size,), dtype=torch.int64, device=inp.device)
114 if keepdim:
115 shape = list(inp.shape)
116 for i in range(0, inp.dim()):
117 shape[i] = 1
118 out = torch.empty(shape, dtype=torch.int64, device=inp.device)
119 else:
120 out = torch.empty([], dtype=torch.int64, device=inp.device)
122 with torch_device_fn.device(inp.device):
123 argmax_kernel_1[(mid_size, 1, 1)](
124 inp,
125 mid_value,
126 mid_index,
127 M,
128 block_size,
129 )
130 argmax_kernel_2[(1, 1, 1)](mid_value, mid_index, out, mid_size, block_mid)
131 return out
132 else:
133 if dim < -inp.ndim or dim >= inp.ndim:
134 raise IndexError(
135 f"Dimension out of range (expected to be in range of [{-inp.ndim}, {inp.ndim - 1}], but got {dim})"
136 )
137 shape = inp.shape
138 dim = dim % inp.ndim
139 N = shape[dim]
140 M = math.prod(shape[:dim])
141 K = inp.numel() // M // N
143 inp = inp.contiguous()
145 shape_list = list(shape)
146 shape_list[dim] = 1
147 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device)
148 if not keepdim:
149 out_index = torch.squeeze(out_index, dim)
151 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]), K)
152 with torch_device_fn.device(inp.device):
153 argmax_kernel[grid](
154 inp,
155 out_index,
156 M,
157 N,
158 K,
159 )
161 return out_index