Coverage for src/flag_gems/runtime/backend/_sunrise/ops/logaddexp.py: 0%

20 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5 

6from flag_gems.utils import pointwise_dynamic 

7 

8logger = logging.getLogger(__name__) 

9 

10 

11@pointwise_dynamic(is_tensor=[True, True], promotion_methods=[(0, 1, "DEFAULT")]) 

12@triton.jit 

13def logaddexp_func(x, y): 

14 # log(exp(x) + exp(y)) = m + log(1 + exp(-|x - y|)), m = max(x, y) 

15 x_f32 = x.to(tl.float32) 

16 y_f32 = y.to(tl.float32) 

17 m = tl.maximum(x_f32, y_f32) 

18 delta = x_f32 - y_f32 

19 return m + tl.log(1.0 + tl.exp(-tl.abs(delta))) 

20 

21 

22def logaddexp(self, other): 

23 logger.debug("GEMS LOGADDEXP") 

24 return logaddexp_func(self, other) 

25 

26 

27def logaddexp_out(self, other, out): 

28 logger.debug("GEMS LOGADDEXP_OUT") 

29 logaddexp_func(self, other, out0=out) 

30 return out