Coverage for src/flag_gems/ops/fp8_paged_mqa_logits.py: 9%
125 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import torch
2import triton
3import triton.language as tl
5from flag_gems import runtime
8def cdiv(x: int, y: int) -> int:
9 return (x + y - 1) // y
12@triton.autotune(
13 configs=runtime.get_tuned_config("fp8_paged_mqa_logits"),
14 key=["heads", "dim", "block_size"],
15)
16@triton.jit
17def fp8_paged_mqa_logits_kernel(
18 q_ptr,
19 kv_ptr,
20 weights_ptr,
21 logits_ptr,
22 block_tables_ptr,
23 context_lens_ptr,
24 stride_qb,
25 stride_qn,
26 stride_qh,
27 stride_qd,
28 stride_kvblk,
29 stride_kvpos,
30 stride_kvone,
31 stride_kvbyte,
32 stride_wrow,
33 stride_wh,
34 stride_lrow,
35 stride_lcol,
36 stride_btb,
37 stride_bts,
38 next_n: tl.constexpr,
39 heads: tl.constexpr,
40 dim: tl.constexpr,
41 block_size: tl.constexpr,
42 max_model_len,
43 dim_plus_4: tl.constexpr,
44 BLOCK_KV: tl.constexpr,
45 BLOCK_D: tl.constexpr,
46 NUM_D_TILES: tl.constexpr,
47 BLOCK_H: tl.constexpr,
48):
49 pid_row = tl.program_id(0)
50 pid_kv_tile = tl.program_id(1)
52 batch_idx = pid_row // next_n
53 next_n_idx = pid_row % next_n
55 context_len = tl.load(context_lens_ptr + batch_idx)
56 query_seq_pos = context_len - next_n + next_n_idx
58 kv_start = pid_kv_tile * BLOCK_KV
59 if kv_start >= context_len:
60 offs_kv = tl.arange(0, BLOCK_KV)
61 kv_pos = kv_start + offs_kv
62 out_mask = kv_pos < max_model_len
63 out_ptrs = logits_ptr + pid_row * stride_lrow + kv_pos * stride_lcol
64 tl.store(out_ptrs, float("-inf"), mask=out_mask)
65 return
67 offs_kv = tl.arange(0, BLOCK_KV)
68 kv_global_pos = kv_start + offs_kv
70 context_mask = kv_global_pos < context_len
71 causal_mask = kv_global_pos <= query_seq_pos
72 valid_mask = context_mask & causal_mask
74 phys_block_idx = kv_global_pos // block_size
75 intra_block_pos = kv_global_pos % block_size
77 phys_block_ids = tl.load(
78 block_tables_ptr + batch_idx * stride_btb + phys_block_idx * stride_bts,
79 mask=valid_mask,
80 other=0,
81 )
83 kv_base = phys_block_ids * stride_kvblk + intra_block_pos * stride_kvpos
85 scale_addr = kv_base + dim * stride_kvbyte
86 scale_ptr = (kv_ptr + scale_addr).to(tl.pointer_type(tl.uint32, 1), bitcast=True)
87 scale_u32 = tl.load(scale_ptr, mask=valid_mask, other=0)
88 scale_f32 = scale_u32.to(tl.float32, bitcast=True)
90 logit_accum = tl.zeros([BLOCK_KV], dtype=tl.float32)
91 offs_d = tl.arange(0, BLOCK_D)
92 q_base = q_ptr + batch_idx * stride_qb + next_n_idx * stride_qn
94 if NUM_D_TILES == 1:
95 d_mask = offs_d < dim
97 kv_byte_ptrs = kv_ptr + kv_base[:, None] + offs_d[None, :] * stride_kvbyte
98 load_mask = valid_mask[:, None] & d_mask[None, :]
99 kv_u8 = tl.load(kv_byte_ptrs, mask=load_mask, other=0)
100 kv_fp8 = kv_u8.to(tl.float8e4nv, bitcast=True)
101 kv_f32 = kv_fp8.to(tl.float32)
103 for h_tile in tl.static_range(0, heads, BLOCK_H):
104 offs_h = h_tile + tl.arange(0, BLOCK_H)
105 h_mask = offs_h < heads
107 q_ptrs = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd
108 q_vals = tl.load(
109 q_ptrs, mask=h_mask[:, None] & d_mask[None, :], other=0.0
110 ).to(tl.float32)
111 weights = tl.load(
112 weights_ptr + pid_row * stride_wrow + offs_h * stride_wh,
113 mask=h_mask,
114 other=0.0,
115 )
117 q_tile = tl.trans(q_vals)
118 partial_dot = tl.dot(kv_f32, q_tile, out_dtype=tl.float32)
119 partial_dot = partial_dot * scale_f32[:, None]
120 partial_dot = tl.maximum(partial_dot, 0.0)
121 logit_accum += tl.sum(partial_dot * weights[None, :], axis=1)
123 else:
124 d_offs0 = offs_d
125 d_mask0 = d_offs0 < dim
126 d_offs1 = BLOCK_D + offs_d
127 d_mask1 = d_offs1 < dim
129 kv_byte_ptrs0 = kv_ptr + kv_base[:, None] + d_offs0[None, :] * stride_kvbyte
130 load_mask0 = valid_mask[:, None] & d_mask0[None, :]
131 kv_u80 = tl.load(kv_byte_ptrs0, mask=load_mask0, other=0)
132 kv_fp80 = kv_u80.to(tl.float8e4nv, bitcast=True)
133 kv_f320 = kv_fp80.to(tl.float32)
135 kv_byte_ptrs1 = kv_ptr + kv_base[:, None] + d_offs1[None, :] * stride_kvbyte
136 load_mask1 = valid_mask[:, None] & d_mask1[None, :]
137 kv_u81 = tl.load(kv_byte_ptrs1, mask=load_mask1, other=0)
138 kv_fp81 = kv_u81.to(tl.float8e4nv, bitcast=True)
139 kv_f321 = kv_fp81.to(tl.float32)
141 for h_tile in tl.static_range(0, heads, BLOCK_H):
142 offs_h = h_tile + tl.arange(0, BLOCK_H)
143 h_mask = offs_h < heads
145 q_ptrs0 = (
146 q_base + offs_h[:, None] * stride_qh + d_offs0[None, :] * stride_qd
147 )
148 q_vals0 = tl.load(
149 q_ptrs0, mask=h_mask[:, None] & d_mask0[None, :], other=0.0
150 ).to(tl.float32)
152 q_ptrs1 = (
153 q_base + offs_h[:, None] * stride_qh + d_offs1[None, :] * stride_qd
154 )
155 q_vals1 = tl.load(
156 q_ptrs1, mask=h_mask[:, None] & d_mask1[None, :], other=0.0
157 ).to(tl.float32)
159 weights = tl.load(
160 weights_ptr + pid_row * stride_wrow + offs_h * stride_wh,
161 mask=h_mask,
162 other=0.0,
163 )
165 q_T0 = tl.trans(q_vals0)
166 q_T1 = tl.trans(q_vals1)
168 partial_dot = tl.dot(kv_f320, q_T0, out_dtype=tl.float32)
169 partial_dot = tl.dot(kv_f321, q_T1, acc=partial_dot, out_dtype=tl.float32)
171 partial_dot = partial_dot * scale_f32[:, None]
172 partial_dot = tl.maximum(partial_dot, 0.0)
173 logit_accum += tl.sum(partial_dot * weights[None, :], axis=1)
175 out_vals = tl.where(valid_mask, logit_accum, float("-inf"))
176 out_ptrs = logits_ptr + pid_row * stride_lrow + kv_global_pos * stride_lcol
177 out_mask = valid_mask & (kv_global_pos < max_model_len)
178 tl.store(out_ptrs, out_vals, mask=out_mask)
181@triton.jit
182def fill_neg_inf_kernel(
183 out_ptr,
184 n_elements,
185 BLOCK: tl.constexpr,
186):
187 pid = tl.program_id(0)
188 offs = pid * BLOCK + tl.arange(0, BLOCK)
189 mask = offs < n_elements
190 tl.store(out_ptr + offs, float("-inf"), mask=mask)
193def fp8_paged_mqa_logits(
194 q: torch.Tensor,
195 kv_cache: torch.Tensor,
196 weights: torch.Tensor,
197 context_lens: torch.Tensor,
198 block_tables: torch.Tensor,
199 max_model_len: int,
200) -> torch.Tensor:
201 assert q.is_cuda and kv_cache.is_cuda and weights.is_cuda
202 assert context_lens.is_cuda and block_tables.is_cuda
204 batch_size, next_n, heads, dim = q.size()
205 num_blocks, block_size, one, dim_plus_4 = kv_cache.size()
207 assert one == 1
208 assert dim_plus_4 == dim + 4
209 assert weights.shape == (batch_size * next_n, heads)
210 assert kv_cache.dtype == torch.uint8
211 assert context_lens.dtype == torch.int32
212 assert block_tables.dtype == torch.int32
214 q_contig = q.contiguous()
215 kv_contig = kv_cache.contiguous()
216 weights_contig = weights.contiguous()
217 context_lens_contig = context_lens.contiguous()
218 block_tables_contig = block_tables.contiguous()
220 total_rows = batch_size * next_n
222 logits = torch.empty(
223 (total_rows, max_model_len),
224 device=q.device,
225 dtype=torch.float32,
226 )
227 n_elements = total_rows * max_model_len
228 FILL_BLOCK = 1024
229 fill_grid = (cdiv(n_elements, FILL_BLOCK),)
230 fill_neg_inf_kernel[fill_grid](logits, n_elements, BLOCK=FILL_BLOCK)
232 max_context = block_tables_contig.shape[1] * block_size
234 def grid(meta):
235 BLOCK_KV = meta["BLOCK_KV"]
236 num_kv_tiles = cdiv(max_context, BLOCK_KV)
237 return (total_rows, num_kv_tiles)
239 fp8_paged_mqa_logits_kernel[grid](
240 q_contig,
241 kv_contig,
242 weights_contig,
243 logits,
244 block_tables_contig,
245 context_lens_contig,
246 q_contig.stride(0),
247 q_contig.stride(1),
248 q_contig.stride(2),
249 q_contig.stride(3),
250 kv_contig.stride(0),
251 kv_contig.stride(1),
252 kv_contig.stride(2),
253 kv_contig.stride(3),
254 weights_contig.stride(0),
255 weights_contig.stride(1),
256 logits.stride(0),
257 logits.stride(1),
258 block_tables_contig.stride(0),
259 block_tables_contig.stride(1),
260 next_n,
261 heads,
262 dim,
263 block_size,
264 max_model_len,
265 dim_plus_4,
266 )
268 return logits