Coverage for src/flag_gems/ops/logaddexp.py: 75%
20 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import logging
3import triton
4import triton.language as tl
6from flag_gems.utils import pointwise_dynamic
8logger = logging.getLogger(__name__)
11@pointwise_dynamic(is_tensor=[True, True], promotion_methods=[(0, 1, "DEFAULT")])
12@triton.jit
13def logaddexp_func(x, y):
14 # log(exp(x) + exp(y)) = m + log(1 + exp(-|x - y|)), m = max(x, y)
15 x_f32 = x.to(tl.float32)
16 y_f32 = y.to(tl.float32)
17 m = tl.maximum(x_f32, y_f32)
18 delta = x_f32 - y_f32
19 return m + tl.log(1.0 + tl.exp(-tl.abs(delta)))
22def logaddexp(self, other):
23 logger.debug("GEMS LOGADDEXP")
24 return logaddexp_func(self, other)
27def logaddexp_out(self, other, out):
28 logger.debug("GEMS LOGADDEXP_OUT")
29 logaddexp_func(self, other, out0=out)
30 return out