Coverage for src/flag_gems/runtime/backend/_ascend/ops/matmul_int8.py: 0%

56 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved. 

2# 

3# Permission is hereby granted, free of charge, to any person obtaining a copy 

4# of this software and associated documentation files (the "Software"), to deal 

5# in the Software without restriction, including without limitation the rights 

6# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 

7# copies of the Software, and to permit persons to whom the Software is 

8# furnished to do so, subject to the following conditions: 

9# 

10# The above copyright notice and this permission notice shall be included in 

11# all copies or substantial portions of the Software. 

12# 

13# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 

14# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 

15# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 

16# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 

17# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 

18# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 

19# THE SOFTWARE. 

20 

21""" 

22Matrix Multiplication 

23=============== 

24""" 

25 

26import torch 

27import torch_npu 

28import triton 

29import triton.language as tl 

30 

31DEV = "npu" 

32 

33 

34def get_output_dtype(a_dtype, b_dtype): 

35 return torch.bfloat16 

36 

37 

38def get_autotune_config(): 

39 return [ 

40 triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}), 

41 triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 128}), 

42 triton.Config({"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 256}), 

43 ] 

44 

45 

46@triton.autotune( 

47 configs=get_autotune_config(), 

48 key=["M", "N", "K"], 

49) 

50@triton.jit 

51def matmul_kernel( 

52 # Pointers to matrices 

53 a_ptr, 

54 b_ptr, 

55 c_ptr, 

56 # Matrix dimensions 

57 M, 

58 N, 

59 K, 

60 # The stride variables represent how much to increase the ptr by when moving by 1 

61 # element in a particular dimension. 

62 stride_am, 

63 stride_ak, # 

64 stride_bk, 

65 stride_bn, # 

66 stride_cm, 

67 stride_cn, 

68 # Meta-parameters 

69 BLOCK_SIZE_M: tl.constexpr, 

70 BLOCK_SIZE_N: tl.constexpr, 

71 BLOCK_SIZE_K: tl.constexpr, # 

72): 

73 """Kernel for computing the matmul C = A x B. 

74 A has shape (M, K), B has shape (K, N) and C has shape (M, N) 

75 """ 

76 # L2 Cache Optimization: Group multiple M-blocks together to reuse B columns 

77 # GROUP_SIZE_M=8 means 8 consecutive M-blocks share the same B columns in L2 cache 

78 GROUP_SIZE_M: tl.constexpr = 8 

79 # ----------------------------------------------------------- 

80 # Map program ids `pid` to the block of C it should compute. 

81 # This is done in a grouped ordering to promote L2 data reuse. 

82 # See above `L2 Cache Optimizations` section for details. 

83 pid = tl.program_id(axis=0) 

84 num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) 

85 num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) 

86 num_pid_in_group = GROUP_SIZE_M * num_pid_n 

87 group_id = pid // num_pid_in_group 

88 first_pid_m = group_id * GROUP_SIZE_M 

89 group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) 

90 pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) 

91 pid_n = (pid % num_pid_in_group) // group_size_m 

92 

93 # ---------------------------------------------------------- 

94 # Create block pointers for A, B, and C using make_block_ptr. 

95 a_block_ptr = tl.make_block_ptr( 

96 base=a_ptr, 

97 shape=(M, K), 

98 strides=(stride_am, stride_ak), 

99 offsets=(pid_m * BLOCK_SIZE_M, 0), 

100 block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K), 

101 order=(1, 0), 

102 ) 

103 b_block_ptr = tl.make_block_ptr( 

104 base=b_ptr, 

105 shape=(K, N), 

106 strides=(stride_bk, stride_bn), 

107 offsets=(0, pid_n * BLOCK_SIZE_N), 

108 block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), 

109 order=(1, 0), 

110 ) 

111 # ----------------------------------------------------------- 

112 # Iterate to compute a block of the C matrix. 

113 # Use int32 accumulator for int8 inputs. 

114 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32) 

115 for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): 

116 a = tl.load(a_block_ptr, boundary_check=(0, 1), padding_option="zero") 

117 b = tl.load(b_block_ptr, boundary_check=(0, 1), padding_option="zero") 

118 accumulator = tl.dot(a, b, accumulator) 

119 a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_SIZE_K)) 

120 b_block_ptr = tl.advance(b_block_ptr, (BLOCK_SIZE_K, 0)) 

121 c = accumulator.to(c_ptr.dtype.element_ty) 

122 # ----------------------------------------------------------- 

123 # Write back the block of the output matrix C. 

124 c_block_ptr = tl.make_block_ptr( 

125 base=c_ptr, 

126 shape=(M, N), 

127 strides=(stride_cm, stride_cn), 

128 offsets=(pid_m * BLOCK_SIZE_M, pid_n * BLOCK_SIZE_N), 

129 block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), 

130 order=(1, 0), 

131 ) 

132 tl.store(c_block_ptr, c, boundary_check=(0, 1)) 

133 

134 

135def torch_matmul(a, b): 

136 print(f"{a.dtype=} {b.dtype=}") 

137 # b is (N, K), npu_quant_matmul expects (K, N), so transpose 

138 scale = torch.ones(1, dtype=torch.float32, device=a.device) 

139 result = torch_npu.npu_quant_matmul( 

140 a, b.t().contiguous(), scale, output_dtype=torch.float16 

141 ) 

142 return result.to(torch.bfloat16) 

143 

144 

145# %% 

146# We can now create a convenience wrapper function that only takes two input tensors, 

147# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. 

148 

149 

150def matmul_int8(a, b): 

151 # Save original shape for 3D support 

152 a_shape = a.shape 

153 if a.ndim == 3: 

154 a = a.contiguous().reshape(-1, a.shape[-1]) 

155 # Handle non-contiguous inputs if necessary 

156 if a.stride(0) > 1 and a.stride(1) > 1: 

157 a = a.contiguous() 

158 # b has shape (N, K), transpose to (K, N) contiguous for the kernel 

159 b = b.t().contiguous() 

160 # Check constraints. After transpose, b has shape (K, N) 

161 assert a.shape[1] == b.shape[0], "Incompatible dimensions" 

162 M, K = a.shape 

163 N = b.shape[1] 

164 # Allocates output. 

165 c_dtype = get_output_dtype(a.dtype, b.dtype) 

166 c = torch.empty((M, N), device=a.device, dtype=c_dtype) 

167 # 1D launch kernel where each block gets its own program. 

168 grid = lambda META: ( 

169 triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), 

170 ) 

171 matmul_kernel[grid]( 

172 a, 

173 b, 

174 c, # 

175 M, 

176 N, 

177 K, # 

178 a.stride(0), 

179 a.stride(1), # 

180 b.stride(0), 

181 b.stride(1), 

182 c.stride(0), 

183 c.stride(1), # 

184 ) 

185 # Reshape output back if input was 3D 

186 if len(a_shape) == 3: 

187 c = c.reshape(*a_shape[:-1], N) 

188 return c 

189 

190 

191# %% 

192# Unit Test 

193# --------- 

194# 

195# We can test our custom matrix multiplication operation against a native torch implementation (i.e., cuBLAS). 

196# if __name__ == "__main__": 

197# torch.npu.set_device(1) 

198# torch.manual_seed(0) 

199 

200# a = torch.randint(-5, 5, (1024, 10240), device=DEV, dtype=torch.int8) 

201# b = torch.randint(-5, 5, (2048, 10240), device=DEV, dtype=torch.int8) # (N, K) 

202# torch_output = torch_matmul(a, b) 

203# print(f"torch_output_with_int8_inputs={torch_output}") 

204# triton_output = matmul(a, b) 

205# print(f"triton_output_with_int8_inputs={triton_output}")