Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/addmm.py: 0%
85 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
2import os
4import torch
5import triton
6import triton.language as tl
8# from flag_gems import runtime
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import broadcastable_to, libentry
11from flag_gems.utils import triton_lang_extension as ext
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16autotune_decorator = triton.autotune(
17 configs=[],
18 generate_configs="addmm",
19 key=["M", "N", "K"],
20)
23KLX_USE_AUTOTUNE = os.environ.get("KLX_USE_AUTOTUNE", "1") == "1"
25if not KLX_USE_AUTOTUNE:
27 def heur_block_m(args):
28 M = args["M"]
29 if M == 1:
30 return 2
31 if M <= 32:
32 return M
34 return 128
36 def heur_block_n(args):
37 N = args["N"]
38 if N == 1:
39 return 2
40 if N <= 32:
41 return N
42 return 128
44 def heur_block_k(args):
45 K = args["K"]
46 return min(K, 128)
48 autotune_decorator = triton.heuristics(
49 {
50 "BLOCK_SIZE_M": heur_block_m,
51 "BLOCK_SIZE_N": heur_block_n,
52 "BLOCK_SIZE_K": heur_block_k,
53 }
54 )
57@libentry()
58@autotune_decorator
59@triton.jit(do_not_specialize=["alpha", "beta"])
60def addmm_kernel(
61 a_ptr,
62 b_ptr,
63 i_ptr,
64 c_ptr,
65 alpha,
66 beta,
67 M,
68 N,
69 K,
70 stride_am,
71 stride_ak,
72 stride_bk,
73 stride_bn,
74 stride_im,
75 stride_in,
76 stride_cm,
77 stride_cn,
78 BLOCK_SIZE_M: tl.constexpr,
79 BLOCK_SIZE_N: tl.constexpr,
80 BLOCK_SIZE_K: tl.constexpr,
81):
82 pid_m = ext.program_id(0)
83 pid_n = ext.program_id(1)
85 offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
86 offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
87 offs_k = tl.arange(0, BLOCK_SIZE_K)
88 a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
89 b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
91 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
92 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
93 a = tl.load(
94 a_ptrs,
95 mask=(offs_am[:, None] < M) & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
96 other=0.0,
97 )
98 b = tl.load(
99 b_ptrs,
100 mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_bn[None, :] < N),
101 other=0.0,
102 )
103 accumulator += tl.dot(a, b, allow_tf32=False)
104 a_ptrs += BLOCK_SIZE_K * stride_ak
105 b_ptrs += BLOCK_SIZE_K * stride_bk
107 offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
108 offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
109 c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
110 c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
111 i_ptrs = i_ptr + stride_im * offs_cm[:, None] + stride_in * offs_cn[None, :]
112 bias = tl.load(i_ptrs, mask=c_mask, other=0.0)
114 accumulator = accumulator * alpha + bias * beta
115 c = accumulator.to(bias.dtype)
116 tl.store(c_ptrs, c, mask=c_mask)
119def addmm(bias, mat1, mat2, *, beta=1.0, alpha=1.0):
120 logger.debug("GEMS_KUNLUNXIN ADDMM")
121 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions"
122 assert broadcastable_to(
123 bias.shape, (mat1.shape[0], mat2.shape[1])
124 ), "Incompatible input shape"
125 M, K = mat1.shape
126 _, N = mat2.shape
128 mat1 = mat1.contiguous()
129 # mat2 = mat2.contiguous()
130 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype)
131 bias = bias.broadcast_to(out.shape)
133 grid = lambda META: (
134 triton.cdiv(M, META["BLOCK_SIZE_M"]),
135 triton.cdiv(N, META["BLOCK_SIZE_N"]),
136 )
137 with torch_device_fn.device(mat1.device):
138 addmm_kernel[grid](
139 mat1,
140 mat2,
141 bias,
142 out,
143 alpha,
144 beta,
145 M,
146 N,
147 K,
148 mat1.stride(0),
149 mat1.stride(1),
150 mat2.stride(0),
151 mat2.stride(1),
152 bias.stride(0),
153 bias.stride(1),
154 out.stride(0),
155 out.stride(1),
156 )
157 return out
160def addmm_out(bias, mat1, mat2, *, beta=1.0, alpha=1.0, out=None):
161 logger.debug("GEMS_KUNLUNXIN ADDMM_OUT")
162 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions"
163 assert broadcastable_to(
164 bias.shape, (mat1.shape[0], mat2.shape[1])
165 ), "Incompatible input shape"
166 M, K = mat1.shape
167 _, N = mat2.shape
168 if out is None:
169 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype)
170 else:
171 assert out.shape == (M, N), "Incompatible output shape"
173 mat1 = mat1.contiguous()
174 bias = bias.broadcast_to(out.shape)
176 grid = lambda META: (
177 triton.cdiv(M, META["BLOCK_SIZE_M"]),
178 triton.cdiv(N, META["BLOCK_SIZE_N"]),
179 )
180 with torch_device_fn.device(mat1.device):
181 addmm_kernel[grid](
182 mat1,
183 mat2,
184 bias,
185 out,
186 alpha,
187 beta,
188 M,
189 N,
190 K,
191 mat1.stride(0),
192 mat1.stride(1),
193 mat2.stride(0),
194 mat2.stride(1),
195 bias.stride(0),
196 bias.stride(1),
197 out.stride(0),
198 out.stride(1),
199 )
200 return out