Coverage for src/flag_gems/runtime/backend/_ascend/ops/polar.py: 0%
27 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import pointwise_dynamic
8from flag_gems.utils.codegen_config_utils import CodeGenConfig
10logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
13config_ = CodeGenConfig(
14 384,
15 tuple([48, 1, 1]),
16 32,
17 False,
18 prefer_1d_tile=int(triton.__version__[0]) < 3,
19)
22@pointwise_dynamic(
23 promotion_methods=[
24 ((0, 1), "DEFAULT"),
25 ((0, 1), "DEFAULT"),
26 ],
27 num_outputs=2,
28 config=config_,
29)
30@triton.jit
31def polar_kernel(abs, angle):
32 real = abs * tl.cos(angle)
33 imag = abs * tl.sin(angle)
34 return real, imag
37def polar(abs, angle):
38 logger.debug("GEMS_ASCEND POLAR")
40 # Use separate contiguous output tensors instead of non-contiguous slices
41 # of a (*, 2) buffer. On Ascend NPU without OPP, AsStrided (used for
42 # output[..., 0] slicing) is not available.
43 out_real = torch.empty_like(abs)
44 out_imag = torch.empty_like(abs)
46 polar_kernel(abs, angle, out0=out_real, out1=out_imag)
48 # Combine into complex tensor via CPU round-trip.
49 # On Ascend NPU without OPP, Pack (torch.stack on device), select+copy_,
50 # and torch.complex are all broken or fall back to CPU.
51 real_cpu = out_real.cpu()
52 imag_cpu = out_imag.cpu()
54 # view_as_complex only supports float16/float32/float64; cast bf16 if needed
55 orig_dtype = real_cpu.dtype
56 if orig_dtype == torch.bfloat16:
57 real_cpu = real_cpu.float()
58 imag_cpu = imag_cpu.float()
60 output_cpu = torch.stack([real_cpu, imag_cpu], dim=-1)
61 return torch.view_as_complex(output_cpu).to(abs.device)