Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/ops/fp8_einsum.py: 0%
11 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1from typing import List
3import torch
5from .w8a8_block_fp8_bmm import w8a8_block_fp8_bmm
8def fp8_einsum(
9 equation: str,
10 x: torch.Tensor,
11 xs: torch.Tensor,
12 y: torch.Tensor,
13 ys: torch.Tensor,
14 block_size: List[int] = [128, 128],
15 output_dtype: torch.dtype = torch.bfloat16,
16) -> torch.Tensor:
17 """Block-wise FP8 einsum, mirroring deep_gemm.fp8_einsum.
19 Only the ``"bhr,hdr->bhd"`` contraction is supported: h is the batch
20 dimension and the per-head op is ``out[b,h,d] = sum_r x[b,h,r] * y[h,d,r]``.
22 Args:
23 equation: must be ``"bhr,hdr->bhd"``.
24 x: (b, h, r) FP8 data.
25 xs: (b, h, r // block_k) FP32 per-token scales.
26 y: (h, d, r) FP8 data.
27 ys: (h, d // block_n, r // block_k) FP32 per-block scales.
28 block_size: [block_n, block_k] of the FP8 scaling grid.
29 output_dtype: dtype of the freshly allocated output.
31 Returns:
32 z: a newly allocated (b, h, d) tensor with the result.
33 """
34 assert (
35 equation == "bhr,hdr->bhd"
36 ), f"fp8_einsum only supports 'bhr,hdr->bhd', got {equation!r}"
37 b, h, r = x.shape
38 h2, d, r2 = y.shape
39 assert h2 == h and r2 == r, f"x {tuple(x.shape)} / y {tuple(y.shape)} mismatch"
41 z = torch.empty((b, h, d), device=x.device, dtype=output_dtype)
43 # h is the batch dim → BMM layout (B=h, M=b, N=d, K=r). The permutes are
44 # pure views (last dim r stays contiguous); the kernel handles xs's strides.
45 w8a8_block_fp8_bmm(
46 x.permute(1, 0, 2), # (h, b, r)
47 y, # (h, d, r)
48 xs.permute(1, 0, 2), # (h, b, r // block_k)
49 ys, # (h, d // block_n, r // block_k)
50 block_size=block_size,
51 z=z.permute(1, 0, 2), # (h, b, d) view into the (b, h, d) output
52 output_dtype=output_dtype,
53 )
54 return z