Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/ops/w8a8_block_fp8_bmm.py: 0%
276 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1from typing import List, Optional
3import torch
4import triton
5from triton.experimental import gluon
6from triton.experimental.gluon import language as gl
7from triton.experimental.gluon.language.nvidia.hopper import (
8 fence_async_shared,
9 mbarrier,
10 tma,
11 warpgroup_mma,
12 warpgroup_mma_wait,
13)
14from triton.experimental.gluon.nvidia.hopper import TensorDescriptor
15from triton.language.core import _aggregate as aggregate
17_TORCH_TO_GL_DTYPE = {
18 torch.float8_e4m3fn: gl.float8e4nv,
19 torch.float8_e5m2: gl.float8e5,
20 torch.bfloat16: gl.bfloat16,
21 torch.float16: gl.float16,
22 torch.float32: gl.float32,
23}
26def _gl_dtype(t: torch.Tensor):
27 try:
28 return _TORCH_TO_GL_DTYPE[t.dtype]
29 except KeyError as e:
30 raise TypeError(f"Unsupported tensor dtype: {t.dtype}") from e
33@gluon.constexpr_function
34def get_warps_per_cta(BLOCK_M, BLOCK_N, num_warps):
35 warps_per_cta = [4, 1]
36 m = 16
37 while warps_per_cta[0] * warps_per_cta[1] != num_warps:
38 if BLOCK_M > m * warps_per_cta[0]:
39 warps_per_cta[0] *= 2
40 else:
41 warps_per_cta[1] *= 2
42 return warps_per_cta
45@gluon.constexpr_function
46def get_instr_shape_n(BLOCK_M, BLOCK_N, num_warps):
47 m = 16
48 m_reps = triton.cdiv(BLOCK_M, m)
49 n_reps = triton.cdiv(num_warps, m_reps)
50 max_n = max(BLOCK_N // n_reps, 8)
51 n = 256
52 while n > max_n or BLOCK_N % n != 0:
53 n -= 8
54 assert n >= 8, "expected to find a valid n"
55 return n
58@gluon.constexpr_function
59def pick_wgmma_layout(dtype, BLOCK_M, BLOCK_N, num_warps):
60 m = 16
61 k = 256 // dtype.primitive_bitwidth
62 n = get_instr_shape_n(BLOCK_M, BLOCK_N, num_warps)
63 warps_per_cta = get_warps_per_cta(BLOCK_M, BLOCK_N, num_warps)
64 return gl.NVMMADistributedLayout(
65 version=[3, 0],
66 warps_per_cta=warps_per_cta,
67 instr_shape=[m, n, k],
68 )
71@aggregate
72class Config:
73 B: gl.constexpr
74 M: gl.constexpr
75 M_aligned: gl.constexpr
76 N: gl.constexpr
77 K: gl.constexpr
78 BLOCK_M: gl.constexpr
79 BLOCK_N: gl.constexpr
80 BLOCK_K: gl.constexpr
81 TILE_ORDER: gl.constexpr
82 SWAP_AB: gl.constexpr
83 num_warps: gl.constexpr
84 num_stages: gl.constexpr
85 num_sms: gl.constexpr
86 # xs (per-token scale) strides into the caller's [B, M, num_kb] tensor.
87 xs_sB: gl.constexpr
88 xs_sM: gl.constexpr
89 xs_sKb: gl.constexpr
90 # Derived: tile counts.
91 num_m_tiles: gl.constexpr
92 num_n_tiles: gl.constexpr
93 num_k_blocks: gl.constexpr
94 num_tiles_per_batch: gl.constexpr
95 num_tiles: gl.constexpr
97 @gluon.constexpr_function
98 def __init__(
99 self,
100 B,
101 M,
102 M_aligned,
103 N,
104 K,
105 BLOCK_M,
106 BLOCK_N,
107 BLOCK_K,
108 TILE_ORDER,
109 SWAP_AB,
110 num_warps,
111 num_stages,
112 num_sms,
113 xs_sB,
114 xs_sM,
115 xs_sKb,
116 ):
117 self.B = gl.constexpr(B)
118 self.M = gl.constexpr(M)
119 self.M_aligned = gl.constexpr(M_aligned)
120 self.N = gl.constexpr(N)
121 self.K = gl.constexpr(K)
122 self.BLOCK_M = gl.constexpr(BLOCK_M)
123 self.BLOCK_N = gl.constexpr(BLOCK_N)
124 self.BLOCK_K = gl.constexpr(BLOCK_K)
125 self.TILE_ORDER = gl.constexpr(TILE_ORDER)
126 self.SWAP_AB = gl.constexpr(SWAP_AB)
127 self.num_warps = gl.constexpr(num_warps)
128 self.num_stages = gl.constexpr(num_stages)
129 self.num_sms = gl.constexpr(num_sms)
130 self.xs_sB = gl.constexpr(xs_sB)
131 self.xs_sM = gl.constexpr(xs_sM)
132 self.xs_sKb = gl.constexpr(xs_sKb)
133 num_m = M_aligned // BLOCK_M
134 num_n = N // BLOCK_N
135 self.num_m_tiles = gl.constexpr(num_m)
136 self.num_n_tiles = gl.constexpr(num_n)
137 self.num_k_blocks = gl.constexpr(K // BLOCK_K)
138 self.num_tiles_per_batch = gl.constexpr(num_m * num_n)
139 self.num_tiles = gl.constexpr(B * num_m * num_n)
142@aggregate
143class BarrierCounter:
144 index: gl.tensor
145 phase: gl.tensor
146 num_barriers: gl.constexpr
148 @gluon.constexpr_function
149 def __init__(self, index, phase, num_barriers):
150 self.index = index
151 self.phase = phase
152 self.num_barriers = gl.constexpr(num_barriers)
154 @gluon.must_use_result
155 @gluon.jit
156 def increment(self):
157 if self.num_barriers == 1:
158 return BarrierCounter(gl.to_tensor(0), self.phase ^ 1, self.num_barriers)
159 next_index = self.index + 1
160 rollover = next_index == self.num_barriers
161 index = gl.where(rollover, 0, next_index)
162 phase = gl.where(rollover, self.phase ^ 1, self.phase)
163 return BarrierCounter(index, phase, self.num_barriers)
166@aggregate
167class Channel:
168 x_smem: gl.shared_memory_descriptor
169 y_smem: gl.shared_memory_descriptor
170 ready_bars: gl.shared_memory_descriptor
171 empty_bars: gl.shared_memory_descriptor
172 num_stages: gl.constexpr
174 @gluon.constexpr_function
175 def __init__(self, x_smem, y_smem, ready_bars, empty_bars, num_stages):
176 self.x_smem = x_smem
177 self.y_smem = y_smem
178 self.ready_bars = ready_bars
179 self.empty_bars = empty_bars
180 self.num_stages = gl.constexpr(num_stages)
182 @gluon.jit
183 def alloc(
184 BLOCK_M: gl.constexpr,
185 BLOCK_N: gl.constexpr,
186 BLOCK_K: gl.constexpr,
187 x_dtype: gl.constexpr,
188 x_layout: gl.constexpr,
189 y_dtype: gl.constexpr,
190 y_layout: gl.constexpr,
191 num_stages: gl.constexpr,
192 num_warps: gl.constexpr,
193 ):
194 # x: 3D box [1, BLOCK_M, BLOCK_K] (x is permuted/non-contig at the global level).
195 # y: 2D box. xs is loaded directly with gl.load (not staged through smem).
196 x_smem = gl.allocate_shared_memory(
197 x_dtype, [num_stages, 1, BLOCK_M, BLOCK_K], x_layout
198 )
199 y_smem = gl.allocate_shared_memory(
200 y_dtype, [num_stages, BLOCK_N, BLOCK_K], y_layout
201 )
202 ready_bars = gl.allocate_shared_memory(
203 gl.int64, [num_stages, 1], mbarrier.MBarrierLayout()
204 )
205 empty_bars = gl.allocate_shared_memory(
206 gl.int64, [num_stages, 1], mbarrier.MBarrierLayout()
207 )
208 for i in gl.static_range(num_stages):
209 mbarrier.init(ready_bars.index(i), count=1)
210 mbarrier.init(empty_bars.index(i), count=1)
211 mbarrier.arrive(empty_bars.index(i), count=1)
212 return Channel(x_smem, y_smem, ready_bars, empty_bars, num_stages)
214 @gluon.jit
215 def release(self):
216 self.x_smem._keep_alive()
217 self.y_smem._keep_alive()
218 for i in gl.static_range(self.num_stages):
219 mbarrier.invalidate(self.ready_bars.index(i))
220 mbarrier.invalidate(self.empty_bars.index(i))
223@gluon.jit
224def get_tile(tile_id, config):
225 # TILE_ORDER: 0 = horizontal (N fastest within batch — favours x reuse across N sweep)
226 # 1 = vertical (M fastest within batch — favours y reuse across M sweep)
227 batch_id = tile_id // config.num_tiles_per_batch
228 local_id = tile_id % config.num_tiles_per_batch
229 if config.TILE_ORDER == 0:
230 m_tile_id = local_id // config.num_n_tiles
231 n_tile_id = local_id % config.num_n_tiles
232 else:
233 n_tile_id = local_id // config.num_m_tiles
234 m_tile_id = local_id % config.num_m_tiles
235 return batch_id, m_tile_id, n_tile_id
238@gluon.jit
239def compute_partition(channel, config, tensors):
240 x_desc, y_desc, xs_ptr, z_desc, ys_ptr = tensors
241 start_pid = gl.program_id(0)
242 counter = BarrierCounter(
243 index=gl.to_tensor(0), phase=gl.to_tensor(0), num_barriers=config.num_stages
244 )
246 if config.SWAP_AB:
247 mma_layout: gl.constexpr = pick_wgmma_layout(
248 x_desc.dtype, config.BLOCK_N, config.BLOCK_M, num_warps=config.num_warps
249 )
250 xs_load_layout: gl.constexpr = gl.SliceLayout(0, mma_layout)
251 else:
252 mma_layout: gl.constexpr = pick_wgmma_layout(
253 x_desc.dtype, config.BLOCK_M, config.BLOCK_N, num_warps=config.num_warps
254 )
255 xs_load_layout: gl.constexpr = gl.SliceLayout(1, mma_layout)
257 z_smem_layout: gl.constexpr = gl.NVMMASharedLayout.get_default_for(
258 [1, config.BLOCK_M, config.BLOCK_N], z_desc.dtype
259 )
260 z_smem = gl.allocate_shared_memory(
261 z_desc.dtype, [1, config.BLOCK_M, config.BLOCK_N], z_smem_layout
262 )
264 # xs in-tile lane indices (one fp32 per token along BLOCK_M).
265 xs_lane = gl.arange(0, config.BLOCK_M, layout=xs_load_layout)
267 for tile_id in range(start_pid, config.num_tiles, config.num_sms):
268 batch_id, m_tile_id, n_tile_id = get_tile(tile_id, config)
269 m_start = m_tile_id * config.BLOCK_M
270 n_start = n_tile_id * config.BLOCK_N
271 # ys layout matches the scale grid (N/BLOCK_N, K/BLOCK_K); one scale per (n_tile, k_block).
272 ys_base = (batch_id * config.num_n_tiles + n_tile_id) * config.num_k_blocks
273 # xs is the caller's [B, M, num_kb] tensor (strided, possibly non-contig).
274 xs_m = m_start + xs_lane
275 xs_mask = xs_m < config.M
276 xs_row_base = batch_id * config.xs_sB + xs_m * config.xs_sM
278 if config.SWAP_AB:
279 partial_zero = gl.zeros(
280 (config.BLOCK_N, config.BLOCK_M), dtype=gl.float32, layout=mma_layout
281 )
282 acc = gl.zeros(
283 (config.BLOCK_N, config.BLOCK_M), dtype=gl.float32, layout=mma_layout
284 )
285 else:
286 partial_zero = gl.zeros(
287 (config.BLOCK_M, config.BLOCK_N), dtype=gl.float32, layout=mma_layout
288 )
289 acc = gl.zeros(
290 (config.BLOCK_M, config.BLOCK_N), dtype=gl.float32, layout=mma_layout
291 )
293 for k in range(0, config.K, config.BLOCK_K):
294 k_block_idx = k // config.BLOCK_K
295 index, phase = counter.index, counter.phase
296 x_slot = channel.x_smem.index(index) # [1, BLOCK_M, BLOCK_K]
297 y_slot = channel.y_smem.index(index) # [BLOCK_N, BLOCK_K]
298 ready_bar = channel.ready_bars.index(index)
299 empty_bar = channel.empty_bars.index(index)
300 mbarrier.wait(ready_bar, phase)
302 x = x_slot.reshape((config.BLOCK_M, config.BLOCK_K))
303 y = y_slot
305 x_s = gl.load(
306 xs_ptr + xs_row_base + k_block_idx * config.xs_sKb,
307 mask=xs_mask,
308 other=0.0,
309 )
310 y_s = gl.load(ys_ptr + ys_base + k_block_idx)
311 xy_s = x_s * y_s
313 if config.SWAP_AB:
314 x_t = x.permute((1, 0))
315 partial_async = warpgroup_mma(
316 y, x_t, partial_zero, use_acc=False, is_async=True
317 )
318 partial = warpgroup_mma_wait(num_outstanding=0, deps=(partial_async,))
319 acc = acc + partial * xy_s[None, :]
320 else:
321 y_t = y.permute((1, 0))
322 partial_async = warpgroup_mma(
323 x, y_t, partial_zero, use_acc=False, is_async=True
324 )
325 partial = warpgroup_mma_wait(num_outstanding=0, deps=(partial_async,))
326 acc = acc + partial * xy_s[:, None]
328 mbarrier.arrive(empty_bar)
329 counter = counter.increment()
331 acc_out = acc.to(z_desc.dtype)
332 if config.SWAP_AB:
333 acc_out = acc_out.permute((1, 0))
334 tma.store_wait(pendings=0)
335 z_smem.reshape((config.BLOCK_M, config.BLOCK_N)).store(acc_out)
336 fence_async_shared()
337 tma.async_copy_shared_to_global(z_desc, [batch_id, m_start, n_start], z_smem)
339 tma.store_wait(pendings=0)
342@gluon.jit
343def load_partition(channel, config, tensors):
344 x_desc, y_desc, xs_ptr, z_desc, ys_ptr = tensors
345 start_pid = gl.program_id(0)
346 counter = BarrierCounter(
347 index=gl.to_tensor(0), phase=gl.to_tensor(0), num_barriers=config.num_stages
348 )
350 nbytes: gl.constexpr = (
351 config.BLOCK_M * config.BLOCK_K + config.BLOCK_N * config.BLOCK_K
352 )
354 for tile_id in range(start_pid, config.num_tiles, config.num_sms):
355 batch_id, m_tile_id, n_tile_id = get_tile(tile_id, config)
356 m_start = m_tile_id * config.BLOCK_M
357 n_start = n_tile_id * config.BLOCK_N
359 y_row = batch_id * config.N + n_start
361 for k in range(0, config.K, config.BLOCK_K):
362 index, phase = counter.index, counter.phase
363 x_slot = channel.x_smem.index(index)
364 y_slot = channel.y_smem.index(index)
365 ready_bar = channel.ready_bars.index(index)
366 empty_bar = channel.empty_bars.index(index)
367 mbarrier.wait(empty_bar, phase)
369 mbarrier.expect(ready_bar, nbytes)
370 tma.async_copy_global_to_shared(
371 x_desc, [batch_id, m_start, k], ready_bar, x_slot
372 )
373 tma.async_copy_global_to_shared(y_desc, [y_row, k], ready_bar, y_slot)
375 counter = counter.increment()
378@triton.autotune(
379 configs=[
380 triton.Config({"TILE_ORDER": tile_order}, num_warps=nw, num_stages=ns)
381 for nw in (4, 8)
382 for ns in (4, 6, 8)
383 for tile_order in (0, 1) # 0=horizontal (n fastest), 1=vertical (m fastest)
384 ],
385 key=["B", "M_aligned", "N", "K"],
386)
387@gluon.jit
388def w8a8_block_fp8_bmm_kernel(
389 x_desc,
390 y_desc,
391 xs_ptr,
392 z_desc,
393 ys_ptr,
394 xs_sB: gl.constexpr,
395 xs_sM: gl.constexpr,
396 xs_sKb: gl.constexpr,
397 B: gl.constexpr,
398 M: gl.constexpr,
399 M_aligned: gl.constexpr,
400 N: gl.constexpr,
401 K: gl.constexpr,
402 BLOCK_M: gl.constexpr,
403 BLOCK_N: gl.constexpr,
404 BLOCK_K: gl.constexpr,
405 TILE_ORDER: gl.constexpr,
406 SWAP_AB: gl.constexpr,
407 num_warps: gl.constexpr,
408 num_stages: gl.constexpr,
409 num_sms: gl.constexpr,
410):
411 config = Config(
412 B=B,
413 M=M,
414 M_aligned=M_aligned,
415 N=N,
416 K=K,
417 BLOCK_M=BLOCK_M,
418 BLOCK_N=BLOCK_N,
419 BLOCK_K=BLOCK_K,
420 TILE_ORDER=TILE_ORDER,
421 SWAP_AB=SWAP_AB,
422 num_warps=num_warps,
423 num_stages=num_stages,
424 num_sms=num_sms,
425 xs_sB=xs_sB,
426 xs_sM=xs_sM,
427 xs_sKb=xs_sKb,
428 )
429 tensors = (x_desc, y_desc, xs_ptr, z_desc, ys_ptr)
430 channel = Channel.alloc(
431 BLOCK_M=BLOCK_M,
432 BLOCK_N=BLOCK_N,
433 BLOCK_K=BLOCK_K,
434 x_dtype=x_desc.dtype,
435 x_layout=gl.constexpr(x_desc.layout),
436 y_dtype=y_desc.dtype,
437 y_layout=gl.constexpr(y_desc.layout),
438 num_stages=num_stages,
439 num_warps=num_warps,
440 )
442 gl.warp_specialize(
443 [
444 (compute_partition, (channel, config, tensors)),
445 (load_partition, (channel, config, tensors)),
446 ],
447 [1],
448 [24],
449 )
451 channel.release()
454def w8a8_block_fp8_bmm(
455 x: torch.Tensor,
456 y: torch.Tensor,
457 xs: torch.Tensor,
458 ys: torch.Tensor,
459 block_size: List[int] = [128, 128],
460 z: Optional[torch.Tensor] = None,
461 output_dtype: torch.dtype = torch.bfloat16,
462):
463 # x: [B, M, K] fp8
464 # y: [B, N, K] fp8
465 # xs: [B, M, K // block_k] f32
466 # ys: [B, N // block_n, K // block_k] f32
467 # z: [B, M, N] out_dtype
468 assert len(block_size) == 2
469 BLOCK_N, BLOCK_K = block_size
470 assert (
471 BLOCK_N == 128 and BLOCK_K == 128
472 ), "this kernel assumes 128x128 block-wise FP8 scales"
474 assert x.ndim == 3 and y.ndim == 3 and xs.ndim == 3 and ys.ndim == 3
475 assert x.shape[0] == y.shape[0] == xs.shape[0] == ys.shape[0]
476 assert x.shape[-1] == y.shape[-1]
477 assert x.shape[:-1] == xs.shape[:-1]
478 assert x.stride(-1) == 1 and y.stride(-1) == 1
480 device = x.device
481 B, M, K = x.shape
482 _, N, _ = y.shape
483 assert K % BLOCK_K == 0 and N % BLOCK_N == 0
484 num_kb = K // BLOCK_K
486 if z is None:
487 z = torch.empty((B, M, N), device=device, dtype=output_dtype)
488 else:
489 assert z.shape == (B, M, N) and z.device == device and z.dtype == output_dtype
490 assert z.stride(-1) == 1
492 BLOCK_M = max(8, min(64, 1 << ((M - 1).bit_length())))
493 SWAP_AB = 1 if BLOCK_M < 64 else 0
495 M_aligned = triton.cdiv(M, BLOCK_M) * BLOCK_M
497 x_gl_dtype = _gl_dtype(x)
498 y_gl_dtype = _gl_dtype(y)
499 z_gl_dtype = _gl_dtype(z)
501 x_layout = gl.NVMMASharedLayout.get_default_for([1, BLOCK_M, BLOCK_K], x_gl_dtype)
502 x_desc = TensorDescriptor.from_tensor(
503 x, block_shape=[1, BLOCK_M, BLOCK_K], layout=x_layout
504 )
506 assert y.is_contiguous(), "y must be contiguous so it can be viewed as (B*N, K)"
507 y_flat = y.view(B * N, K)
508 y_layout = gl.NVMMASharedLayout.get_default_for([BLOCK_N, BLOCK_K], y_gl_dtype)
509 y_desc = TensorDescriptor.from_tensor(
510 y_flat, block_shape=[BLOCK_N, BLOCK_K], layout=y_layout
511 )
513 assert xs.ndim == 3 and xs.shape == (B, M, num_kb)
514 xs_sB, xs_sM, xs_sKb = xs.stride()
516 z_layout = gl.NVMMASharedLayout.get_default_for([1, BLOCK_M, BLOCK_N], z_gl_dtype)
517 z_desc = TensorDescriptor.from_tensor(
518 z, block_shape=[1, BLOCK_M, BLOCK_N], layout=z_layout
519 )
521 num_sms = torch.cuda.get_device_properties(device).multi_processor_count
522 w8a8_block_fp8_bmm_kernel[(num_sms,)](
523 x_desc,
524 y_desc,
525 xs,
526 z_desc,
527 ys,
528 xs_sB=xs_sB,
529 xs_sM=xs_sM,
530 xs_sKb=xs_sKb,
531 B=B,
532 M=M,
533 M_aligned=M_aligned,
534 N=N,
535 K=K,
536 BLOCK_M=BLOCK_M,
537 BLOCK_N=BLOCK_N,
538 BLOCK_K=BLOCK_K,
539 SWAP_AB=SWAP_AB,
540 num_sms=num_sms,
541 )
543 return z