Coverage for src/flag_gems/runtime/backend/_sunrise/fused/hc_head_fused_kernel.py: 0%
63 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.autotune(
7 configs=[
8 triton.Config({"BLOCK_H": 128}, num_warps=4, num_stages=1),
9 triton.Config({"BLOCK_H": 256}, num_warps=4, num_stages=1),
10 triton.Config({"BLOCK_H": 256}, num_warps=8, num_stages=1),
11 triton.Config({"BLOCK_H": 512}, num_warps=4, num_stages=1),
12 triton.Config({"BLOCK_H": 512}, num_warps=8, num_stages=1),
13 triton.Config({"BLOCK_H": 1024}, num_warps=8, num_stages=1),
14 ],
15 key=["H", "HC"],
16)
17@triton.jit
18def _hc_head_apply_pre_mix_kernel(
19 hs_ptr,
20 pre_mix_ptr,
21 out_ptr,
22 T,
23 H,
24 hs_stride_t,
25 hs_stride_m,
26 hs_stride_h,
27 pre_stride_t,
28 pre_stride_m,
29 out_stride_t,
30 out_stride_h,
31 HC: tl.constexpr,
32 BLOCK_H: tl.constexpr,
33):
34 pid_t = tl.program_id(0)
35 pid_h = tl.program_id(1)
37 if pid_t >= T:
38 return
40 h_off = pid_h * BLOCK_H + tl.arange(0, BLOCK_H)
41 h_mask = h_off < H
43 acc = tl.zeros([BLOCK_H], dtype=tl.float32)
44 hs_t_base = pid_t * hs_stride_t
45 pre_t_base = pid_t * pre_stride_t
47 for i_hc in tl.static_range(HC):
48 pre = tl.load(pre_mix_ptr + pre_t_base + i_hc * pre_stride_m).to(tl.float32)
49 hs_ptrs = hs_ptr + hs_t_base + i_hc * hs_stride_m + h_off * hs_stride_h
50 hs_vals = tl.load(hs_ptrs, mask=h_mask, other=0.0).to(tl.float32)
51 acc += pre * hs_vals
53 out_ptrs = out_ptr + pid_t * out_stride_t + h_off * out_stride_h
54 tl.store(out_ptrs, acc, mask=h_mask)
57def hc_head_fused_kernel_ref(
58 hs_flat: torch.Tensor,
59 fn: torch.Tensor,
60 hc_scale: torch.Tensor,
61 hc_base: torch.Tensor,
62 out: torch.Tensor,
63 hidden_size: int,
64 rms_eps: float,
65 hc_eps: float,
66 hc_mult: int,
67) -> torch.Tensor:
68 if hs_flat.shape[0] == 0:
69 return out
70 x = hs_flat.reshape(hs_flat.shape[0], hc_mult * hidden_size).to(torch.float32)
71 mixes = torch.matmul(x, fn.t())
72 sqrsum = x.square().sum(dim=-1, keepdim=True)
73 rsqrt = torch.rsqrt(sqrsum / (hc_mult * hidden_size) + rms_eps)
74 pre_mix = torch.sigmoid(mixes * rsqrt * hc_scale[0] + hc_base) + hc_eps
75 result = torch.sum(pre_mix.unsqueeze(-1) * hs_flat.to(torch.float32), dim=1).to(
76 out.dtype
77 )
78 out.copy_(result)
79 return out
82def hc_head_fused_kernel(
83 hs_flat: torch.Tensor,
84 fn: torch.Tensor,
85 hc_scale: torch.Tensor,
86 hc_base: torch.Tensor,
87 out: torch.Tensor,
88 hidden_size: int,
89 rms_eps: float,
90 hc_eps: float,
91 hc_mult: int,
92) -> torch.Tensor:
93 assert hs_flat.dtype in [torch.float32, torch.float16, torch.bfloat16]
94 assert fn.dtype == torch.float32
95 assert hc_scale.dtype == torch.float32
96 assert hc_base.dtype == torch.float32
98 num_tokens = hs_flat.shape[0]
99 if num_tokens == 0:
100 return out
102 assert hs_flat.shape == (num_tokens, hc_mult, hidden_size)
103 assert fn.shape == (hc_mult, hc_mult * hidden_size)
104 assert hc_scale.shape == (1,)
105 assert hc_base.shape == (hc_mult,)
106 assert out.shape == (num_tokens, hidden_size)
107 assert out.dtype == hs_flat.dtype
109 x = hs_flat.reshape(num_tokens, hc_mult * hidden_size).to(torch.float32)
110 mixes = torch.matmul(x, fn.t())
111 sqrsum = x.square().sum(dim=-1, keepdim=True)
112 rsqrt = torch.rsqrt(sqrsum / (hc_mult * hidden_size) + rms_eps)
113 pre_mix = torch.sigmoid(mixes * rsqrt * hc_scale[0] + hc_base) + hc_eps
115 if hs_flat.device.type not in ["cuda", "ptpu"]: # [sunrise fix]
116 return hc_head_fused_kernel_ref(
117 hs_flat,
118 fn,
119 hc_scale,
120 hc_base,
121 out,
122 hidden_size,
123 rms_eps,
124 hc_eps,
125 hc_mult,
126 )
128 hs_flat_c = hs_flat.contiguous()
129 pre_mix_c = pre_mix.contiguous()
130 out_c = out.contiguous()
132 def grid(meta):
133 return num_tokens, triton.cdiv(hidden_size, meta["BLOCK_H"])
135 _hc_head_apply_pre_mix_kernel[grid](
136 hs_flat_c,
137 pre_mix_c,
138 out_c,
139 num_tokens,
140 hidden_size,
141 hs_flat_c.stride(0),
142 hs_flat_c.stride(1),
143 hs_flat_c.stride(2),
144 pre_mix_c.stride(0),
145 pre_mix_c.stride(1),
146 out_c.stride(0),
147 out_c.stride(1),
148 HC=hc_mult,
149 )
151 if out.data_ptr() != out_c.data_ptr():
152 out.copy_(out_c)
153 return out