Coverage for src/flag_gems/ops/renorm.py: 60%
107 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry, tl_extra_shim
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger(__name__)
15@libentry()
16@triton.jit
17def renorm_kernel_norms(
18 X,
19 norms_out,
20 M,
21 N,
22 p_val,
23 BLOCK_SIZE: tl.constexpr,
24):
25 """Kernel to compute p-norms for each sub-tensor (one per row when dim=0)."""
26 pid = tle.program_id(0)
28 if tl.constexpr(X.dtype.element_ty == tl.float16) or tl.constexpr(
29 X.dtype.element_ty == tl.bfloat16
30 ):
31 cdtype = tl.float32
32 else:
33 cdtype = X.dtype.element_ty
35 row_offset = pid * N
36 x_ptr_row = X + row_offset
37 norm_ptr = norms_out + pid
39 _sum = tl.zeros([BLOCK_SIZE], dtype=cdtype)
41 for off in range(0, N, BLOCK_SIZE):
42 cols = off + tl.arange(0, BLOCK_SIZE)
43 mask = cols < N
44 x_vals = tl.load(x_ptr_row + cols, mask=mask, other=0.0).to(cdtype)
45 abs_vals = tl.abs(x_vals)
46 if p_val == 2.0:
47 powered = x_vals * x_vals
48 else:
49 powered = tl_extra_shim.pow(abs_vals, p_val)
50 _sum += powered
52 sum_val = tl.sum(_sum)
53 if p_val == 2.0:
54 norm = tl_extra_shim.sqrt(sum_val)
55 else:
56 norm = tl_extra_shim.pow(sum_val, 1.0 / p_val)
58 tl.store(norm_ptr, norm)
61@libentry()
62@triton.jit
63def renorm_kernel_scale(
64 X,
65 norms_in,
66 Y,
67 M,
68 N,
69 p_val,
70 maxnorm,
71 BLOCK_SIZE: tl.constexpr,
72):
73 """Kernel to apply scaling based on precomputed norms."""
74 pid = tle.program_id(0)
76 if tl.constexpr(X.dtype.element_ty == tl.float16) or tl.constexpr(
77 X.dtype.element_ty == tl.bfloat16
78 ):
79 cdtype = tl.float32
80 else:
81 cdtype = X.dtype.element_ty
83 row_offset = pid * N
84 x_ptr_row = X + row_offset
85 y_ptr_row = Y + row_offset
86 norm = tl.load(norms_in + pid)
88 if norm <= maxnorm:
89 for off in range(0, N, BLOCK_SIZE):
90 cols = off + tl.arange(0, BLOCK_SIZE)
91 mask = cols < N
92 x_vals = tl.load(x_ptr_row + cols, mask=mask, other=0.0)
93 tl.store(y_ptr_row + cols, x_vals, mask=mask)
94 else:
95 scale = maxnorm / norm
96 for off in range(0, N, BLOCK_SIZE):
97 cols = off + tl.arange(0, BLOCK_SIZE)
98 mask = cols < N
99 x_vals = tl.load(x_ptr_row + cols, mask=mask, other=0.0).to(cdtype)
100 y_vals = x_vals * scale
101 tl.store(y_ptr_row + cols, y_vals.to(X.dtype.element_ty), mask=mask)
104def renorm(input, p, dim, maxnorm):
105 logger.debug("GEMS RENORM")
107 if dim < 0:
108 dim = input.ndim + dim
110 # Handle dim 0 case efficiently with single-kernel-per-row approach
111 if dim == 0:
112 M = input.shape[0]
113 N = input.numel() // M
115 input = input.contiguous()
116 norms = torch.empty((M,), dtype=input.dtype, device=input.device)
118 BLOCK = min(triton.next_power_of_2(N), 128)
119 grid = (M,)
121 with torch_device_fn.device(input.device):
122 renorm_kernel_norms[grid](
123 input,
124 norms,
125 M,
126 N,
127 p,
128 BLOCK_SIZE=BLOCK,
129 )
131 output = torch.empty_like(input)
133 with torch_device_fn.device(input.device):
134 renorm_kernel_scale[grid](
135 input,
136 norms,
137 output,
138 M,
139 N,
140 p,
141 maxnorm,
142 BLOCK_SIZE=BLOCK,
143 )
145 return output
146 else:
147 # For non-zero dim, use permute to make dim=0
148 ndim = input.ndim
149 perm = list(range(ndim))
150 perm.remove(dim)
151 perm.insert(0, dim)
152 inv_perm = [perm.index(i) for i in range(ndim)]
154 x_perm = input.permute(perm)
155 result = renorm(x_perm, p, 0, maxnorm)
156 return result.permute(inv_perm)
159def renorm_(input, p, dim, maxnorm):
160 logger.debug("GEMS RENORM_")
162 if dim < 0:
163 dim = input.ndim + dim
165 if dim == 0:
166 M = input.shape[0]
167 N = input.numel() // M
169 input = input.contiguous()
170 norms = torch.empty((M,), dtype=input.dtype, device=input.device)
172 BLOCK = min(triton.next_power_of_2(N), 128)
173 grid = (M,)
175 with torch_device_fn.device(input.device):
176 renorm_kernel_norms[grid](
177 input,
178 norms,
179 M,
180 N,
181 p,
182 BLOCK_SIZE=BLOCK,
183 )
185 with torch_device_fn.device(input.device):
186 renorm_kernel_scale[grid](
187 input,
188 norms,
189 input,
190 M,
191 N,
192 p,
193 maxnorm,
194 BLOCK_SIZE=BLOCK,
195 )
197 return input
198 else:
199 # For non-zero dim, use permute to make dim=0
200 ndim = input.ndim
201 perm = list(range(ndim))
202 perm.remove(dim)
203 perm.insert(0, dim)
204 inv_perm = [perm.index(i) for i in range(ndim)]
206 x_perm = input.permute(perm)
207 result = renorm_(x_perm, p, 0, maxnorm)
208 input.copy_(result.permute(inv_perm))
209 return input