Coverage for src/flag_gems/ops/renorm.py: 60%

107 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry, tl_extra_shim 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@libentry() 

16@triton.jit 

17def renorm_kernel_norms( 

18 X, 

19 norms_out, 

20 M, 

21 N, 

22 p_val, 

23 BLOCK_SIZE: tl.constexpr, 

24): 

25 """Kernel to compute p-norms for each sub-tensor (one per row when dim=0).""" 

26 pid = tle.program_id(0) 

27 

28 if tl.constexpr(X.dtype.element_ty == tl.float16) or tl.constexpr( 

29 X.dtype.element_ty == tl.bfloat16 

30 ): 

31 cdtype = tl.float32 

32 else: 

33 cdtype = X.dtype.element_ty 

34 

35 row_offset = pid * N 

36 x_ptr_row = X + row_offset 

37 norm_ptr = norms_out + pid 

38 

39 _sum = tl.zeros([BLOCK_SIZE], dtype=cdtype) 

40 

41 for off in range(0, N, BLOCK_SIZE): 

42 cols = off + tl.arange(0, BLOCK_SIZE) 

43 mask = cols < N 

44 x_vals = tl.load(x_ptr_row + cols, mask=mask, other=0.0).to(cdtype) 

45 abs_vals = tl.abs(x_vals) 

46 if p_val == 2.0: 

47 powered = x_vals * x_vals 

48 else: 

49 powered = tl_extra_shim.pow(abs_vals, p_val) 

50 _sum += powered 

51 

52 sum_val = tl.sum(_sum) 

53 if p_val == 2.0: 

54 norm = tl_extra_shim.sqrt(sum_val) 

55 else: 

56 norm = tl_extra_shim.pow(sum_val, 1.0 / p_val) 

57 

58 tl.store(norm_ptr, norm) 

59 

60 

61@libentry() 

62@triton.jit 

63def renorm_kernel_scale( 

64 X, 

65 norms_in, 

66 Y, 

67 M, 

68 N, 

69 p_val, 

70 maxnorm, 

71 BLOCK_SIZE: tl.constexpr, 

72): 

73 """Kernel to apply scaling based on precomputed norms.""" 

74 pid = tle.program_id(0) 

75 

76 if tl.constexpr(X.dtype.element_ty == tl.float16) or tl.constexpr( 

77 X.dtype.element_ty == tl.bfloat16 

78 ): 

79 cdtype = tl.float32 

80 else: 

81 cdtype = X.dtype.element_ty 

82 

83 row_offset = pid * N 

84 x_ptr_row = X + row_offset 

85 y_ptr_row = Y + row_offset 

86 norm = tl.load(norms_in + pid) 

87 

88 if norm <= maxnorm: 

89 for off in range(0, N, BLOCK_SIZE): 

90 cols = off + tl.arange(0, BLOCK_SIZE) 

91 mask = cols < N 

92 x_vals = tl.load(x_ptr_row + cols, mask=mask, other=0.0) 

93 tl.store(y_ptr_row + cols, x_vals, mask=mask) 

94 else: 

95 scale = maxnorm / norm 

96 for off in range(0, N, BLOCK_SIZE): 

97 cols = off + tl.arange(0, BLOCK_SIZE) 

98 mask = cols < N 

99 x_vals = tl.load(x_ptr_row + cols, mask=mask, other=0.0).to(cdtype) 

100 y_vals = x_vals * scale 

101 tl.store(y_ptr_row + cols, y_vals.to(X.dtype.element_ty), mask=mask) 

102 

103 

104def renorm(input, p, dim, maxnorm): 

105 logger.debug("GEMS RENORM") 

106 

107 if dim < 0: 

108 dim = input.ndim + dim 

109 

110 # Handle dim 0 case efficiently with single-kernel-per-row approach 

111 if dim == 0: 

112 M = input.shape[0] 

113 N = input.numel() // M 

114 

115 input = input.contiguous() 

116 norms = torch.empty((M,), dtype=input.dtype, device=input.device) 

117 

118 BLOCK = min(triton.next_power_of_2(N), 128) 

119 grid = (M,) 

120 

121 with torch_device_fn.device(input.device): 

122 renorm_kernel_norms[grid]( 

123 input, 

124 norms, 

125 M, 

126 N, 

127 p, 

128 BLOCK_SIZE=BLOCK, 

129 ) 

130 

131 output = torch.empty_like(input) 

132 

133 with torch_device_fn.device(input.device): 

134 renorm_kernel_scale[grid]( 

135 input, 

136 norms, 

137 output, 

138 M, 

139 N, 

140 p, 

141 maxnorm, 

142 BLOCK_SIZE=BLOCK, 

143 ) 

144 

145 return output 

146 else: 

147 # For non-zero dim, use permute to make dim=0 

148 ndim = input.ndim 

149 perm = list(range(ndim)) 

150 perm.remove(dim) 

151 perm.insert(0, dim) 

152 inv_perm = [perm.index(i) for i in range(ndim)] 

153 

154 x_perm = input.permute(perm) 

155 result = renorm(x_perm, p, 0, maxnorm) 

156 return result.permute(inv_perm) 

157 

158 

159def renorm_(input, p, dim, maxnorm): 

160 logger.debug("GEMS RENORM_") 

161 

162 if dim < 0: 

163 dim = input.ndim + dim 

164 

165 if dim == 0: 

166 M = input.shape[0] 

167 N = input.numel() // M 

168 

169 input = input.contiguous() 

170 norms = torch.empty((M,), dtype=input.dtype, device=input.device) 

171 

172 BLOCK = min(triton.next_power_of_2(N), 128) 

173 grid = (M,) 

174 

175 with torch_device_fn.device(input.device): 

176 renorm_kernel_norms[grid]( 

177 input, 

178 norms, 

179 M, 

180 N, 

181 p, 

182 BLOCK_SIZE=BLOCK, 

183 ) 

184 

185 with torch_device_fn.device(input.device): 

186 renorm_kernel_scale[grid]( 

187 input, 

188 norms, 

189 input, 

190 M, 

191 N, 

192 p, 

193 maxnorm, 

194 BLOCK_SIZE=BLOCK, 

195 ) 

196 

197 return input 

198 else: 

199 # For non-zero dim, use permute to make dim=0 

200 ndim = input.ndim 

201 perm = list(range(ndim)) 

202 perm.remove(dim) 

203 perm.insert(0, dim) 

204 inv_perm = [perm.index(i) for i in range(ndim)] 

205 

206 x_perm = input.permute(perm) 

207 result = renorm_(x_perm, p, 0, maxnorm) 

208 input.copy_(result.permute(inv_perm)) 

209 return input