Coverage for src/flag_gems/runtime/backend/_spacemit/ops/mm.py: 0%

89 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import torch 

2import triton 

3import triton.language as tl 

4import triton.language.extra.smt as smt 

5 

6from flag_gems import runtime 

7from flag_gems.utils import libentry, libtuner 

8 

9 

10@libentry() 

11@libtuner( 

12 configs=runtime.get_tuned_config("mm_spacemit"), 

13 key=["M", "N", "K"], 

14) 

15@triton.jit 

16def mm_kernel( 

17 a_ptr, 

18 b_ptr, 

19 c_ptr, 

20 M, 

21 N, 

22 K, 

23 stride_am, 

24 stride_ak, 

25 stride_bk, 

26 stride_bn, 

27 stride_cm, 

28 stride_cn, 

29 BLOCK_SIZE_M: tl.constexpr, 

30 BLOCK_SIZE_N: tl.constexpr, 

31 BLOCK_SIZE_K: tl.constexpr, 

32 EVEN_K: tl.constexpr, 

33 SPLIT_M: tl.constexpr, 

34 SPLIT_N: tl.constexpr, 

35 SPLIT_K: tl.constexpr, 

36 SUB_BLK_M: tl.constexpr, 

37 SUB_BLK_N: tl.constexpr, 

38 MICRO_M: tl.constexpr, 

39 MICRO_K: tl.constexpr, 

40 MICRO_N: tl.constexpr, 

41 SUB_BLK_K: tl.constexpr, 

42): 

43 pid_m = tl.program_id(0) 

44 pid_n = tl.program_id(1) 

45 a_block_ptr = tl.make_block_ptr( 

46 base=a_ptr, 

47 shape=[M, K], 

48 strides=[stride_am, stride_ak], 

49 offsets=[pid_m * BLOCK_SIZE_M, 0], 

50 block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], 

51 order=[1, 0], 

52 ) 

53 

54 b_block_ptr = tl.make_block_ptr( 

55 base=b_ptr, 

56 shape=[K, N], 

57 strides=[stride_bk, stride_bn], 

58 offsets=[0, pid_n * BLOCK_SIZE_N], 

59 block_shape=[BLOCK_SIZE_K, BLOCK_SIZE_N], 

60 order=[1, 0], 

61 ) 

62 

63 if EVEN_K: 

64 a_descriptor_load = smt.descriptor_load(a_block_ptr, (0, 0)) 

65 a = smt.view( 

66 a_descriptor_load, (0, 0), (BLOCK_SIZE_M, BLOCK_SIZE_K), (MICRO_M, MICRO_K) 

67 ) 

68 b_descriptor_load = smt.descriptor_load(b_block_ptr, (0, 0)) 

69 b = smt.view( 

70 b_descriptor_load, (0, 0), (BLOCK_SIZE_K, BLOCK_SIZE_N), (MICRO_K, MICRO_N) 

71 ) 

72 accumulator = smt.dot(a, b) 

73 accumulator = smt.view( 

74 accumulator, (0, 0), (BLOCK_SIZE_M, BLOCK_SIZE_N), (1, 1) 

75 ) 

76 c = accumulator.to(c_ptr.dtype.element_ty) 

77 c_block_ptr = tl.make_block_ptr( 

78 base=c_ptr, 

79 shape=[M, N], 

80 strides=[stride_cm, stride_cn], 

81 offsets=[pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], 

82 block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], 

83 order=[1, 0], 

84 ) 

85 tl.store(c_block_ptr, c, boundary_check=(0, 1)) 

86 

87 elif SPLIT_M: 

88 b_descriptor_load = smt.descriptor_load(b_block_ptr, (0, 0)) 

89 b = smt.view( 

90 b_descriptor_load, (0, 0), (BLOCK_SIZE_K, BLOCK_SIZE_N), (MICRO_K, MICRO_N) 

91 ) 

92 sub_num = ( 

93 min(BLOCK_SIZE_M, M - BLOCK_SIZE_M * pid_m) + SUB_BLK_M - 1 

94 ) // SUB_BLK_M 

95 for s in smt.parallel(0, sub_num): 

96 a_descriptor_load = smt.descriptor_load(a_block_ptr, (0, 0)) 

97 a = smt.view( 

98 a_descriptor_load, 

99 (s * SUB_BLK_M, 0), 

100 (SUB_BLK_M, BLOCK_SIZE_K), 

101 (MICRO_M, MICRO_K), 

102 ) 

103 accumulator = smt.dot(a, b) 

104 accumulator = smt.view( 

105 accumulator, (0, 0), (SUB_BLK_M, BLOCK_SIZE_N), (1, 1) 

106 ) 

107 c = accumulator.to(c_ptr.dtype.element_ty) 

108 c_block_ptr = tl.make_block_ptr( 

109 base=c_ptr, 

110 shape=[M, N], 

111 strides=[stride_cm, stride_cn], 

112 offsets=[pid_m * BLOCK_SIZE_M + s * SUB_BLK_M, pid_n * BLOCK_SIZE_N], 

113 block_shape=[SUB_BLK_M, BLOCK_SIZE_N], 

114 order=[1, 0], 

115 ) 

116 tl.store(c_block_ptr, c, boundary_check=(0, 1)) 

117 

118 elif SPLIT_N: 

119 sub_num_m = ( 

120 min(BLOCK_SIZE_M, M - BLOCK_SIZE_M * pid_m) + SUB_BLK_M - 1 

121 ) // SUB_BLK_M 

122 sub_num_n = ( 

123 min(BLOCK_SIZE_N, N - BLOCK_SIZE_N * pid_n) + SUB_BLK_N - 1 

124 ) // SUB_BLK_N 

125 total_sub_blocks = sub_num_m * sub_num_n 

126 b_alloc_ptr = smt.alloc(shape=[BLOCK_SIZE_K, BLOCK_SIZE_N]) 

127 b_alloc_view_ptr = smt.view( 

128 b_alloc_ptr, (0, 0), (BLOCK_SIZE_K, BLOCK_SIZE_N), (MICRO_K, MICRO_N) 

129 ) 

130 bar = smt.mbarrier(flag=0, expect_count=sub_num_n) 

131 for s in smt.parallel(0, total_sub_blocks): 

132 s_m = s // sub_num_n 

133 s_n = s % sub_num_n 

134 a_descriptor_load = smt.descriptor_load(a_block_ptr, (0, 0)) 

135 a = smt.view( 

136 a_descriptor_load, 

137 (s_m * SUB_BLK_M, 0), 

138 (SUB_BLK_M, BLOCK_SIZE_K), 

139 (MICRO_M, MICRO_K), 

140 ) 

141 b_alloc_sub_ptr = smt.view( 

142 b_alloc_view_ptr, (0, s_n * SUB_BLK_N), (BLOCK_SIZE_K, SUB_BLK_N) 

143 ) 

144 if s_m == 0: 

145 b_descriptor_load = smt.descriptor_load(b_block_ptr, (0, 0)) 

146 b = smt.view( 

147 b_descriptor_load, 

148 (0, s_n * SUB_BLK_N), 

149 (BLOCK_SIZE_K, SUB_BLK_N), 

150 (MICRO_K, MICRO_N), 

151 ) 

152 tl.store(b_alloc_sub_ptr, b, boundary_check=(0, 1, 2, 3)) 

153 smt.barrier_arrive(bar) 

154 else: 

155 smt.barrier_wait(bar, flag=1) 

156 

157 b_alloc = tl.load(b_alloc_sub_ptr, boundary_check=(0, 1, 2, 3)) 

158 accumulator = smt.dot(a, b_alloc) 

159 accumulator = smt.view(accumulator, (0, 0), (SUB_BLK_M, SUB_BLK_N), (1, 1)) 

160 c = accumulator.to(c_ptr.dtype.element_ty) 

161 c_block_ptr = tl.make_block_ptr( 

162 base=c_ptr, 

163 shape=[M, N], 

164 strides=[stride_cm, stride_cn], 

165 offsets=[ 

166 pid_m * BLOCK_SIZE_M + s_m * SUB_BLK_M, 

167 pid_n * BLOCK_SIZE_N + s_n * SUB_BLK_N, 

168 ], 

169 block_shape=[SUB_BLK_M, SUB_BLK_N], 

170 order=[1, 0], 

171 ) 

172 tl.store(c_block_ptr, c, boundary_check=(0, 1)) 

173 

174 elif SPLIT_K: 

175 accumulator = tl.zeros( 

176 (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=a_ptr.type.element_ty 

177 ) 

178 accumulator = smt.view( 

179 accumulator, (0, 0), (BLOCK_SIZE_M, BLOCK_SIZE_N), (MICRO_M, MICRO_N) 

180 ) 

181 sub_num = (K + SUB_BLK_K - 1) // SUB_BLK_K 

182 for k in tl.range(0, sub_num): 

183 a_descriptor_load = smt.descriptor_load(a_block_ptr, (0, 0)) 

184 a = smt.view( 

185 a_descriptor_load, 

186 (0, k * SUB_BLK_K), 

187 (BLOCK_SIZE_M, SUB_BLK_K), 

188 (MICRO_M, MICRO_K), 

189 ) 

190 b_descriptor_load = smt.descriptor_load(b_block_ptr, (0, 0)) 

191 b = smt.view( 

192 b_descriptor_load, 

193 (k * SUB_BLK_K, 0), 

194 (SUB_BLK_K, BLOCK_SIZE_N), 

195 (MICRO_K, MICRO_N), 

196 ) 

197 accumulator += smt.dot(a, b) 

198 accumulator = smt.view( 

199 accumulator, (0, 0), (BLOCK_SIZE_M, BLOCK_SIZE_N), (1, 1) 

200 ) 

201 c = accumulator.to(c_ptr.dtype.element_ty) 

202 

203 c_block_ptr = tl.make_block_ptr( 

204 base=c_ptr, 

205 shape=[M, N], 

206 strides=[stride_cm, stride_cn], 

207 offsets=[pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N], 

208 block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_N], 

209 order=[1, 0], 

210 ) 

211 tl.store(c_block_ptr, c, boundary_check=(0, 1)) 

212 

213 

214def mm(a, b): 

215 if not a.is_contiguous(): 

216 a = a.contiguous() 

217 if not b.is_contiguous(): 

218 b = b.contiguous() 

219 # checks constraints 

220 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

221 M, K = a.shape 

222 _, N = b.shape 

223 # allocates output 

224 c = torch.empty((M, N), device=a.device, dtype=a.dtype) 

225 # launch kernel 

226 grid = lambda META: ( 

227 triton.cdiv(M, META["BLOCK_SIZE_M"]), 

228 triton.cdiv(N, META["BLOCK_SIZE_N"]), 

229 ) 

230 BLOCK_SIZE_K = triton.next_power_of_2(K) 

231 SUB_BLK_K = min(512, BLOCK_SIZE_K) 

232 

233 mm_kernel[grid]( 

234 a, 

235 b, 

236 c, 

237 M, 

238 N, 

239 K, 

240 a.stride(0), 

241 a.stride(1), 

242 b.stride(0), 

243 b.stride(1), 

244 c.stride(0), 

245 c.stride(1), 

246 BLOCK_SIZE_K=BLOCK_SIZE_K, 

247 SUB_BLK_K=SUB_BLK_K, 

248 ) 

249 return c