Coverage for src/flag_gems/runtime/backend/_sunrise/ops/mul.py: 0%
78 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
3import torch
4import triton
6from flag_gems.utils import pointwise_dynamic
7from flag_gems.utils.codegen_config_utils import CodeGenConfig
9logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
12config_for_broadcast = CodeGenConfig(
13 8192,
14 (65536, 65536, 65536),
15 32,
16 True,
17 prefer_1d_tile=False,
18 # num_warps=16
19)
22@pointwise_dynamic(
23 is_tensor=[True, True],
24 promotion_methods=[(0, 1, "DEFAULT")],
25 config=config_for_broadcast,
26)
27@triton.jit
28def mul_func(x, y):
29 return x * y
32@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
33@triton.jit
34def mul_func_scalar(x, y):
35 return x * y
38@pointwise_dynamic(
39 is_tensor=[True, True, True, True], # ar, ai, br, bi
40 num_outputs=2,
41 promotion_methods=[(0, 1, 2, 3, "DEFAULT"), (0, 1, 2, 3, "DEFAULT")],
42)
43@triton.jit
44def mul_complex_kernel(ar, ai, br, bi):
45 real = ar * br - ai * bi
46 imag = ar * bi + ai * br
47 return real, imag
50def mul(A, B):
51 logger.debug("GEMS MUL")
52 A_is_complex = (isinstance(A, torch.Tensor) and A.is_complex()) or isinstance(
53 A, complex
54 )
55 B_is_complex = (isinstance(B, torch.Tensor) and B.is_complex()) or isinstance(
56 B, complex
57 )
58 if A_is_complex or B_is_complex:
59 # 1) A、B both are complex
60 if A_is_complex and B_is_complex:
61 a_device = A.device
62 b_device = B.device
63 A = A.to(device="cpu")
64 B = B.to(device="cpu")
65 Ar = torch.view_as_real(A)
66 Br = torch.view_as_real(B)
67 ar, ai = Ar[..., 0], Ar[..., 1]
68 br, bi = Br[..., 0], Br[..., 1]
69 ar = ar.to(a_device)
70 ai = ai.to(a_device)
71 br = br.to(b_device)
72 bi = bi.to(b_device)
73 common_dtype = torch.promote_types(ar.dtype, br.dtype)
74 ar, ai = ar.to(common_dtype), ai.to(common_dtype)
75 br, bi = br.to(common_dtype), bi.to(common_dtype)
77 real_out = torch.empty_like(ar, dtype=common_dtype)
78 imag_out = torch.empty_like(ar, dtype=common_dtype)
79 mul_complex_kernel(ar, ai, br, bi, out0=real_out, out1=imag_out)
81 out = torch.view_as_complex(torch.stack((real_out, imag_out), dim=-1).cpu())
82 return out.to(torch.result_type(A, B)).to(a_device)
83 # 2) A complex, B real
84 elif A_is_complex and not B_is_complex:
85 a_device = A.device
86 A = A.to(device="cpu")
87 Ar = torch.view_as_real(A)
88 Ar = Ar.to(a_device)
89 Br = B.unsqueeze(-1) if isinstance(B, torch.Tensor) else B
90 if isinstance(Br, torch.Tensor):
91 out_real = mul_func(Ar, Br)
92 else:
93 out_real = mul_func_scalar(Ar, Br)
94 return (
95 torch.view_as_complex(out_real.cpu())
96 .to(torch.result_type(A, B))
97 .to(a_device)
98 )
99 # 3) A real, B complex
100 else: # not A_is_complex and B_is_complex
101 b_device = B.device
102 B = B.to(device="cpu")
103 Br = torch.view_as_real(B)
104 Br = Br.to(b_device)
105 Ar = A.unsqueeze(-1) if isinstance(A, torch.Tensor) else A
106 if isinstance(Ar, torch.Tensor):
107 out_real = mul_func(Ar, Br) # shape broadcasting requires Ar and Br
108 else:
109 out_real = mul_func_scalar(Br, Ar) # Br is tensor, Ar is scalar
110 return (
111 torch.view_as_complex(out_real.cpu())
112 .to(torch.result_type(A, B))
113 .to(b_device)
114 )
115 elif isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
116 return mul_func(A, B)
117 elif isinstance(A, torch.Tensor):
118 return mul_func_scalar(A, B)
119 elif isinstance(B, torch.Tensor):
120 return mul_func_scalar(B, A)
121 else:
122 # Both scalar
123 return torch.tensor(A * B)
126def mul_(A, B):
127 logger.debug("GEMS MUL_")
128 if isinstance(B, torch.Tensor):
129 return mul_func(A, B, out0=A)
130 else:
131 return mul_func_scalar(A, B, out0=A)