Coverage for src/flag_gems/runtime/backend/_cambricon/ops/addcmul.py: 0%
18 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
6from ..utils.pointwise_dynamic import pointwise_dynamic
8logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
11@pointwise_dynamic(
12 is_tensor=[True, True, True, False], promotion_methods=[(0, 1, 2, "DEFAULT")]
13)
14@triton.jit
15def addcmul_forward(x, t1, t2, value):
16 return x + value * t1 * t2
19def addcmul(inp, tensor1, tensor2, *, value=1.0, out=None):
20 logger.debug("GEMS_CAMBRICON ADDCMUL FORWARD")
21 if out is not None:
22 broadcast_shape = torch.broadcast_shapes(
23 inp.shape, tensor1.shape, tensor2.shape
24 )
25 if list(out.shape) != list(broadcast_shape):
26 out.resize_(broadcast_shape)
27 addcmul_forward(inp, tensor1, tensor2, value, out0=out)
28 return out
29 else:
30 return addcmul_forward(inp, tensor1, tensor2, value)