Coverage for src/flag_gems/ops/copysign.py: 58%
33 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
3import triton
4import triton.language as tl
6from flag_gems.utils import pointwise_dynamic
8logger = logging.getLogger(__name__)
11def _unwrap_if_constexpr(o):
12 return o.value if isinstance(o, tl.constexpr) else o
15@tl.constexpr
16def _get_uint_dtype(num_bits):
17 num_bits = _unwrap_if_constexpr(num_bits)
18 return tl.core.get_int_dtype(num_bits, False)
21@tl.constexpr
22def _get_sign_bit_mask(num_bits):
23 num_bits = _unwrap_if_constexpr(num_bits)
24 return 1 << (num_bits - 1)
27@pointwise_dynamic(is_tensor=[True, True], promotion_methods=[(0, 1, "DEFAULT")])
28@triton.jit
29def copysign_func(input, other):
30 # Magnitude of input, sign of other
31 abs_val = tl.abs(input)
32 # Check sign bit of other (bitcast to unsigned int and check MSB)
33 num_bits: tl.constexpr = input.dtype.primitive_bitwidth
34 uint_dtype = _get_uint_dtype(num_bits)
35 sign_bit_mask: tl.constexpr = _get_sign_bit_mask(num_bits)
36 other_u = other.to(uint_dtype, bitcast=True)
37 # Extract sign bit and check if it's set
38 return tl.where((other_u & sign_bit_mask) != 0, -abs_val, abs_val)
41def copysign(input, other, *, out=None):
42 logger.debug("GEMS COPYSIGN")
43 return copysign_func(input, other)
46def copysign_out(input, other, *, out=None):
47 logger.debug("GEMS COPYSIGN_OUT")
48 if out is None:
49 return copysign_func(input, other)
50 copysign_func(input, other, out0=out)
51 return out