Coverage for src/flag_gems/runtime/backend/_sunrise/ops/angle.py: 0%
32 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.utils import pointwise_dynamic, tl_extra_shim
10atan2 = tl_extra_shim.atan2
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15@pointwise_dynamic(is_tensor=[True, True], promotion_methods=[(0, "DEFAULT")])
16@triton.jit
17def angle_func(real, imag):
18 real_last, imag_last = (
19 (real.to(tl.float32), imag.to(tl.float32))
20 if real.dtype == tl.float16
21 else (real, imag)
22 )
23 result = atan2(imag_last, real_last)
24 return result
27@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "INT_TO_FLOAT")])
28@triton.jit
29def angle_float_and_int(real):
30 zero = 0.0
31 pi = math.pi
32 real_positive = real >= zero
33 result = tl.where(real_positive, zero, pi)
34 return result
37def angle(input_tensor: torch.Tensor) -> torch.Tensor:
38 logger.debug("GEMS ANGLE")
39 if input_tensor.dtype == torch.complex32 or input_tensor.dtype == torch.complex64:
40 device = input_tensor.device
41 input_tensor = input_tensor.to(device="cpu")
42 real = input_tensor.real.to(device=device)
43 imag = input_tensor.imag.to(device=device)
44 return angle_func(real, imag)
45 else:
46 real = input_tensor
47 return angle_float_and_int(real)