Coverage for src/flag_gems/ops/mul.py: 83%
36 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
6from flag_gems.utils import pointwise_dynamic
7from flag_gems.utils.pointwise_dynamic import ComplexMode
9logger = logging.getLogger(__name__)
12@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
13@triton.jit
14def mul_func(x, y):
15 return x * y
18@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
19@triton.jit
20def mul_func_scalar(x, y):
21 return x * y
24@pointwise_dynamic(
25 is_tensor=[True, True, True, True], # ar, ai, br, bi
26 num_outputs=2,
27 promotion_methods=[(0, 1, 2, 3, "DEFAULT"), (0, 1, 2, 3, "DEFAULT")],
28)
29@triton.jit
30def mul_complex_kernel(ar, ai, br, bi):
31 real = ar * br - ai * bi
32 imag = ar * bi + ai * br
33 return real, imag
36# Register complex support
37mul_func.register_complex(mode=ComplexMode.CROSS, cross_kernel=mul_complex_kernel)
38mul_func_scalar.register_complex(
39 mode=ComplexMode.CROSS, tensorize_scalars=True, fallback_target=mul_func
40)
43def mul(A, B):
44 logger.debug("GEMS MUL")
45 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
46 return mul_func(A, B)
47 elif isinstance(A, torch.Tensor):
48 return mul_func_scalar(A, B)
49 elif isinstance(B, torch.Tensor):
50 return mul_func_scalar(B, A)
51 else:
52 # Both scalar
53 return torch.tensor(A * B)
56def mul_(A, B):
57 logger.debug("GEMS MUL_")
58 if isinstance(B, torch.Tensor):
59 return mul_func(A, B, out0=A)
60 else:
61 return mul_func_scalar(A, B, out0=A)