Coverage for src/flag_gems/runtime/backend/_spacemit/ops/addmm.py: 0%
53 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6import triton.language.extra.smt as smt
8from flag_gems import runtime
9from flag_gems.utils import libentry, libtuner
11logger = logging.getLogger(__name__)
14@libentry()
15@libtuner(
16 configs=runtime.get_tuned_config("addmm_spacemit"),
17 key=["M", "N", "K"],
18)
19@triton.jit
20def addmm_kernel(
21 a_ptr,
22 b_ptr,
23 bias_ptr,
24 c_ptr,
25 alpha,
26 beta,
27 M,
28 N,
29 K,
30 stride_am,
31 stride_ak,
32 stride_bk,
33 stride_bn,
34 stride_im,
35 stride_in,
36 stride_cm,
37 stride_cn,
38 BLOCK_SIZE_M: tl.constexpr,
39 BLOCK_SIZE_N: tl.constexpr,
40 EVEN_K: tl.constexpr,
41 BLOCK_SIZE_K: tl.constexpr,
42 MICRO_M: tl.constexpr,
43 MICRO_K: tl.constexpr,
44 MICRO_N: tl.constexpr,
45 SUB_BLK_K: tl.constexpr,
46):
47 pid_m = tl.program_id(0)
48 pid_n = tl.program_id(1)
50 a_block_ptr = tl.make_block_ptr(
51 base=a_ptr,
52 shape=[M, K],
53 strides=[stride_am, stride_ak],
54 offsets=[pid_m * BLOCK_SIZE_M, 0],
55 block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K],
56 order=[1, 0],
57 )
59 b_block_ptr = tl.make_block_ptr(
60 base=b_ptr,
61 shape=[K, N],
62 strides=[stride_bk, stride_bn],
63 offsets=[0, pid_n * BLOCK_SIZE_N],
64 block_shape=[BLOCK_SIZE_K, BLOCK_SIZE_N],
65 order=[1, 0],
66 )
68 if EVEN_K:
69 a_descriptor_load = smt.descriptor_load(a_block_ptr, (0, 0))
70 a = smt.view(
71 a_descriptor_load,
72 (0, 0),
73 (BLOCK_SIZE_M, BLOCK_SIZE_K),
74 (MICRO_M, MICRO_K),
75 )
76 b_descriptor_load = smt.descriptor_load(b_block_ptr, (0, 0))
77 b = smt.view(
78 b_descriptor_load,
79 (0, 0),
80 (BLOCK_SIZE_K, BLOCK_SIZE_N),
81 (MICRO_K, MICRO_N),
82 )
83 accumulator = smt.dot(a, b)
84 else:
85 accumulator = tl.zeros(
86 (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=a_ptr.type.element_ty
87 )
88 accumulator = smt.view(
89 accumulator, (0, 0), (BLOCK_SIZE_M, BLOCK_SIZE_N), (MICRO_M, MICRO_N)
90 )
91 sub_num = (K + SUB_BLK_K - 1) // SUB_BLK_K
92 for k in tl.range(0, sub_num):
93 a_descriptor_load = smt.descriptor_load(a_block_ptr, (0, 0))
94 a = smt.view(
95 a_descriptor_load,
96 (0, k * SUB_BLK_K),
97 (BLOCK_SIZE_M, SUB_BLK_K),
98 (MICRO_M, MICRO_K),
99 )
100 b_descriptor_load = smt.descriptor_load(b_block_ptr, (0, 0))
101 b = smt.view(
102 b_descriptor_load,
103 (k * SUB_BLK_K, 0),
104 (SUB_BLK_K, BLOCK_SIZE_N),
105 (MICRO_K, MICRO_N),
106 )
107 accumulator += smt.dot(a, b)
108 accumulator = smt.view(accumulator, (0, 0), (BLOCK_SIZE_M, BLOCK_SIZE_N), (1, 1))
110 bias_block_ptr = tl.make_block_ptr(
111 base=bias_ptr,
112 shape=[M, N],
113 strides=[stride_im, stride_in],
114 offsets=[pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N],
115 block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
116 order=[1, 0],
117 )
118 bias = tl.load(bias_block_ptr, boundary_check=(0, 1))
119 accumulator = accumulator * alpha + bias * beta
120 c = accumulator.to(c_ptr.dtype.element_ty)
122 c_block_ptr = tl.make_block_ptr(
123 base=c_ptr,
124 shape=[M, N],
125 strides=[stride_cm, stride_cn],
126 offsets=[pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N],
127 block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N],
128 order=[1, 0],
129 )
131 tl.store(c_block_ptr, c, boundary_check=(0, 1))
134def addmm(bias, mat1, mat2, *, beta=1, alpha=1):
135 logger.debug("GEMS_SPACEMIT ADDMM")
136 assert mat1.shape[1] == mat2.shape[0], "Incompatible dimensions"
137 M, K = mat1.shape
138 _, N = mat2.shape
140 mat1 = mat1.contiguous()
141 mat2 = mat2.contiguous()
142 out = torch.empty((M, N), device=mat1.device, dtype=mat1.dtype)
143 bias = bias.broadcast_to(out.shape).contiguous()
145 def grid(META):
146 return (
147 triton.cdiv(M, META["BLOCK_SIZE_M"]),
148 triton.cdiv(N, META["BLOCK_SIZE_N"]),
149 )
151 BLOCK_SIZE_K = triton.next_power_of_2(K)
152 SUB_BLK_K = min(1024, BLOCK_SIZE_K)
154 addmm_kernel[grid](
155 mat1,
156 mat2,
157 bias,
158 out,
159 alpha,
160 beta,
161 M,
162 N,
163 K,
164 mat1.stride(0),
165 mat1.stride(1),
166 mat2.stride(0),
167 mat2.stride(1),
168 bias.stride(0),
169 bias.stride(1),
170 out.stride(0),
171 out.stride(1),
172 BLOCK_SIZE_K=BLOCK_SIZE_K,
173 SUB_BLK_K=SUB_BLK_K,
174 )
175 return out