Coverage for src/flag_gems/ops/fp8_mqa_logits.py: 22%
54 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.utils import libentry, libtuner
10logger = logging.getLogger(__name__)
13@libentry()
14@libtuner(
15 configs=runtime.get_tuned_config("fp8_mqa_logits"),
16 key=["M", "N", "D"],
17)
18@triton.jit
19def _fp8_mqa_logits_kernel(
20 Q,
21 K,
22 K_SCALES,
23 WEIGHTS,
24 CU_SEQLEN_KS,
25 CU_SEQLEN_KE,
26 LOGITS,
27 stride_qm,
28 stride_qh,
29 stride_qd,
30 stride_kn,
31 stride_kd,
32 M: tl.constexpr,
33 H: tl.constexpr,
34 D: tl.constexpr,
35 N: tl.constexpr,
36 CLEAN_LOGITS: tl.constexpr,
37 BLOCK_M: tl.constexpr,
38 BLOCK_N: tl.constexpr,
39 BLOCK_D: tl.constexpr,
40):
41 """
42 Optimized Triton kernel for FP8 MQA logits computation.
44 Each program computes logits[m, n] = sum_h(ReLU(score[m, h, n]) * weights[m, h])
45 where score[m, h, n] = sum_d(q[m, h, d] * k[n, d])
47 Optimization: Each program handles a BLOCK_M x BLOCK_N tile.
48 K is loaded once and reused across H dimension.
49 """
50 pid_m = tl.program_id(0)
51 pid_n = tl.program_id(1)
53 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
54 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
55 offs_d = tl.arange(0, BLOCK_D)
57 mask_m = offs_m < M
58 mask_n = offs_n < N
60 ks_start = tl.load(CU_SEQLEN_KS + offs_m, mask=mask_m, other=0)
61 ke_end = tl.load(CU_SEQLEN_KE + offs_m, mask=mask_m, other=N)
63 k_scales = tl.load(K_SCALES + offs_n, mask=mask_n, other=1.0)
64 k_scales = k_scales.to(tl.float32)
66 acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
68 for h_idx in range(H):
69 weight_ptrs = WEIGHTS + offs_m * H + h_idx
70 weight_h = tl.load(weight_ptrs, mask=mask_m, other=0.0)
71 weight_h = weight_h.to(tl.float32)
73 score_h = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
75 for d_start in range(0, D, BLOCK_D):
76 d_mask = (d_start + offs_d) < D
77 d_offs = d_start + offs_d
79 q_ptrs = (
80 Q
81 + offs_m[:, None] * stride_qm
82 + h_idx * stride_qh
83 + d_offs[None, :] * stride_qd
84 )
85 q = tl.load(q_ptrs, mask=mask_m[:, None] & d_mask[None, :], other=0.0)
86 q = q.to(tl.float32)
88 k_ptrs = K + offs_n[:, None] * stride_kn + d_offs[None, :] * stride_kd
89 k = tl.load(k_ptrs, mask=mask_n[:, None] & d_mask[None, :], other=0.0)
90 k = k.to(tl.float32) * k_scales[:, None]
92 score_h += tl.dot(q, tl.trans(k))
94 score_h = tl.maximum(score_h, 0.0)
95 acc += score_h * weight_h[:, None]
97 if CLEAN_LOGITS:
98 n_valid = (offs_n[None, :] >= ks_start[:, None]) & (
99 offs_n[None, :] < ke_end[:, None]
100 )
101 acc = tl.where(n_valid, acc, float("-inf"))
103 out_ptrs = LOGITS + offs_m[:, None] * N + offs_n[None, :]
104 tl.store(out_ptrs, acc, mask=mask_m[:, None] & mask_n[None, :])
107def fp8_mqa_logits(
108 q: torch.Tensor,
109 kv: tuple[torch.Tensor, torch.Tensor],
110 weights: torch.Tensor,
111 cu_seqlen_ks: torch.Tensor,
112 cu_seqlen_ke: torch.Tensor,
113 clean_logits: bool,
114) -> torch.Tensor:
115 logger.debug("GEMS FP8_MQA_LOGITS")
117 k_fp8, k_scales = kv
119 M, H, D = q.shape
120 N = k_fp8.shape[0]
122 logits = torch.zeros((M, N), dtype=torch.float32, device=q.device)
124 grid = lambda META: (
125 triton.cdiv(M, META["BLOCK_M"]),
126 triton.cdiv(N, META["BLOCK_N"]),
127 )
129 _fp8_mqa_logits_kernel[grid](
130 q,
131 k_fp8,
132 k_scales,
133 weights,
134 cu_seqlen_ks,
135 cu_seqlen_ke,
136 logits,
137 q.stride(0), # stride_qm
138 q.stride(1), # stride_qh
139 q.stride(2), # stride_qd
140 k_fp8.stride(0), # stride_kn
141 k_fp8.stride(1), # stride_kd
142 M,
143 H,
144 D,
145 N,
146 clean_logits,
147 )
149 return logits