Coverage for src/flag_gems/fused/mhc/mhc_post.py: 35%

93 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1""" 

2Triton implementation of mHC Post operator (optimized v3). 

3 

4Computes: 

5 out[n, i, h] = post_layer_mix[n, i] * x[n, h] 

6 + sum_j(comb_res_mix[n, j, i] * residual[n, j, h]) 

7 

8Key optimizations (v3): 

9- 2D grid = (N, cdiv(H, BLOCK_H)): high program count for latency hiding. 

10- @triton.autotune over BLOCK_H / num_warps / num_stages. 

11- Contiguous layout: stride math removed, enabling LDG.128. 

12- All 4 accumulators computed then stored (better ILP). 

13- BLOCK_H chosen to evenly divide H when possible (256 divides all targets). 

14""" 

15 

16import logging 

17 

18import torch 

19import triton 

20import triton.language as tl 

21 

22logger = logging.getLogger(__name__) 

23 

24 

25@triton.autotune( 

26 configs=[ 

27 # Small BLOCK_H: many programs, good for latency hiding 

28 triton.Config({"BLOCK_H": 128}, num_warps=4, num_stages=1), 

29 triton.Config({"BLOCK_H": 128}, num_warps=4, num_stages=2), 

30 triton.Config({"BLOCK_H": 256}, num_warps=4, num_stages=1), 

31 triton.Config({"BLOCK_H": 256}, num_warps=4, num_stages=2), 

32 triton.Config({"BLOCK_H": 256}, num_warps=8, num_stages=1), 

33 triton.Config({"BLOCK_H": 256}, num_warps=8, num_stages=2), 

34 # Medium BLOCK_H 

35 triton.Config({"BLOCK_H": 512}, num_warps=4, num_stages=1), 

36 triton.Config({"BLOCK_H": 512}, num_warps=4, num_stages=2), 

37 triton.Config({"BLOCK_H": 512}, num_warps=8, num_stages=1), 

38 triton.Config({"BLOCK_H": 512}, num_warps=8, num_stages=2), 

39 # Large BLOCK_H 

40 triton.Config({"BLOCK_H": 1024}, num_warps=4, num_stages=1), 

41 triton.Config({"BLOCK_H": 1024}, num_warps=8, num_stages=1), 

42 ], 

43 key=["H"], 

44) 

45@triton.jit 

46def mhc_post_kernel_hc_mult_4( 

47 a_ptr, # comb_res_mix : (N, 4, 4), float32 — a[n, j, i] 

48 b_ptr, # residual : (N, 4, H), bfloat16 

49 c_ptr, # post_layer_mix: (N, 4), float32 

50 d_ptr, # x : (N, H), bfloat16 

51 out_ptr, # output : (N, 4, H), bfloat16 

52 H: tl.constexpr, 

53 BLOCK_H: tl.constexpr, 

54): 

55 """ 

56 Grid: (N, cdiv(H, BLOCK_H)). 

57 Each program handles one token × one h-tile × all 4 hc streams. 

58 """ 

59 pid_n = tl.program_id(0) 

60 pid_h = tl.program_id(1) 

61 

62 h_off = pid_h * BLOCK_H + tl.arange(0, BLOCK_H) 

63 h_mask = h_off < H 

64 

65 # ── pointer bases (contiguous layout) ── 

66 a_base = pid_n * 16 # (N, 4, 4) → stride_n = 16 

67 c_base = pid_n * 4 # (N, 4) → stride_n = 4 

68 b_base = pid_n * 4 * H # (N, 4, H) → stride_n = 4*H 

69 d_base = pid_n * H # (N, H) → stride_n = H 

70 out_base = pid_n * 4 * H # (N, 4, H) → stride_n = 4*H 

71 

72 # ── load 20 scalars (L1 cached across h-tiles) ── 

73 c0 = tl.load(c_ptr + c_base + 0).to(tl.float32) 

74 c1 = tl.load(c_ptr + c_base + 1).to(tl.float32) 

75 c2 = tl.load(c_ptr + c_base + 2).to(tl.float32) 

76 c3 = tl.load(c_ptr + c_base + 3).to(tl.float32) 

77 

78 a00 = tl.load(a_ptr + a_base + 0).to(tl.float32) 

79 a01 = tl.load(a_ptr + a_base + 1).to(tl.float32) 

80 a02 = tl.load(a_ptr + a_base + 2).to(tl.float32) 

81 a03 = tl.load(a_ptr + a_base + 3).to(tl.float32) 

82 a10 = tl.load(a_ptr + a_base + 4).to(tl.float32) 

83 a11 = tl.load(a_ptr + a_base + 5).to(tl.float32) 

84 a12 = tl.load(a_ptr + a_base + 6).to(tl.float32) 

85 a13 = tl.load(a_ptr + a_base + 7).to(tl.float32) 

86 a20 = tl.load(a_ptr + a_base + 8).to(tl.float32) 

87 a21 = tl.load(a_ptr + a_base + 9).to(tl.float32) 

88 a22 = tl.load(a_ptr + a_base + 10).to(tl.float32) 

89 a23 = tl.load(a_ptr + a_base + 11).to(tl.float32) 

90 a30 = tl.load(a_ptr + a_base + 12).to(tl.float32) 

91 a31 = tl.load(a_ptr + a_base + 13).to(tl.float32) 

92 a32 = tl.load(a_ptr + a_base + 14).to(tl.float32) 

93 a33 = tl.load(a_ptr + a_base + 15).to(tl.float32) 

94 

95 # ── load vectors (bf16 → f32) ── 

96 d_vals = tl.load(d_ptr + d_base + h_off, mask=h_mask, other=0.0).to(tl.float32) 

97 b0 = tl.load(b_ptr + b_base + 0 * H + h_off, mask=h_mask, other=0.0).to(tl.float32) 

98 b1 = tl.load(b_ptr + b_base + 1 * H + h_off, mask=h_mask, other=0.0).to(tl.float32) 

99 b2 = tl.load(b_ptr + b_base + 2 * H + h_off, mask=h_mask, other=0.0).to(tl.float32) 

100 b3 = tl.load(b_ptr + b_base + 3 * H + h_off, mask=h_mask, other=0.0).to(tl.float32) 

101 

102 # ── compute all 4 output streams ── 

103 acc0 = c0 * d_vals + a00 * b0 + a10 * b1 + a20 * b2 + a30 * b3 

104 acc1 = c1 * d_vals + a01 * b0 + a11 * b1 + a21 * b2 + a31 * b3 

105 acc2 = c2 * d_vals + a02 * b0 + a12 * b1 + a22 * b2 + a32 * b3 

106 acc3 = c3 * d_vals + a03 * b0 + a13 * b1 + a23 * b2 + a33 * b3 

107 

108 # ── store all 4 outputs ── 

109 tl.store(out_ptr + out_base + 0 * H + h_off, acc0.to(tl.bfloat16), mask=h_mask) 

110 tl.store(out_ptr + out_base + 1 * H + h_off, acc1.to(tl.bfloat16), mask=h_mask) 

111 tl.store(out_ptr + out_base + 2 * H + h_off, acc2.to(tl.bfloat16), mask=h_mask) 

112 tl.store(out_ptr + out_base + 3 * H + h_off, acc3.to(tl.bfloat16), mask=h_mask) 

113 

114 

115@triton.autotune( 

116 configs=[ 

117 triton.Config({"BLOCK_H": 128}, num_warps=4, num_stages=1), 

118 triton.Config({"BLOCK_H": 128}, num_warps=4, num_stages=2), 

119 triton.Config({"BLOCK_H": 256}, num_warps=4, num_stages=1), 

120 triton.Config({"BLOCK_H": 256}, num_warps=8, num_stages=1), 

121 triton.Config({"BLOCK_H": 512}, num_warps=4, num_stages=1), 

122 triton.Config({"BLOCK_H": 512}, num_warps=8, num_stages=1), 

123 ], 

124 key=["H", "HC"], 

125) 

126@triton.jit 

127def mhc_post_kernel_generic( 

128 a_ptr, # comb_res_mix : (N, HC, HC), float32 

129 b_ptr, # residual : (N, HC, H), bfloat16 

130 c_ptr, # post_layer_mix: (N, HC), float32 

131 d_ptr, # x : (N, H), bfloat16 

132 out_ptr, # output : (N, HC, H), bfloat16 

133 H: tl.constexpr, 

134 HC: tl.constexpr, 

135 BLOCK_H: tl.constexpr, 

136): 

137 """Generic mHC post kernel for arbitrary HC. 

138 

139 Grid: (N, HC, cdiv(H, BLOCK_H)). 

140 Each program handles one token × one output-stream(i) × one h-tile. 

141 """ 

142 pid_n = tl.program_id(0) 

143 pid_i = tl.program_id(1) 

144 pid_h = tl.program_id(2) 

145 

146 h_off = pid_h * BLOCK_H + tl.arange(0, BLOCK_H) 

147 h_mask = h_off < H 

148 

149 a_base = pid_n * HC * HC 

150 b_base = pid_n * HC * H 

151 c_base = pid_n * HC 

152 d_base = pid_n * H 

153 out_base = pid_n * HC * H + pid_i * H 

154 

155 d_vals = tl.load(d_ptr + d_base + h_off, mask=h_mask, other=0.0).to(tl.float32) 

156 c_i = tl.load(c_ptr + c_base + pid_i).to(tl.float32) 

157 

158 acc = c_i * d_vals 

159 for j in tl.static_range(0, HC): 

160 a_ji = tl.load(a_ptr + a_base + j * HC + pid_i).to(tl.float32) 

161 b_j = tl.load(b_ptr + b_base + j * H + h_off, mask=h_mask, other=0.0).to( 

162 tl.float32 

163 ) 

164 acc += a_ji * b_j 

165 

166 tl.store(out_ptr + out_base + h_off, acc.to(tl.bfloat16), mask=h_mask) 

167 

168 

169def mhc_post( 

170 x: torch.Tensor, 

171 residual: torch.Tensor, 

172 post_layer_mix: torch.Tensor, 

173 comb_res_mix: torch.Tensor, 

174) -> torch.Tensor: 

175 """ 

176 mHC post-processing operator. 

177 

178 Args: 

179 x: (N, H), bfloat16 — layer output 

180 residual: (N, hc_mult, H), bfloat16 — multi-head residual 

181 post_layer_mix: (N, hc_mult, 1), float32 — per-stream scale for x 

182 comb_res_mix: (N, hc_mult, hc_mult), float32 — combination matrix 

183 

184 Returns: 

185 out: (N, hc_mult, H), bfloat16 

186 """ 

187 logger.debug( 

188 "GEMS MHC_POST FORWARD, x=%s, residual=%s, post_layer_mix=%s, comb_res_mix=%s", 

189 x.shape, 

190 residual.shape, 

191 post_layer_mix.shape, 

192 comb_res_mix.shape, 

193 ) 

194 

195 N, hc, H = residual.shape 

196 assert x.shape == (N, H) 

197 assert post_layer_mix.shape in ((N, hc, 1), (N, hc)) 

198 assert comb_res_mix.shape == (N, hc, hc) 

199 

200 out = torch.empty_like(residual) 

201 

202 c = post_layer_mix.squeeze(-1).contiguous() # (N, hc) 

203 a = comb_res_mix.contiguous() # (N, hc, hc) 

204 b = residual.contiguous() # (N, hc, H) 

205 d = x.contiguous() # (N, H) 

206 

207 if hc == 4: 

208 

209 def grid_specialized(META): 

210 return (N, triton.cdiv(H, META["BLOCK_H"])) 

211 

212 mhc_post_kernel_hc_mult_4[grid_specialized]( 

213 a, 

214 b, 

215 c, 

216 d, 

217 out, 

218 H=H, 

219 ) 

220 else: 

221 

222 def grid_generic(META): 

223 return (N, hc, triton.cdiv(H, META["BLOCK_H"])) 

224 

225 mhc_post_kernel_generic[grid_generic]( 

226 a, 

227 b, 

228 c, 

229 d, 

230 out, 

231 H=H, 

232 HC=hc, 

233 ) 

234 return out 

235 

236 

237def mhc_post_ref( 

238 x: torch.Tensor, 

239 residual: torch.Tensor, 

240 post_layer_mix: torch.Tensor, 

241 comb_res_mix: torch.Tensor, 

242) -> torch.Tensor: 

243 """PyTorch reference implementation.""" 

244 y = x.unsqueeze(-2) * post_layer_mix + torch.bmm(comb_res_mix.mT, residual.float()) 

245 return y.type_as(x)