Coverage for src/flag_gems/runtime/backend/_sunrise/ops/aminmax.py: 0%
42 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
5logger = logging.getLogger(__name__)
8def _aminmax_cpu_reference(op_name, inp, *args, out=None, **kwargs):
9 cpu_inp = inp.cpu()
10 cpu_args = tuple(
11 arg.cpu() if isinstance(arg, torch.Tensor) else arg for arg in args
12 )
13 cpu_kwargs = {
14 key: value.cpu() if isinstance(value, torch.Tensor) else value
15 for key, value in kwargs.items()
16 }
17 cpu_result = getattr(torch, op_name)(cpu_inp, *cpu_args, **cpu_kwargs)
19 if out is None:
20 if isinstance(cpu_result, tuple):
21 return tuple(item.to(device=inp.device) for item in cpu_result)
22 return cpu_result.to(device=inp.device)
24 if isinstance(out, tuple):
25 for cpu_item, out_item in zip(cpu_result, out):
26 out_item.copy_(cpu_item.to(device=out_item.device))
27 return out
29 out.copy_(cpu_result.to(device=out.device))
30 return out
33def amin(inp, dim=None, keepdim=False, *, out=None):
34 logger.debug("SUNRISE AMIN CPU REFERENCE")
35 return _aminmax_cpu_reference("amin", inp, dim=dim, keepdim=keepdim, out=out)
38def amin_out(inp, dim=None, keepdim=False, *, out=None):
39 logger.debug("SUNRISE AMIN_OUT CPU REFERENCE")
40 if out is None:
41 raise ValueError("amin_out expects an out tensor")
42 return amin(inp, dim=dim, keepdim=keepdim, out=out)
45def amax(inp, dim=None, keepdim=False, *, out=None):
46 logger.debug("SUNRISE AMAX CPU REFERENCE")
47 return _aminmax_cpu_reference("amax", inp, dim=dim, keepdim=keepdim, out=out)
50def amax_out(inp, dim=None, keepdim=False, *, out=None):
51 logger.debug("SUNRISE AMAX_OUT CPU REFERENCE")
52 if out is None:
53 raise ValueError("amax_out expects an out tensor")
54 return amax(inp, dim=dim, keepdim=keepdim, out=out)
57def aminmax(inp, dim=None, keepdim=False, *, out=None):
58 logger.debug("SUNRISE AMINMAX CPU REFERENCE")
59 return _aminmax_cpu_reference("aminmax", inp, dim=dim, keepdim=keepdim, out=out)
62def aminmax_out(inp, dim=None, keepdim=False, *, out=None):
63 logger.debug("SUNRISE AMINMAX_OUT CPU REFERENCE")
64 if out is None:
65 raise ValueError("aminmax_out expects an out tuple")
66 return aminmax(inp, dim=dim, keepdim=keepdim, out=out)