Coverage for src/flag_gems/runtime/backend/_sunrise/ops/add.py: 0%
119 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
6from flag_gems.utils import pointwise_dynamic
7from flag_gems.utils.codegen_config_utils import CodeGenConfig
8from flag_gems.utils.pointwise_dynamic import ComplexMode
10logger = logging.getLogger(__name__)
13config_for_general = CodeGenConfig(
14 1024,
15 (65536, 65536, 65536),
16 32,
17 True,
18 prefer_1d_tile=False,
19 # num_warps=2
20)
23@pointwise_dynamic(
24 is_tensor=[True, True, False],
25 promotion_methods=[(0, 1, "DEFAULT")],
26 config=config_for_general,
27)
28@triton.jit
29def add_func(x, y, alpha):
30 return x + y * alpha
33config_for_broadcast = CodeGenConfig(
34 128,
35 (65536, 65536, 65536),
36 32,
37 True,
38 prefer_1d_tile=True,
39 # num_warps=4
40)
43@pointwise_dynamic(
44 is_tensor=[True, True, False],
45 promotion_methods=[(0, 1, "DEFAULT")],
46 config=config_for_broadcast,
47)
48@triton.jit
49def add_func_broadcast(x, y, alpha):
50 return x + y * alpha
53@pointwise_dynamic(
54 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")]
55)
56@triton.jit
57def add_func_tensor_scalar(x, y, alpha):
58 return x + y * alpha
61@pointwise_dynamic(
62 is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")]
63)
64@triton.jit
65def add_func_scalar_tensor(x, y, alpha):
66 return x + y * alpha
69def get_best_strided_output_tensor(A, B):
70 def get_best_strides(A, B, broadcast_shape):
71 if A.shape == broadcast_shape:
72 return A.stride()
73 elif B.shape == broadcast_shape:
74 return B.stride()
75 return None
77 broadcast_shape = torch.broadcast_shapes(A.shape, B.shape)
78 dtype = torch.float32
79 out = torch.empty(broadcast_shape, device=A.device, dtype=dtype)
80 best_stride = get_best_strides(A, B, broadcast_shape)
81 if best_stride is not None:
82 out = out.as_strided(broadcast_shape, best_stride)
83 return out
86def is_power_of_two(n):
87 return n > 0 and (n & (n - 1)) == 0
90def should_use_broadcast_configs(A, B):
91 # In scenarios where broadcasting is involved and the last two dimensions
92 # of the two input tensors are the same, we use 1D tiling with a smaller
93 # max_tile_size config for better performance.
94 need_broadcast = A.shape != B.shape
95 has_equal_last_dimentions = (
96 len(A.shape) >= 2 and len(B.shape) >= 2 and A.shape[-2:] == B.shape[-2:]
97 )
98 return (
99 need_broadcast
100 and has_equal_last_dimentions
101 and not is_power_of_two(A.shape[-1])
102 and torch.result_type(A, B) in [torch.float16, torch.float32]
103 )
106# Register complex support (elementwise)
107add_func.register_complex(mode=ComplexMode.ELEMENTWISE)
108add_func_tensor_scalar.register_complex(
109 mode=ComplexMode.ELEMENTWISE, tensorize_scalars=True, fallback_target=add_func
110)
111add_func_scalar_tensor.register_complex(
112 mode=ComplexMode.ELEMENTWISE, tensorize_scalars=True, fallback_target=add_func
113)
116def _view_as_real_ptpu_safe(x: torch.Tensor) -> torch.Tensor:
117 """`torch.view_as_real(x)` with a CPU bounce when x is on PTPU.
119 [sunrise fix] PTPU lacks `aten::view_as_real`. The surrounding complex
120 branch uses the result only as a read-only input to the triton `add_func`
121 kernel (which IS PTPU-native), and the subsequent `.to(common_dtype)` would
122 materialize a non-aliasing copy anyway — so it is safe to break alias
123 semantics here. Per the FlagGems Sunrise skill, do not generically
124 monkey-patch view_as_real (alias/view primitive). Compute stays on PTPU.
125 """
126 try:
127 return torch.view_as_real(x)
128 except NotImplementedError:
129 if x.device.type != "ptpu":
130 raise
131 return torch.view_as_real(x.cpu()).to(x.device)
134def _view_as_complex_ptpu_safe(x: torch.Tensor) -> torch.Tensor:
135 """`torch.view_as_complex(x)` with a CPU bounce when x is on PTPU.
137 See `_view_as_real_ptpu_safe` above. Used here to recompose the complex
138 output after the PTPU-native real-domain `add_func(Ar, Br, alpha)` finishes.
139 """
140 try:
141 return torch.view_as_complex(x)
142 except NotImplementedError:
143 if x.device.type != "ptpu":
144 raise
145 return torch.view_as_complex(x.cpu()).to(x.device)
148def _scalar_complex_as_real_ptpu_safe(
149 scalar, complex_dtype: torch.dtype, target_shape, device: torch.device
150) -> torch.Tensor:
151 """Broadcast a python scalar to `view_as_real`-shaped tensor on `device`.
153 [sunrise fix] The natural code path is
155 torch.view_as_real(
156 torch.tensor(scalar, dtype=complex_dtype, device=device).expand_as(ref)
157 )
159 On PTPU this dies at the `view_as_real` step (no kernel) and the obvious
160 CPU fallback (`.cpu()`) also dies because PTPU's `direct_copy_kernel_ptpu`
161 has no entry for `ComplexHalf` / `ComplexFloat`. So instead we build the
162 complex scalar AND take its real view ENTIRELY on CPU, then only move the
163 final real-dtype tensor onto PTPU (which the device's copy_ DOES support).
164 """
165 cpu_scalar = torch.tensor(scalar, dtype=complex_dtype, device="cpu").expand(
166 target_shape
167 )
168 cpu_real = torch.view_as_real(cpu_scalar).contiguous()
169 if device.type == "cpu":
170 return cpu_real
171 return cpu_real.to(device)
174def add(A, B, *, alpha=1):
175 logger.debug("GEMS ADD")
176 A_is_complex = (isinstance(A, torch.Tensor) and A.is_complex()) or isinstance(
177 A, complex
178 )
179 B_is_complex = (isinstance(B, torch.Tensor) and B.is_complex()) or isinstance(
180 B, complex
181 )
182 if A_is_complex or B_is_complex:
183 if A_is_complex and B_is_complex:
184 Ar = _view_as_real_ptpu_safe(A)
185 Br = _view_as_real_ptpu_safe(B)
186 common_dtype = torch.promote_types(Ar.dtype, Br.dtype)
187 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype)
188 out_real = add_func(Ar, Br, alpha)
189 return _view_as_complex_ptpu_safe(out_real).to(torch.result_type(A, B))
190 elif A_is_complex and not B_is_complex:
191 Ar = _view_as_real_ptpu_safe(A)
192 if isinstance(B, torch.Tensor):
193 Br = _view_as_real_ptpu_safe(B.to(A.dtype))
194 else:
195 Br = _scalar_complex_as_real_ptpu_safe(B, A.dtype, A.shape, A.device)
196 common_dtype = torch.promote_types(Ar.dtype, Br.dtype)
197 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype)
198 out_real = add_func(Ar, Br, alpha)
199 return _view_as_complex_ptpu_safe(out_real).to(torch.result_type(A, B))
200 else:
201 Br = _view_as_real_ptpu_safe(B)
202 if isinstance(A, torch.Tensor):
203 Ar = _view_as_real_ptpu_safe(A.to(B.dtype))
204 else:
205 Ar = _scalar_complex_as_real_ptpu_safe(A, B.dtype, B.shape, B.device)
206 common_dtype = torch.promote_types(Ar.dtype, Br.dtype)
207 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype)
208 out_real = add_func(Ar, Br, alpha)
209 return _view_as_complex_ptpu_safe(out_real).to(torch.result_type(A, B))
210 elif isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
211 if B.device != A.device:
212 B = B.to(A.device)
213 if should_use_broadcast_configs(A, B):
214 out = get_best_strided_output_tensor(A, B)
215 add_func_broadcast(A, B, alpha, out0=out)
216 return out.to(torch.result_type(A, B))
217 else:
218 return add_func(A, B, alpha)
219 elif isinstance(A, torch.Tensor):
220 return add_func_tensor_scalar(A, B, alpha)
221 elif isinstance(B, torch.Tensor):
222 return add_func_scalar_tensor(A, B, alpha)
223 else:
224 return torch.tensor(A + B * alpha)
227def add_(A, B, *, alpha=1):
228 logger.debug("GEMS ADD_")
229 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
230 if B.device != A.device:
231 B = B.to(A.device)
232 return add_func(A, B, alpha, out0=A)
233 elif isinstance(A, torch.Tensor):
234 return add_func_tensor_scalar(A, B, alpha, out0=A)
235 # elif isinstance(B, torch.Tensor):
236 # return add_func_scalar_tensor(A, B, alpha, out0=A)
237 else:
238 raise ValueError("Unreachable.")