Coverage for src/flag_gems/ops/addcdiv.py: 94%
17 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
6from flag_gems.utils import pointwise_dynamic
8logger = logging.getLogger(__name__)
11@pointwise_dynamic(
12 is_tensor=[True, True, True, False], promotion_methods=[(0, 1, 2, "DEFAULT")]
13)
14@triton.jit
15def addcdiv_kernel(x, t1, t2, value):
16 return x + value * (t1 / t2)
19def addcdiv_out(inp, tensor1, tensor2, *, value=1.0, out):
20 logger.debug("GEMS ADDCDIV_OUT")
21 addcdiv_kernel(inp, tensor1, tensor2, value, out0=out)
22 return out
25def addcdiv(inp, tensor1, tensor2, value=1.0):
26 """Functional entry; CUDA may dispatch here without hitting ``addcdiv.out``."""
27 logger.debug("GEMS ADDCDIV")
28 out = torch.empty_like(inp)
29 return addcdiv_kernel(inp, tensor1, tensor2, value, out0=out)