Coverage for src/flag_gems/ops/addcmul.py: 95%
22 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import logging
3import torch
4import triton
6from flag_gems.utils import pointwise_dynamic
8logger = logging.getLogger(__name__)
11@pointwise_dynamic(
12 is_tensor=[True, True, True, False], promotion_methods=[(0, 1, 2, "DEFAULT")]
13)
14@triton.jit
15def addcmul_forward(x, t1, t2, value):
16 return x + value * t1 * t2
19def addcmul_out(inp, tensor1, tensor2, *, value=1.0, out):
20 logger.debug("GEMS ADDCMUL_OUT")
21 broadcast_shape = torch.broadcast_shapes(inp.shape, tensor1.shape, tensor2.shape)
22 if list(out.shape) != list(broadcast_shape):
23 out.resize_(broadcast_shape)
24 addcmul_forward(inp, tensor1, tensor2, value, out0=out)
25 return out
28def addcmul(inp, tensor1, tensor2, *, value=1.0):
29 """Functional entry; keep alongside ``addcmul.out`` for dispatch coverage."""
30 logger.debug("GEMS ADDCMUL")
31 broadcast_shape = torch.broadcast_shapes(inp.shape, tensor1.shape, tensor2.shape)
32 dtype = torch.promote_types(
33 inp.dtype, torch.promote_types(tensor1.dtype, tensor2.dtype)
34 )
35 out = torch.empty(broadcast_shape, device=inp.device, dtype=dtype)
36 return addcmul_out(inp, tensor1, tensor2, value=value, out=out)