Coverage for src/flag_gems/runtime/backend/_sunrise/fused/hc_head_fused_kernel.py: 0%

63 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.autotune( 

7 configs=[ 

8 triton.Config({"BLOCK_H": 128}, num_warps=4, num_stages=1), 

9 triton.Config({"BLOCK_H": 256}, num_warps=4, num_stages=1), 

10 triton.Config({"BLOCK_H": 256}, num_warps=8, num_stages=1), 

11 triton.Config({"BLOCK_H": 512}, num_warps=4, num_stages=1), 

12 triton.Config({"BLOCK_H": 512}, num_warps=8, num_stages=1), 

13 triton.Config({"BLOCK_H": 1024}, num_warps=8, num_stages=1), 

14 ], 

15 key=["H", "HC"], 

16) 

17@triton.jit 

18def _hc_head_apply_pre_mix_kernel( 

19 hs_ptr, 

20 pre_mix_ptr, 

21 out_ptr, 

22 T, 

23 H, 

24 hs_stride_t, 

25 hs_stride_m, 

26 hs_stride_h, 

27 pre_stride_t, 

28 pre_stride_m, 

29 out_stride_t, 

30 out_stride_h, 

31 HC: tl.constexpr, 

32 BLOCK_H: tl.constexpr, 

33): 

34 pid_t = tl.program_id(0) 

35 pid_h = tl.program_id(1) 

36 

37 if pid_t >= T: 

38 return 

39 

40 h_off = pid_h * BLOCK_H + tl.arange(0, BLOCK_H) 

41 h_mask = h_off < H 

42 

43 acc = tl.zeros([BLOCK_H], dtype=tl.float32) 

44 hs_t_base = pid_t * hs_stride_t 

45 pre_t_base = pid_t * pre_stride_t 

46 

47 for i_hc in tl.static_range(HC): 

48 pre = tl.load(pre_mix_ptr + pre_t_base + i_hc * pre_stride_m).to(tl.float32) 

49 hs_ptrs = hs_ptr + hs_t_base + i_hc * hs_stride_m + h_off * hs_stride_h 

50 hs_vals = tl.load(hs_ptrs, mask=h_mask, other=0.0).to(tl.float32) 

51 acc += pre * hs_vals 

52 

53 out_ptrs = out_ptr + pid_t * out_stride_t + h_off * out_stride_h 

54 tl.store(out_ptrs, acc, mask=h_mask) 

55 

56 

57def hc_head_fused_kernel_ref( 

58 hs_flat: torch.Tensor, 

59 fn: torch.Tensor, 

60 hc_scale: torch.Tensor, 

61 hc_base: torch.Tensor, 

62 out: torch.Tensor, 

63 hidden_size: int, 

64 rms_eps: float, 

65 hc_eps: float, 

66 hc_mult: int, 

67) -> torch.Tensor: 

68 if hs_flat.shape[0] == 0: 

69 return out 

70 x = hs_flat.reshape(hs_flat.shape[0], hc_mult * hidden_size).to(torch.float32) 

71 mixes = torch.matmul(x, fn.t()) 

72 sqrsum = x.square().sum(dim=-1, keepdim=True) 

73 rsqrt = torch.rsqrt(sqrsum / (hc_mult * hidden_size) + rms_eps) 

74 pre_mix = torch.sigmoid(mixes * rsqrt * hc_scale[0] + hc_base) + hc_eps 

75 result = torch.sum(pre_mix.unsqueeze(-1) * hs_flat.to(torch.float32), dim=1).to( 

76 out.dtype 

77 ) 

78 out.copy_(result) 

79 return out 

80 

81 

82def hc_head_fused_kernel( 

83 hs_flat: torch.Tensor, 

84 fn: torch.Tensor, 

85 hc_scale: torch.Tensor, 

86 hc_base: torch.Tensor, 

87 out: torch.Tensor, 

88 hidden_size: int, 

89 rms_eps: float, 

90 hc_eps: float, 

91 hc_mult: int, 

92) -> torch.Tensor: 

93 assert hs_flat.dtype in [torch.float32, torch.float16, torch.bfloat16] 

94 assert fn.dtype == torch.float32 

95 assert hc_scale.dtype == torch.float32 

96 assert hc_base.dtype == torch.float32 

97 

98 num_tokens = hs_flat.shape[0] 

99 if num_tokens == 0: 

100 return out 

101 

102 assert hs_flat.shape == (num_tokens, hc_mult, hidden_size) 

103 assert fn.shape == (hc_mult, hc_mult * hidden_size) 

104 assert hc_scale.shape == (1,) 

105 assert hc_base.shape == (hc_mult,) 

106 assert out.shape == (num_tokens, hidden_size) 

107 assert out.dtype == hs_flat.dtype 

108 

109 x = hs_flat.reshape(num_tokens, hc_mult * hidden_size).to(torch.float32) 

110 mixes = torch.matmul(x, fn.t()) 

111 sqrsum = x.square().sum(dim=-1, keepdim=True) 

112 rsqrt = torch.rsqrt(sqrsum / (hc_mult * hidden_size) + rms_eps) 

113 pre_mix = torch.sigmoid(mixes * rsqrt * hc_scale[0] + hc_base) + hc_eps 

114 

115 if hs_flat.device.type not in ["cuda", "ptpu"]: # [sunrise fix] 

116 return hc_head_fused_kernel_ref( 

117 hs_flat, 

118 fn, 

119 hc_scale, 

120 hc_base, 

121 out, 

122 hidden_size, 

123 rms_eps, 

124 hc_eps, 

125 hc_mult, 

126 ) 

127 

128 hs_flat_c = hs_flat.contiguous() 

129 pre_mix_c = pre_mix.contiguous() 

130 out_c = out.contiguous() 

131 

132 def grid(meta): 

133 return num_tokens, triton.cdiv(hidden_size, meta["BLOCK_H"]) 

134 

135 _hc_head_apply_pre_mix_kernel[grid]( 

136 hs_flat_c, 

137 pre_mix_c, 

138 out_c, 

139 num_tokens, 

140 hidden_size, 

141 hs_flat_c.stride(0), 

142 hs_flat_c.stride(1), 

143 hs_flat_c.stride(2), 

144 pre_mix_c.stride(0), 

145 pre_mix_c.stride(1), 

146 out_c.stride(0), 

147 out_c.stride(1), 

148 HC=hc_mult, 

149 ) 

150 

151 if out.data_ptr() != out_c.data_ptr(): 

152 out.copy_(out_c) 

153 return out