Coverage for src/flag_gems/ops/atanh.py: 57%

21 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import pointwise_dynamic 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12@pointwise_dynamic(promotion_methods=[(0, "DEFAULT")]) 

13@triton.jit 

14def atanh_func(x): 

15 # atanh(x) = 0.5 * ln((1 + x) / (1 - x)) 

16 # Compute in float32 for better precision, then convert back 

17 x_fp32 = x.to(tl.float32) 

18 one = 1.0 

19 # Compute result: 0.5 * log((1 + x) / (1 - x)) 

20 numerator = one + x_fp32 

21 denominator = one - x_fp32 

22 # For x outside (-1, 1), log of negative or zero gives NaN/inf naturally 

23 result = 0.5 * tl.math.log(numerator / denominator) 

24 return result.to(x.dtype) 

25 

26 

27def atanh(A): 

28 logger.debug("GEMS ATANH") 

29 return atanh_func(A) 

30 

31 

32def atanh_(A): 

33 logger.debug("GEMS ATANH_") 

34 atanh_func(A, out0=A) 

35 return A