Coverage for src/flag_gems/ops/mv.py: 56%
39 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry, libtuner
10from flag_gems.utils import triton_lang_extension as ext
12logger = logging.getLogger(__name__)
15@libentry()
16@libtuner(
17 configs=runtime.get_tuned_config("mv"),
18 key=["M", "N"],
19 strategy=["align32", "align32"],
20 flagtune_op_name="mv",
21)
22@triton.jit
23def mv_kernel(
24 A,
25 B,
26 C,
27 N,
28 M,
29 stride_an,
30 stride_am,
31 stride_bm,
32 stride_cn,
33 BLOCK_N: tl.constexpr,
34 BLOCK_M: tl.constexpr,
35):
36 pid = ext.program_id(0)
37 offset_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)[:, None]
38 offset_m = tl.arange(0, BLOCK_M)[None, :]
39 n_mask = offset_n < N
40 A_ptrs = A + offset_n * stride_an + offset_m * stride_am
41 B_ptrs = B + offset_m * stride_bm
42 acc = tl.zeros((BLOCK_N, BLOCK_M), dtype=tl.float32)
43 for m in range(0, M, BLOCK_M):
44 m_mask = m + offset_m < M
45 a = tl.load(A_ptrs, mask=n_mask & m_mask, other=0.0).to(tl.float32)
46 b = tl.load(B_ptrs, mask=m_mask, other=0.0).to(tl.float32)
47 acc += a * b
48 A_ptrs += BLOCK_M * stride_am
49 B_ptrs += BLOCK_M * stride_bm
51 acc = tl.sum(acc, axis=1)
52 C_ptrs = C + offset_n * stride_cn
53 tl.store(C_ptrs, acc[:, None], mask=n_mask)
56def mv(inp, vec):
57 logger.debug("GEMS MV")
58 assert inp.shape[1] == vec.shape[0], "incompatible dimensions"
59 N, M = inp.shape
60 out = torch.empty((N,), device=inp.device, dtype=inp.dtype)
61 grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]),)
62 with torch_device_fn.device(inp.device):
63 mv_kernel[grid](
64 inp,
65 vec,
66 out,
67 N,
68 M,
69 inp.stride(0),
70 inp.stride(1),
71 vec.stride(0),
72 out.stride(0),
73 )
74 return out