Coverage for src/flag_gems/runtime/backend/_ascend/ops/argmin.py: 0%
98 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as ext
11from flag_gems.utils.limits import get_dtype_max
13logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
16@libentry()
17@triton.jit
18def argmin_kernel_1(
19 inp,
20 mid_value,
21 mid_index,
22 M,
23 BLOCK_SIZE: tl.constexpr,
24):
25 pid = ext.program_id(0)
26 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
27 inp_ptrs = inp + offset
28 mask = offset < M
30 max_value = get_dtype_max(inp.type.element_ty)
31 inp_val = tl.load(inp_ptrs, mask=mask, other=max_value)
32 min_val, min_index = tl.min(inp_val, axis=0, return_indices=True)
33 min_index = min_index + pid * BLOCK_SIZE
34 mid_value_ptr = mid_value + pid
35 min_index_ptr = mid_index + pid
36 tl.store(mid_value_ptr, min_val)
37 tl.store(min_index_ptr, min_index)
40@libentry()
41@triton.jit
42def argmin_kernel_2(
43 mid_value,
44 mid_index,
45 out,
46 mid_size,
47 BLOCK_MID: tl.constexpr,
48):
49 offset = tl.arange(0, BLOCK_MID)
50 mid_ptrs = mid_value + offset
51 mask = offset < mid_size
52 max_value = get_dtype_max(mid_value.type.element_ty)
53 mid_val = tl.load(mid_ptrs, mask=mask, other=max_value)
54 index_val = tl.argmin(mid_val, axis=0)
55 mid_index_ptrs = mid_index + index_val
56 out_val = tl.load(mid_index_ptrs)
57 tl.store(out, out_val)
60@libentry()
61@triton.autotune(
62 configs=[
63 triton.Config({"BLOCK_M": 1, "BLOCK_N": 512}),
64 triton.Config({"BLOCK_M": 4, "BLOCK_N": 256}),
65 triton.Config({"BLOCK_M": 8, "BLOCK_N": 128}),
66 ],
67 key=["M", "N", "K"],
68)
69@triton.jit
70def argmin_kernel(
71 inp,
72 out_index,
73 M,
74 N,
75 K,
76 BLOCK_M: tl.constexpr,
77 BLOCK_N: tl.constexpr,
78):
79 pid_m = ext.program_id(0)
80 pid_k = ext.program_id(1)
82 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
84 dtype = inp.type.element_ty
85 acc_type = tl.float32 if dtype is tl.bfloat16 else dtype
86 max_value = get_dtype_max(dtype)
87 min_values = tl.full([BLOCK_M], dtype=acc_type, value=max_value)
88 argmin_values = tl.full([BLOCK_M], dtype=tl.int64, value=0)
89 for start_n in range(0, N, BLOCK_N):
90 n_offset = start_n + tl.arange(0, BLOCK_N)
91 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
92 mask = m_offset[:, None] < M and n_offset[None, :] < N
93 inp_ptrs = inp + offset
94 inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value)
95 local_min, local_argmin = tl.min(
96 inp_vals, 1, return_indices=True, return_indices_tie_break_left=True
97 )
98 update = local_min < min_values
99 min_values = tl.where(update, local_min, min_values)
100 argmin_values = tl.where(update, start_n + local_argmin, argmin_values)
102 offset_index = m_offset * K + pid_k
103 out_index_ptrs = out_index + offset_index
104 mask1 = m_offset < M
105 tl.store(out_index_ptrs, argmin_values, mask=mask1)
108def argmin(inp, dim=None, keepdim=False, *, dtype=None):
109 logger.debug("GEMS_ASCEND ARGMIN")
110 if inp.dtype == torch.bfloat16:
111 result = argmin(inp.to(torch.float32), dim=dim, keepdim=keepdim, dtype=dtype)
112 return result
113 if dim is None:
114 M = inp.numel()
115 if dtype is None:
116 dtype = inp.dtype
117 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
118 mid_size = triton.cdiv(M, block_size)
119 block_mid = triton.next_power_of_2(mid_size)
121 mid_value = torch.empty((mid_size,), dtype=dtype, device=inp.device)
122 mid_index = torch.empty((mid_size,), dtype=torch.int64, device=inp.device)
123 out = torch.empty([], dtype=torch.int64, device=inp.device)
125 with torch_device_fn.device(inp.device):
126 argmin_kernel_1[(mid_size, 1, 1)](
127 inp,
128 mid_value,
129 mid_index,
130 M,
131 block_size,
132 )
133 argmin_kernel_2[(1, 1, 1)](
134 mid_value,
135 mid_index,
136 out,
137 mid_size,
138 block_mid,
139 )
140 return out
141 else:
142 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
143 shape = inp.shape
144 dim = dim % inp.ndim
145 N = shape[dim]
146 M = math.prod(shape[:dim])
147 K = inp.numel() // M // N
149 inp = inp.contiguous()
151 shape_list = list(shape)
152 shape_list[dim] = 1
153 out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device)
154 if not keepdim:
155 out_index = torch.squeeze(out_index, dim)
157 grid = lambda meta: (
158 triton.cdiv(M, meta["BLOCK_M"]),
159 K,
160 )
161 with torch_device_fn.device(inp.device):
162 argmin_kernel[grid](
163 inp,
164 out_index,
165 M,
166 N,
167 K,
168 )
170 return out_index