Coverage for src/flag_gems/ops/fp8_mqa_logits.py: 22%

54 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.utils import libentry, libtuner 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@libentry() 

14@libtuner( 

15 configs=runtime.get_tuned_config("fp8_mqa_logits"), 

16 key=["M", "N", "D"], 

17) 

18@triton.jit 

19def _fp8_mqa_logits_kernel( 

20 Q, 

21 K, 

22 K_SCALES, 

23 WEIGHTS, 

24 CU_SEQLEN_KS, 

25 CU_SEQLEN_KE, 

26 LOGITS, 

27 stride_qm, 

28 stride_qh, 

29 stride_qd, 

30 stride_kn, 

31 stride_kd, 

32 M: tl.constexpr, 

33 H: tl.constexpr, 

34 D: tl.constexpr, 

35 N: tl.constexpr, 

36 CLEAN_LOGITS: tl.constexpr, 

37 BLOCK_M: tl.constexpr, 

38 BLOCK_N: tl.constexpr, 

39 BLOCK_D: tl.constexpr, 

40): 

41 """ 

42 Optimized Triton kernel for FP8 MQA logits computation. 

43 

44 Each program computes logits[m, n] = sum_h(ReLU(score[m, h, n]) * weights[m, h]) 

45 where score[m, h, n] = sum_d(q[m, h, d] * k[n, d]) 

46 

47 Optimization: Each program handles a BLOCK_M x BLOCK_N tile. 

48 K is loaded once and reused across H dimension. 

49 """ 

50 pid_m = tl.program_id(0) 

51 pid_n = tl.program_id(1) 

52 

53 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

54 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

55 offs_d = tl.arange(0, BLOCK_D) 

56 

57 mask_m = offs_m < M 

58 mask_n = offs_n < N 

59 

60 ks_start = tl.load(CU_SEQLEN_KS + offs_m, mask=mask_m, other=0) 

61 ke_end = tl.load(CU_SEQLEN_KE + offs_m, mask=mask_m, other=N) 

62 

63 k_scales = tl.load(K_SCALES + offs_n, mask=mask_n, other=1.0) 

64 k_scales = k_scales.to(tl.float32) 

65 

66 acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

67 

68 for h_idx in range(H): 

69 weight_ptrs = WEIGHTS + offs_m * H + h_idx 

70 weight_h = tl.load(weight_ptrs, mask=mask_m, other=0.0) 

71 weight_h = weight_h.to(tl.float32) 

72 

73 score_h = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

74 

75 for d_start in range(0, D, BLOCK_D): 

76 d_mask = (d_start + offs_d) < D 

77 d_offs = d_start + offs_d 

78 

79 q_ptrs = ( 

80 Q 

81 + offs_m[:, None] * stride_qm 

82 + h_idx * stride_qh 

83 + d_offs[None, :] * stride_qd 

84 ) 

85 q = tl.load(q_ptrs, mask=mask_m[:, None] & d_mask[None, :], other=0.0) 

86 q = q.to(tl.float32) 

87 

88 k_ptrs = K + offs_n[:, None] * stride_kn + d_offs[None, :] * stride_kd 

89 k = tl.load(k_ptrs, mask=mask_n[:, None] & d_mask[None, :], other=0.0) 

90 k = k.to(tl.float32) * k_scales[:, None] 

91 

92 score_h += tl.dot(q, tl.trans(k)) 

93 

94 score_h = tl.maximum(score_h, 0.0) 

95 acc += score_h * weight_h[:, None] 

96 

97 if CLEAN_LOGITS: 

98 n_valid = (offs_n[None, :] >= ks_start[:, None]) & ( 

99 offs_n[None, :] < ke_end[:, None] 

100 ) 

101 acc = tl.where(n_valid, acc, float("-inf")) 

102 

103 out_ptrs = LOGITS + offs_m[:, None] * N + offs_n[None, :] 

104 tl.store(out_ptrs, acc, mask=mask_m[:, None] & mask_n[None, :]) 

105 

106 

107def fp8_mqa_logits( 

108 q: torch.Tensor, 

109 kv: tuple[torch.Tensor, torch.Tensor], 

110 weights: torch.Tensor, 

111 cu_seqlen_ks: torch.Tensor, 

112 cu_seqlen_ke: torch.Tensor, 

113 clean_logits: bool, 

114) -> torch.Tensor: 

115 logger.debug("GEMS FP8_MQA_LOGITS") 

116 

117 k_fp8, k_scales = kv 

118 

119 M, H, D = q.shape 

120 N = k_fp8.shape[0] 

121 

122 logits = torch.zeros((M, N), dtype=torch.float32, device=q.device) 

123 

124 grid = lambda META: ( 

125 triton.cdiv(M, META["BLOCK_M"]), 

126 triton.cdiv(N, META["BLOCK_N"]), 

127 ) 

128 

129 _fp8_mqa_logits_kernel[grid]( 

130 q, 

131 k_fp8, 

132 k_scales, 

133 weights, 

134 cu_seqlen_ks, 

135 cu_seqlen_ke, 

136 logits, 

137 q.stride(0), # stride_qm 

138 q.stride(1), # stride_qh 

139 q.stride(2), # stride_qd 

140 k_fp8.stride(0), # stride_kn 

141 k_fp8.stride(1), # stride_kd 

142 M, 

143 H, 

144 D, 

145 N, 

146 clean_logits, 

147 ) 

148 

149 return logits