Coverage for src/flag_gems/ops/act_quant.py: 15%
91 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1from typing import Optional, Tuple
3import torch
4import triton
5import triton.language as tl
8@triton.jit
9def fast_log2_ceil(x):
10 # bits_x = T.reinterpret("uint32", x)
11 bits_x = x.cast(tl.uint32, bitcast=True)
12 exp_x = (bits_x >> 23) & 0xFF
13 man_bits = bits_x & ((1 << 23) - 1)
14 # return T.Cast("int32", exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0))
15 return (exp_x - 127 + tl.where(man_bits != 0, 1, 0)).cast(tl.int32)
18@triton.jit
19def fast_pow2(x):
20 bits_x = (x + 127) << 23
21 # return T.reinterpret("float32", bits_x)
22 return bits_x.cast(tl.float32, bitcast=True)
25@triton.jit
26def fast_round_scale(amax, fp8_max_inv):
27 return fast_pow2(fast_log2_ceil(amax * fp8_max_inv))
30# @libentry()
31@triton.jit(
32 do_not_specialize=[
33 "M",
34 ]
35)
36def act_quant_triton_kernel(
37 X_ptr,
38 Y_ptr,
39 S_ptr,
40 M,
41 N,
42 stride_xm,
43 stride_ym,
44 stride_sm,
45 BLOCK_M: tl.constexpr,
46 BLOCK_N: tl.constexpr,
47 ROUND_SCALE: tl.constexpr,
48):
49 pid_m = tl.program_id(0)
50 pid_n = tl.program_id(1)
52 row_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
53 col_offsets = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
55 mask_row = row_offset < M
56 mask_col = col_offsets < N
57 mask = mask_row[:, None] & mask_col[None, :]
59 x = tl.load(
60 X_ptr + row_offset[:, None] * stride_xm + col_offsets[None, :],
61 mask=mask,
62 other=0.0,
63 )
65 amax = tl.max(tl.abs(x), axis=1)
66 amax = tl.maximum(amax, 1e-4)
68 FP8_MAX: tl.constexpr = 448.0
69 FP8_MAX_INV: tl.constexpr = 1.0 / 448.0
71 if ROUND_SCALE:
72 # Round scale to power of 2: scale = 2^ceil(log2(amax / 448))
73 # scale_raw = amax * FP8_MAX_INV
74 # log2_scale = tl.math.log2(scale_raw)
75 # log2_ceil = tl.math.ceil(log2_scale)
76 # scale = tl.math.exp2(log2_ceil)
77 scale = fast_round_scale(amax, FP8_MAX_INV)
78 else:
79 scale = amax * FP8_MAX_INV
81 y = x / scale[:, None]
82 y = tl.clamp(y, -FP8_MAX, FP8_MAX)
84 y_offset = row_offset
85 tl.store(
86 Y_ptr + y_offset[:, None] * stride_ym + col_offsets[None, :],
87 y.to(tl.float8e4nv),
88 mask=mask,
89 )
91 s_offset = row_offset
92 tl.store(S_ptr + s_offset * stride_sm + pid_n, scale, mask=mask_row)
95def act_quant_triton(
96 x: torch.Tensor, block_size: int = 128, scale_fmt: Optional[str] = None
97) -> Tuple[torch.Tensor, torch.Tensor]:
98 """
99 Quantizes the input tensor `x` using block-wise quantization
101 Args:
102 x (torch.Tensor): The input tensor to be quantized. Must be contiguous and
103 its last dimension size must be divisible by `block_size`.
104 block_size (int, optional): The size of the blocks for quantization. Default is 128.
105 scale_fmt (Optional[str], optional): If not None, rounds scale to power of 2.
107 Returns:
108 Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
109 - The quantized tensor with dtype `torch.float8_e4m3fn`.
110 - A tensor of scaling factors with dtype `torch.float32`.
111 """
112 assert x.is_contiguous(), "Input tensor must be contiguous"
113 assert (
114 x.size(-1) % block_size == 0
115 ), f"Last dimension size must be divisible by block_size (block_size={block_size})"
117 N = x.size(-1)
118 # original_shape = x.shape
119 x_2d = x.view(-1, N)
120 M = x_2d.size(0)
122 BLOCK_M = 32
123 # if M <= 32:
124 # BLOCK_M = M
125 # elif M <= 512:
126 # BLOCK_M = 16
127 # else:
128 # BLOCK_M = 32
130 BLOCK_N = block_size
131 m_blocks = triton.cdiv(M, BLOCK_M)
132 n_blocks = N // BLOCK_N
134 y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
135 s = x.new_empty(*x.size()[:-1], n_blocks, dtype=torch.float32)
136 y_view = y.view(-1, N)
137 s_view = s.view(-1, n_blocks)
139 grid = (m_blocks, n_blocks)
140 act_quant_triton_kernel[grid](
141 x_2d,
142 y_view,
143 s_view,
144 M,
145 N,
146 x_2d.stride(0),
147 y_view.stride(0),
148 s_view.stride(0),
149 BLOCK_M=BLOCK_M,
150 BLOCK_N=BLOCK_N,
151 ROUND_SCALE=(scale_fmt is not None),
152 )
154 # y = y.view(original_shape)
155 # s = s.view(*original_shape[:-1], n_blocks)
157 return y, s
160if __name__ == "__main__":
161 from kernel import act_quant
163 torch.manual_seed(2026)
165 # test_shape = [
166 # (16, 128, 128),
167 # (32, 128, 512),
168 # (64, 128, 2048),
169 # (128, 128, 8192),
170 # (256, 128, 32768),
172 # # [1, 12, 4096],
173 # # [1, 12, 1024],
174 # # [1, 12, 448],
175 # # [1, 12, 2048],
176 # # [2, 4096],
177 # # [1, 2048],
178 # ]
179 M = [1, 40, 164, 512, 3454, 12027, 38594]
180 # M = [1, 64, 128, 512, 4096, 4096*4, 4096*16]
181 N = [128, 448, 2048, 8192]
182 test_shape = [(m, n) for m in M for n in N]
183 fmt = [None, "ue8m0"]
184 block_sizes = [64, 128]
186 for scale_fmt in fmt:
187 for shape in test_shape:
188 for block_size in block_sizes:
189 # print(f"Testing shape {shape} with block_size {block_size} and scale_fmt {scale_fmt}")
190 if shape[-1] % block_size != 0:
191 print(
192 f"Skipping shape {shape} with block_size {block_size} due to incompatible dimensions."
193 )
194 continue
195 x = torch.randn(shape, dtype=torch.bfloat16, device="cuda")
197 y_ref, s_ref = act_quant(x, block_size=block_size, scale_fmt=scale_fmt)
198 y_triton, s_triton = act_quant_triton(
199 x, block_size=block_size, scale_fmt=scale_fmt
200 )
201 torch.testing.assert_close(
202 y_ref.float(), y_triton.float(), rtol=1e-2, atol=1e-2
203 )
204 torch.testing.assert_close(s_ref, s_triton, rtol=1e-5, atol=1e-5)
205 print(
206 f"Shape {str(shape):20s} | scale_fmt:{scale_fmt} | block_size:{block_size} | PASS"
207 )
209 print("=" * 60)
211 su = []
212 for scale_fmt in fmt:
213 for shape in test_shape:
214 for block_size in block_sizes:
215 if shape[-1] % block_size != 0:
216 print(
217 f"Skipping shape {shape} with block_size {block_size} due to incompatible dimensions."
218 )
219 continue
220 x = torch.randn(shape, dtype=torch.bfloat16, device="cuda")
221 ref_time = triton.testing.do_bench(
222 lambda: act_quant(x, block_size=block_size, scale_fmt=scale_fmt),
223 warmup=50,
224 rep=200,
225 )
227 triton_time = triton.testing.do_bench(
228 lambda: act_quant_triton(
229 x, block_size=block_size, scale_fmt=scale_fmt
230 ),
231 warmup=50,
232 rep=200,
233 )
234 su.append(ref_time / triton_time)
235 print(
236 f"Shape {str(shape):20s}, Scale format: {scale_fmt}, "
237 f"block_size: {block_size} | "
238 f"TileLang: {ref_time:.3f} ms | Triton: {triton_time:.3f} ms | "
239 f"Speedup: {ref_time / triton_time:.2f}x"
240 )
241 print(
242 f"Average speedup: {sum(su) / len(su):.2f}x, max speedup: {max(su):.2f}x, min speedup: {min(su):.2f}x"
243 )
245 # x = torch.randn(4096*4, 40960, dtype=torch.bfloat16, device="cuda")
247 # # Warmup
248 # for _ in range(10):
249 # _ = act_quant(x)
250 # _ = act_quant_triton(x)
252 # torch.cuda.synchronize()
254 # import time
256 # # TileLang
257 # torch.cuda.synchronize()
258 # start = time.perf_counter()
259 # for _ in range(100):
260 # _ = act_quant(x)
261 # torch.cuda.synchronize()
262 # tilelang_time = (time.perf_counter() - start) / 100 * 1000
264 # # Triton
265 # torch.cuda.synchronize()
266 # start = time.perf_counter()
267 # for _ in range(100):
268 # _ = act_quant_triton(x)
269 # torch.cuda.synchronize()
270 # triton_time = (time.perf_counter() - start) / 100 * 1000
272 # print(f"TileLang: {tilelang_time:.3f} ms")
273 # print(f"Triton: {triton_time:.3f} ms")
274 # print(f"Speedup: {tilelang_time / triton_time:.2f}x")