Coverage for src/flag_gems/runtime/backend/_sunrise/ops/aminmax.py: 0%

42 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4 

5logger = logging.getLogger(__name__) 

6 

7 

8def _aminmax_cpu_reference(op_name, inp, *args, out=None, **kwargs): 

9 cpu_inp = inp.cpu() 

10 cpu_args = tuple( 

11 arg.cpu() if isinstance(arg, torch.Tensor) else arg for arg in args 

12 ) 

13 cpu_kwargs = { 

14 key: value.cpu() if isinstance(value, torch.Tensor) else value 

15 for key, value in kwargs.items() 

16 } 

17 cpu_result = getattr(torch, op_name)(cpu_inp, *cpu_args, **cpu_kwargs) 

18 

19 if out is None: 

20 if isinstance(cpu_result, tuple): 

21 return tuple(item.to(device=inp.device) for item in cpu_result) 

22 return cpu_result.to(device=inp.device) 

23 

24 if isinstance(out, tuple): 

25 for cpu_item, out_item in zip(cpu_result, out): 

26 out_item.copy_(cpu_item.to(device=out_item.device)) 

27 return out 

28 

29 out.copy_(cpu_result.to(device=out.device)) 

30 return out 

31 

32 

33def amin(inp, dim=None, keepdim=False, *, out=None): 

34 logger.debug("SUNRISE AMIN CPU REFERENCE") 

35 return _aminmax_cpu_reference("amin", inp, dim=dim, keepdim=keepdim, out=out) 

36 

37 

38def amin_out(inp, dim=None, keepdim=False, *, out=None): 

39 logger.debug("SUNRISE AMIN_OUT CPU REFERENCE") 

40 if out is None: 

41 raise ValueError("amin_out expects an out tensor") 

42 return amin(inp, dim=dim, keepdim=keepdim, out=out) 

43 

44 

45def amax(inp, dim=None, keepdim=False, *, out=None): 

46 logger.debug("SUNRISE AMAX CPU REFERENCE") 

47 return _aminmax_cpu_reference("amax", inp, dim=dim, keepdim=keepdim, out=out) 

48 

49 

50def amax_out(inp, dim=None, keepdim=False, *, out=None): 

51 logger.debug("SUNRISE AMAX_OUT CPU REFERENCE") 

52 if out is None: 

53 raise ValueError("amax_out expects an out tensor") 

54 return amax(inp, dim=dim, keepdim=keepdim, out=out) 

55 

56 

57def aminmax(inp, dim=None, keepdim=False, *, out=None): 

58 logger.debug("SUNRISE AMINMAX CPU REFERENCE") 

59 return _aminmax_cpu_reference("aminmax", inp, dim=dim, keepdim=keepdim, out=out) 

60 

61 

62def aminmax_out(inp, dim=None, keepdim=False, *, out=None): 

63 logger.debug("SUNRISE AMINMAX_OUT CPU REFERENCE") 

64 if out is None: 

65 raise ValueError("aminmax_out expects an out tuple") 

66 return aminmax(inp, dim=dim, keepdim=keepdim, out=out)