Coverage for src/flag_gems/runtime/backend/_cambricon/ops/mul.py: 0%
72 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
3import torch
4import triton
6from ..utils.pointwise_dynamic import pointwise_dynamic
8logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
11@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")])
12@triton.jit
13def mul_func(x, y, inplace):
14 return x * y
17@pointwise_dynamic(
18 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")]
19)
20@triton.jit
21def mul_func_scalar(x, y, inplace):
22 return x * y
25@pointwise_dynamic(
26 is_tensor=[True, True, True, True], # ar, ai, br, bi
27 num_outputs=2,
28 promotion_methods=[(0, 1, 2, 3, "DEFAULT"), (0, 1, 2, 3, "DEFAULT")],
29)
30@triton.jit
31def mul_complex_kernel(ar, ai, br, bi):
32 real = ar * br - ai * bi
33 imag = ar * bi + ai * br
34 return real, imag
37def mul(A, B):
38 logger.debug("GEMS_CAMBRICON MUL")
39 A_is_complex = (isinstance(A, torch.Tensor) and A.is_complex()) or isinstance(
40 A, complex
41 )
42 B_is_complex = (isinstance(B, torch.Tensor) and B.is_complex()) or isinstance(
43 B, complex
44 )
45 if A_is_complex or B_is_complex:
46 # 1) A、B both are complex
47 if A_is_complex and B_is_complex:
48 Ar = torch.view_as_real(A)
49 Br = torch.view_as_real(B)
50 ar, ai = Ar[..., 0], Ar[..., 1]
51 br, bi = Br[..., 0], Br[..., 1]
52 common_dtype = torch.promote_types(ar.dtype, br.dtype)
53 ar, ai = ar.to(common_dtype), ai.to(common_dtype)
54 br, bi = br.to(common_dtype), bi.to(common_dtype)
56 real_out = torch.empty_like(ar, dtype=common_dtype)
57 imag_out = torch.empty_like(ar, dtype=common_dtype)
58 mul_complex_kernel(ar, ai, br, bi, out0=real_out, out1=imag_out)
60 out = torch.view_as_complex(torch.stack((real_out, imag_out), dim=-1))
61 return out.to(torch.result_type(A, B))
62 # 2) A complex, B real
63 elif A_is_complex and not B_is_complex:
64 Ar = torch.view_as_real(A)
65 Br = B.unsqueeze(-1) if isinstance(B, torch.Tensor) else B
66 if isinstance(Br, torch.Tensor):
67 out_real = mul_func(Ar, Br, False)
68 else:
69 out_real = mul_func_scalar(Ar, Br, False)
70 return torch.view_as_complex(out_real).to(torch.result_type(A, B))
71 # 3) A real, B complex
72 else: # not A_is_complex and B_is_complex
73 Br = torch.view_as_real(B)
74 Ar = A.unsqueeze(-1) if isinstance(A, torch.Tensor) else A
75 if isinstance(Ar, torch.Tensor):
76 out_real = mul_func(
77 Ar, Br, False
78 ) # shape broadcasting requires Ar and Br
79 else:
80 out_real = mul_func_scalar(Br, Ar, False) # Br is tensor, Ar is scalar
81 return torch.view_as_complex(out_real).to(torch.result_type(A, B))
82 elif isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
83 if A.device != B.device:
84 if A.dim() == 0:
85 assert A.device == torch.device("cpu"), "expect scalar tensor on cpu"
86 A = A.to(B.device)
87 elif B.dim() == 0:
88 assert B.device == torch.device("cpu"), "expect scalar tensor on cpu"
89 B = B.to(A.device)
90 return mul_func(A, B, False)
91 elif isinstance(A, torch.Tensor):
92 return mul_func_scalar(A, B, False)
93 elif isinstance(B, torch.Tensor):
94 return mul_func_scalar(B, A, False)
95 else:
96 # Both scalar
97 return torch.tensor(A * B)
100def mul_(A, B):
101 logger.debug("GEMS_CAMBRICON MUL_")
102 if isinstance(B, torch.Tensor):
103 if B.device != A.device and B.dim() == 0:
104 assert B.device == torch.device("cpu"), "expect scalar tensor on cpu"
105 B = B.to(A.device)
106 return mul_func(A, B, True, out0=A)
107 else:
108 return mul_func_scalar(A, B, True, out0=A)