Coverage for src/flag_gems/ops/aminmax.py: 51%
120 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import dim_compress, libentry, libtuner
11from flag_gems.utils import triton_lang_extension as ext
12from flag_gems.utils.limits import get_dtype_max, get_dtype_min
14logger = logging.getLogger(__name__)
17@libentry()
18@triton.jit
19def aminmax_kernel_1(
20 inp,
21 min_out,
22 max_out,
23 M,
24 BLOCK_SIZE: tl.constexpr,
25):
26 pid = ext.program_id(0)
28 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
29 inp_ptrs = inp + offset
30 mask = offset < M
31 min_fill = get_dtype_max(inp.type.element_ty)
32 max_fill = get_dtype_min(inp.type.element_ty)
33 min_val = tl.load(inp_ptrs, mask=mask, other=min_fill)
34 max_val = tl.load(inp_ptrs, mask=mask, other=max_fill)
36 min_val = tl.min(min_val)
37 max_val = tl.max(max_val)
39 min_ptr = min_out + pid
40 max_ptr = max_out + pid
41 tl.store(min_ptr, min_val)
42 tl.store(max_ptr, max_val)
45@libentry()
46@triton.jit
47def aminmax_kernel_2(
48 min_inp, max_inp, min_out, max_out, mid_size, BLOCK_MID: tl.constexpr
49):
50 offset = tl.arange(0, BLOCK_MID)
51 min_ptrs = min_inp + offset
52 max_ptrs = max_inp + offset
53 mask = offset < mid_size
54 min_fill = get_dtype_max(min_inp.type.element_ty)
55 max_fill = get_dtype_min(max_inp.type.element_ty)
56 min_val = tl.load(min_ptrs, mask=mask, other=min_fill)
57 max_val = tl.load(max_ptrs, mask=mask, other=max_fill)
59 min_val = tl.min(min_val)
60 max_val = tl.max(max_val)
62 tl.store(min_out, min_val)
63 tl.store(max_out, max_val)
66@libentry()
67@libtuner(
68 configs=runtime.get_tuned_config("naive_reduction"),
69 key=["M", "N"],
70)
71@triton.jit
72def aminmax_kernel(
73 inp,
74 min_out,
75 max_out,
76 M,
77 N,
78 BLOCK_M: tl.constexpr,
79 BLOCK_N: tl.constexpr,
80):
81 dtype = inp.type.element_ty
82 min_value = get_dtype_min(dtype)
83 max_value = get_dtype_max(dtype)
85 # Map the program id to the row of inp it should compute.
86 pid = ext.program_id(0)
87 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
88 inp = inp + rows * N
89 min_out = min_out + rows
90 max_out = max_out + rows
91 row_mask = rows < M
93 acc_type = tl.float32 if dtype is tl.bfloat16 else dtype
94 _min = tl.full([BLOCK_M, BLOCK_N], value=max_value, dtype=acc_type)
95 _max = tl.full([BLOCK_M, BLOCK_N], value=min_value, dtype=acc_type)
96 for off in range(0, N, BLOCK_N):
97 cols = off + tl.arange(0, BLOCK_N)[None, :]
98 col_mask = cols < N
99 mask = row_mask & col_mask
100 a = tl.load(inp + cols, mask=mask, other=min_value)
101 _min = tl.where(mask, tl.minimum(_min, a), _min)
102 _max = tl.where(mask, tl.maximum(_max, a), _max)
103 min_result = tl.min(_min, axis=1)[:, None]
104 max_result = tl.max(_max, axis=1)[:, None]
105 tl.store(min_out, min_result, row_mask)
106 tl.store(max_out, max_result, row_mask)
109def aminmax(inp, dim=None, keepdim=False, *, out=None):
110 logger.debug("GEMS AMINMAX")
112 if dim is None:
113 M = inp.numel()
114 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
115 mid_size = triton.cdiv(M, block_size)
116 block_mid = triton.next_power_of_2(mid_size)
117 dtype = inp.dtype
118 min_mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
119 max_mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)
121 if out is not None:
122 min_out = out[0] if isinstance(out, tuple) else out
123 max_out = out[1] if isinstance(out, tuple) else out
124 if not keepdim:
125 min_out = min_out.squeeze()
126 max_out = max_out.squeeze()
127 else:
128 if not keepdim:
129 min_out = torch.empty([], dtype=dtype, device=inp.device)
130 max_out = torch.empty([], dtype=dtype, device=inp.device)
131 else:
132 shape = [1] * inp.dim()
133 min_out = torch.empty(shape, dtype=dtype, device=inp.device)
134 max_out = torch.empty(shape, dtype=dtype, device=inp.device)
136 with torch_device_fn.device(inp.device):
137 aminmax_kernel_1[(mid_size, 1)](
138 inp,
139 min_mid,
140 max_mid,
141 M,
142 block_size,
143 )
144 aminmax_kernel_2[(1, 1)](
145 min_mid, max_mid, min_out, max_out, mid_size, block_mid
146 )
147 return min_out, max_out
148 else:
149 if isinstance(dim, int):
150 dim = [dim]
151 assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim"
152 dtype = inp.dtype
154 shape = list(inp.shape)
155 dim = [d % inp.ndim for d in dim]
156 inp = dim_compress(inp, dim)
157 N = 1
158 for i in dim:
159 N *= shape[i]
160 shape[i] = 1
161 M = inp.numel() // N
163 if out is not None:
164 min_out = out[0] if isinstance(out, tuple) else out
165 max_out = out[1] if isinstance(out, tuple) else out
166 else:
167 min_out = torch.empty(shape, dtype=dtype, device=inp.device)
168 max_out = torch.empty(shape, dtype=dtype, device=inp.device)
170 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
171 with torch_device_fn.device(inp.device):
172 aminmax_kernel[grid](inp, min_out, max_out, M, N)
173 if not keepdim:
174 min_out = min_out.squeeze(dim=dim)
175 max_out = max_out.squeeze(dim=dim)
176 return min_out, max_out