Coverage for src/flag_gems/runtime/backend/_spacemit/ops/mv.py: 0%
33 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
11logger = logging.getLogger(__name__)
14@libentry()
15@triton.autotune(
16 configs=runtime.get_tuned_config("mv"),
17 key=["M", "N"],
18)
19@triton.jit
20def mv_kernel(
21 A,
22 B,
23 C,
24 N,
25 M,
26 stride_an,
27 stride_am,
28 stride_bm,
29 stride_cn,
30 BLOCK_M: tl.constexpr,
31 BLOCK_N: tl.constexpr,
32):
33 pid = tl.program_id(0)
34 block_start_n = pid * BLOCK_N
35 acc = tl.zeros((BLOCK_N, BLOCK_M), dtype=A.dtype.element_ty)
37 for m in range(0, M, BLOCK_M):
38 a_block_ptr = tl.make_block_ptr(
39 base=A,
40 shape=[N, M],
41 strides=[stride_an, stride_am],
42 offsets=[block_start_n, m],
43 block_shape=[BLOCK_N, BLOCK_M],
44 order=[1, 0],
45 )
46 a = tl.load(a_block_ptr, boundary_check=(0, 1)).to(A.dtype.element_ty)
48 b_block_ptr = tl.make_block_ptr(
49 base=B,
50 shape=[M],
51 strides=[stride_bm],
52 offsets=[m],
53 block_shape=[BLOCK_M],
54 order=[0],
55 )
56 b = tl.load(b_block_ptr, boundary_check=(0,)).to(A.dtype.element_ty)
57 acc += a * b[None, :]
59 result = tl.sum(acc, axis=1)
60 c_block_ptr = tl.make_block_ptr(
61 base=C,
62 shape=[N],
63 strides=[stride_cn],
64 offsets=[block_start_n],
65 block_shape=[BLOCK_N],
66 order=[0],
67 )
68 tl.store(c_block_ptr, result.to(C.dtype.element_ty), boundary_check=(0,))
71def mv(inp, vec):
72 logger.debug("GEMS_SPACEMIT MV")
73 assert inp.shape[1] == vec.shape[0], "incompatible dimensions"
74 N, M = inp.shape
75 out = torch.empty((N,), device=inp.device, dtype=inp.dtype)
76 grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]),)
77 with torch_device_fn.device(inp.device):
78 mv_kernel[grid](
79 inp,
80 vec,
81 out,
82 N,
83 M,
84 inp.stride(0),
85 inp.stride(1),
86 vec.stride(0),
87 out.stride(0),
88 )
89 return out