Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/ops/w8a8_block_fp8_bmm.py: 0%

276 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1from typing import List, Optional 

2 

3import torch 

4import triton 

5from triton.experimental import gluon 

6from triton.experimental.gluon import language as gl 

7from triton.experimental.gluon.language.nvidia.hopper import ( 

8 fence_async_shared, 

9 mbarrier, 

10 tma, 

11 warpgroup_mma, 

12 warpgroup_mma_wait, 

13) 

14from triton.experimental.gluon.nvidia.hopper import TensorDescriptor 

15from triton.language.core import _aggregate as aggregate 

16 

17_TORCH_TO_GL_DTYPE = { 

18 torch.float8_e4m3fn: gl.float8e4nv, 

19 torch.float8_e5m2: gl.float8e5, 

20 torch.bfloat16: gl.bfloat16, 

21 torch.float16: gl.float16, 

22 torch.float32: gl.float32, 

23} 

24 

25 

26def _gl_dtype(t: torch.Tensor): 

27 try: 

28 return _TORCH_TO_GL_DTYPE[t.dtype] 

29 except KeyError as e: 

30 raise TypeError(f"Unsupported tensor dtype: {t.dtype}") from e 

31 

32 

33@gluon.constexpr_function 

34def get_warps_per_cta(BLOCK_M, BLOCK_N, num_warps): 

35 warps_per_cta = [4, 1] 

36 m = 16 

37 while warps_per_cta[0] * warps_per_cta[1] != num_warps: 

38 if BLOCK_M > m * warps_per_cta[0]: 

39 warps_per_cta[0] *= 2 

40 else: 

41 warps_per_cta[1] *= 2 

42 return warps_per_cta 

43 

44 

45@gluon.constexpr_function 

46def get_instr_shape_n(BLOCK_M, BLOCK_N, num_warps): 

47 m = 16 

48 m_reps = triton.cdiv(BLOCK_M, m) 

49 n_reps = triton.cdiv(num_warps, m_reps) 

50 max_n = max(BLOCK_N // n_reps, 8) 

51 n = 256 

52 while n > max_n or BLOCK_N % n != 0: 

53 n -= 8 

54 assert n >= 8, "expected to find a valid n" 

55 return n 

56 

57 

58@gluon.constexpr_function 

59def pick_wgmma_layout(dtype, BLOCK_M, BLOCK_N, num_warps): 

60 m = 16 

61 k = 256 // dtype.primitive_bitwidth 

62 n = get_instr_shape_n(BLOCK_M, BLOCK_N, num_warps) 

63 warps_per_cta = get_warps_per_cta(BLOCK_M, BLOCK_N, num_warps) 

64 return gl.NVMMADistributedLayout( 

65 version=[3, 0], 

66 warps_per_cta=warps_per_cta, 

67 instr_shape=[m, n, k], 

68 ) 

69 

70 

71@aggregate 

72class Config: 

73 B: gl.constexpr 

74 M: gl.constexpr 

75 M_aligned: gl.constexpr 

76 N: gl.constexpr 

77 K: gl.constexpr 

78 BLOCK_M: gl.constexpr 

79 BLOCK_N: gl.constexpr 

80 BLOCK_K: gl.constexpr 

81 TILE_ORDER: gl.constexpr 

82 SWAP_AB: gl.constexpr 

83 num_warps: gl.constexpr 

84 num_stages: gl.constexpr 

85 num_sms: gl.constexpr 

86 # xs (per-token scale) strides into the caller's [B, M, num_kb] tensor. 

87 xs_sB: gl.constexpr 

88 xs_sM: gl.constexpr 

89 xs_sKb: gl.constexpr 

90 # Derived: tile counts. 

91 num_m_tiles: gl.constexpr 

92 num_n_tiles: gl.constexpr 

93 num_k_blocks: gl.constexpr 

94 num_tiles_per_batch: gl.constexpr 

95 num_tiles: gl.constexpr 

96 

97 @gluon.constexpr_function 

98 def __init__( 

99 self, 

100 B, 

101 M, 

102 M_aligned, 

103 N, 

104 K, 

105 BLOCK_M, 

106 BLOCK_N, 

107 BLOCK_K, 

108 TILE_ORDER, 

109 SWAP_AB, 

110 num_warps, 

111 num_stages, 

112 num_sms, 

113 xs_sB, 

114 xs_sM, 

115 xs_sKb, 

116 ): 

117 self.B = gl.constexpr(B) 

118 self.M = gl.constexpr(M) 

119 self.M_aligned = gl.constexpr(M_aligned) 

120 self.N = gl.constexpr(N) 

121 self.K = gl.constexpr(K) 

122 self.BLOCK_M = gl.constexpr(BLOCK_M) 

123 self.BLOCK_N = gl.constexpr(BLOCK_N) 

124 self.BLOCK_K = gl.constexpr(BLOCK_K) 

125 self.TILE_ORDER = gl.constexpr(TILE_ORDER) 

126 self.SWAP_AB = gl.constexpr(SWAP_AB) 

127 self.num_warps = gl.constexpr(num_warps) 

128 self.num_stages = gl.constexpr(num_stages) 

129 self.num_sms = gl.constexpr(num_sms) 

130 self.xs_sB = gl.constexpr(xs_sB) 

131 self.xs_sM = gl.constexpr(xs_sM) 

132 self.xs_sKb = gl.constexpr(xs_sKb) 

133 num_m = M_aligned // BLOCK_M 

134 num_n = N // BLOCK_N 

135 self.num_m_tiles = gl.constexpr(num_m) 

136 self.num_n_tiles = gl.constexpr(num_n) 

137 self.num_k_blocks = gl.constexpr(K // BLOCK_K) 

138 self.num_tiles_per_batch = gl.constexpr(num_m * num_n) 

139 self.num_tiles = gl.constexpr(B * num_m * num_n) 

140 

141 

142@aggregate 

143class BarrierCounter: 

144 index: gl.tensor 

145 phase: gl.tensor 

146 num_barriers: gl.constexpr 

147 

148 @gluon.constexpr_function 

149 def __init__(self, index, phase, num_barriers): 

150 self.index = index 

151 self.phase = phase 

152 self.num_barriers = gl.constexpr(num_barriers) 

153 

154 @gluon.must_use_result 

155 @gluon.jit 

156 def increment(self): 

157 if self.num_barriers == 1: 

158 return BarrierCounter(gl.to_tensor(0), self.phase ^ 1, self.num_barriers) 

159 next_index = self.index + 1 

160 rollover = next_index == self.num_barriers 

161 index = gl.where(rollover, 0, next_index) 

162 phase = gl.where(rollover, self.phase ^ 1, self.phase) 

163 return BarrierCounter(index, phase, self.num_barriers) 

164 

165 

166@aggregate 

167class Channel: 

168 x_smem: gl.shared_memory_descriptor 

169 y_smem: gl.shared_memory_descriptor 

170 ready_bars: gl.shared_memory_descriptor 

171 empty_bars: gl.shared_memory_descriptor 

172 num_stages: gl.constexpr 

173 

174 @gluon.constexpr_function 

175 def __init__(self, x_smem, y_smem, ready_bars, empty_bars, num_stages): 

176 self.x_smem = x_smem 

177 self.y_smem = y_smem 

178 self.ready_bars = ready_bars 

179 self.empty_bars = empty_bars 

180 self.num_stages = gl.constexpr(num_stages) 

181 

182 @gluon.jit 

183 def alloc( 

184 BLOCK_M: gl.constexpr, 

185 BLOCK_N: gl.constexpr, 

186 BLOCK_K: gl.constexpr, 

187 x_dtype: gl.constexpr, 

188 x_layout: gl.constexpr, 

189 y_dtype: gl.constexpr, 

190 y_layout: gl.constexpr, 

191 num_stages: gl.constexpr, 

192 num_warps: gl.constexpr, 

193 ): 

194 # x: 3D box [1, BLOCK_M, BLOCK_K] (x is permuted/non-contig at the global level). 

195 # y: 2D box. xs is loaded directly with gl.load (not staged through smem). 

196 x_smem = gl.allocate_shared_memory( 

197 x_dtype, [num_stages, 1, BLOCK_M, BLOCK_K], x_layout 

198 ) 

199 y_smem = gl.allocate_shared_memory( 

200 y_dtype, [num_stages, BLOCK_N, BLOCK_K], y_layout 

201 ) 

202 ready_bars = gl.allocate_shared_memory( 

203 gl.int64, [num_stages, 1], mbarrier.MBarrierLayout() 

204 ) 

205 empty_bars = gl.allocate_shared_memory( 

206 gl.int64, [num_stages, 1], mbarrier.MBarrierLayout() 

207 ) 

208 for i in gl.static_range(num_stages): 

209 mbarrier.init(ready_bars.index(i), count=1) 

210 mbarrier.init(empty_bars.index(i), count=1) 

211 mbarrier.arrive(empty_bars.index(i), count=1) 

212 return Channel(x_smem, y_smem, ready_bars, empty_bars, num_stages) 

213 

214 @gluon.jit 

215 def release(self): 

216 self.x_smem._keep_alive() 

217 self.y_smem._keep_alive() 

218 for i in gl.static_range(self.num_stages): 

219 mbarrier.invalidate(self.ready_bars.index(i)) 

220 mbarrier.invalidate(self.empty_bars.index(i)) 

221 

222 

223@gluon.jit 

224def get_tile(tile_id, config): 

225 # TILE_ORDER: 0 = horizontal (N fastest within batch — favours x reuse across N sweep) 

226 # 1 = vertical (M fastest within batch — favours y reuse across M sweep) 

227 batch_id = tile_id // config.num_tiles_per_batch 

228 local_id = tile_id % config.num_tiles_per_batch 

229 if config.TILE_ORDER == 0: 

230 m_tile_id = local_id // config.num_n_tiles 

231 n_tile_id = local_id % config.num_n_tiles 

232 else: 

233 n_tile_id = local_id // config.num_m_tiles 

234 m_tile_id = local_id % config.num_m_tiles 

235 return batch_id, m_tile_id, n_tile_id 

236 

237 

238@gluon.jit 

239def compute_partition(channel, config, tensors): 

240 x_desc, y_desc, xs_ptr, z_desc, ys_ptr = tensors 

241 start_pid = gl.program_id(0) 

242 counter = BarrierCounter( 

243 index=gl.to_tensor(0), phase=gl.to_tensor(0), num_barriers=config.num_stages 

244 ) 

245 

246 if config.SWAP_AB: 

247 mma_layout: gl.constexpr = pick_wgmma_layout( 

248 x_desc.dtype, config.BLOCK_N, config.BLOCK_M, num_warps=config.num_warps 

249 ) 

250 xs_load_layout: gl.constexpr = gl.SliceLayout(0, mma_layout) 

251 else: 

252 mma_layout: gl.constexpr = pick_wgmma_layout( 

253 x_desc.dtype, config.BLOCK_M, config.BLOCK_N, num_warps=config.num_warps 

254 ) 

255 xs_load_layout: gl.constexpr = gl.SliceLayout(1, mma_layout) 

256 

257 z_smem_layout: gl.constexpr = gl.NVMMASharedLayout.get_default_for( 

258 [1, config.BLOCK_M, config.BLOCK_N], z_desc.dtype 

259 ) 

260 z_smem = gl.allocate_shared_memory( 

261 z_desc.dtype, [1, config.BLOCK_M, config.BLOCK_N], z_smem_layout 

262 ) 

263 

264 # xs in-tile lane indices (one fp32 per token along BLOCK_M). 

265 xs_lane = gl.arange(0, config.BLOCK_M, layout=xs_load_layout) 

266 

267 for tile_id in range(start_pid, config.num_tiles, config.num_sms): 

268 batch_id, m_tile_id, n_tile_id = get_tile(tile_id, config) 

269 m_start = m_tile_id * config.BLOCK_M 

270 n_start = n_tile_id * config.BLOCK_N 

271 # ys layout matches the scale grid (N/BLOCK_N, K/BLOCK_K); one scale per (n_tile, k_block). 

272 ys_base = (batch_id * config.num_n_tiles + n_tile_id) * config.num_k_blocks 

273 # xs is the caller's [B, M, num_kb] tensor (strided, possibly non-contig). 

274 xs_m = m_start + xs_lane 

275 xs_mask = xs_m < config.M 

276 xs_row_base = batch_id * config.xs_sB + xs_m * config.xs_sM 

277 

278 if config.SWAP_AB: 

279 partial_zero = gl.zeros( 

280 (config.BLOCK_N, config.BLOCK_M), dtype=gl.float32, layout=mma_layout 

281 ) 

282 acc = gl.zeros( 

283 (config.BLOCK_N, config.BLOCK_M), dtype=gl.float32, layout=mma_layout 

284 ) 

285 else: 

286 partial_zero = gl.zeros( 

287 (config.BLOCK_M, config.BLOCK_N), dtype=gl.float32, layout=mma_layout 

288 ) 

289 acc = gl.zeros( 

290 (config.BLOCK_M, config.BLOCK_N), dtype=gl.float32, layout=mma_layout 

291 ) 

292 

293 for k in range(0, config.K, config.BLOCK_K): 

294 k_block_idx = k // config.BLOCK_K 

295 index, phase = counter.index, counter.phase 

296 x_slot = channel.x_smem.index(index) # [1, BLOCK_M, BLOCK_K] 

297 y_slot = channel.y_smem.index(index) # [BLOCK_N, BLOCK_K] 

298 ready_bar = channel.ready_bars.index(index) 

299 empty_bar = channel.empty_bars.index(index) 

300 mbarrier.wait(ready_bar, phase) 

301 

302 x = x_slot.reshape((config.BLOCK_M, config.BLOCK_K)) 

303 y = y_slot 

304 

305 x_s = gl.load( 

306 xs_ptr + xs_row_base + k_block_idx * config.xs_sKb, 

307 mask=xs_mask, 

308 other=0.0, 

309 ) 

310 y_s = gl.load(ys_ptr + ys_base + k_block_idx) 

311 xy_s = x_s * y_s 

312 

313 if config.SWAP_AB: 

314 x_t = x.permute((1, 0)) 

315 partial_async = warpgroup_mma( 

316 y, x_t, partial_zero, use_acc=False, is_async=True 

317 ) 

318 partial = warpgroup_mma_wait(num_outstanding=0, deps=(partial_async,)) 

319 acc = acc + partial * xy_s[None, :] 

320 else: 

321 y_t = y.permute((1, 0)) 

322 partial_async = warpgroup_mma( 

323 x, y_t, partial_zero, use_acc=False, is_async=True 

324 ) 

325 partial = warpgroup_mma_wait(num_outstanding=0, deps=(partial_async,)) 

326 acc = acc + partial * xy_s[:, None] 

327 

328 mbarrier.arrive(empty_bar) 

329 counter = counter.increment() 

330 

331 acc_out = acc.to(z_desc.dtype) 

332 if config.SWAP_AB: 

333 acc_out = acc_out.permute((1, 0)) 

334 tma.store_wait(pendings=0) 

335 z_smem.reshape((config.BLOCK_M, config.BLOCK_N)).store(acc_out) 

336 fence_async_shared() 

337 tma.async_copy_shared_to_global(z_desc, [batch_id, m_start, n_start], z_smem) 

338 

339 tma.store_wait(pendings=0) 

340 

341 

342@gluon.jit 

343def load_partition(channel, config, tensors): 

344 x_desc, y_desc, xs_ptr, z_desc, ys_ptr = tensors 

345 start_pid = gl.program_id(0) 

346 counter = BarrierCounter( 

347 index=gl.to_tensor(0), phase=gl.to_tensor(0), num_barriers=config.num_stages 

348 ) 

349 

350 nbytes: gl.constexpr = ( 

351 config.BLOCK_M * config.BLOCK_K + config.BLOCK_N * config.BLOCK_K 

352 ) 

353 

354 for tile_id in range(start_pid, config.num_tiles, config.num_sms): 

355 batch_id, m_tile_id, n_tile_id = get_tile(tile_id, config) 

356 m_start = m_tile_id * config.BLOCK_M 

357 n_start = n_tile_id * config.BLOCK_N 

358 

359 y_row = batch_id * config.N + n_start 

360 

361 for k in range(0, config.K, config.BLOCK_K): 

362 index, phase = counter.index, counter.phase 

363 x_slot = channel.x_smem.index(index) 

364 y_slot = channel.y_smem.index(index) 

365 ready_bar = channel.ready_bars.index(index) 

366 empty_bar = channel.empty_bars.index(index) 

367 mbarrier.wait(empty_bar, phase) 

368 

369 mbarrier.expect(ready_bar, nbytes) 

370 tma.async_copy_global_to_shared( 

371 x_desc, [batch_id, m_start, k], ready_bar, x_slot 

372 ) 

373 tma.async_copy_global_to_shared(y_desc, [y_row, k], ready_bar, y_slot) 

374 

375 counter = counter.increment() 

376 

377 

378@triton.autotune( 

379 configs=[ 

380 triton.Config({"TILE_ORDER": tile_order}, num_warps=nw, num_stages=ns) 

381 for nw in (4, 8) 

382 for ns in (4, 6, 8) 

383 for tile_order in (0, 1) # 0=horizontal (n fastest), 1=vertical (m fastest) 

384 ], 

385 key=["B", "M_aligned", "N", "K"], 

386) 

387@gluon.jit 

388def w8a8_block_fp8_bmm_kernel( 

389 x_desc, 

390 y_desc, 

391 xs_ptr, 

392 z_desc, 

393 ys_ptr, 

394 xs_sB: gl.constexpr, 

395 xs_sM: gl.constexpr, 

396 xs_sKb: gl.constexpr, 

397 B: gl.constexpr, 

398 M: gl.constexpr, 

399 M_aligned: gl.constexpr, 

400 N: gl.constexpr, 

401 K: gl.constexpr, 

402 BLOCK_M: gl.constexpr, 

403 BLOCK_N: gl.constexpr, 

404 BLOCK_K: gl.constexpr, 

405 TILE_ORDER: gl.constexpr, 

406 SWAP_AB: gl.constexpr, 

407 num_warps: gl.constexpr, 

408 num_stages: gl.constexpr, 

409 num_sms: gl.constexpr, 

410): 

411 config = Config( 

412 B=B, 

413 M=M, 

414 M_aligned=M_aligned, 

415 N=N, 

416 K=K, 

417 BLOCK_M=BLOCK_M, 

418 BLOCK_N=BLOCK_N, 

419 BLOCK_K=BLOCK_K, 

420 TILE_ORDER=TILE_ORDER, 

421 SWAP_AB=SWAP_AB, 

422 num_warps=num_warps, 

423 num_stages=num_stages, 

424 num_sms=num_sms, 

425 xs_sB=xs_sB, 

426 xs_sM=xs_sM, 

427 xs_sKb=xs_sKb, 

428 ) 

429 tensors = (x_desc, y_desc, xs_ptr, z_desc, ys_ptr) 

430 channel = Channel.alloc( 

431 BLOCK_M=BLOCK_M, 

432 BLOCK_N=BLOCK_N, 

433 BLOCK_K=BLOCK_K, 

434 x_dtype=x_desc.dtype, 

435 x_layout=gl.constexpr(x_desc.layout), 

436 y_dtype=y_desc.dtype, 

437 y_layout=gl.constexpr(y_desc.layout), 

438 num_stages=num_stages, 

439 num_warps=num_warps, 

440 ) 

441 

442 gl.warp_specialize( 

443 [ 

444 (compute_partition, (channel, config, tensors)), 

445 (load_partition, (channel, config, tensors)), 

446 ], 

447 [1], 

448 [24], 

449 ) 

450 

451 channel.release() 

452 

453 

454def w8a8_block_fp8_bmm( 

455 x: torch.Tensor, 

456 y: torch.Tensor, 

457 xs: torch.Tensor, 

458 ys: torch.Tensor, 

459 block_size: List[int] = [128, 128], 

460 z: Optional[torch.Tensor] = None, 

461 output_dtype: torch.dtype = torch.bfloat16, 

462): 

463 # x: [B, M, K] fp8 

464 # y: [B, N, K] fp8 

465 # xs: [B, M, K // block_k] f32 

466 # ys: [B, N // block_n, K // block_k] f32 

467 # z: [B, M, N] out_dtype 

468 assert len(block_size) == 2 

469 BLOCK_N, BLOCK_K = block_size 

470 assert ( 

471 BLOCK_N == 128 and BLOCK_K == 128 

472 ), "this kernel assumes 128x128 block-wise FP8 scales" 

473 

474 assert x.ndim == 3 and y.ndim == 3 and xs.ndim == 3 and ys.ndim == 3 

475 assert x.shape[0] == y.shape[0] == xs.shape[0] == ys.shape[0] 

476 assert x.shape[-1] == y.shape[-1] 

477 assert x.shape[:-1] == xs.shape[:-1] 

478 assert x.stride(-1) == 1 and y.stride(-1) == 1 

479 

480 device = x.device 

481 B, M, K = x.shape 

482 _, N, _ = y.shape 

483 assert K % BLOCK_K == 0 and N % BLOCK_N == 0 

484 num_kb = K // BLOCK_K 

485 

486 if z is None: 

487 z = torch.empty((B, M, N), device=device, dtype=output_dtype) 

488 else: 

489 assert z.shape == (B, M, N) and z.device == device and z.dtype == output_dtype 

490 assert z.stride(-1) == 1 

491 

492 BLOCK_M = max(8, min(64, 1 << ((M - 1).bit_length()))) 

493 SWAP_AB = 1 if BLOCK_M < 64 else 0 

494 

495 M_aligned = triton.cdiv(M, BLOCK_M) * BLOCK_M 

496 

497 x_gl_dtype = _gl_dtype(x) 

498 y_gl_dtype = _gl_dtype(y) 

499 z_gl_dtype = _gl_dtype(z) 

500 

501 x_layout = gl.NVMMASharedLayout.get_default_for([1, BLOCK_M, BLOCK_K], x_gl_dtype) 

502 x_desc = TensorDescriptor.from_tensor( 

503 x, block_shape=[1, BLOCK_M, BLOCK_K], layout=x_layout 

504 ) 

505 

506 assert y.is_contiguous(), "y must be contiguous so it can be viewed as (B*N, K)" 

507 y_flat = y.view(B * N, K) 

508 y_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_N, BLOCK_K], y_gl_dtype) 

509 y_desc = TensorDescriptor.from_tensor( 

510 y_flat, block_shape=[BLOCK_N, BLOCK_K], layout=y_layout 

511 ) 

512 

513 assert xs.ndim == 3 and xs.shape == (B, M, num_kb) 

514 xs_sB, xs_sM, xs_sKb = xs.stride() 

515 

516 z_layout = gl.NVMMASharedLayout.get_default_for([1, BLOCK_M, BLOCK_N], z_gl_dtype) 

517 z_desc = TensorDescriptor.from_tensor( 

518 z, block_shape=[1, BLOCK_M, BLOCK_N], layout=z_layout 

519 ) 

520 

521 num_sms = torch.cuda.get_device_properties(device).multi_processor_count 

522 w8a8_block_fp8_bmm_kernel[(num_sms,)]( 

523 x_desc, 

524 y_desc, 

525 xs, 

526 z_desc, 

527 ys, 

528 xs_sB=xs_sB, 

529 xs_sM=xs_sM, 

530 xs_sKb=xs_sKb, 

531 B=B, 

532 M=M, 

533 M_aligned=M_aligned, 

534 N=N, 

535 K=K, 

536 BLOCK_M=BLOCK_M, 

537 BLOCK_N=BLOCK_N, 

538 BLOCK_K=BLOCK_K, 

539 SWAP_AB=SWAP_AB, 

540 num_sms=num_sms, 

541 ) 

542 

543 return z