Coverage for src/flag_gems/runtime/backend/_sunrise/ops/mv.py: 0%
40 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
2import os
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry, libtuner
11from flag_gems.utils import triton_lang_extension as ext
13logger = logging.getLogger(__name__)
16@libentry()
17@libtuner(
18 # configs=runtime.get_tuned_config("mv"),
19 configs=runtime.ops_get_configs("mv", pre_hook=None)
20 if os.environ.get("USE_FLAGTUNE") == "1"
21 else runtime.get_tuned_config("mv"),
22 key=["M", "N"],
23 strategy=runtime.get_expand_config("mv")["strategy"]
24 if os.environ.get("USE_FLAGTUNE") == "1"
25 else ["align32", "align32"],
26)
27@triton.jit
28def mv_kernel(
29 A,
30 B,
31 C,
32 N,
33 M,
34 stride_an,
35 stride_am,
36 stride_bm,
37 stride_cn,
38 BLOCK_N: tl.constexpr,
39 BLOCK_M: tl.constexpr,
40):
41 pid = ext.program_id(0)
42 offset_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)[:, None]
43 offset_m = tl.arange(0, BLOCK_M)[None, :]
44 n_mask = offset_n < N
45 A_ptrs = A + offset_n * stride_an + offset_m * stride_am
46 B_ptrs = B + offset_m * stride_bm
47 acc = tl.zeros((BLOCK_N, BLOCK_M), dtype=tl.float32)
48 for m in range(0, M, BLOCK_M):
49 m_mask = m + offset_m < M
50 a = tl.load(A_ptrs, mask=n_mask & m_mask, other=0.0).to(tl.float32)
51 b = tl.load(B_ptrs, mask=m_mask, other=0.0).to(tl.float32)
52 acc += a * b
53 A_ptrs += BLOCK_M * stride_am
54 B_ptrs += BLOCK_M * stride_bm
56 acc = tl.sum(acc, axis=1)
57 C_ptrs = C + offset_n * stride_cn
58 tl.store(C_ptrs, acc[:, None], mask=n_mask)
61def mv(inp, vec):
62 logger.debug("GEMS MV")
63 assert inp.shape[1] == vec.shape[0], "incompatible dimensions"
64 N, M = inp.shape
65 out = torch.empty((N,), dtype=inp.dtype).to(device=inp.device)
66 grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]),)
67 with torch_device_fn.device(inp.device):
68 mv_kernel[grid](
69 inp,
70 vec,
71 out,
72 N,
73 M,
74 inp.stride(0),
75 inp.stride(1),
76 vec.stride(0),
77 out.stride(0),
78 )
79 return out