Coverage for src/flag_gems/ops/copysign.py: 58%

33 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5 

6from flag_gems.utils import pointwise_dynamic 

7 

8logger = logging.getLogger(__name__) 

9 

10 

11def _unwrap_if_constexpr(o): 

12 return o.value if isinstance(o, tl.constexpr) else o 

13 

14 

15@tl.constexpr 

16def _get_uint_dtype(num_bits): 

17 num_bits = _unwrap_if_constexpr(num_bits) 

18 return tl.core.get_int_dtype(num_bits, False) 

19 

20 

21@tl.constexpr 

22def _get_sign_bit_mask(num_bits): 

23 num_bits = _unwrap_if_constexpr(num_bits) 

24 return 1 << (num_bits - 1) 

25 

26 

27@pointwise_dynamic(is_tensor=[True, True], promotion_methods=[(0, 1, "DEFAULT")]) 

28@triton.jit 

29def copysign_func(input, other): 

30 # Magnitude of input, sign of other 

31 abs_val = tl.abs(input) 

32 # Check sign bit of other (bitcast to unsigned int and check MSB) 

33 num_bits: tl.constexpr = input.dtype.primitive_bitwidth 

34 uint_dtype = _get_uint_dtype(num_bits) 

35 sign_bit_mask: tl.constexpr = _get_sign_bit_mask(num_bits) 

36 other_u = other.to(uint_dtype, bitcast=True) 

37 # Extract sign bit and check if it's set 

38 return tl.where((other_u & sign_bit_mask) != 0, -abs_val, abs_val) 

39 

40 

41def copysign(input, other, *, out=None): 

42 logger.debug("GEMS COPYSIGN") 

43 return copysign_func(input, other) 

44 

45 

46def copysign_out(input, other, *, out=None): 

47 logger.debug("GEMS COPYSIGN_OUT") 

48 if out is None: 

49 return copysign_func(input, other) 

50 copysign_func(input, other, out0=out) 

51 return out