Coverage for src/flag_gems/ops/asinh.py: 78%
18 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import logging
3import triton
4import triton.language as tl
6from flag_gems.utils import pointwise_dynamic
8logger = logging.getLogger(__name__)
11# asinh(x) = sign(x) * log(|x| + sqrt(x^2 + 1))
12# The sign(x) * log(|x| + ...) form preserves sign on -inf input
13# (the naive x + sqrt(x^2+1) form evaluates to -inf + inf = NaN).
14# Uses float32 intermediate for numerical precision.
15# INT_TO_FLOAT promotion handles integer input tensors.
16@pointwise_dynamic(promotion_methods=[(0, "INT_TO_FLOAT")])
17@triton.jit
18def asinh_func(x):
19 x_fp32 = x.to(tl.float32)
20 abs_x = tl.abs(x_fp32)
21 y = tl.log(abs_x + tl.sqrt(abs_x * abs_x + 1.0))
22 return tl.where(x_fp32 < 0.0, -y, y)
25def asinh(A):
26 logger.debug("GEMS ASINH")
27 return asinh_func(A)
30def asinh_out(A, out):
31 logger.debug("GEMS ASINH_OUT")
32 return asinh_func(A, out0=out)