Coverage for src/flag_gems/runtime/backend/_spacemit/ops/flash_attention.py: 0%

135 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import torch 

2import triton 

3import triton.language as tl 

4import triton.language.extra.smt as smt 

5 

6from flag_gems import runtime 

7from flag_gems.utils import libentry, libtuner 

8 

9 

10@triton.jit 

11def _attn_fwd_inner( 

12 acc, 

13 l_i_2d, 

14 m_i_2d, 

15 Q_block_ptr, 

16 K_block_ptr, 

17 V_block_ptr, 

18 start_m, 

19 qk_scale, 

20 BLOCK_M: tl.constexpr, 

21 BLOCK_SIZE_K: tl.constexpr, 

22 BLOCK_N: tl.constexpr, 

23 STAGE: tl.constexpr, 

24 offs_m: tl.constexpr, 

25 offs_n: tl.constexpr, 

26 Q_CTX: tl.constexpr, 

27 KV_CTX: tl.constexpr, 

28 MICRO_M: tl.constexpr, 

29 MICRO_K: tl.constexpr, 

30 MICRO_N: tl.constexpr, 

31 num_m_tiles: tl.constexpr, 

32 num_n_tiles: tl.constexpr, 

33): 

34 if STAGE == 1: 

35 tl.static_assert(BLOCK_M >= BLOCK_N) 

36 lo, hi = 0, start_m * BLOCK_M 

37 elif STAGE == 2: 

38 tl.static_assert(BLOCK_M >= BLOCK_N) 

39 lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M 

40 lo = tl.multiple_of(lo, BLOCK_M) 

41 else: 

42 lo, hi = 0, KV_CTX 

43 

44 K_block_ptr = tl.advance(K_block_ptr, (lo, 0)) 

45 V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) 

46 

47 q_desc = smt.descriptor_load(Q_block_ptr, (0, 0)) 

48 q = smt.view(q_desc, (0, 0), (BLOCK_M, BLOCK_SIZE_K), (MICRO_M, MICRO_K)) 

49 

50 offs_m_4d = tl.reshape(offs_m, (num_m_tiles, 1, MICRO_M, 1)) 

51 offs_n_4d = tl.reshape(offs_n, (1, num_n_tiles, 1, MICRO_N)) 

52 

53 causal_offset = KV_CTX - Q_CTX 

54 

55 for start_n in tl.range(lo, hi, BLOCK_N): 

56 start_n = tl.multiple_of(start_n, BLOCK_N) 

57 

58 k_desc = smt.descriptor_load(K_block_ptr, (0, 0)) 

59 k = smt.view(k_desc, (0, 0), (BLOCK_N, BLOCK_SIZE_K), (MICRO_N, MICRO_K)) 

60 trans_k = tl.permute(k, (1, 0, 3, 2)) 

61 

62 qk = smt.dot(q, trans_k) * qk_scale 

63 

64 mask_n = (start_n + offs_n) < KV_CTX 

65 mask_n_4d = tl.reshape(mask_n, (1, num_n_tiles, 1, MICRO_N)) 

66 qk = tl.where(mask_n_4d, qk, -1.0e6) 

67 

68 if (STAGE == 2) or (STAGE == 4): 

69 mask_causal = (offs_m_4d + causal_offset) >= (start_n + offs_n_4d) 

70 mask = mask_causal & mask_n_4d 

71 qk = tl.where(mask, qk, -1.0e6) 

72 

73 qk_max_3 = tl.max(qk, axis=3) 

74 m_ij_2d = tl.max(qk_max_3, axis=1) 

75 m_ij_2d = tl.maximum(m_i_2d, m_ij_2d) 

76 

77 m_ij_bc = tl.reshape(m_ij_2d, (num_m_tiles, 1, MICRO_M, 1)) 

78 qk = qk - m_ij_bc 

79 

80 p = tl.math.exp(qk) 

81 p_sum_3 = tl.sum(p, axis=3) 

82 l_ij_2d = tl.sum(p_sum_3, axis=1) 

83 

84 alpha_2d = tl.math.exp(m_i_2d - m_ij_2d) 

85 l_i_2d = l_i_2d * alpha_2d + l_ij_2d 

86 

87 alpha_bc = tl.reshape(alpha_2d, (num_m_tiles, 1, MICRO_M, 1)) 

88 acc = acc * alpha_bc 

89 

90 v_desc = smt.descriptor_load(V_block_ptr, (0, 0)) 

91 v = smt.view(v_desc, (0, 0), (BLOCK_N, BLOCK_SIZE_K), (MICRO_K, MICRO_N)) 

92 

93 p_cast = p.to(v.dtype) 

94 

95 p_cast = smt.view(p_cast, (0, 0), (BLOCK_M, BLOCK_N), (MICRO_M, MICRO_K)) 

96 

97 acc += smt.dot(p_cast, v) 

98 

99 m_i_2d = m_ij_2d 

100 V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) 

101 K_block_ptr = tl.advance(K_block_ptr, (BLOCK_N, 0)) 

102 

103 return acc, l_i_2d, m_i_2d 

104 

105 

106@libentry() 

107@libtuner( 

108 configs=runtime.get_tuned_config("attention"), 

109 key=["Q_CTX", "KV_CTX", "HEAD_DIM", "GROUP_SIZE"], 

110) 

111@triton.jit 

112def _attn_fwd( 

113 Q, 

114 K, 

115 V, 

116 M, 

117 Out, 

118 acc_buffer, 

119 sm_scale, 

120 stride_qz: tl.constexpr, 

121 stride_qh: tl.constexpr, 

122 stride_qm: tl.constexpr, 

123 stride_qk: tl.constexpr, 

124 stride_kz: tl.constexpr, 

125 stride_kh: tl.constexpr, 

126 stride_kn: tl.constexpr, 

127 stride_kk: tl.constexpr, 

128 stride_vz: tl.constexpr, 

129 stride_vh: tl.constexpr, 

130 stride_vn: tl.constexpr, 

131 stride_vk: tl.constexpr, 

132 stride_oz: tl.constexpr, 

133 stride_oh: tl.constexpr, 

134 stride_om: tl.constexpr, 

135 stride_on: tl.constexpr, 

136 Z, 

137 H_Q, 

138 H_KV, 

139 GROUP_SIZE, 

140 Q_CTX, 

141 KV_CTX, 

142 HEAD_DIM, 

143 BLOCK_M: tl.constexpr, 

144 BLOCK_N: tl.constexpr, 

145 STAGE: tl.constexpr, 

146 BLOCK_SIZE_K: tl.constexpr, 

147 MICRO_M: tl.constexpr, 

148 MICRO_K: tl.constexpr, 

149 MICRO_N: tl.constexpr, 

150 num_ctas: tl.constexpr, 

151): 

152 NUM_BLOCKS_M = tl.cdiv(Q_CTX, BLOCK_M) 

153 NUM_BLOCKS = NUM_BLOCKS_M * Z * H_Q 

154 

155 pid = tl.program_id(0) 

156 sub_num = tl.cdiv(max(NUM_BLOCKS - pid, 0), num_ctas) 

157 

158 for block_idx in tl.range(0, sub_num): 

159 task_hz_idx = (pid + num_ctas * block_idx) // NUM_BLOCKS_M 

160 task_m_idx = (pid + num_ctas * block_idx) % NUM_BLOCKS_M 

161 

162 off_z = task_hz_idx // H_Q 

163 off_hq = task_hz_idx % H_Q 

164 off_hkv = off_hq // GROUP_SIZE 

165 

166 q_offset = off_z.to(tl.int64) * stride_qz + off_hq.to(tl.int64) * stride_qh 

167 k_offset = off_z.to(tl.int64) * stride_kz + off_hkv.to(tl.int64) * stride_kh 

168 v_offset = off_z.to(tl.int64) * stride_vz + off_hkv.to(tl.int64) * stride_vh 

169 o_offset = off_z.to(tl.int64) * stride_oz + off_hq.to(tl.int64) * stride_oh 

170 

171 Q_block_ptr = tl.make_block_ptr( 

172 base=Q + q_offset, 

173 shape=(Q_CTX, HEAD_DIM), 

174 strides=(stride_qm, stride_qk), 

175 offsets=(task_m_idx * BLOCK_M, 0), 

176 block_shape=(BLOCK_M, BLOCK_SIZE_K), 

177 order=(1, 0), 

178 ) 

179 K_block_ptr = tl.make_block_ptr( 

180 base=K + k_offset, 

181 shape=(KV_CTX, HEAD_DIM), 

182 strides=(stride_kn, stride_kk), 

183 offsets=(0, 0), 

184 block_shape=(BLOCK_N, BLOCK_SIZE_K), 

185 order=(1, 0), 

186 ) 

187 V_block_ptr = tl.make_block_ptr( 

188 base=V + v_offset, 

189 shape=(KV_CTX, HEAD_DIM), 

190 strides=(stride_vn, stride_vk), 

191 offsets=(0, 0), 

192 block_shape=(BLOCK_N, BLOCK_SIZE_K), 

193 order=(1, 0), 

194 ) 

195 O_block_ptr = tl.make_block_ptr( 

196 base=Out + o_offset, 

197 shape=(Q_CTX, HEAD_DIM), 

198 strides=(stride_om, stride_on), 

199 offsets=(task_m_idx * BLOCK_M, 0), 

200 block_shape=(BLOCK_M, BLOCK_SIZE_K), 

201 order=(1, 0), 

202 ) 

203 

204 offs_m = task_m_idx * BLOCK_M + tl.arange(0, BLOCK_M) 

205 offs_n = tl.arange(0, BLOCK_N) 

206 

207 # compute tile counts from META-provided BLOCK_*/MICRO_* 

208 num_m_tiles: tl.constexpr = BLOCK_M // MICRO_M 

209 num_k_tiles: tl.constexpr = BLOCK_SIZE_K // MICRO_N 

210 

211 m_i_2d = tl.zeros([num_m_tiles, MICRO_M], dtype=tl.float32) - float("inf") 

212 l_i_2d = tl.zeros([num_m_tiles, MICRO_M], dtype=tl.float32) + 1.0 

213 acc_4d = tl.zeros( 

214 [num_m_tiles, num_k_tiles, MICRO_M, MICRO_N], dtype=tl.float32 

215 ) 

216 

217 if STAGE == 4: 

218 acc_4d, l_i_2d, m_i_2d = _attn_fwd_inner( 

219 acc_4d, 

220 l_i_2d, 

221 m_i_2d, 

222 Q_block_ptr, 

223 K_block_ptr, 

224 V_block_ptr, 

225 task_m_idx, 

226 sm_scale, 

227 BLOCK_M, 

228 BLOCK_SIZE_K, 

229 BLOCK_N, 

230 4, 

231 offs_m, 

232 offs_n, 

233 Q_CTX, 

234 KV_CTX, 

235 MICRO_M, 

236 MICRO_K, 

237 MICRO_N, 

238 BLOCK_M // MICRO_M, 

239 BLOCK_N // MICRO_N, 

240 BLOCK_SIZE_K // MICRO_N, 

241 ) 

242 else: 

243 if STAGE & 1: 

244 acc_4d, l_i_2d, m_i_2d = _attn_fwd_inner( 

245 acc_4d, 

246 l_i_2d, 

247 m_i_2d, 

248 Q_block_ptr, 

249 K_block_ptr, 

250 V_block_ptr, 

251 task_m_idx, 

252 sm_scale, 

253 BLOCK_M, 

254 BLOCK_SIZE_K, 

255 BLOCK_N, 

256 4 - STAGE, 

257 offs_m, 

258 offs_n, 

259 Q_CTX, 

260 KV_CTX, 

261 MICRO_M, 

262 MICRO_K, 

263 MICRO_N, 

264 BLOCK_M // MICRO_M, 

265 BLOCK_N // MICRO_N, 

266 BLOCK_SIZE_K // MICRO_N, 

267 ) 

268 if STAGE & 2: 

269 acc_4d, l_i_2d, m_i_2d = _attn_fwd_inner( 

270 acc_4d, 

271 l_i_2d, 

272 m_i_2d, 

273 Q_block_ptr, 

274 K_block_ptr, 

275 V_block_ptr, 

276 task_m_idx, 

277 sm_scale, 

278 BLOCK_M, 

279 BLOCK_SIZE_K, 

280 BLOCK_N, 

281 2, 

282 offs_m, 

283 offs_n, 

284 Q_CTX, 

285 KV_CTX, 

286 MICRO_M, 

287 MICRO_K, 

288 MICRO_N, 

289 BLOCK_M // MICRO_M, 

290 BLOCK_N // MICRO_N, 

291 BLOCK_SIZE_K // MICRO_N, 

292 ) 

293 

294 acc_2d = smt.view(acc_4d, (0, 0), (BLOCK_M, BLOCK_SIZE_K), (1, 1)) 

295 m_i = tl.reshape(m_i_2d, (BLOCK_M,)) 

296 l_i = tl.reshape(l_i_2d, (BLOCK_M,)) 

297 

298 m_i = m_i + tl.math.log(l_i) 

299 accumulator = acc_2d / l_i[:, None] 

300 

301 mask_m = offs_m < Q_CTX 

302 m_ptrs = M + task_hz_idx * Q_CTX + offs_m 

303 tl.store(m_ptrs, m_i.to(M.type.element_ty), mask=mask_m) 

304 

305 tl.store( 

306 O_block_ptr, accumulator.to(Out.type.element_ty), boundary_check=(0, 1) 

307 ) 

308 

309 

310class Attention(torch.autograd.Function): 

311 @staticmethod 

312 def forward(ctx, q, k, v, sm_scale, is_causal, enable_gqa: bool): 

313 HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1] 

314 assert HEAD_DIM_Q == HEAD_DIM_K == v.shape[-1] 

315 BLOCK_SIZE_K = triton.next_power_of_2(HEAD_DIM_K) 

316 

317 Q_CTX = q.shape[2] 

318 KV_CTX = k.shape[2] 

319 

320 H_Q = q.shape[1] 

321 H_KV = k.shape[1] 

322 

323 if H_Q != H_KV: 

324 enable_gqa = True 

325 

326 if enable_gqa: 

327 assert ( 

328 H_Q % H_KV == 0 

329 ), f"GQA requires H_Q % H_KV == 0, got H_Q={H_Q}, H_KV={H_KV}" 

330 GROUP_SIZE = H_Q // H_KV 

331 else: 

332 assert H_Q == H_KV 

333 GROUP_SIZE = 1 

334 

335 o = torch.empty_like(q) 

336 acc = torch.empty( 

337 (q.shape[0], q.shape[1], q.shape[2], HEAD_DIM_K), 

338 dtype=torch.float32, 

339 device=q.device, 

340 ) 

341 M = torch.empty( 

342 (q.shape[0], q.shape[1], q.shape[2]), dtype=torch.float32, device=q.device 

343 ) 

344 

345 if is_causal: 

346 STAGE = 3 if (Q_CTX == KV_CTX) else 4 

347 else: 

348 STAGE = 1 

349 

350 num_ctas = 16 

351 grid = lambda META: (META.get("num_ctas", num_ctas),) 

352 

353 _attn_fwd[grid]( 

354 q, 

355 k, 

356 v, 

357 M, 

358 o, 

359 acc, 

360 sm_scale, 

361 q.stride(0), 

362 q.stride(1), 

363 q.stride(2), 

364 q.stride(3), 

365 k.stride(0), 

366 k.stride(1), 

367 k.stride(2), 

368 k.stride(3), 

369 v.stride(0), 

370 v.stride(1), 

371 v.stride(2), 

372 v.stride(3), 

373 o.stride(0), 

374 o.stride(1), 

375 o.stride(2), 

376 o.stride(3), 

377 q.shape[0], 

378 H_Q=H_Q, 

379 H_KV=H_KV, 

380 GROUP_SIZE=GROUP_SIZE, 

381 Q_CTX=Q_CTX, 

382 KV_CTX=KV_CTX, 

383 HEAD_DIM=HEAD_DIM_K, 

384 STAGE=STAGE, 

385 BLOCK_SIZE_K=BLOCK_SIZE_K, 

386 ) 

387 

388 return o 

389 

390 

391def flash_attention( 

392 query, 

393 key, 

394 value, 

395 attn_mask=None, 

396 dropout_p=0.0, 

397 is_causal=False, 

398 scale=None, 

399 enable_gqa=False, 

400): 

401 return Attention.apply(query, key, value, scale, is_causal, enable_gqa) 

402 

403 

404def scaled_dot_product_attention( 

405 query, 

406 key, 

407 value, 

408 attn_mask=None, 

409 dropout_p=0.0, 

410 is_causal=False, 

411 scale=None, 

412 enable_gqa=False, 

413): 

414 query = query.clone().contiguous() 

415 key = key.clone().contiguous() 

416 value = value.clone().contiguous() 

417 return Attention.apply(query, key, value, scale, is_causal, enable_gqa)