Coverage for src/flag_gems/runtime/backend/_arm/ops/min.py: 0%
91 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
2import math
3from collections import namedtuple
5import numpy as np
6import torch
7import triton
8import triton.language as tl
10from flag_gems import runtime
12# from ..runtime import torch_device_fn
13# from ..utils import libentry
14from flag_gems.utils import triton_lang_extension as tle
17# @libentry()
18@triton.jit
19def min_kernel_1(
20 inp,
21 mid,
22 M,
23 BLOCK_SIZE: tl.constexpr,
24):
25 pid = tle.program_id(0)
26 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
27 inp_ptrs = inp + offset
28 mask = offset < M
29 inp_val = tl.load(inp_ptrs, mask=mask, other=float("inf"))
30 min_val = tl.min(inp_val)
31 mid_ptr = mid + pid
32 tl.store(mid_ptr, min_val)
35# @libentry()
36@triton.jit
37def min_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):
38 offset = tl.arange(0, BLOCK_MID)
39 mid_ptrs = mid + offset
40 mask = offset < mid_size
41 mid_val = tl.load(mid_ptrs, mask=mask, other=float("inf"))
42 min_val = tl.min(mid_val)
43 tl.store(out, min_val)
46@triton.autotune(
47 configs=[
48 triton.Config({"BLOCK_SIZE": 8}, num_warps=1),
49 triton.Config({"BLOCK_SIZE": 2}, num_warps=2),
50 triton.Config({"BLOCK_SIZE": 16}, num_warps=4),
51 triton.Config({"BLOCK_SIZE": 32}, num_warps=4),
52 ],
53 key=["M"], # re-tune when tensor size changes
54)
55# @libentry()
56@triton.jit
57def min_kernel_3(inp, out, M, BLOCK_SIZE: tl.constexpr):
58 pid = tl.program_id(0)
59 start = pid * BLOCK_SIZE
60 offsets = start + tl.arange(0, BLOCK_SIZE)
61 mask = offsets < M
62 x = tl.load(inp + offsets, mask=mask)
63 min_val = tl.min(x, axis=None)
64 tl.atomic_min(out, min_val)
67def heur_block_n(args):
68 return triton.next_power_of_2(args["N"])
71# @libentry()
72@triton.autotune(
73 configs=runtime.get_tuned_config("min"),
74 key=[
75 "M",
76 "N",
77 ],
78)
79@triton.jit
80def min_kernel(
81 inp,
82 out_value,
83 out_index,
84 M,
85 N,
86 K,
87 BLOCK_M: tl.constexpr,
88 BLOCK_N: tl.constexpr,
89):
90 # set offset
91 pid_m = tle.program_id(0)
92 pid_k = tle.program_id(1)
93 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
95 min_values = tl.full([BLOCK_M], dtype=tl.float32, value=float("inf"))
96 argmin_values = tl.full([BLOCK_M], dtype=tl.int64, value=0)
97 for start_n in range(0, N, BLOCK_N):
98 n_offset = start_n + tl.arange(0, BLOCK_N)
99 offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
100 mask = m_offset[:, None] < M and n_offset[None, :] < N
101 inp_ptrs = inp + offset
102 inp_vals = tl.load(inp_ptrs, mask=mask, other=float("inf"))
103 local_min, local_argmin = tl.min(inp_vals, 1, return_indices=True)
104 # if return indices is not supported, call a tl.argmax in addition
105 # local_argmin = tl.argmin(inp_vals, 1)
106 update = local_min < min_values
107 min_values = tl.where(update, local_min, min_values)
108 argmin_values = tl.where(update, start_n + local_argmin, argmin_values)
110 offset_index = m_offset * K + pid_k
111 out_value_ptrs = out_value + offset_index
112 out_index_ptrs = out_index + offset_index
113 mask1 = m_offset < M
114 tl.store(out_value_ptrs, min_values, mask=mask1)
115 tl.store(out_index_ptrs, argmin_values, mask=mask1)
118def min(inp):
119 logging.debug("GEMS MIN")
120 M = inp.numel()
121 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
122 mid_size = triton.cdiv(M, block_size)
123 block_mid = triton.next_power_of_2(mid_size)
125 dtype = inp.dtype
126 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
127 out = torch.empty([], dtype=dtype, device=inp.device)
128 # Use two-stage reduction for broader dtype support on Triton CPU.
129 min_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size)
130 min_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid)
131 return out
134def min_dim(inp, dim=None, keepdim=False):
135 logging.debug("GEMS MIN DIM")
136 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
137 dim = dim % inp.ndim
138 inp_np = inp.detach().cpu().numpy()
139 out_index_np = np.argmin(inp_np, axis=dim)
140 gather_index = np.expand_dims(out_index_np, axis=dim)
141 out_value_np = np.take_along_axis(inp_np, gather_index, axis=dim)
142 out_index = torch.from_numpy(out_index_np.astype(np.int64, copy=False)).to(
143 inp.device
144 )
145 out_value = torch.from_numpy(out_value_np).to(inp.device)
146 if keepdim:
147 out_index = out_index.unsqueeze(dim)
148 else:
149 out_value = out_value.squeeze(dim)
150 Min_out = namedtuple("min", ["values", "indices"])
151 out = Min_out(values=out_value, indices=out_index)
152 return out