Coverage for src/flag_gems/fused/mhc/hc_head_fused_kernel.py: 36%

112 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7logger = logging.getLogger(__name__) 

8 

9 

10@triton.autotune( 

11 configs=[ 

12 triton.Config({"BLOCK_H": 512}, num_warps=4, num_stages=2), 

13 triton.Config({"BLOCK_H": 1024}, num_warps=4, num_stages=2), 

14 triton.Config({"BLOCK_H": 1024}, num_warps=8, num_stages=2), 

15 triton.Config({"BLOCK_H": 2048}, num_warps=8, num_stages=2), 

16 triton.Config({"BLOCK_H": 2048}, num_warps=8, num_stages=3), 

17 ], 

18 key=["H", "HC"], 

19) 

20@triton.jit 

21def _hc_head_fused_kernel( 

22 residual_ptr, 

23 fn_ptr, 

24 hc_scale_ptr, 

25 hc_base_ptr, 

26 out_ptr, 

27 T, 

28 H: tl.constexpr, 

29 rms_eps, 

30 hc_eps, 

31 residual_stride_t, 

32 fn_stride_m, 

33 out_stride_t, 

34 HC: tl.constexpr, 

35 BLOCK_H: tl.constexpr, 

36): 

37 pid_t = tl.program_id(0) 

38 if pid_t >= T: 

39 return 

40 

41 x_base = pid_t * residual_stride_t 

42 

43 # Pass 1: iterate over H blocks to compute sqrsum and mixes 

44 sqr_acc = tl.zeros([BLOCK_H], dtype=tl.float32) 

45 mix_acc0 = tl.zeros([BLOCK_H], dtype=tl.float32) 

46 mix_acc1 = tl.zeros([BLOCK_H], dtype=tl.float32) 

47 mix_acc2 = tl.zeros([BLOCK_H], dtype=tl.float32) 

48 mix_acc3 = tl.zeros([BLOCK_H], dtype=tl.float32) 

49 

50 for h_start in range(0, H, BLOCK_H): 

51 h_off = h_start + tl.arange(0, BLOCK_H) 

52 h_mask = h_off < H 

53 

54 r0 = tl.load(residual_ptr + x_base + 0 * H + h_off, mask=h_mask, other=0.0).to( 

55 tl.float32 

56 ) 

57 r1 = tl.load(residual_ptr + x_base + 1 * H + h_off, mask=h_mask, other=0.0).to( 

58 tl.float32 

59 ) 

60 sqr_acc += r0 * r0 + r1 * r1 

61 

62 fn00 = tl.load(fn_ptr + 0 * fn_stride_m + 0 * H + h_off, mask=h_mask, other=0.0) 

63 fn01 = tl.load(fn_ptr + 0 * fn_stride_m + 1 * H + h_off, mask=h_mask, other=0.0) 

64 mix_acc0 += r0 * fn00 + r1 * fn01 

65 

66 fn10 = tl.load(fn_ptr + 1 * fn_stride_m + 0 * H + h_off, mask=h_mask, other=0.0) 

67 fn11 = tl.load(fn_ptr + 1 * fn_stride_m + 1 * H + h_off, mask=h_mask, other=0.0) 

68 mix_acc1 += r0 * fn10 + r1 * fn11 

69 

70 if HC > 2: 

71 r2 = tl.load( 

72 residual_ptr + x_base + 2 * H + h_off, mask=h_mask, other=0.0 

73 ).to(tl.float32) 

74 r3 = tl.load( 

75 residual_ptr + x_base + 3 * H + h_off, mask=h_mask, other=0.0 

76 ).to(tl.float32) 

77 sqr_acc += r2 * r2 + r3 * r3 

78 

79 mix_acc0 += r2 * tl.load( 

80 fn_ptr + 0 * fn_stride_m + 2 * H + h_off, mask=h_mask, other=0.0 

81 ) 

82 mix_acc0 += r3 * tl.load( 

83 fn_ptr + 0 * fn_stride_m + 3 * H + h_off, mask=h_mask, other=0.0 

84 ) 

85 

86 mix_acc1 += r2 * tl.load( 

87 fn_ptr + 1 * fn_stride_m + 2 * H + h_off, mask=h_mask, other=0.0 

88 ) 

89 mix_acc1 += r3 * tl.load( 

90 fn_ptr + 1 * fn_stride_m + 3 * H + h_off, mask=h_mask, other=0.0 

91 ) 

92 

93 fn20 = tl.load( 

94 fn_ptr + 2 * fn_stride_m + 0 * H + h_off, mask=h_mask, other=0.0 

95 ) 

96 fn21 = tl.load( 

97 fn_ptr + 2 * fn_stride_m + 1 * H + h_off, mask=h_mask, other=0.0 

98 ) 

99 fn22 = tl.load( 

100 fn_ptr + 2 * fn_stride_m + 2 * H + h_off, mask=h_mask, other=0.0 

101 ) 

102 fn23 = tl.load( 

103 fn_ptr + 2 * fn_stride_m + 3 * H + h_off, mask=h_mask, other=0.0 

104 ) 

105 mix_acc2 += r0 * fn20 + r1 * fn21 + r2 * fn22 + r3 * fn23 

106 

107 fn30 = tl.load( 

108 fn_ptr + 3 * fn_stride_m + 0 * H + h_off, mask=h_mask, other=0.0 

109 ) 

110 fn31 = tl.load( 

111 fn_ptr + 3 * fn_stride_m + 1 * H + h_off, mask=h_mask, other=0.0 

112 ) 

113 fn32 = tl.load( 

114 fn_ptr + 3 * fn_stride_m + 2 * H + h_off, mask=h_mask, other=0.0 

115 ) 

116 fn33 = tl.load( 

117 fn_ptr + 3 * fn_stride_m + 3 * H + h_off, mask=h_mask, other=0.0 

118 ) 

119 mix_acc3 += r0 * fn30 + r1 * fn31 + r2 * fn32 + r3 * fn33 

120 

121 K = HC * H 

122 sqr_total = tl.sum(sqr_acc) 

123 rsqrt_val = tl.math.rsqrt(sqr_total / K + rms_eps) 

124 hc_scale = tl.load(hc_scale_ptr) 

125 

126 mix0 = tl.sum(mix_acc0) 

127 mix1 = tl.sum(mix_acc1) 

128 hc_base0 = tl.load(hc_base_ptr + 0) 

129 hc_base1 = tl.load(hc_base_ptr + 1) 

130 pre_mix0 = tl.sigmoid(mix0 * rsqrt_val * hc_scale + hc_base0) + hc_eps 

131 pre_mix1 = tl.sigmoid(mix1 * rsqrt_val * hc_scale + hc_base1) + hc_eps 

132 

133 if HC > 2: 

134 mix2 = tl.sum(mix_acc2) 

135 mix3 = tl.sum(mix_acc3) 

136 hc_base2 = tl.load(hc_base_ptr + 2) 

137 hc_base3 = tl.load(hc_base_ptr + 3) 

138 pre_mix2 = tl.sigmoid(mix2 * rsqrt_val * hc_scale + hc_base2) + hc_eps 

139 pre_mix3 = tl.sigmoid(mix3 * rsqrt_val * hc_scale + hc_base3) + hc_eps 

140 

141 # Pass 2: weighted sum 

142 out_base = pid_t * out_stride_t 

143 for h_start in range(0, H, BLOCK_H): 

144 h_off = h_start + tl.arange(0, BLOCK_H) 

145 h_mask = h_off < H 

146 r0 = tl.load(residual_ptr + x_base + 0 * H + h_off, mask=h_mask, other=0.0).to( 

147 tl.float32 

148 ) 

149 r1 = tl.load(residual_ptr + x_base + 1 * H + h_off, mask=h_mask, other=0.0).to( 

150 tl.float32 

151 ) 

152 acc = pre_mix0 * r0 + pre_mix1 * r1 

153 if HC > 2: 

154 r2 = tl.load( 

155 residual_ptr + x_base + 2 * H + h_off, mask=h_mask, other=0.0 

156 ).to(tl.float32) 

157 r3 = tl.load( 

158 residual_ptr + x_base + 3 * H + h_off, mask=h_mask, other=0.0 

159 ).to(tl.float32) 

160 acc += pre_mix2 * r2 + pre_mix3 * r3 

161 tl.store(out_ptr + out_base + h_off, acc.to(tl.bfloat16), mask=h_mask) 

162 

163 

164def hc_head_fused_kernel_ref( 

165 hs_flat: torch.Tensor, 

166 fn: torch.Tensor, 

167 hc_scale: torch.Tensor, 

168 hc_base: torch.Tensor, 

169 out: torch.Tensor, 

170 hidden_size: int, 

171 rms_eps: float, 

172 hc_eps: float, 

173 hc_mult: int, 

174) -> torch.Tensor: 

175 """Pure PyTorch reference implementation for correctness testing.""" 

176 if hs_flat.shape[0] == 0: 

177 return out 

178 x = hs_flat.reshape(hs_flat.shape[0], hc_mult * hidden_size).to(torch.float32) 

179 mixes = torch.matmul(x, fn.t()) 

180 sqrsum = x.square().sum(dim=-1, keepdim=True) 

181 rsqrt = torch.rsqrt(sqrsum / (hc_mult * hidden_size) + rms_eps) 

182 pre_mix = torch.sigmoid(mixes * rsqrt * hc_scale[0] + hc_base) + hc_eps 

183 result = torch.sum(pre_mix.unsqueeze(-1) * hs_flat.to(torch.float32), dim=1).to( 

184 out.dtype 

185 ) 

186 out.copy_(result) 

187 return out 

188 

189 

190def hc_head_fused_kernel( 

191 hs_flat: torch.Tensor, 

192 fn: torch.Tensor, 

193 hc_scale: torch.Tensor, 

194 hc_base: torch.Tensor, 

195 out: torch.Tensor, 

196 hidden_size: int, 

197 rms_eps: float, 

198 hc_eps: float, 

199 hc_mult: int, 

200) -> torch.Tensor: 

201 """HC head fused kernel: fully fused Triton implementation.""" 

202 logger.debug("GEMS HC_HEAD_FUSED") 

203 assert hs_flat.dtype == torch.bfloat16 

204 assert fn.dtype == torch.float32 

205 assert hc_scale.dtype == torch.float32 

206 assert hc_base.dtype == torch.float32 

207 

208 num_tokens = hs_flat.shape[0] 

209 if num_tokens == 0: 

210 return out 

211 

212 assert hs_flat.shape == (num_tokens, hc_mult, hidden_size) 

213 assert fn.shape == (hc_mult, hc_mult * hidden_size) 

214 assert hc_scale.shape == (1,) 

215 assert hc_base.shape == (hc_mult,) 

216 assert out.shape == (num_tokens, hidden_size) 

217 assert out.dtype == hs_flat.dtype 

218 

219 if hs_flat.device.type != "cuda": 

220 return hc_head_fused_kernel_ref( 

221 hs_flat, fn, hc_scale, hc_base, out, hidden_size, rms_eps, hc_eps, hc_mult 

222 ) 

223 

224 H = hidden_size 

225 

226 residual_c = hs_flat.contiguous() 

227 fn_c = fn.contiguous() 

228 out_c = out if out.is_contiguous() else torch.empty_like(out) 

229 

230 _hc_head_fused_kernel[(num_tokens,)]( 

231 residual_c, 

232 fn_c, 

233 hc_scale, 

234 hc_base, 

235 out_c, 

236 num_tokens, 

237 H, 

238 rms_eps, 

239 hc_eps, 

240 residual_c.stride(0), 

241 fn_c.stride(0), 

242 out_c.stride(0), 

243 HC=hc_mult, 

244 ) 

245 

246 if out.data_ptr() != out_c.data_ptr(): 

247 out.copy_(out_c) 

248 

249 return out