Coverage for src/flag_gems/runtime/backend/_sunrise/fused/skip_layernorm.py: 0%

85 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as ext 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15@libentry() 

16@triton.jit(do_not_specialize=["eps"]) 

17def skip_layer_norm_kernel( 

18 Y, # pointer to the output 

19 X, # pointer to the input 

20 R, # pointer to the residual 

21 W, # pointer to the weights 

22 B, # pointer to the biases 

23 y_stride_r, 

24 y_stride_c, 

25 x_stride_r, # how much to increase the pointer when moving by 1 row 

26 x_stride_c, # how much to increase the pointer when moving by 1 col 

27 r_stride_r, # how much to increase the pointer when moving by 1 row 

28 r_stride_c, # how much to increase the pointer when moving by 1 col 

29 N, # number of columns in X 

30 eps, # epsilon to avoid division by zero 

31 BLOCK_SIZE: tl.constexpr, 

32): 

33 pid = ext.program_id(0) 

34 Y += pid * y_stride_r 

35 X += pid * x_stride_r 

36 R += pid * r_stride_r 

37 

38 mask = tl.arange(0, BLOCK_SIZE) < N 

39 cols = tl.arange(0, BLOCK_SIZE) 

40 x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32) 

41 r = tl.load(R + cols * r_stride_c, mask, other=0.0).to(tl.float32) 

42 

43 x += r 

44 

45 mean = tl.sum(x, axis=0) / N 

46 

47 # Compute variance 

48 _var = tl.where(mask, x - mean, 0.0) 

49 _var = _var * _var 

50 var = tl.sum(_var, axis=0) / N 

51 rstd = 1 / tl.sqrt(var + eps) 

52 

53 w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0).to(tl.float32) 

54 b = tl.load(B + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0).to(tl.float32) 

55 

56 x_hat = (x - mean) * rstd 

57 y = w * x_hat + b 

58 y = y.to(Y.dtype.element_ty) 

59 tl.store(Y + cols * y_stride_c, y, mask=mask) 

60 

61 

62@libentry() 

63@triton.jit(do_not_specialize=["eps"]) 

64def skip_layer_norm_c_split_kernel( 

65 Y, # pointer to the output 

66 X, # pointer to the input 

67 R, # pointer to the residual 

68 W, # pointer to the weights 

69 B, # pointer to the biases 

70 y_stride_r, 

71 y_stride_c, 

72 x_stride_r, # how much to increase the pointer when moving by 1 row 

73 x_stride_c, # how much to increase the pointer when moving by 1 col 

74 r_stride_r, # how much to increase the pointer when moving by 1 row 

75 r_stride_c, # how much to increase the pointer when moving by 1 col 

76 N, # number of columns in X 

77 eps, # epsilon to avoid division by zero 

78 BLOCK_SIZE: tl.constexpr, 

79): 

80 pid = ext.program_id(0) 

81 Y += pid * y_stride_r 

82 X += pid * x_stride_r 

83 R += pid * r_stride_r 

84 

85 _sum = tl.zeros((), dtype=tl.float32) 

86 _var = tl.zeros((), dtype=tl.float32) 

87 

88 for off in range(0, N, BLOCK_SIZE): 

89 cols = off + tl.arange(0, BLOCK_SIZE) 

90 mask = cols < N 

91 x = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

92 r = tl.load(R + cols, mask, other=0.0).to(tl.float32) 

93 x += r 

94 _sum += tl.sum(x, axis=0) 

95 _var += tl.sum(x * x, axis=0) 

96 

97 mean = _sum / N 

98 var = (_var / N) - (mean * mean) 

99 rstd = 1 / tl.sqrt(var + eps) 

100 

101 for off in range(0, N, BLOCK_SIZE): 

102 cols = off + tl.arange(0, BLOCK_SIZE) 

103 mask = cols < N 

104 w = tl.load(W + cols, mask=mask, other=0.0).to(tl.float32) 

105 b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32) 

106 x = tl.load(X + cols, mask, other=0.0).to(tl.float32) 

107 r = tl.load(R + cols, mask, other=0.0).to(tl.float32) 

108 x += r 

109 x_hat = (x - mean) * rstd 

110 y = w * x_hat + b 

111 y = y.to(Y.dtype.element_ty) 

112 tl.store(Y + cols * y_stride_c, y, mask=mask) 

113 

114 

115class SkipLayerNorm(torch.autograd.Function): 

116 @staticmethod 

117 def forward(ctx, x, residual, normalized_shape, weight, bias, eps=1e-5): 

118 logger.debug("GEMS SKIP LAYERNORM FORWARD") 

119 dim = x.ndim - len(normalized_shape) 

120 M = math.prod(x.shape[:dim]) 

121 N = math.prod(normalized_shape) 

122 

123 BLOCK_SIZE = triton.next_power_of_2(N) 

124 x = x.contiguous() 

125 residual = residual.contiguous() 

126 weight = weight.contiguous() 

127 bias = bias.contiguous() 

128 y = torch.empty_like(x) 

129 

130 with torch_device_fn.device(x.device): 

131 if BLOCK_SIZE <= 1024: 

132 skip_layer_norm_kernel[M,]( 

133 y, x, residual, weight, bias, N, 1, N, 1, N, 1, N, eps, BLOCK_SIZE 

134 ) 

135 else: 

136 BLOCK_SIZE = 1024 

137 skip_layer_norm_c_split_kernel[M,]( 

138 y, 

139 x, 

140 residual, 

141 weight, 

142 bias, 

143 N, 

144 1, 

145 N, 

146 1, 

147 N, 

148 1, 

149 N, 

150 eps, 

151 BLOCK_SIZE, 

152 num_warps=16, 

153 ) 

154 return y 

155 

156 

157def skip_layer_norm(x, residual, normalized_shape, weight, bias, eps=1e-5): 

158 return SkipLayerNorm.apply(x, residual, normalized_shape, weight, bias, eps)