Coverage for src/flag_gems/runtime/backend/_ascend/ops/matmul_int8.py: 0%
56 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
2#
3# Permission is hereby granted, free of charge, to any person obtaining a copy
4# of this software and associated documentation files (the "Software"), to deal
5# in the Software without restriction, including without limitation the rights
6# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7# copies of the Software, and to permit persons to whom the Software is
8# furnished to do so, subject to the following conditions:
9#
10# The above copyright notice and this permission notice shall be included in
11# all copies or substantial portions of the Software.
12#
13# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19# THE SOFTWARE.
21"""
22Matrix Multiplication
23===============
24"""
26import torch
27import torch_npu
28import triton
29import triton.language as tl
31DEV = "npu"
34def get_output_dtype(a_dtype, b_dtype):
35 return torch.bfloat16
38def get_autotune_config():
39 return [
40 triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}),
41 triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128}),
42 triton.Config({"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 256}),
43 ]
46@triton.autotune(
47 configs=get_autotune_config(),
48 key=["M", "N", "K"],
49)
50@triton.jit
51def matmul_kernel(
52 # Pointers to matrices
53 a_ptr,
54 b_ptr,
55 c_ptr,
56 # Matrix dimensions
57 M,
58 N,
59 K,
60 # The stride variables represent how much to increase the ptr by when moving by 1
61 # element in a particular dimension.
62 stride_am,
63 stride_ak, #
64 stride_bk,
65 stride_bn, #
66 stride_cm,
67 stride_cn,
68 # Meta-parameters
69 BLOCK_SIZE_M: tl.constexpr,
70 BLOCK_SIZE_N: tl.constexpr,
71 BLOCK_SIZE_K: tl.constexpr, #
72):
73 """Kernel for computing the matmul C = A x B.
74 A has shape (M, K), B has shape (K, N) and C has shape (M, N)
75 """
76 # L2 Cache Optimization: Group multiple M-blocks together to reuse B columns
77 # GROUP_SIZE_M=8 means 8 consecutive M-blocks share the same B columns in L2 cache
78 GROUP_SIZE_M: tl.constexpr = 8
79 # -----------------------------------------------------------
80 # Map program ids `pid` to the block of C it should compute.
81 # This is done in a grouped ordering to promote L2 data reuse.
82 # See above `L2 Cache Optimizations` section for details.
83 pid = tl.program_id(axis=0)
84 num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
85 num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
86 num_pid_in_group = GROUP_SIZE_M * num_pid_n
87 group_id = pid // num_pid_in_group
88 first_pid_m = group_id * GROUP_SIZE_M
89 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
90 pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
91 pid_n = (pid % num_pid_in_group) // group_size_m
93 # ----------------------------------------------------------
94 # Create block pointers for A, B, and C using make_block_ptr.
95 a_block_ptr = tl.make_block_ptr(
96 base=a_ptr,
97 shape=(M, K),
98 strides=(stride_am, stride_ak),
99 offsets=(pid_m * BLOCK_SIZE_M, 0),
100 block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K),
101 order=(1, 0),
102 )
103 b_block_ptr = tl.make_block_ptr(
104 base=b_ptr,
105 shape=(K, N),
106 strides=(stride_bk, stride_bn),
107 offsets=(0, pid_n * BLOCK_SIZE_N),
108 block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N),
109 order=(1, 0),
110 )
111 # -----------------------------------------------------------
112 # Iterate to compute a block of the C matrix.
113 # Use int32 accumulator for int8 inputs.
114 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)
115 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
116 a = tl.load(a_block_ptr, boundary_check=(0, 1), padding_option="zero")
117 b = tl.load(b_block_ptr, boundary_check=(0, 1), padding_option="zero")
118 accumulator = tl.dot(a, b, accumulator)
119 a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K))
120 b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0))
121 c = accumulator.to(c_ptr.dtype.element_ty)
122 # -----------------------------------------------------------
123 # Write back the block of the output matrix C.
124 c_block_ptr = tl.make_block_ptr(
125 base=c_ptr,
126 shape=(M, N),
127 strides=(stride_cm, stride_cn),
128 offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N),
129 block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N),
130 order=(1, 0),
131 )
132 tl.store(c_block_ptr, c, boundary_check=(0, 1))
135def torch_matmul(a, b):
136 print(f"{a.dtype=} {b.dtype=}")
137 # b is (N, K), npu_quant_matmul expects (K, N), so transpose
138 scale = torch.ones(1, dtype=torch.float32, device=a.device)
139 result = torch_npu.npu_quant_matmul(
140 a, b.t().contiguous(), scale, output_dtype=torch.float16
141 )
142 return result.to(torch.bfloat16)
145# %%
146# We can now create a convenience wrapper function that only takes two input tensors,
147# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel.
150def matmul_int8(a, b):
151 # Save original shape for 3D support
152 a_shape = a.shape
153 if a.ndim == 3:
154 a = a.contiguous().reshape(-1, a.shape[-1])
155 # Handle non-contiguous inputs if necessary
156 if a.stride(0) > 1 and a.stride(1) > 1:
157 a = a.contiguous()
158 # b has shape (N, K), transpose to (K, N) contiguous for the kernel
159 b = b.t().contiguous()
160 # Check constraints. After transpose, b has shape (K, N)
161 assert a.shape[1] == b.shape[0], "Incompatible dimensions"
162 M, K = a.shape
163 N = b.shape[1]
164 # Allocates output.
165 c_dtype = get_output_dtype(a.dtype, b.dtype)
166 c = torch.empty((M, N), device=a.device, dtype=c_dtype)
167 # 1D launch kernel where each block gets its own program.
168 grid = lambda META: (
169 triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
170 )
171 matmul_kernel[grid](
172 a,
173 b,
174 c, #
175 M,
176 N,
177 K, #
178 a.stride(0),
179 a.stride(1), #
180 b.stride(0),
181 b.stride(1),
182 c.stride(0),
183 c.stride(1), #
184 )
185 # Reshape output back if input was 3D
186 if len(a_shape) == 3:
187 c = c.reshape(*a_shape[:-1], N)
188 return c
191# %%
192# Unit Test
193# ---------
194#
195# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS).
196# if __name__ == "__main__":
197# torch.npu.set_device(1)
198# torch.manual_seed(0)
200# a = torch.randint(-5, 5, (1024, 10240), device=DEV, dtype=torch.int8)
201# b = torch.randint(-5, 5, (2048, 10240), device=DEV, dtype=torch.int8) # (N, K)
202# torch_output = torch_matmul(a, b)
203# print(f"torch_output_with_int8_inputs={torch_output}")
204# triton_output = matmul(a, b)
205# print(f"triton_output_with_int8_inputs={triton_output}")