Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/mul.py: 0%
64 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
6from ..utils.pointwise_dynamic import pointwise_dynamic
8logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
11@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
12@triton.jit
13def mul_func(x, y):
14 return x * y
17@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
18@triton.jit
19def mul_func_scalar(x, y):
20 return x * y
23@pointwise_dynamic(
24 is_tensor=[True, True, True, True],
25 num_outputs=2,
26 promotion_methods=[(0, 1, 2, 3, "DEFAULT"), (0, 1, 2, 3, "DEFAULT")],
27)
28@triton.jit
29def mul_complex_kernel(ar, ai, br, bi):
30 real = ar * br - ai * bi
31 imag = ar * bi + ai * br
32 return real, imag
35def mul(A, B):
36 logger.debug("GEMS MUL")
37 A_is_complex = isinstance(A, torch.Tensor) and A.is_complex()
38 B_is_complex = isinstance(B, torch.Tensor) and B.is_complex()
40 if A_is_complex or B_is_complex:
41 if A_is_complex and B_is_complex:
42 Ar = torch.view_as_real(A.resolve_conj())
43 Br = torch.view_as_real(B.resolve_conj())
44 ar, ai = Ar[..., 0].contiguous(), Ar[..., 1].contiguous()
45 br, bi = Br[..., 0].contiguous(), Br[..., 1].contiguous()
46 # Upcast float16 to float32 to avoid precision loss in
47 # complex multiplication (ac-bd, ad+bc)
48 orig_dtype = ar.dtype
49 if orig_dtype == torch.float16:
50 ar, ai = ar.to(torch.float32), ai.to(torch.float32)
51 br, bi = br.to(torch.float32), bi.to(torch.float32)
52 real_out, imag_out = mul_complex_kernel(ar, ai, br, bi)
53 if orig_dtype == torch.float16:
54 real_out = real_out.to(orig_dtype)
55 imag_out = imag_out.to(orig_dtype)
56 out = torch.view_as_complex(torch.stack((real_out, imag_out), dim=-1))
57 return out.to(torch.result_type(A, B))
58 elif A_is_complex and not B_is_complex:
59 Ar = torch.view_as_real(A.resolve_conj())
60 if isinstance(B, torch.Tensor):
61 Br = B.unsqueeze(-1)
62 out_real = mul_func(Ar, Br)
63 else:
64 out_real = mul_func_scalar(Ar, B)
65 return torch.view_as_complex(out_real.contiguous())
66 else:
67 Br = torch.view_as_real(B.resolve_conj())
68 if isinstance(A, torch.Tensor):
69 Ar = A.unsqueeze(-1)
70 out_real = mul_func(Ar, Br)
71 else:
72 out_real = mul_func_scalar(Br, A)
73 return torch.view_as_complex(out_real.contiguous())
75 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
76 return mul_func(A, B)
77 elif isinstance(A, torch.Tensor):
78 return mul_func_scalar(A, B)
79 elif isinstance(B, torch.Tensor):
80 return mul_func_scalar(B, A)
81 else:
82 # Both scalar
83 return torch.tensor(A * B)
86def mul_(A, B):
87 logger.debug("GEMS MUL_")
88 if isinstance(B, torch.Tensor):
89 return mul_func(A, B, out0=A)
90 else:
91 return mul_func_scalar(A, B, out0=A)