Coverage for src/flag_gems/ops/act_quant.py: 15%

91 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1from typing import Optional, Tuple 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7 

8@triton.jit 

9def fast_log2_ceil(x): 

10 # bits_x = T.reinterpret("uint32", x) 

11 bits_x = x.cast(tl.uint32, bitcast=True) 

12 exp_x = (bits_x >> 23) & 0xFF 

13 man_bits = bits_x & ((1 << 23) - 1) 

14 # return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0)) 

15 return (exp_x - 127 + tl.where(man_bits != 0, 1, 0)).cast(tl.int32) 

16 

17 

18@triton.jit 

19def fast_pow2(x): 

20 bits_x = (x + 127) << 23 

21 # return T.reinterpret("float32", bits_x) 

22 return bits_x.cast(tl.float32, bitcast=True) 

23 

24 

25@triton.jit 

26def fast_round_scale(amax, fp8_max_inv): 

27 return fast_pow2(fast_log2_ceil(amax * fp8_max_inv)) 

28 

29 

30# @libentry() 

31@triton.jit( 

32 do_not_specialize=[ 

33 "M", 

34 ] 

35) 

36def act_quant_triton_kernel( 

37 X_ptr, 

38 Y_ptr, 

39 S_ptr, 

40 M, 

41 N, 

42 stride_xm, 

43 stride_ym, 

44 stride_sm, 

45 BLOCK_M: tl.constexpr, 

46 BLOCK_N: tl.constexpr, 

47 ROUND_SCALE: tl.constexpr, 

48): 

49 pid_m = tl.program_id(0) 

50 pid_n = tl.program_id(1) 

51 

52 row_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

53 col_offsets = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

54 

55 mask_row = row_offset < M 

56 mask_col = col_offsets < N 

57 mask = mask_row[:, None] & mask_col[None, :] 

58 

59 x = tl.load( 

60 X_ptr + row_offset[:, None] * stride_xm + col_offsets[None, :], 

61 mask=mask, 

62 other=0.0, 

63 ) 

64 

65 amax = tl.max(tl.abs(x), axis=1) 

66 amax = tl.maximum(amax, 1e-4) 

67 

68 FP8_MAX: tl.constexpr = 448.0 

69 FP8_MAX_INV: tl.constexpr = 1.0 / 448.0 

70 

71 if ROUND_SCALE: 

72 # Round scale to power of 2: scale = 2^ceil(log2(amax / 448)) 

73 # scale_raw = amax * FP8_MAX_INV 

74 # log2_scale = tl.math.log2(scale_raw) 

75 # log2_ceil = tl.math.ceil(log2_scale) 

76 # scale = tl.math.exp2(log2_ceil) 

77 scale = fast_round_scale(amax, FP8_MAX_INV) 

78 else: 

79 scale = amax * FP8_MAX_INV 

80 

81 y = x / scale[:, None] 

82 y = tl.clamp(y, -FP8_MAX, FP8_MAX) 

83 

84 y_offset = row_offset 

85 tl.store( 

86 Y_ptr + y_offset[:, None] * stride_ym + col_offsets[None, :], 

87 y.to(tl.float8e4nv), 

88 mask=mask, 

89 ) 

90 

91 s_offset = row_offset 

92 tl.store(S_ptr + s_offset * stride_sm + pid_n, scale, mask=mask_row) 

93 

94 

95def act_quant_triton( 

96 x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None 

97) -> Tuple[torch.Tensor, torch.Tensor]: 

98 """ 

99 Quantizes the input tensor `x` using block-wise quantization 

100 

101 Args: 

102 x (torch.Tensor): The input tensor to be quantized. Must be contiguous and 

103 its last dimension size must be divisible by `block_size`. 

104 block_size (int, optional): The size of the blocks for quantization. Default is 128. 

105 scale_fmt (Optional[str], optional): If not None, rounds scale to power of 2. 

106 

107 Returns: 

108 Tuple[torch.Tensor, torch.Tensor]: A tuple containing: 

109 - The quantized tensor with dtype `torch.float8_e4m3fn`. 

110 - A tensor of scaling factors with dtype `torch.float32`. 

111 """ 

112 assert x.is_contiguous(), "Input tensor must be contiguous" 

113 assert ( 

114 x.size(-1) % block_size == 0 

115 ), f"Last dimension size must be divisible by block_size (block_size={block_size})" 

116 

117 N = x.size(-1) 

118 # original_shape = x.shape 

119 x_2d = x.view(-1, N) 

120 M = x_2d.size(0) 

121 

122 BLOCK_M = 32 

123 # if M <= 32: 

124 # BLOCK_M = M 

125 # elif M <= 512: 

126 # BLOCK_M = 16 

127 # else: 

128 # BLOCK_M = 32 

129 

130 BLOCK_N = block_size 

131 m_blocks = triton.cdiv(M, BLOCK_M) 

132 n_blocks = N // BLOCK_N 

133 

134 y = torch.empty_like(x, dtype=torch.float8_e4m3fn) 

135 s = x.new_empty(*x.size()[:-1], n_blocks, dtype=torch.float32) 

136 y_view = y.view(-1, N) 

137 s_view = s.view(-1, n_blocks) 

138 

139 grid = (m_blocks, n_blocks) 

140 act_quant_triton_kernel[grid]( 

141 x_2d, 

142 y_view, 

143 s_view, 

144 M, 

145 N, 

146 x_2d.stride(0), 

147 y_view.stride(0), 

148 s_view.stride(0), 

149 BLOCK_M=BLOCK_M, 

150 BLOCK_N=BLOCK_N, 

151 ROUND_SCALE=(scale_fmt is not None), 

152 ) 

153 

154 # y = y.view(original_shape) 

155 # s = s.view(*original_shape[:-1], n_blocks) 

156 

157 return y, s 

158 

159 

160if __name__ == "__main__": 

161 from kernel import act_quant 

162 

163 torch.manual_seed(2026) 

164 

165 # test_shape = [ 

166 # (16, 128, 128), 

167 # (32, 128, 512), 

168 # (64, 128, 2048), 

169 # (128, 128, 8192), 

170 # (256, 128, 32768), 

171 

172 # # [1, 12, 4096], 

173 # # [1, 12, 1024], 

174 # # [1, 12, 448], 

175 # # [1, 12, 2048], 

176 # # [2, 4096], 

177 # # [1, 2048], 

178 # ] 

179 M = [1, 40, 164, 512, 3454, 12027, 38594] 

180 # M = [1, 64, 128, 512, 4096, 4096*4, 4096*16] 

181 N = [128, 448, 2048, 8192] 

182 test_shape = [(m, n) for m in M for n in N] 

183 fmt = [None, "ue8m0"] 

184 block_sizes = [64, 128] 

185 

186 for scale_fmt in fmt: 

187 for shape in test_shape: 

188 for block_size in block_sizes: 

189 # print(f"Testing shape {shape} with block_size {block_size} and scale_fmt {scale_fmt}") 

190 if shape[-1] % block_size != 0: 

191 print( 

192 f"Skipping shape {shape} with block_size {block_size} due to incompatible dimensions." 

193 ) 

194 continue 

195 x = torch.randn(shape, dtype=torch.bfloat16, device="cuda") 

196 

197 y_ref, s_ref = act_quant(x, block_size=block_size, scale_fmt=scale_fmt) 

198 y_triton, s_triton = act_quant_triton( 

199 x, block_size=block_size, scale_fmt=scale_fmt 

200 ) 

201 torch.testing.assert_close( 

202 y_ref.float(), y_triton.float(), rtol=1e-2, atol=1e-2 

203 ) 

204 torch.testing.assert_close(s_ref, s_triton, rtol=1e-5, atol=1e-5) 

205 print( 

206 f"Shape {str(shape):20s} | scale_fmt:{scale_fmt} | block_size:{block_size} | PASS" 

207 ) 

208 

209 print("=" * 60) 

210 

211 su = [] 

212 for scale_fmt in fmt: 

213 for shape in test_shape: 

214 for block_size in block_sizes: 

215 if shape[-1] % block_size != 0: 

216 print( 

217 f"Skipping shape {shape} with block_size {block_size} due to incompatible dimensions." 

218 ) 

219 continue 

220 x = torch.randn(shape, dtype=torch.bfloat16, device="cuda") 

221 ref_time = triton.testing.do_bench( 

222 lambda: act_quant(x, block_size=block_size, scale_fmt=scale_fmt), 

223 warmup=50, 

224 rep=200, 

225 ) 

226 

227 triton_time = triton.testing.do_bench( 

228 lambda: act_quant_triton( 

229 x, block_size=block_size, scale_fmt=scale_fmt 

230 ), 

231 warmup=50, 

232 rep=200, 

233 ) 

234 su.append(ref_time / triton_time) 

235 print( 

236 f"Shape {str(shape):20s}, Scale format: {scale_fmt}, " 

237 f"block_size: {block_size} | " 

238 f"TileLang: {ref_time:.3f} ms | Triton: {triton_time:.3f} ms | " 

239 f"Speedup: {ref_time / triton_time:.2f}x" 

240 ) 

241 print( 

242 f"Average speedup: {sum(su) / len(su):.2f}x, max speedup: {max(su):.2f}x, min speedup: {min(su):.2f}x" 

243 ) 

244 

245 # x = torch.randn(4096*4, 40960, dtype=torch.bfloat16, device="cuda") 

246 

247 # # Warmup 

248 # for _ in range(10): 

249 # _ = act_quant(x) 

250 # _ = act_quant_triton(x) 

251 

252 # torch.cuda.synchronize() 

253 

254 # import time 

255 

256 # # TileLang 

257 # torch.cuda.synchronize() 

258 # start = time.perf_counter() 

259 # for _ in range(100): 

260 # _ = act_quant(x) 

261 # torch.cuda.synchronize() 

262 # tilelang_time = (time.perf_counter() - start) / 100 * 1000 

263 

264 # # Triton 

265 # torch.cuda.synchronize() 

266 # start = time.perf_counter() 

267 # for _ in range(100): 

268 # _ = act_quant_triton(x) 

269 # torch.cuda.synchronize() 

270 # triton_time = (time.perf_counter() - start) / 100 * 1000 

271 

272 # print(f"TileLang: {tilelang_time:.3f} ms") 

273 # print(f"Triton: {triton_time:.3f} ms") 

274 # print(f"Speedup: {tilelang_time / triton_time:.2f}x")