Coverage for src/flag_gems/ops/conj_physical.py: 74%
31 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.utils import libentry, libtuner
10logger = logging.getLogger(__name__)
13@libentry()
14@libtuner(
15 configs=runtime.get_tuned_config("conj_physical"),
16 key=["n_elements"],
17)
18@triton.jit
19def conj_physical_kernel(in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
20 pid = tl.program_id(0)
21 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
22 mask = offsets < n_elements
24 base = offsets * 2
25 real = tl.load(in_ptr + base, mask=mask)
26 imag = tl.load(in_ptr + base + 1, mask=mask)
28 tl.store(out_ptr + base, real, mask=mask)
29 tl.store(out_ptr + base + 1, -imag, mask=mask)
32def conj_physical(input: torch.Tensor) -> torch.Tensor:
33 logger.debug("GEMS Conj_Physical")
34 if not input.is_complex():
35 return input
37 n_elements = input.numel()
38 src = input if input.is_contiguous() else input.contiguous()
39 output = torch.empty_like(src)
40 in_real_ptr = torch.view_as_real(src)
41 out_real_ptr = torch.view_as_real(output)
43 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
45 conj_physical_kernel[grid](in_real_ptr, out_real_ptr, n_elements)
47 return output