Coverage for src/flag_gems/runtime/backend/_sunrise/ops/vdot.py: 0%
147 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6from torch import Tensor
8from flag_gems import runtime
9from flag_gems.utils import libentry, tensor_wrapper
11logger = logging.getLogger(__name__)
14def _view_as_complex_ptpu_safe(x: torch.Tensor) -> torch.Tensor:
15 """`torch.view_as_complex(x)` with a CPU bounce when x is on PTPU."""
16 try:
17 return torch.view_as_complex(x)
18 except NotImplementedError:
19 if x.device.type != "ptpu":
20 raise
21 return torch.view_as_complex(x.cpu()).to(x.device)
24@triton.jit
25def compute_vdot(
26 inp_real, inp_imag, other_real, other_imag, inp_is_conj, other_is_conj
27):
28 # # Given inp storage: [inp_real, inp_imag], other: [other_real, other_imag]
30 # # Case 1: inp_is_conj = False, other_is_conj = False
31 # out_real = inp_real * other_real + inp_imag * other_imag
32 # out_imag = inp_real * other_imag - inp_imag * other_real
34 # # Case 2: inp_is_conj = True, other_is_conj = False
35 # out_real = inp_real * other_real - inp_imag * other_imag
36 # out_imag = inp_real * other_imag + inp_imag * other_real
38 # # Case 3: inp_is_conj = False, other_is_conj = True
39 # out_real = inp_real * other_real - inp_imag * other_imag
40 # out_imag = -inp_real * other_imag - inp_imag * other_real
42 # # Case 4: inp_is_conj = True, other_is_conj = True
43 # out_real = inp_real * other_real + inp_imag * other_imag
44 # out_imag = inp_real * other_imag - inp_imag * other_real
45 if not inp_is_conj and not other_is_conj: # Case 1
46 out_real = tl.sum(inp_real * other_real + inp_imag * other_imag)
47 out_imag = tl.sum(inp_real * other_imag - inp_imag * other_real)
48 elif inp_is_conj and not other_is_conj: # Case 2
49 out_real = tl.sum(inp_real * other_real - inp_imag * other_imag)
50 out_imag = tl.sum(inp_real * other_imag + inp_imag * other_real)
51 elif not inp_is_conj and other_is_conj: # Case 3
52 out_real = tl.sum(inp_real * other_real - inp_imag * other_imag)
53 out_imag = tl.sum(-inp_real * other_imag - inp_imag * other_real)
54 else: # Case 4
55 out_real = tl.sum(inp_real * other_real + inp_imag * other_imag)
56 out_imag = tl.sum(-inp_real * other_imag + inp_imag * other_real)
58 return out_real, out_imag
61# support old version triton which do not support tl.split
62@libentry()
63@triton.jit()
64def vdot_kernel_complex(
65 inp_ptr,
66 other_ptr,
67 out_ptr,
68 n_elements,
69 inp_is_conj: tl.constexpr,
70 other_is_conj: tl.constexpr,
71 inp_stride: tl.constexpr,
72 other_stride: tl.constexpr,
73 BLOCK_SIZE: tl.constexpr,
74):
75 pid = tl.program_id(0)
76 num_progs = tl.num_programs(0)
78 grid_stride = num_progs * BLOCK_SIZE
80 acc_real = tl.zeros([], dtype=tl.float32)
81 acc_imag = tl.zeros([], dtype=tl.float32)
83 for current_start in range(0, n_elements // 2, grid_stride):
84 complex_idx = current_start + pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
85 mask = complex_idx < n_elements // 2
87 real_offset = complex_idx * 2
89 inp_real = tl.load(inp_ptr + real_offset * inp_stride, mask=mask, other=0.0)
90 inp_imag = tl.load(inp_ptr + real_offset * inp_stride + 1, mask=mask, other=0.0)
92 other_real = tl.load(
93 other_ptr + real_offset * other_stride, mask=mask, other=0.0
94 )
95 other_imag = tl.load(
96 other_ptr + real_offset * other_stride + 1, mask=mask, other=0.0
97 )
99 out_real, out_imag = compute_vdot(
100 inp_real, inp_imag, other_real, other_imag, inp_is_conj, other_is_conj
101 )
102 acc_real += out_real
103 acc_imag += out_imag
105 temp_offset = pid * 2
106 tl.store(out_ptr + temp_offset, acc_real)
107 tl.store(out_ptr + temp_offset + 1, acc_imag)
110@libentry()
111@triton.jit()
112def reduce_kernel_complex(input_ptr, out_ptr, n_blocks, BLOCK_SIZE: tl.constexpr):
113 pid = tl.program_id(0)
114 base_offset = tl.arange(0, BLOCK_SIZE)
115 mask = base_offset < n_blocks
117 inp_real = tl.load(input_ptr + base_offset * 2, mask=mask, other=0.0)
118 inp_imag = tl.load(input_ptr + base_offset * 2 + 1, mask=mask, other=0.0)
119 final_out_real = tl.sum(inp_real)
120 final_out_imag = tl.sum(inp_imag)
121 if pid == 0:
122 tl.store(out_ptr, final_out_real)
123 tl.store(out_ptr + 1, final_out_imag)
126# only support real number
127@libentry()
128@triton.heuristics(runtime.get_heuristic_config("vdot"))
129@triton.jit()
130def dot_kernel(
131 inp_ptr,
132 other_ptr,
133 out_ptr,
134 n_elements,
135 inp_stride: tl.constexpr,
136 other_stride: tl.constexpr,
137 BLOCK_SIZE: tl.constexpr,
138):
139 pid = tl.program_id(0)
140 num_progs = tl.num_programs(0)
141 grid_stride = num_progs * BLOCK_SIZE
143 acc = tl.zeros([], dtype=tl.float32)
145 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
147 for current_start in range(0, n_elements, grid_stride):
148 cur_offsets = current_start + offsets
149 mask = cur_offsets < n_elements
151 inp = tl.load(inp_ptr + inp_stride * cur_offsets, mask=mask, other=0.0).to(
152 tl.float32
153 )
154 other = tl.load(
155 other_ptr + other_stride * cur_offsets, mask=mask, other=0.0
156 ).to(tl.float32)
158 acc += tl.sum(inp * other)
160 tl.store(out_ptr + pid, acc)
163@libentry()
164@triton.jit()
165def reduce_kernel(
166 partial_sums_ptr,
167 output_ptr,
168 n_blocks,
169 BLOCK_SIZE: tl.constexpr,
170):
171 offset = tl.arange(0, BLOCK_SIZE)
172 mask = offset < n_blocks
174 partial_sums = tl.load(partial_sums_ptr + offset, mask=mask, other=0.0)
175 final_sum = tl.sum(partial_sums)
177 if tl.program_id(0) == 0:
178 tl.store(output_ptr, final_sum)
181@libentry()
182@triton.heuristics(runtime.get_heuristic_config("vdot"))
183@triton.jit()
184def dot_kernel_fp32(
185 inp_ptr,
186 other_ptr,
187 out_ptr,
188 n_elements,
189 inp_stride: tl.constexpr,
190 other_stride: tl.constexpr,
191 BLOCK_SIZE: tl.constexpr,
192):
193 pid = tl.program_id(0)
194 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
195 mask = offset < n_elements
197 inp = tl.load(inp_ptr + inp_stride * offset, mask=mask)
198 other = tl.load(other_ptr + other_stride * offset, mask=mask)
200 out = tl.sum(inp * other)
201 tl.atomic_add(out_ptr, out)
204def vdot(input: Tensor, other: Tensor):
205 logger.debug("GEMS VDOT")
207 assert (
208 input.dtype == other.dtype
209 ), f"Input tensors must have the same dtype. Got {input.dtype} and {other.dtype}."
210 assert (
211 input.ndim == 1 and other.ndim == 1
212 ), f"Input tensors must be 1D. Got {input.ndim}D and {other.ndim}D."
213 assert (
214 input.size() == other.size()
215 ), f"Input tensors must have the same size. Got {input.size()} and {other.size()}."
217 inp = input
218 inp_stride = inp.stride()[0]
219 other_stride = other.stride()[0]
221 if inp.is_complex():
222 inp_is_conj = False
223 other_is_conj = False
225 if inp.is_conj():
226 inp_is_conj = True
227 inp = inp.conj()
229 if other.is_conj():
230 other_is_conj = True
231 other = other.conj()
233 inp_real = tensor_wrapper.TypedPtr.reinterpret_tensor(inp, inp.dtype.to_real())
234 other_real = tensor_wrapper.TypedPtr.reinterpret_tensor(
235 other, other.dtype.to_real()
236 )
238 n_elements = inp.numel() * 2
239 n_complex = inp.numel()
241 block_size = runtime.get_heuristic_config("vdot")["BLOCK_SIZE"](
242 {"n_elements": n_elements}
243 )
244 num_blocks = triton.cdiv(n_complex, block_size)
246 grid_size = min(num_blocks, 1024)
248 partial_real_sums = torch.empty(
249 grid_size, dtype=inp_real.dtype, device=inp.device
250 )
251 grid = (grid_size,)
252 vdot_kernel_complex[grid](
253 inp_real,
254 other_real,
255 partial_real_sums,
256 n_elements=n_elements,
257 inp_is_conj=inp_is_conj,
258 other_is_conj=other_is_conj,
259 inp_stride=inp_stride,
260 other_stride=other_stride,
261 BLOCK_SIZE=block_size,
262 )
263 output_real = torch.empty(2, dtype=inp_real.dtype, device=inp.device)
264 reduce_kernel_complex[(1,)](
265 partial_real_sums,
266 output_real,
267 grid_size,
268 BLOCK_SIZE=triton.next_power_of_2(grid_size),
269 )
270 return _view_as_complex_ptpu_safe(output_real)
271 elif inp.dtype == torch.float32:
272 output = torch.zeros([], dtype=torch.float32, device=inp.device)
273 n_elements = inp.numel()
274 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
275 dot_kernel_fp32[grid](
276 inp,
277 other,
278 output,
279 n_elements=n_elements,
280 inp_stride=inp_stride,
281 other_stride=other_stride,
282 )
283 return output
284 else:
285 n_elements = inp.numel()
286 block_size = runtime.get_heuristic_config("vdot")["BLOCK_SIZE"](
287 {"n_elements": n_elements}
288 )
290 num_blocks = triton.cdiv(n_elements, block_size)
291 grid_size = min(num_blocks, 1024)
293 grid = (num_blocks,)
294 partial_sums = torch.empty(grid_size, dtype=torch.float32, device=inp.device)
295 dot_kernel[grid](
296 inp,
297 other,
298 partial_sums,
299 n_elements=n_elements,
300 inp_stride=inp_stride,
301 other_stride=other_stride,
302 BLOCK_SIZE=block_size,
303 )
304 output = torch.empty([], dtype=input.dtype, device=inp.device)
305 reduce_bs = min(triton.next_power_of_2(grid_size), 1024)
306 reduce_kernel[(1,)](
307 partial_sums,
308 output,
309 num_blocks,
310 BLOCK_SIZE=reduce_bs,
311 )
312 return output