Coverage for src/flag_gems/fused/mhc/hc_head_fused_kernel.py: 36%
112 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7logger = logging.getLogger(__name__)
10@triton.autotune(
11 configs=[
12 triton.Config({"BLOCK_H": 512}, num_warps=4, num_stages=2),
13 triton.Config({"BLOCK_H": 1024}, num_warps=4, num_stages=2),
14 triton.Config({"BLOCK_H": 1024}, num_warps=8, num_stages=2),
15 triton.Config({"BLOCK_H": 2048}, num_warps=8, num_stages=2),
16 triton.Config({"BLOCK_H": 2048}, num_warps=8, num_stages=3),
17 ],
18 key=["H", "HC"],
19)
20@triton.jit
21def _hc_head_fused_kernel(
22 residual_ptr,
23 fn_ptr,
24 hc_scale_ptr,
25 hc_base_ptr,
26 out_ptr,
27 T,
28 H: tl.constexpr,
29 rms_eps,
30 hc_eps,
31 residual_stride_t,
32 fn_stride_m,
33 out_stride_t,
34 HC: tl.constexpr,
35 BLOCK_H: tl.constexpr,
36):
37 pid_t = tl.program_id(0)
38 if pid_t >= T:
39 return
41 x_base = pid_t * residual_stride_t
43 # Pass 1: iterate over H blocks to compute sqrsum and mixes
44 sqr_acc = tl.zeros([BLOCK_H], dtype=tl.float32)
45 mix_acc0 = tl.zeros([BLOCK_H], dtype=tl.float32)
46 mix_acc1 = tl.zeros([BLOCK_H], dtype=tl.float32)
47 mix_acc2 = tl.zeros([BLOCK_H], dtype=tl.float32)
48 mix_acc3 = tl.zeros([BLOCK_H], dtype=tl.float32)
50 for h_start in range(0, H, BLOCK_H):
51 h_off = h_start + tl.arange(0, BLOCK_H)
52 h_mask = h_off < H
54 r0 = tl.load(residual_ptr + x_base + 0 * H + h_off, mask=h_mask, other=0.0).to(
55 tl.float32
56 )
57 r1 = tl.load(residual_ptr + x_base + 1 * H + h_off, mask=h_mask, other=0.0).to(
58 tl.float32
59 )
60 sqr_acc += r0 * r0 + r1 * r1
62 fn00 = tl.load(fn_ptr + 0 * fn_stride_m + 0 * H + h_off, mask=h_mask, other=0.0)
63 fn01 = tl.load(fn_ptr + 0 * fn_stride_m + 1 * H + h_off, mask=h_mask, other=0.0)
64 mix_acc0 += r0 * fn00 + r1 * fn01
66 fn10 = tl.load(fn_ptr + 1 * fn_stride_m + 0 * H + h_off, mask=h_mask, other=0.0)
67 fn11 = tl.load(fn_ptr + 1 * fn_stride_m + 1 * H + h_off, mask=h_mask, other=0.0)
68 mix_acc1 += r0 * fn10 + r1 * fn11
70 if HC > 2:
71 r2 = tl.load(
72 residual_ptr + x_base + 2 * H + h_off, mask=h_mask, other=0.0
73 ).to(tl.float32)
74 r3 = tl.load(
75 residual_ptr + x_base + 3 * H + h_off, mask=h_mask, other=0.0
76 ).to(tl.float32)
77 sqr_acc += r2 * r2 + r3 * r3
79 mix_acc0 += r2 * tl.load(
80 fn_ptr + 0 * fn_stride_m + 2 * H + h_off, mask=h_mask, other=0.0
81 )
82 mix_acc0 += r3 * tl.load(
83 fn_ptr + 0 * fn_stride_m + 3 * H + h_off, mask=h_mask, other=0.0
84 )
86 mix_acc1 += r2 * tl.load(
87 fn_ptr + 1 * fn_stride_m + 2 * H + h_off, mask=h_mask, other=0.0
88 )
89 mix_acc1 += r3 * tl.load(
90 fn_ptr + 1 * fn_stride_m + 3 * H + h_off, mask=h_mask, other=0.0
91 )
93 fn20 = tl.load(
94 fn_ptr + 2 * fn_stride_m + 0 * H + h_off, mask=h_mask, other=0.0
95 )
96 fn21 = tl.load(
97 fn_ptr + 2 * fn_stride_m + 1 * H + h_off, mask=h_mask, other=0.0
98 )
99 fn22 = tl.load(
100 fn_ptr + 2 * fn_stride_m + 2 * H + h_off, mask=h_mask, other=0.0
101 )
102 fn23 = tl.load(
103 fn_ptr + 2 * fn_stride_m + 3 * H + h_off, mask=h_mask, other=0.0
104 )
105 mix_acc2 += r0 * fn20 + r1 * fn21 + r2 * fn22 + r3 * fn23
107 fn30 = tl.load(
108 fn_ptr + 3 * fn_stride_m + 0 * H + h_off, mask=h_mask, other=0.0
109 )
110 fn31 = tl.load(
111 fn_ptr + 3 * fn_stride_m + 1 * H + h_off, mask=h_mask, other=0.0
112 )
113 fn32 = tl.load(
114 fn_ptr + 3 * fn_stride_m + 2 * H + h_off, mask=h_mask, other=0.0
115 )
116 fn33 = tl.load(
117 fn_ptr + 3 * fn_stride_m + 3 * H + h_off, mask=h_mask, other=0.0
118 )
119 mix_acc3 += r0 * fn30 + r1 * fn31 + r2 * fn32 + r3 * fn33
121 K = HC * H
122 sqr_total = tl.sum(sqr_acc)
123 rsqrt_val = tl.math.rsqrt(sqr_total / K + rms_eps)
124 hc_scale = tl.load(hc_scale_ptr)
126 mix0 = tl.sum(mix_acc0)
127 mix1 = tl.sum(mix_acc1)
128 hc_base0 = tl.load(hc_base_ptr + 0)
129 hc_base1 = tl.load(hc_base_ptr + 1)
130 pre_mix0 = tl.sigmoid(mix0 * rsqrt_val * hc_scale + hc_base0) + hc_eps
131 pre_mix1 = tl.sigmoid(mix1 * rsqrt_val * hc_scale + hc_base1) + hc_eps
133 if HC > 2:
134 mix2 = tl.sum(mix_acc2)
135 mix3 = tl.sum(mix_acc3)
136 hc_base2 = tl.load(hc_base_ptr + 2)
137 hc_base3 = tl.load(hc_base_ptr + 3)
138 pre_mix2 = tl.sigmoid(mix2 * rsqrt_val * hc_scale + hc_base2) + hc_eps
139 pre_mix3 = tl.sigmoid(mix3 * rsqrt_val * hc_scale + hc_base3) + hc_eps
141 # Pass 2: weighted sum
142 out_base = pid_t * out_stride_t
143 for h_start in range(0, H, BLOCK_H):
144 h_off = h_start + tl.arange(0, BLOCK_H)
145 h_mask = h_off < H
146 r0 = tl.load(residual_ptr + x_base + 0 * H + h_off, mask=h_mask, other=0.0).to(
147 tl.float32
148 )
149 r1 = tl.load(residual_ptr + x_base + 1 * H + h_off, mask=h_mask, other=0.0).to(
150 tl.float32
151 )
152 acc = pre_mix0 * r0 + pre_mix1 * r1
153 if HC > 2:
154 r2 = tl.load(
155 residual_ptr + x_base + 2 * H + h_off, mask=h_mask, other=0.0
156 ).to(tl.float32)
157 r3 = tl.load(
158 residual_ptr + x_base + 3 * H + h_off, mask=h_mask, other=0.0
159 ).to(tl.float32)
160 acc += pre_mix2 * r2 + pre_mix3 * r3
161 tl.store(out_ptr + out_base + h_off, acc.to(tl.bfloat16), mask=h_mask)
164def hc_head_fused_kernel_ref(
165 hs_flat: torch.Tensor,
166 fn: torch.Tensor,
167 hc_scale: torch.Tensor,
168 hc_base: torch.Tensor,
169 out: torch.Tensor,
170 hidden_size: int,
171 rms_eps: float,
172 hc_eps: float,
173 hc_mult: int,
174) -> torch.Tensor:
175 """Pure PyTorch reference implementation for correctness testing."""
176 if hs_flat.shape[0] == 0:
177 return out
178 x = hs_flat.reshape(hs_flat.shape[0], hc_mult * hidden_size).to(torch.float32)
179 mixes = torch.matmul(x, fn.t())
180 sqrsum = x.square().sum(dim=-1, keepdim=True)
181 rsqrt = torch.rsqrt(sqrsum / (hc_mult * hidden_size) + rms_eps)
182 pre_mix = torch.sigmoid(mixes * rsqrt * hc_scale[0] + hc_base) + hc_eps
183 result = torch.sum(pre_mix.unsqueeze(-1) * hs_flat.to(torch.float32), dim=1).to(
184 out.dtype
185 )
186 out.copy_(result)
187 return out
190def hc_head_fused_kernel(
191 hs_flat: torch.Tensor,
192 fn: torch.Tensor,
193 hc_scale: torch.Tensor,
194 hc_base: torch.Tensor,
195 out: torch.Tensor,
196 hidden_size: int,
197 rms_eps: float,
198 hc_eps: float,
199 hc_mult: int,
200) -> torch.Tensor:
201 """HC head fused kernel: fully fused Triton implementation."""
202 logger.debug("GEMS HC_HEAD_FUSED")
203 assert hs_flat.dtype == torch.bfloat16
204 assert fn.dtype == torch.float32
205 assert hc_scale.dtype == torch.float32
206 assert hc_base.dtype == torch.float32
208 num_tokens = hs_flat.shape[0]
209 if num_tokens == 0:
210 return out
212 assert hs_flat.shape == (num_tokens, hc_mult, hidden_size)
213 assert fn.shape == (hc_mult, hc_mult * hidden_size)
214 assert hc_scale.shape == (1,)
215 assert hc_base.shape == (hc_mult,)
216 assert out.shape == (num_tokens, hidden_size)
217 assert out.dtype == hs_flat.dtype
219 if hs_flat.device.type != "cuda":
220 return hc_head_fused_kernel_ref(
221 hs_flat, fn, hc_scale, hc_base, out, hidden_size, rms_eps, hc_eps, hc_mult
222 )
224 H = hidden_size
226 residual_c = hs_flat.contiguous()
227 fn_c = fn.contiguous()
228 out_c = out if out.is_contiguous() else torch.empty_like(out)
230 _hc_head_fused_kernel[(num_tokens,)](
231 residual_c,
232 fn_c,
233 hc_scale,
234 hc_base,
235 out_c,
236 num_tokens,
237 H,
238 rms_eps,
239 hc_eps,
240 residual_c.stride(0),
241 fn_c.stride(0),
242 out_c.stride(0),
243 HC=hc_mult,
244 )
246 if out.data_ptr() != out_c.data_ptr():
247 out.copy_(out_c)
249 return out