Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/dot.py: 0%
61 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import triton_lang_extension as ext
9from flag_gems.utils.libentry import libentry
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14@libentry()
15@triton.jit
16def dot_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK_SIZE: tl.constexpr):
17 pid = ext.program_id(0)
18 block_start = pid * BLOCK_SIZE
20 offsets = block_start + tl.arange(0, BLOCK_SIZE)
22 mask = offsets < N
23 x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
24 y = tl.load(y_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
26 sum = tl.sum(x * y)
27 tl.store(out_ptr, sum)
30@libentry()
31@triton.autotune(
32 configs=[
33 triton.Config({"BLOCK_SIZE": 4096}, num_warps=8, num_stages=2),
34 triton.Config({"BLOCK_SIZE": 8192}, num_warps=8, num_stages=2),
35 triton.Config({"BLOCK_SIZE": 16384}, num_warps=16, num_stages=2),
36 triton.Config({"BLOCK_SIZE": 32768}, num_warps=16, num_stages=2),
37 ],
38 key=["N"],
39)
40@triton.jit
41def dot_kernel_1(x_ptr, y_ptr, mid_ptr, N, BLOCK_SIZE: tl.constexpr):
42 pid = ext.program_id(0)
43 block_start = pid * BLOCK_SIZE
45 offsets = block_start + tl.arange(0, BLOCK_SIZE)
47 mask = offsets < N
48 x = tl.load(x_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
49 y = tl.load(y_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
51 partial_sum = tl.sum(x * y)
52 tl.store(mid_ptr + pid, partial_sum)
55@libentry()
56@triton.jit
57def dot_kernel_2(mid_ptr, out_ptr, M, BLOCK_MID: tl.constexpr):
58 offset = tl.arange(0, BLOCK_MID)
59 mid = mid_ptr + offset
60 mask = offset < M
61 mid_val = tl.load(mid, mask=mask, other=0.0)
62 out_val = tl.sum(mid_val)
63 tl.store(out_ptr, out_val)
66def dot(x, y):
67 logger.debug("GEMS_KUNLUNXIN DOT")
69 assert x.shape == y.shape, "Input vectors must have the same shape"
70 assert x.dim() == 1, "Input must be 1D tensors"
72 N = x.shape[0]
74 if N >= 4096:
75 # Allocate for worst case (smallest block size = 4096)
76 max_mid_size = triton.cdiv(N, 4096)
77 block_mid = triton.next_power_of_2(max_mid_size)
79 grid_1 = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),)
81 mid = torch.empty((max_mid_size,), dtype=torch.float32, device=x.device)
82 out = torch.empty([], dtype=x.dtype, device=x.device)
84 with torch_device_fn.device(x.device):
85 dot_kernel_1[grid_1](x, y, mid, N)
86 dot_kernel_2[(1,)](mid, out, max_mid_size, block_mid)
88 else:
89 block_size = triton.next_power_of_2(N)
91 grid = (1, 1, 1)
93 out = torch.empty([], dtype=torch.float32, device=x.device)
95 with torch_device_fn.device(x.device):
96 dot_kernel[grid](x, y, out, N, block_size)
97 out = out.to(x.dtype)
99 return out