Coverage for src/flag_gems/runtime/backend/_ascend/ops/polar.py: 0%

27 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import pointwise_dynamic 

8from flag_gems.utils.codegen_config_utils import CodeGenConfig 

9 

10logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

11 

12 

13config_ = CodeGenConfig( 

14 384, 

15 tuple([48, 1, 1]), 

16 32, 

17 False, 

18 prefer_1d_tile=int(triton.__version__[0]) < 3, 

19) 

20 

21 

22@pointwise_dynamic( 

23 promotion_methods=[ 

24 ((0, 1), "DEFAULT"), 

25 ((0, 1), "DEFAULT"), 

26 ], 

27 num_outputs=2, 

28 config=config_, 

29) 

30@triton.jit 

31def polar_kernel(abs, angle): 

32 real = abs * tl.cos(angle) 

33 imag = abs * tl.sin(angle) 

34 return real, imag 

35 

36 

37def polar(abs, angle): 

38 logger.debug("GEMS_ASCEND POLAR") 

39 

40 # Use separate contiguous output tensors instead of non-contiguous slices 

41 # of a (*, 2) buffer. On Ascend NPU without OPP, AsStrided (used for 

42 # output[..., 0] slicing) is not available. 

43 out_real = torch.empty_like(abs) 

44 out_imag = torch.empty_like(abs) 

45 

46 polar_kernel(abs, angle, out0=out_real, out1=out_imag) 

47 

48 # Combine into complex tensor via CPU round-trip. 

49 # On Ascend NPU without OPP, Pack (torch.stack on device), select+copy_, 

50 # and torch.complex are all broken or fall back to CPU. 

51 real_cpu = out_real.cpu() 

52 imag_cpu = out_imag.cpu() 

53 

54 # view_as_complex only supports float16/float32/float64; cast bf16 if needed 

55 orig_dtype = real_cpu.dtype 

56 if orig_dtype == torch.bfloat16: 

57 real_cpu = real_cpu.float() 

58 imag_cpu = imag_cpu.float() 

59 

60 output_cpu = torch.stack([real_cpu, imag_cpu], dim=-1) 

61 return torch.view_as_complex(output_cpu).to(abs.device)