Coverage for src/flag_gems/fused/mhc/mhc_post.py: 35%
93 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1"""
2Triton implementation of mHC Post operator (optimized v3).
4Computes:
5 out[n, i, h] = post_layer_mix[n, i] * x[n, h]
6 + sum_j(comb_res_mix[n, j, i] * residual[n, j, h])
8Key optimizations (v3):
9- 2D grid = (N, cdiv(H, BLOCK_H)): high program count for latency hiding.
10- @triton.autotune over BLOCK_H / num_warps / num_stages.
11- Contiguous layout: stride math removed, enabling LDG.128.
12- All 4 accumulators computed then stored (better ILP).
13- BLOCK_H chosen to evenly divide H when possible (256 divides all targets).
14"""
16import logging
18import torch
19import triton
20import triton.language as tl
22logger = logging.getLogger(__name__)
25@triton.autotune(
26 configs=[
27 # Small BLOCK_H: many programs, good for latency hiding
28 triton.Config({"BLOCK_H": 128}, num_warps=4, num_stages=1),
29 triton.Config({"BLOCK_H": 128}, num_warps=4, num_stages=2),
30 triton.Config({"BLOCK_H": 256}, num_warps=4, num_stages=1),
31 triton.Config({"BLOCK_H": 256}, num_warps=4, num_stages=2),
32 triton.Config({"BLOCK_H": 256}, num_warps=8, num_stages=1),
33 triton.Config({"BLOCK_H": 256}, num_warps=8, num_stages=2),
34 # Medium BLOCK_H
35 triton.Config({"BLOCK_H": 512}, num_warps=4, num_stages=1),
36 triton.Config({"BLOCK_H": 512}, num_warps=4, num_stages=2),
37 triton.Config({"BLOCK_H": 512}, num_warps=8, num_stages=1),
38 triton.Config({"BLOCK_H": 512}, num_warps=8, num_stages=2),
39 # Large BLOCK_H
40 triton.Config({"BLOCK_H": 1024}, num_warps=4, num_stages=1),
41 triton.Config({"BLOCK_H": 1024}, num_warps=8, num_stages=1),
42 ],
43 key=["H"],
44)
45@triton.jit
46def mhc_post_kernel_hc_mult_4(
47 a_ptr, # comb_res_mix : (N, 4, 4), float32 — a[n, j, i]
48 b_ptr, # residual : (N, 4, H), bfloat16
49 c_ptr, # post_layer_mix: (N, 4), float32
50 d_ptr, # x : (N, H), bfloat16
51 out_ptr, # output : (N, 4, H), bfloat16
52 H: tl.constexpr,
53 BLOCK_H: tl.constexpr,
54):
55 """
56 Grid: (N, cdiv(H, BLOCK_H)).
57 Each program handles one token × one h-tile × all 4 hc streams.
58 """
59 pid_n = tl.program_id(0)
60 pid_h = tl.program_id(1)
62 h_off = pid_h * BLOCK_H + tl.arange(0, BLOCK_H)
63 h_mask = h_off < H
65 # ── pointer bases (contiguous layout) ──
66 a_base = pid_n * 16 # (N, 4, 4) → stride_n = 16
67 c_base = pid_n * 4 # (N, 4) → stride_n = 4
68 b_base = pid_n * 4 * H # (N, 4, H) → stride_n = 4*H
69 d_base = pid_n * H # (N, H) → stride_n = H
70 out_base = pid_n * 4 * H # (N, 4, H) → stride_n = 4*H
72 # ── load 20 scalars (L1 cached across h-tiles) ──
73 c0 = tl.load(c_ptr + c_base + 0).to(tl.float32)
74 c1 = tl.load(c_ptr + c_base + 1).to(tl.float32)
75 c2 = tl.load(c_ptr + c_base + 2).to(tl.float32)
76 c3 = tl.load(c_ptr + c_base + 3).to(tl.float32)
78 a00 = tl.load(a_ptr + a_base + 0).to(tl.float32)
79 a01 = tl.load(a_ptr + a_base + 1).to(tl.float32)
80 a02 = tl.load(a_ptr + a_base + 2).to(tl.float32)
81 a03 = tl.load(a_ptr + a_base + 3).to(tl.float32)
82 a10 = tl.load(a_ptr + a_base + 4).to(tl.float32)
83 a11 = tl.load(a_ptr + a_base + 5).to(tl.float32)
84 a12 = tl.load(a_ptr + a_base + 6).to(tl.float32)
85 a13 = tl.load(a_ptr + a_base + 7).to(tl.float32)
86 a20 = tl.load(a_ptr + a_base + 8).to(tl.float32)
87 a21 = tl.load(a_ptr + a_base + 9).to(tl.float32)
88 a22 = tl.load(a_ptr + a_base + 10).to(tl.float32)
89 a23 = tl.load(a_ptr + a_base + 11).to(tl.float32)
90 a30 = tl.load(a_ptr + a_base + 12).to(tl.float32)
91 a31 = tl.load(a_ptr + a_base + 13).to(tl.float32)
92 a32 = tl.load(a_ptr + a_base + 14).to(tl.float32)
93 a33 = tl.load(a_ptr + a_base + 15).to(tl.float32)
95 # ── load vectors (bf16 → f32) ──
96 d_vals = tl.load(d_ptr + d_base + h_off, mask=h_mask, other=0.0).to(tl.float32)
97 b0 = tl.load(b_ptr + b_base + 0 * H + h_off, mask=h_mask, other=0.0).to(tl.float32)
98 b1 = tl.load(b_ptr + b_base + 1 * H + h_off, mask=h_mask, other=0.0).to(tl.float32)
99 b2 = tl.load(b_ptr + b_base + 2 * H + h_off, mask=h_mask, other=0.0).to(tl.float32)
100 b3 = tl.load(b_ptr + b_base + 3 * H + h_off, mask=h_mask, other=0.0).to(tl.float32)
102 # ── compute all 4 output streams ──
103 acc0 = c0 * d_vals + a00 * b0 + a10 * b1 + a20 * b2 + a30 * b3
104 acc1 = c1 * d_vals + a01 * b0 + a11 * b1 + a21 * b2 + a31 * b3
105 acc2 = c2 * d_vals + a02 * b0 + a12 * b1 + a22 * b2 + a32 * b3
106 acc3 = c3 * d_vals + a03 * b0 + a13 * b1 + a23 * b2 + a33 * b3
108 # ── store all 4 outputs ──
109 tl.store(out_ptr + out_base + 0 * H + h_off, acc0.to(tl.bfloat16), mask=h_mask)
110 tl.store(out_ptr + out_base + 1 * H + h_off, acc1.to(tl.bfloat16), mask=h_mask)
111 tl.store(out_ptr + out_base + 2 * H + h_off, acc2.to(tl.bfloat16), mask=h_mask)
112 tl.store(out_ptr + out_base + 3 * H + h_off, acc3.to(tl.bfloat16), mask=h_mask)
115@triton.autotune(
116 configs=[
117 triton.Config({"BLOCK_H": 128}, num_warps=4, num_stages=1),
118 triton.Config({"BLOCK_H": 128}, num_warps=4, num_stages=2),
119 triton.Config({"BLOCK_H": 256}, num_warps=4, num_stages=1),
120 triton.Config({"BLOCK_H": 256}, num_warps=8, num_stages=1),
121 triton.Config({"BLOCK_H": 512}, num_warps=4, num_stages=1),
122 triton.Config({"BLOCK_H": 512}, num_warps=8, num_stages=1),
123 ],
124 key=["H", "HC"],
125)
126@triton.jit
127def mhc_post_kernel_generic(
128 a_ptr, # comb_res_mix : (N, HC, HC), float32
129 b_ptr, # residual : (N, HC, H), bfloat16
130 c_ptr, # post_layer_mix: (N, HC), float32
131 d_ptr, # x : (N, H), bfloat16
132 out_ptr, # output : (N, HC, H), bfloat16
133 H: tl.constexpr,
134 HC: tl.constexpr,
135 BLOCK_H: tl.constexpr,
136):
137 """Generic mHC post kernel for arbitrary HC.
139 Grid: (N, HC, cdiv(H, BLOCK_H)).
140 Each program handles one token × one output-stream(i) × one h-tile.
141 """
142 pid_n = tl.program_id(0)
143 pid_i = tl.program_id(1)
144 pid_h = tl.program_id(2)
146 h_off = pid_h * BLOCK_H + tl.arange(0, BLOCK_H)
147 h_mask = h_off < H
149 a_base = pid_n * HC * HC
150 b_base = pid_n * HC * H
151 c_base = pid_n * HC
152 d_base = pid_n * H
153 out_base = pid_n * HC * H + pid_i * H
155 d_vals = tl.load(d_ptr + d_base + h_off, mask=h_mask, other=0.0).to(tl.float32)
156 c_i = tl.load(c_ptr + c_base + pid_i).to(tl.float32)
158 acc = c_i * d_vals
159 for j in tl.static_range(0, HC):
160 a_ji = tl.load(a_ptr + a_base + j * HC + pid_i).to(tl.float32)
161 b_j = tl.load(b_ptr + b_base + j * H + h_off, mask=h_mask, other=0.0).to(
162 tl.float32
163 )
164 acc += a_ji * b_j
166 tl.store(out_ptr + out_base + h_off, acc.to(tl.bfloat16), mask=h_mask)
169def mhc_post(
170 x: torch.Tensor,
171 residual: torch.Tensor,
172 post_layer_mix: torch.Tensor,
173 comb_res_mix: torch.Tensor,
174) -> torch.Tensor:
175 """
176 mHC post-processing operator.
178 Args:
179 x: (N, H), bfloat16 — layer output
180 residual: (N, hc_mult, H), bfloat16 — multi-head residual
181 post_layer_mix: (N, hc_mult, 1), float32 — per-stream scale for x
182 comb_res_mix: (N, hc_mult, hc_mult), float32 — combination matrix
184 Returns:
185 out: (N, hc_mult, H), bfloat16
186 """
187 logger.debug(
188 "GEMS MHC_POST FORWARD, x=%s, residual=%s, post_layer_mix=%s, comb_res_mix=%s",
189 x.shape,
190 residual.shape,
191 post_layer_mix.shape,
192 comb_res_mix.shape,
193 )
195 N, hc, H = residual.shape
196 assert x.shape == (N, H)
197 assert post_layer_mix.shape in ((N, hc, 1), (N, hc))
198 assert comb_res_mix.shape == (N, hc, hc)
200 out = torch.empty_like(residual)
202 c = post_layer_mix.squeeze(-1).contiguous() # (N, hc)
203 a = comb_res_mix.contiguous() # (N, hc, hc)
204 b = residual.contiguous() # (N, hc, H)
205 d = x.contiguous() # (N, H)
207 if hc == 4:
209 def grid_specialized(META):
210 return (N, triton.cdiv(H, META["BLOCK_H"]))
212 mhc_post_kernel_hc_mult_4[grid_specialized](
213 a,
214 b,
215 c,
216 d,
217 out,
218 H=H,
219 )
220 else:
222 def grid_generic(META):
223 return (N, hc, triton.cdiv(H, META["BLOCK_H"]))
225 mhc_post_kernel_generic[grid_generic](
226 a,
227 b,
228 c,
229 d,
230 out,
231 H=H,
232 HC=hc,
233 )
234 return out
237def mhc_post_ref(
238 x: torch.Tensor,
239 residual: torch.Tensor,
240 post_layer_mix: torch.Tensor,
241 comb_res_mix: torch.Tensor,
242) -> torch.Tensor:
243 """PyTorch reference implementation."""
244 y = x.unsqueeze(-2) * post_layer_mix + torch.bmm(comb_res_mix.mT, residual.float())
245 return y.type_as(x)