Coverage for src/flag_gems/runtime/backend/_sunrise/ops/angle.py: 0%

32 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.utils import pointwise_dynamic, tl_extra_shim 

9 

10atan2 = tl_extra_shim.atan2 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15@pointwise_dynamic(is_tensor=[True, True], promotion_methods=[(0, "DEFAULT")]) 

16@triton.jit 

17def angle_func(real, imag): 

18 real_last, imag_last = ( 

19 (real.to(tl.float32), imag.to(tl.float32)) 

20 if real.dtype == tl.float16 

21 else (real, imag) 

22 ) 

23 result = atan2(imag_last, real_last) 

24 return result 

25 

26 

27@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "INT_TO_FLOAT")]) 

28@triton.jit 

29def angle_float_and_int(real): 

30 zero = 0.0 

31 pi = math.pi 

32 real_positive = real >= zero 

33 result = tl.where(real_positive, zero, pi) 

34 return result 

35 

36 

37def angle(input_tensor: torch.Tensor) -> torch.Tensor: 

38 logger.debug("GEMS ANGLE") 

39 if input_tensor.dtype == torch.complex32 or input_tensor.dtype == torch.complex64: 

40 device = input_tensor.device 

41 input_tensor = input_tensor.to(device="cpu") 

42 real = input_tensor.real.to(device=device) 

43 imag = input_tensor.imag.to(device=device) 

44 return angle_func(real, imag) 

45 else: 

46 real = input_tensor 

47 return angle_float_and_int(real)