Coverage for src/flag_gems/runtime/backend/_sunrise/ops/mul.py: 0%
99 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
6from flag_gems.utils import pointwise_dynamic
7from flag_gems.utils.codegen_config_utils import CodeGenConfig
8from flag_gems.utils.pointwise_dynamic import ComplexMode
10logger = logging.getLogger(__name__)
13config_for_broadcast = CodeGenConfig(
14 8192,
15 (65536, 65536, 65536),
16 32,
17 True,
18 prefer_1d_tile=False,
19 # num_warps=16
20)
23@pointwise_dynamic(
24 is_tensor=[True, True],
25 promotion_methods=[(0, 1, "DEFAULT")],
26 config=config_for_broadcast,
27)
28@triton.jit
29def mul_func(x, y):
30 return x * y
33@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
34@triton.jit
35def mul_func_scalar(x, y):
36 return x * y
39@pointwise_dynamic(
40 is_tensor=[True, True, True, True], # ar, ai, br, bi
41 num_outputs=2,
42 promotion_methods=[(0, 1, 2, 3, "DEFAULT"), (0, 1, 2, 3, "DEFAULT")],
43)
44@triton.jit
45def mul_complex_kernel(ar, ai, br, bi):
46 real = ar * br - ai * bi
47 imag = ar * bi + ai * br
48 return real, imag
51# Register complex support
52mul_func.register_complex(mode=ComplexMode.CROSS, cross_kernel=mul_complex_kernel)
53mul_func_scalar.register_complex(
54 mode=ComplexMode.CROSS, tensorize_scalars=True, fallback_target=mul_func
55)
58def _view_as_real_ptpu_safe(x: torch.Tensor) -> torch.Tensor:
59 """`torch.view_as_real(x)` with a CPU bounce when x is on PTPU."""
60 try:
61 return torch.view_as_real(x)
62 except NotImplementedError:
63 if x.device.type != "ptpu":
64 raise
65 return torch.view_as_real(x.cpu()).to(x.device)
68def _view_as_complex_ptpu_safe(x: torch.Tensor) -> torch.Tensor:
69 """`torch.view_as_complex(x)` with a CPU bounce when x is on PTPU."""
70 try:
71 return torch.view_as_complex(x)
72 except NotImplementedError:
73 if x.device.type != "ptpu":
74 raise
75 return torch.view_as_complex(x.cpu()).to(x.device)
78def _scalar_complex_as_real_ptpu_safe(
79 scalar, complex_dtype: torch.dtype, target_shape, device: torch.device
80) -> torch.Tensor:
81 """Broadcast a python complex scalar to a `view_as_real`-shaped tensor."""
82 cpu_scalar = torch.tensor(scalar, dtype=complex_dtype, device="cpu").expand(
83 target_shape
84 )
85 cpu_real = torch.view_as_real(cpu_scalar).contiguous()
86 if device.type == "cpu":
87 return cpu_real
88 return cpu_real.to(device)
91def _operand_as_real_ptpu_safe(
92 value, complex_dtype: torch.dtype, target_shape, device: torch.device
93) -> torch.Tensor:
94 if isinstance(value, torch.Tensor):
95 tensor = value if value.is_complex() else value.to(complex_dtype)
96 return _view_as_real_ptpu_safe(tensor)
97 return _scalar_complex_as_real_ptpu_safe(value, complex_dtype, target_shape, device)
100def _complex_mul(A, B):
101 result_dtype = torch.result_type(A, B)
102 shape_a = A.shape if isinstance(A, torch.Tensor) else torch.Size([])
103 shape_b = B.shape if isinstance(B, torch.Tensor) else torch.Size([])
104 target_shape = torch.broadcast_shapes(shape_a, shape_b)
105 device = A.device if isinstance(A, torch.Tensor) else B.device
107 A_is_complex = (isinstance(A, torch.Tensor) and A.is_complex()) or isinstance(
108 A, complex
109 )
110 B_is_complex = (isinstance(B, torch.Tensor) and B.is_complex()) or isinstance(
111 B, complex
112 )
114 if A_is_complex and B_is_complex:
115 Ar = _operand_as_real_ptpu_safe(A, result_dtype, target_shape, device)
116 Br = _operand_as_real_ptpu_safe(B, result_dtype, target_shape, device)
117 ar, ai = Ar[..., 0], Ar[..., 1]
118 br, bi = Br[..., 0], Br[..., 1]
119 common_dtype = torch.promote_types(ar.dtype, br.dtype)
120 ar, ai = ar.to(common_dtype), ai.to(common_dtype)
121 br, bi = br.to(common_dtype), bi.to(common_dtype)
122 real_out, imag_out = mul_complex_kernel(ar, ai, br, bi)
123 out = torch.stack((real_out, imag_out), dim=-1)
124 return _view_as_complex_ptpu_safe(out.contiguous()).to(result_dtype)
126 if A_is_complex:
127 Ar = _operand_as_real_ptpu_safe(A, result_dtype, target_shape, device)
128 if isinstance(B, torch.Tensor):
129 Br = B.unsqueeze(-1)
130 out_real = mul_func(Ar, Br)
131 else:
132 out_real = mul_func_scalar(Ar, B)
133 return _view_as_complex_ptpu_safe(out_real.contiguous()).to(result_dtype)
135 Br = _operand_as_real_ptpu_safe(B, result_dtype, target_shape, device)
136 if isinstance(A, torch.Tensor):
137 Ar = A.unsqueeze(-1)
138 out_real = mul_func(Ar, Br)
139 else:
140 out_real = mul_func_scalar(Br, A)
141 return _view_as_complex_ptpu_safe(out_real.contiguous()).to(result_dtype)
144def mul(A, B):
145 logger.debug("GEMS MUL")
146 A_is_complex = (isinstance(A, torch.Tensor) and A.is_complex()) or isinstance(
147 A, complex
148 )
149 B_is_complex = (isinstance(B, torch.Tensor) and B.is_complex()) or isinstance(
150 B, complex
151 )
152 if A_is_complex or B_is_complex:
153 return _complex_mul(A, B)
154 elif isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
155 return mul_func(A, B)
156 elif isinstance(A, torch.Tensor):
157 return mul_func_scalar(A, B)
158 elif isinstance(B, torch.Tensor):
159 return mul_func_scalar(B, A)
160 else:
161 # Both scalar
162 return torch.tensor(A * B)
165def mul_(A, B):
166 logger.debug("GEMS MUL_")
167 if isinstance(B, torch.Tensor):
168 return mul_func(A, B, out0=A)
169 else:
170 return mul_func_scalar(A, B, out0=A)