Coverage for src/flag_gems/fused/add_rms_norm.py: 38%
93 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
3import math
5import torch
6import triton
7import triton.language as tl
9from flag_gems import runtime
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils import libentry
12from flag_gems.utils import triton_lang_extension as tle
14logger = logging.getLogger(__name__)
17@triton.jit
18def prev_multiple_of(a, b):
19 return tl.cdiv(a, b) * b - b
22@libentry()
23@triton.jit(do_not_specialize=["eps"])
24def add_rms_norm_kernel(
25 out_ptr,
26 in_ptr1,
27 in_ptr2,
28 w_ptr,
29 y_stride_r,
30 y_stride_c,
31 x1_stride_r,
32 x1_stride_c,
33 x2_stride_r,
34 x2_stride_c,
35 N,
36 eps,
37 BLOCK_SIZE: tl.constexpr,
38):
39 if tl.constexpr(in_ptr1.dtype.element_ty == tl.float16) or tl.constexpr(
40 in_ptr1.dtype.element_ty == tl.bfloat16
41 ):
42 cdtype = tl.float32
43 else:
44 cdtype = in_ptr1.dtype.element_ty
46 pid = tl.program_id(0)
47 out_ptr += pid * y_stride_r
48 in_ptr1 += pid * x1_stride_r
49 in_ptr2 += pid * x2_stride_r
51 mask = tl.arange(0, BLOCK_SIZE) < N
52 cols = tl.arange(0, BLOCK_SIZE)
53 x1 = tl.load(in_ptr1 + cols * x1_stride_c, mask, other=0.0).to(cdtype)
54 x2 = tl.load(in_ptr2 + cols * x2_stride_c, mask, other=0.0).to(cdtype)
56 # Add the two inputs
57 x = x1 + x2
59 var = tl.sum(x * x, axis=0) / N
60 rrms = 1 / tl.sqrt(var + eps)
62 w = tl.load(w_ptr + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0)
63 y = (x * rrms * w).to(cdtype)
64 tl.store(out_ptr + cols * y_stride_c, y, mask=mask)
67@libentry()
68@triton.autotune(
69 configs=runtime.get_tuned_config("add_rms_norm_loop"),
70 key=["N"],
71)
72@triton.jit(do_not_specialize=["eps"])
73def add_rms_norm_loop_kernel(
74 out_ptr,
75 in_ptr1,
76 in_ptr2,
77 w_ptr,
78 N,
79 eps,
80 TILE_N: tl.constexpr,
81):
82 if tl.constexpr(in_ptr1.dtype.element_ty == tl.float16) or tl.constexpr(
83 in_ptr1.dtype.element_ty == tl.bfloat16
84 ):
85 cdtype = tl.float32
86 else:
87 cdtype = in_ptr1.dtype.element_ty
89 pid = tle.program_id(0)
91 # Pass 1: compute sum(x^2) in chunks
92 acc = tl.zeros((TILE_N,), dtype=tl.float32)
93 num_steps = tl.cdiv(N, TILE_N)
95 for step in range(0, num_steps - 1):
96 start_n = step * TILE_N
97 n_offsets = start_n + tl.arange(0, TILE_N)
98 x1 = tl.load(in_ptr1 + pid * N + n_offsets).to(tl.float32)
99 x2 = tl.load(in_ptr2 + pid * N + n_offsets).to(tl.float32)
100 x = x1 + x2
101 acc += x * x
103 # last step with mask
104 start_n = (num_steps - 1) * TILE_N
105 n_offsets = start_n + tl.arange(0, TILE_N)
106 mask = n_offsets < N
107 x1 = tl.load(in_ptr1 + pid * N + n_offsets, mask=mask, other=0.0).to(tl.float32)
108 x2 = tl.load(in_ptr2 + pid * N + n_offsets, mask=mask, other=0.0).to(tl.float32)
109 x = x1 + x2
110 acc += x * x
112 var = tl.sum(acc) / N
113 rrms = 1 / tl.sqrt(var + eps)
115 # Pass 2: normalize in reverse order (better L2 cache reuse)
116 prev_multiple = prev_multiple_of(N, TILE_N)
118 # first reverse step with mask
119 for start_n in range(0, TILE_N, TILE_N):
120 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N)
121 mask = n_offsets < N
122 x1 = tl.load(
123 in_ptr1 + pid * N + n_offsets,
124 mask=mask,
125 other=0.0,
126 eviction_policy="evict_first",
127 ).to(cdtype)
128 x2 = tl.load(
129 in_ptr2 + pid * N + n_offsets,
130 mask=mask,
131 other=0.0,
132 eviction_policy="evict_first",
133 ).to(cdtype)
134 x = x1 + x2
135 w = tl.load(w_ptr + n_offsets, mask=mask, other=0.0)
136 y = (x * rrms * w).to(cdtype)
137 tl.store(out_ptr + pid * N + n_offsets, y, mask=mask)
139 for start_n in range(TILE_N, N, TILE_N):
140 n_offsets = (prev_multiple - start_n) + tl.arange(0, TILE_N)
141 x1 = tl.load(
142 in_ptr1 + pid * N + n_offsets,
143 eviction_policy="evict_first",
144 ).to(cdtype)
145 x2 = tl.load(
146 in_ptr2 + pid * N + n_offsets,
147 eviction_policy="evict_first",
148 ).to(cdtype)
149 x = x1 + x2
150 w = tl.load(w_ptr + n_offsets)
151 y = (x * rrms * w).to(cdtype)
152 tl.store(out_ptr + pid * N + n_offsets, y)
155def add_rms_norm(x1, x2, normalized_shape, weight, eps=1e-5):
156 """
157 Add_RMSNorm: Add two inputs element-wise and apply RMS normalization.
159 Args:
160 x1: First input tensor
161 x2: Second input tensor
162 normalized_shape: Shape to normalize over (typically the last dimensions)
163 weight: Optional weight tensor for the normalization
164 eps: Epsilon value for numerical stability
166 Returns:
167 Normalized output tensor
168 """
169 logger.debug(
170 "GEMS ADD_RMS_NORM FORWARD, [input1 shape]: %s, [input2 shape]: %s, [weight shape]: %s",
171 x1.size(),
172 x2.size(),
173 weight.size() if weight is not None else None,
174 )
175 dim = x1.ndim - len(normalized_shape)
176 M = math.prod(x1.shape[:dim])
177 N = math.prod(normalized_shape)
179 # Verify shapes match
180 assert x1.shape == x2.shape, f"Input shapes must match: {x1.shape} vs {x2.shape}"
182 x1 = x1.contiguous()
183 x2 = x2.contiguous()
184 weight = weight.contiguous()
185 y = torch.empty_like(x1)
187 with torch_device_fn.device(x1.device):
188 if N <= 4096:
189 BLOCK_SIZE = triton.next_power_of_2(N)
190 add_rms_norm_kernel[M,](
191 y, x1, x2, weight, N, 1, N, 1, N, 1, N, eps, BLOCK_SIZE
192 )
193 else:
194 add_rms_norm_loop_kernel[M,](y, x1, x2, weight, N, eps)
196 return y