Coverage for src/flag_gems/runtime/backend/_spacemit/ops/argmin.py: 0%
103 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as tle
12from flag_gems.utils.limits import get_dtype_max
14logger = logging.getLogger(__name__)
17@libentry()
18@triton.jit
19def argmin_kernel_1(
20 inp,
21 mid_value,
22 mid_index,
23 M,
24 BLOCK_SIZE: tl.constexpr,
25):
26 pid = tle.program_id(0)
27 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
28 inp_ptrs = inp + offset
29 mask = offset < M
30 max_value = get_dtype_max(inp.type.element_ty)
31 inp_val = tl.load(inp_ptrs, mask=mask, other=max_value)
32 min_val, min_index = tl.min(inp_val, axis=0, return_indices=True)
33 min_index = min_index + pid * BLOCK_SIZE
34 tl.store(mid_value + pid, min_val)
35 tl.store(mid_index + pid, min_index)
38@libentry()
39@triton.jit
40def argmin_kernel_2(mid_value, mid_index, out, mid_size, BLOCK_MID: tl.constexpr):
41 offset = tl.arange(0, BLOCK_MID)
42 mid_ptrs = mid_value + offset
43 mask = offset < mid_size
44 max_value = get_dtype_max(mid_value.type.element_ty)
45 mid_val = tl.load(mid_ptrs, mask=mask, other=max_value)
46 index_val = tl.argmin(mid_val, axis=0)
47 out_val = tl.load(mid_index + index_val)
48 tl.store(out, out_val)
51@libentry()
52@triton.heuristics(runtime.get_heuristic_config("argmin"))
53@triton.jit
54def argmin_kernel(
55 inp_ptr,
56 out_ptr,
57 M,
58 N,
59 K,
60 BLOCK_M: tl.constexpr,
61 BLOCK_N: tl.constexpr,
62):
63 pid_m = tl.program_id(0)
64 pid_k = tle.program_id(1)
65 start_row = pid_m * BLOCK_M
66 row_offsets = start_row + tl.arange(0, BLOCK_M)
67 row_mask = row_offsets < M
69 dtype = inp_ptr.dtype.element_ty
70 acc_type = (
71 tl.float32
72 if (dtype is tl.bfloat16 or dtype is tl.float16)
73 else (
74 tl.int32
75 if (dtype is tl.int16 or dtype is tl.int8 or dtype is tl.uint8)
76 else dtype
77 )
78 )
79 max_value = get_dtype_max(dtype)
80 max_value_acc = get_dtype_max(acc_type)
81 row_min = tl.full((BLOCK_M,), max_value_acc, dtype=acc_type)
82 row_argmin = tl.full((BLOCK_M,), -1, dtype=tl.int32)
84 for block_start in range(0, N, BLOCK_N):
85 col_offsets = block_start + tl.arange(0, BLOCK_N)
86 col_mask = col_offsets < N
87 mask = row_mask[:, None] & col_mask[None, :]
88 input_ptrs = (
89 inp_ptr + row_offsets[:, None] * N * K + col_offsets[None, :] * K + pid_k
90 )
91 current_block = tl.load(input_ptrs, mask=mask, other=max_value).to(acc_type)
93 block_min = tl.min(current_block, axis=1)
94 block_argmin = tl.argmin(current_block, axis=1).to(tl.int32) + block_start
96 update_mask = block_min < row_min
97 tie_mask = (block_min == row_min) & (
98 (row_argmin < 0) | (block_argmin < row_argmin)
99 )
100 choose_new = update_mask | tie_mask
102 row_argmin = tl.where(choose_new, block_argmin, row_argmin)
103 row_min = tl.where(update_mask, block_min, row_min)
105 out_offsets = row_offsets * K + pid_k
106 tl.store(
107 out_ptr + out_offsets, row_argmin.to(out_ptr.dtype.element_ty), mask=row_mask
108 )
111def argmin(inp, dim=None, keepdim=False, *, dtype=None):
112 logger.debug("GEMS_SPACEMIT ARGMIN")
113 if dim is None:
114 M = inp.numel()
115 if dtype is None:
116 dtype = inp.dtype
117 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
118 mid_size = triton.cdiv(M, block_size)
119 block_mid = triton.next_power_of_2(mid_size)
121 mid_value = torch.empty((mid_size,), dtype=dtype, device=inp.device)
122 mid_index = torch.empty((mid_size,), dtype=torch.int64, device=inp.device)
123 if keepdim:
124 shape = list(inp.shape)
125 for i in range(0, inp.dim()):
126 shape[i] = 1
127 out = torch.empty(shape, dtype=torch.int64, device=inp.device)
128 else:
129 out = torch.empty([], dtype=torch.int64, device=inp.device)
131 with torch_device_fn.device(inp.device):
132 argmin_kernel_1[(mid_size, 1, 1)](
133 inp,
134 mid_value,
135 mid_index,
136 M,
137 block_size,
138 )
139 argmin_kernel_2[(1, 1, 1)](mid_value, mid_index, out, mid_size, block_mid)
140 return out
142 if dim < -inp.ndim or dim >= inp.ndim:
143 raise IndexError(
144 f"Dimension out of range (expected to be in range of [{-inp.ndim}, {inp.ndim - 1}], but got {dim})"
145 )
147 shape = inp.shape
148 dim = dim % inp.ndim
149 N = shape[dim]
150 M = math.prod(shape[:dim])
151 K = inp.numel() // M // N
153 inp = inp.contiguous()
155 shape_list = list(shape)
156 shape_list[dim] = 1
157 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device)
158 if not keepdim:
159 out_index = torch.squeeze(out_index, dim)
161 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]), K)
162 with torch_device_fn.device(inp.device):
163 argmin_kernel[grid](
164 inp,
165 out_index,
166 M,
167 N,
168 K,
169 )
171 return out_index