Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/vdot.py: 0%
81 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6from torch import Tensor
8from flag_gems import runtime
9from flag_gems.utils import libentry
11logger = logging.getLogger(__name__)
14@triton.jit
15def compute_vdot(
16 inp_real, inp_imag, other_real, other_imag, inp_is_conj, other_is_conj
17):
18 # # Given inp storage: [inp_real, inp_imag], other: [other_real, other_imag]
20 # # Case 1: inp_is_conj = False, other_is_conj = False
21 # out_real = inp_real * other_real + inp_imag * other_imag
22 # out_imag = inp_real * other_imag - inp_imag * other_real
24 # # Case 2: inp_is_conj = True, other_is_conj = False
25 # out_real = inp_real * other_real - inp_imag * other_imag
26 # out_imag = inp_real * other_imag + inp_imag * other_real
28 # # Case 3: inp_is_conj = False, other_is_conj = True
29 # out_real = inp_real * other_real - inp_imag * other_imag
30 # out_imag = -inp_real * other_imag - inp_imag * other_real
32 # # Case 4: inp_is_conj = True, other_is_conj = True
33 # out_real = inp_real * other_real + inp_imag * other_imag
34 # out_imag = inp_real * other_imag - inp_imag * other_real
35 if not inp_is_conj and not other_is_conj: # Case 1
36 out_real = tl.sum(inp_real * other_real + inp_imag * other_imag)
37 out_imag = tl.sum(inp_real * other_imag - inp_imag * other_real)
38 elif inp_is_conj and not other_is_conj: # Case 2
39 out_real = tl.sum(inp_real * other_real - inp_imag * other_imag)
40 out_imag = tl.sum(inp_real * other_imag + inp_imag * other_real)
41 elif not inp_is_conj and other_is_conj: # Case 3
42 out_real = tl.sum(inp_real * other_real - inp_imag * other_imag)
43 out_imag = tl.sum(-inp_real * other_imag - inp_imag * other_real)
44 else: # Case 4
45 out_real = tl.sum(inp_real * other_real + inp_imag * other_imag)
46 out_imag = tl.sum(-inp_real * other_imag + inp_imag * other_real)
48 return out_real, out_imag
51# support old version triton which do not support tl.split
52@libentry()
53@triton.heuristics(runtime.get_heuristic_config("vdot"))
54@triton.jit()
55def vdot_kernel_complex(
56 inp_ptr,
57 other_ptr,
58 out_ptr,
59 n_elements,
60 inp_is_conj: tl.constexpr,
61 other_is_conj: tl.constexpr,
62 inp_stride: tl.constexpr,
63 other_stride: tl.constexpr,
64 BLOCK_SIZE: tl.constexpr,
65):
66 pid = tl.program_id(0)
68 base_offset = 2 * pid * BLOCK_SIZE + 2 * tl.arange(0, BLOCK_SIZE) + tl.arange(0, 1)
70 inp_real_offset = inp_stride * base_offset
71 inp_imag_offset = inp_real_offset + 1
73 other_real_offset = other_stride * base_offset
74 other_imag_offset = other_real_offset + 1
76 mask = base_offset < n_elements
78 inp_real = tl.load(inp_ptr + inp_real_offset, mask=mask)
79 inp_imag = tl.load(inp_ptr + inp_imag_offset, mask=mask)
81 other_real = tl.load(other_ptr + other_real_offset, mask=mask)
82 other_imag = tl.load(other_ptr + other_imag_offset, mask=mask)
84 # Compute based on conjugate flags
85 out_real, out_imag = compute_vdot(
86 inp_real, inp_imag, other_real, other_imag, inp_is_conj, other_is_conj
87 )
89 tl.atomic_add(out_ptr, out_real)
90 tl.atomic_add(out_ptr + 1, out_imag)
93# only support real number
94@libentry()
95@triton.heuristics(runtime.get_heuristic_config("vdot"))
96@triton.jit()
97def dot_kernel(
98 inp_ptr,
99 other_ptr,
100 out_ptr,
101 n_elements,
102 inp_stride: tl.constexpr,
103 other_stride: tl.constexpr,
104 BLOCK_SIZE: tl.constexpr,
105):
106 pid = tl.program_id(0)
107 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
108 mask = offset < n_elements
110 inp = tl.load(inp_ptr + inp_stride * offset, mask=mask).to(tl.float32)
111 other = tl.load(other_ptr + other_stride * offset, mask=mask).to(tl.float32)
113 out = tl.sum(inp * other)
114 tl.atomic_add(out_ptr, out)
117def vdot(input: Tensor, other: Tensor):
118 logger.debug("GEMS_TSINGMICRO VDOT")
120 assert (
121 input.dtype == other.dtype
122 ), f"Input tensors must have the same dtype. Got {input.dtype} and {other.dtype}."
123 assert (
124 input.ndim == 1 and other.ndim == 1
125 ), f"Input tensors must be 1D. Got {input.ndim}D and {other.ndim}D."
126 assert (
127 input.size() == other.size()
128 ), f"Input tensors must have the same size. Got {input.size()} and {other.size()}."
130 inp = input
131 inp_stride = inp.stride()[0]
132 other_stride = other.stride()[0]
134 if inp.is_complex():
135 inp_is_conj = False
136 other_is_conj = False
138 if inp.is_conj():
139 inp_is_conj = True
140 inp = inp.conj()
142 if other.is_conj():
143 other_is_conj = True
144 other = other.conj()
146 inp_real = torch.view_as_real(inp)
147 other_real = torch.view_as_real(other)
149 n_elements = inp_real.numel()
150 n_complex = inp.numel()
152 output_real = torch.zeros(2, dtype=inp_real.dtype, device=inp.device)
154 grid = lambda meta: (triton.cdiv(n_complex, meta["BLOCK_SIZE"]),)
156 vdot_kernel_complex[grid](
157 inp_real,
158 other_real,
159 output_real,
160 n_elements=n_elements,
161 inp_is_conj=inp_is_conj,
162 other_is_conj=other_is_conj,
163 inp_stride=inp_stride,
164 other_stride=other_stride,
165 )
167 return torch.view_as_complex(output_real)
168 else:
169 output = torch.zeros([], dtype=torch.float32, device=inp.device)
170 n_elements = inp.numel()
171 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
172 dot_kernel[grid](
173 inp,
174 other,
175 output,
176 n_elements=n_elements,
177 inp_stride=inp_stride,
178 other_stride=other_stride,
179 )
180 return output.to(inp.dtype)