Coverage for src/flag_gems/ops/fp8_paged_mqa_logits.py: 9%

125 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5from flag_gems import runtime 

6 

7 

8def cdiv(x: int, y: int) -> int: 

9 return (x + y - 1) // y 

10 

11 

12@triton.autotune( 

13 configs=runtime.get_tuned_config("fp8_paged_mqa_logits"), 

14 key=["heads", "dim", "block_size"], 

15) 

16@triton.jit 

17def fp8_paged_mqa_logits_kernel( 

18 q_ptr, 

19 kv_ptr, 

20 weights_ptr, 

21 logits_ptr, 

22 block_tables_ptr, 

23 context_lens_ptr, 

24 stride_qb, 

25 stride_qn, 

26 stride_qh, 

27 stride_qd, 

28 stride_kvblk, 

29 stride_kvpos, 

30 stride_kvone, 

31 stride_kvbyte, 

32 stride_wrow, 

33 stride_wh, 

34 stride_lrow, 

35 stride_lcol, 

36 stride_btb, 

37 stride_bts, 

38 next_n: tl.constexpr, 

39 heads: tl.constexpr, 

40 dim: tl.constexpr, 

41 block_size: tl.constexpr, 

42 max_model_len, 

43 dim_plus_4: tl.constexpr, 

44 BLOCK_KV: tl.constexpr, 

45 BLOCK_D: tl.constexpr, 

46 NUM_D_TILES: tl.constexpr, 

47 BLOCK_H: tl.constexpr, 

48): 

49 pid_row = tl.program_id(0) 

50 pid_kv_tile = tl.program_id(1) 

51 

52 batch_idx = pid_row // next_n 

53 next_n_idx = pid_row % next_n 

54 

55 context_len = tl.load(context_lens_ptr + batch_idx) 

56 query_seq_pos = context_len - next_n + next_n_idx 

57 

58 kv_start = pid_kv_tile * BLOCK_KV 

59 if kv_start >= context_len: 

60 offs_kv = tl.arange(0, BLOCK_KV) 

61 kv_pos = kv_start + offs_kv 

62 out_mask = kv_pos < max_model_len 

63 out_ptrs = logits_ptr + pid_row * stride_lrow + kv_pos * stride_lcol 

64 tl.store(out_ptrs, float("-inf"), mask=out_mask) 

65 return 

66 

67 offs_kv = tl.arange(0, BLOCK_KV) 

68 kv_global_pos = kv_start + offs_kv 

69 

70 context_mask = kv_global_pos < context_len 

71 causal_mask = kv_global_pos <= query_seq_pos 

72 valid_mask = context_mask & causal_mask 

73 

74 phys_block_idx = kv_global_pos // block_size 

75 intra_block_pos = kv_global_pos % block_size 

76 

77 phys_block_ids = tl.load( 

78 block_tables_ptr + batch_idx * stride_btb + phys_block_idx * stride_bts, 

79 mask=valid_mask, 

80 other=0, 

81 ) 

82 

83 kv_base = phys_block_ids * stride_kvblk + intra_block_pos * stride_kvpos 

84 

85 scale_addr = kv_base + dim * stride_kvbyte 

86 scale_ptr = (kv_ptr + scale_addr).to(tl.pointer_type(tl.uint32, 1), bitcast=True) 

87 scale_u32 = tl.load(scale_ptr, mask=valid_mask, other=0) 

88 scale_f32 = scale_u32.to(tl.float32, bitcast=True) 

89 

90 logit_accum = tl.zeros([BLOCK_KV], dtype=tl.float32) 

91 offs_d = tl.arange(0, BLOCK_D) 

92 q_base = q_ptr + batch_idx * stride_qb + next_n_idx * stride_qn 

93 

94 if NUM_D_TILES == 1: 

95 d_mask = offs_d < dim 

96 

97 kv_byte_ptrs = kv_ptr + kv_base[:, None] + offs_d[None, :] * stride_kvbyte 

98 load_mask = valid_mask[:, None] & d_mask[None, :] 

99 kv_u8 = tl.load(kv_byte_ptrs, mask=load_mask, other=0) 

100 kv_fp8 = kv_u8.to(tl.float8e4nv, bitcast=True) 

101 kv_f32 = kv_fp8.to(tl.float32) 

102 

103 for h_tile in tl.static_range(0, heads, BLOCK_H): 

104 offs_h = h_tile + tl.arange(0, BLOCK_H) 

105 h_mask = offs_h < heads 

106 

107 q_ptrs = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd 

108 q_vals = tl.load( 

109 q_ptrs, mask=h_mask[:, None] & d_mask[None, :], other=0.0 

110 ).to(tl.float32) 

111 weights = tl.load( 

112 weights_ptr + pid_row * stride_wrow + offs_h * stride_wh, 

113 mask=h_mask, 

114 other=0.0, 

115 ) 

116 

117 q_tile = tl.trans(q_vals) 

118 partial_dot = tl.dot(kv_f32, q_tile, out_dtype=tl.float32) 

119 partial_dot = partial_dot * scale_f32[:, None] 

120 partial_dot = tl.maximum(partial_dot, 0.0) 

121 logit_accum += tl.sum(partial_dot * weights[None, :], axis=1) 

122 

123 else: 

124 d_offs0 = offs_d 

125 d_mask0 = d_offs0 < dim 

126 d_offs1 = BLOCK_D + offs_d 

127 d_mask1 = d_offs1 < dim 

128 

129 kv_byte_ptrs0 = kv_ptr + kv_base[:, None] + d_offs0[None, :] * stride_kvbyte 

130 load_mask0 = valid_mask[:, None] & d_mask0[None, :] 

131 kv_u80 = tl.load(kv_byte_ptrs0, mask=load_mask0, other=0) 

132 kv_fp80 = kv_u80.to(tl.float8e4nv, bitcast=True) 

133 kv_f320 = kv_fp80.to(tl.float32) 

134 

135 kv_byte_ptrs1 = kv_ptr + kv_base[:, None] + d_offs1[None, :] * stride_kvbyte 

136 load_mask1 = valid_mask[:, None] & d_mask1[None, :] 

137 kv_u81 = tl.load(kv_byte_ptrs1, mask=load_mask1, other=0) 

138 kv_fp81 = kv_u81.to(tl.float8e4nv, bitcast=True) 

139 kv_f321 = kv_fp81.to(tl.float32) 

140 

141 for h_tile in tl.static_range(0, heads, BLOCK_H): 

142 offs_h = h_tile + tl.arange(0, BLOCK_H) 

143 h_mask = offs_h < heads 

144 

145 q_ptrs0 = ( 

146 q_base + offs_h[:, None] * stride_qh + d_offs0[None, :] * stride_qd 

147 ) 

148 q_vals0 = tl.load( 

149 q_ptrs0, mask=h_mask[:, None] & d_mask0[None, :], other=0.0 

150 ).to(tl.float32) 

151 

152 q_ptrs1 = ( 

153 q_base + offs_h[:, None] * stride_qh + d_offs1[None, :] * stride_qd 

154 ) 

155 q_vals1 = tl.load( 

156 q_ptrs1, mask=h_mask[:, None] & d_mask1[None, :], other=0.0 

157 ).to(tl.float32) 

158 

159 weights = tl.load( 

160 weights_ptr + pid_row * stride_wrow + offs_h * stride_wh, 

161 mask=h_mask, 

162 other=0.0, 

163 ) 

164 

165 q_T0 = tl.trans(q_vals0) 

166 q_T1 = tl.trans(q_vals1) 

167 

168 partial_dot = tl.dot(kv_f320, q_T0, out_dtype=tl.float32) 

169 partial_dot = tl.dot(kv_f321, q_T1, acc=partial_dot, out_dtype=tl.float32) 

170 

171 partial_dot = partial_dot * scale_f32[:, None] 

172 partial_dot = tl.maximum(partial_dot, 0.0) 

173 logit_accum += tl.sum(partial_dot * weights[None, :], axis=1) 

174 

175 out_vals = tl.where(valid_mask, logit_accum, float("-inf")) 

176 out_ptrs = logits_ptr + pid_row * stride_lrow + kv_global_pos * stride_lcol 

177 out_mask = valid_mask & (kv_global_pos < max_model_len) 

178 tl.store(out_ptrs, out_vals, mask=out_mask) 

179 

180 

181@triton.jit 

182def fill_neg_inf_kernel( 

183 out_ptr, 

184 n_elements, 

185 BLOCK: tl.constexpr, 

186): 

187 pid = tl.program_id(0) 

188 offs = pid * BLOCK + tl.arange(0, BLOCK) 

189 mask = offs < n_elements 

190 tl.store(out_ptr + offs, float("-inf"), mask=mask) 

191 

192 

193def fp8_paged_mqa_logits( 

194 q: torch.Tensor, 

195 kv_cache: torch.Tensor, 

196 weights: torch.Tensor, 

197 context_lens: torch.Tensor, 

198 block_tables: torch.Tensor, 

199 max_model_len: int, 

200) -> torch.Tensor: 

201 assert q.is_cuda and kv_cache.is_cuda and weights.is_cuda 

202 assert context_lens.is_cuda and block_tables.is_cuda 

203 

204 batch_size, next_n, heads, dim = q.size() 

205 num_blocks, block_size, one, dim_plus_4 = kv_cache.size() 

206 

207 assert one == 1 

208 assert dim_plus_4 == dim + 4 

209 assert weights.shape == (batch_size * next_n, heads) 

210 assert kv_cache.dtype == torch.uint8 

211 assert context_lens.dtype == torch.int32 

212 assert block_tables.dtype == torch.int32 

213 

214 q_contig = q.contiguous() 

215 kv_contig = kv_cache.contiguous() 

216 weights_contig = weights.contiguous() 

217 context_lens_contig = context_lens.contiguous() 

218 block_tables_contig = block_tables.contiguous() 

219 

220 total_rows = batch_size * next_n 

221 

222 logits = torch.empty( 

223 (total_rows, max_model_len), 

224 device=q.device, 

225 dtype=torch.float32, 

226 ) 

227 n_elements = total_rows * max_model_len 

228 FILL_BLOCK = 1024 

229 fill_grid = (cdiv(n_elements, FILL_BLOCK),) 

230 fill_neg_inf_kernel[fill_grid](logits, n_elements, BLOCK=FILL_BLOCK) 

231 

232 max_context = block_tables_contig.shape[1] * block_size 

233 

234 def grid(meta): 

235 BLOCK_KV = meta["BLOCK_KV"] 

236 num_kv_tiles = cdiv(max_context, BLOCK_KV) 

237 return (total_rows, num_kv_tiles) 

238 

239 fp8_paged_mqa_logits_kernel[grid]( 

240 q_contig, 

241 kv_contig, 

242 weights_contig, 

243 logits, 

244 block_tables_contig, 

245 context_lens_contig, 

246 q_contig.stride(0), 

247 q_contig.stride(1), 

248 q_contig.stride(2), 

249 q_contig.stride(3), 

250 kv_contig.stride(0), 

251 kv_contig.stride(1), 

252 kv_contig.stride(2), 

253 kv_contig.stride(3), 

254 weights_contig.stride(0), 

255 weights_contig.stride(1), 

256 logits.stride(0), 

257 logits.stride(1), 

258 block_tables_contig.stride(0), 

259 block_tables_contig.stride(1), 

260 next_n, 

261 heads, 

262 dim, 

263 block_size, 

264 max_model_len, 

265 dim_plus_4, 

266 ) 

267 

268 return logits