Coverage for src/flag_gems/runtime/backend/_nvidia/hopper/ops/fp8_einsum.py: 0%

11 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1from typing import List 

2 

3import torch 

4 

5from .w8a8_block_fp8_bmm import w8a8_block_fp8_bmm 

6 

7 

8def fp8_einsum( 

9 equation: str, 

10 x: torch.Tensor, 

11 xs: torch.Tensor, 

12 y: torch.Tensor, 

13 ys: torch.Tensor, 

14 block_size: List[int] = [128, 128], 

15 output_dtype: torch.dtype = torch.bfloat16, 

16) -> torch.Tensor: 

17 """Block-wise FP8 einsum, mirroring deep_gemm.fp8_einsum. 

18 

19 Only the ``"bhr,hdr->bhd"`` contraction is supported: h is the batch 

20 dimension and the per-head op is ``out[b,h,d] = sum_r x[b,h,r] * y[h,d,r]``. 

21 

22 Args: 

23 equation: must be ``"bhr,hdr->bhd"``. 

24 x: (b, h, r) FP8 data. 

25 xs: (b, h, r // block_k) FP32 per-token scales. 

26 y: (h, d, r) FP8 data. 

27 ys: (h, d // block_n, r // block_k) FP32 per-block scales. 

28 block_size: [block_n, block_k] of the FP8 scaling grid. 

29 output_dtype: dtype of the freshly allocated output. 

30 

31 Returns: 

32 z: a newly allocated (b, h, d) tensor with the result. 

33 """ 

34 assert ( 

35 equation == "bhr,hdr->bhd" 

36 ), f"fp8_einsum only supports 'bhr,hdr->bhd', got {equation!r}" 

37 b, h, r = x.shape 

38 h2, d, r2 = y.shape 

39 assert h2 == h and r2 == r, f"x {tuple(x.shape)} / y {tuple(y.shape)} mismatch" 

40 

41 z = torch.empty((b, h, d), device=x.device, dtype=output_dtype) 

42 

43 # h is the batch dim → BMM layout (B=h, M=b, N=d, K=r). The permutes are 

44 # pure views (last dim r stays contiguous); the kernel handles xs's strides. 

45 w8a8_block_fp8_bmm( 

46 x.permute(1, 0, 2), # (h, b, r) 

47 y, # (h, d, r) 

48 xs.permute(1, 0, 2), # (h, b, r // block_k) 

49 ys, # (h, d // block_n, r // block_k) 

50 block_size=block_size, 

51 z=z.permute(1, 0, 2), # (h, b, d) view into the (b, h, d) output 

52 output_dtype=output_dtype, 

53 ) 

54 return z