Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/addmv.py: 0%
69 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import broadcastable_to, libentry
9from flag_gems.utils import triton_lang_extension as ext
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14def heur_block_n(args):
15 N = args.get("N", 0)
16 # Use smaller BLOCK_N for more parallelism
17 if N <= 64:
18 return triton.next_power_of_2(N)
19 elif N <= 256:
20 return 64
21 elif N <= 1024:
22 return 128
23 else:
24 return 256
27def heur_block_m(args):
28 import builtins
30 M = args.get("M", 0)
31 # Larger BLOCK_M for better memory coalescing
32 return builtins.min(triton.next_power_of_2(M), 4096)
35@libentry()
36@triton.heuristics(
37 {
38 "BLOCK_N": heur_block_n,
39 "BLOCK_M": heur_block_m,
40 }
41)
42@triton.jit(do_not_specialize=["alpha", "beta"])
43def addmv_kernel(
44 A,
45 B,
46 Inp,
47 Out,
48 N: tl.constexpr,
49 M: tl.constexpr,
50 alpha,
51 beta,
52 stride_an: tl.constexpr,
53 stride_am: tl.constexpr,
54 stride_bm: tl.constexpr,
55 stride_in: tl.constexpr,
56 stride_outn: tl.constexpr,
57 BLOCK_N: tl.constexpr,
58 BLOCK_M: tl.constexpr,
59):
60 pid = ext.program_id(0)
61 offset_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)[:, None]
62 offset_m = tl.arange(0, BLOCK_M)[None, :]
63 n_mask = offset_n < N
64 A_ptrs = A + offset_n * stride_an + offset_m * stride_am
65 B_ptrs = B + offset_m * stride_bm
66 acc = tl.zeros((BLOCK_N, BLOCK_M), dtype=tl.float32)
67 for m in range(0, M, BLOCK_M):
68 m_mask = m + offset_m < M
69 a = tl.load(A_ptrs, mask=n_mask & m_mask, other=0.0).to(tl.float32)
70 b = tl.load(B_ptrs, mask=m_mask, other=0.0).to(tl.float32)
71 acc += a * b
72 A_ptrs += BLOCK_M * stride_am
73 B_ptrs += BLOCK_M * stride_bm
75 acc = tl.sum(acc, axis=1)[:, None]
76 Inp_ptrs = Inp + offset_n * stride_in
77 inp = tl.load(Inp_ptrs, mask=n_mask, other=0.0).to(tl.float32)
78 Out_ptrs = Out + offset_n * stride_outn
79 out_block = acc * alpha + inp * beta
80 tl.store(Out_ptrs, out_block, mask=n_mask)
83def addmv(self, mat, vec, *, beta=1, alpha=1):
84 logger.debug("GEMS_KUNLUNXIN ADDMV")
85 assert mat.shape[1] == vec.shape[0], "incompatible dimensions"
86 assert broadcastable_to(self.shape, (mat.shape[0],)), "Incompatible self shape"
87 N, M = mat.shape
88 out = torch.empty((N,), device=mat.device, dtype=mat.dtype)
89 self = self.broadcast_to(out.shape)
90 grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]),)
91 with torch_device_fn.device(mat.device):
92 addmv_kernel[grid](
93 mat,
94 vec,
95 self,
96 out,
97 N,
98 M,
99 alpha,
100 beta,
101 mat.stride(0),
102 mat.stride(1),
103 vec.stride(0),
104 self.stride(0),
105 out.stride(0),
106 )
107 return out
110def addmv_out(self, mat, vec, *, beta=1, alpha=1, out=None):
111 logger.debug("GEMS_KUNLUNXIN ADDMV OUT")
112 assert mat.shape[1] == vec.shape[0], "incompatible dimensions"
113 assert broadcastable_to(self.shape, (mat.shape[0],)), "Incompatible self shape"
114 N, M = mat.shape
115 if out is None:
116 out = torch.empty((N,), device=mat.device, dtype=mat.dtype)
117 else:
118 assert out.shape == (N,), "Incompatible output shape"
120 self = self.broadcast_to(out.shape)
121 grid = lambda META: (triton.cdiv(N, META["BLOCK_N"]),)
122 with torch_device_fn.device(mat.device):
123 addmv_kernel[grid](
124 mat,
125 vec,
126 self,
127 out,
128 N,
129 M,
130 alpha,
131 beta,
132 mat.stride(0),
133 mat.stride(1),
134 vec.stride(0),
135 self.stride(0),
136 out.stride(0),
137 )
138 return out