Coverage for src/flag_gems/ops/asinh.py: 78%

18 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5 

6from flag_gems.utils import pointwise_dynamic 

7 

8logger = logging.getLogger(__name__) 

9 

10 

11# asinh(x) = sign(x) * log(|x| + sqrt(x^2 + 1)) 

12# The sign(x) * log(|x| + ...) form preserves sign on -inf input 

13# (the naive x + sqrt(x^2+1) form evaluates to -inf + inf = NaN). 

14# Uses float32 intermediate for numerical precision. 

15# INT_TO_FLOAT promotion handles integer input tensors. 

16@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")]) 

17@triton.jit 

18def asinh_func(x): 

19 x_fp32 = x.to(tl.float32) 

20 abs_x = tl.abs(x_fp32) 

21 y = tl.log(abs_x + tl.sqrt(abs_x * abs_x + 1.0)) 

22 return tl.where(x_fp32 < 0.0, -y, y) 

23 

24 

25def asinh(A): 

26 logger.debug("GEMS ASINH") 

27 return asinh_func(A) 

28 

29 

30def asinh_out(A, out): 

31 logger.debug("GEMS ASINH_OUT") 

32 return asinh_func(A, out0=out)