Coverage for src/flag_gems/runtime/backend/_ascend/ops/mm.py: 0%

92 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.runtime.backend._ascend import heuristics_config_utils as _hcu 

10from flag_gems.utils import libentry, libtuner 

11 

12logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

13 

14 

15@libentry() 

16@libtuner( 

17 configs=runtime.get_tuned_config("mm"), 

18 key=["M", "N", "K"], 

19) 

20@triton.heuristics(_hcu.HEURISTICS_CONFIGS["mm"]) 

21@triton.jit 

22def mm_kernel( 

23 A, 

24 B, 

25 C, 

26 M: tl.constexpr, 

27 N: tl.constexpr, 

28 K: tl.constexpr, 

29 stride_am: tl.constexpr, 

30 stride_ak: tl.constexpr, 

31 stride_bk: tl.constexpr, 

32 stride_bn: tl.constexpr, 

33 stride_cm: tl.constexpr, 

34 stride_cn: tl.constexpr, 

35 dot_out_dtype: tl.constexpr, 

36 BLOCK_M: tl.constexpr, 

37 BLOCK_N: tl.constexpr, 

38 BLOCK_K: tl.constexpr, 

39 GROUP_M: tl.constexpr, 

40 SPLIT_K: tl.constexpr, 

41 EVEN_K: tl.constexpr, 

42): 

43 pid = tl.program_id(0) 

44 pid_z = tl.program_id(1) 

45 grid_m = tl.cdiv(M, BLOCK_M) 

46 grid_n = tl.cdiv(N, BLOCK_N) 

47 width = GROUP_M * grid_n 

48 group_id = pid // width 

49 group_size = min(grid_m - group_id * GROUP_M, GROUP_M) 

50 pid_m = group_id * GROUP_M + (pid % group_size) 

51 pid_n = (pid % width) // (group_size) 

52 ram = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

53 rbn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

54 rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K) 

55 A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak) 

56 B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn) 

57 acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=dot_out_dtype) 

58 for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): 

59 if EVEN_K: 

60 a = tl.load(A, mask=(ram < M)[:, None], other=0.0) 

61 b = tl.load(B, mask=(rbn < N)[None, :], other=0.0) 

62 else: 

63 k_remaining = K - k * (BLOCK_K * SPLIT_K) 

64 a = tl.load( 

65 A, 

66 mask=(rk[None, :] < k_remaining) & (ram < M)[:, None], 

67 other=0.0, 

68 ) 

69 b = tl.load( 

70 B, 

71 mask=(rk[:, None] < k_remaining) & (rbn < N)[None, :], 

72 other=0.0, 

73 ) 

74 if a.dtype != b.dtype: 

75 a = a.to(C.dtype.element_ty) 

76 b = b.to(C.dtype.element_ty) 

77 acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=False) 

78 A += BLOCK_K * SPLIT_K * stride_ak 

79 B += BLOCK_K * SPLIT_K * stride_bk 

80 acc = acc.to(C.dtype.element_ty) 

81 rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

82 rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) 

83 C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) 

84 mask = (rm < M)[:, None] & (rn < N)[None, :] 

85 if SPLIT_K == 1: 

86 tl.store(C, acc, mask=mask) 

87 else: 

88 tl.atomic_add(C, acc, mask=mask) 

89 

90 

91_ordered_datatypes = [torch.float16, torch.bfloat16, torch.float32] 

92 

93 

94def get_higher_dtype(a, b): 

95 if a is b: 

96 return a 

97 

98 assert a in _ordered_datatypes 

99 assert b in _ordered_datatypes 

100 

101 for d in _ordered_datatypes: 

102 if a is d: 

103 return b 

104 if b is d: 

105 return a 

106 

107 

108def mm(a, b): 

109 logger.debug("GEMS_ASCEND MM") 

110 device = a.device 

111 # handle non-contiguous inputs if necessary 

112 if a.stride(0) > 1 and a.stride(1) > 1: 

113 a = a.contiguous() 

114 if b.stride(0) > 1 and b.stride(1) > 1: 

115 b = b.contiguous() 

116 # checks constraints 

117 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

118 M, K = a.shape 

119 _, N = b.shape 

120 # allocates output 

121 c_dtype = get_higher_dtype(a.dtype, b.dtype) 

122 c = torch.empty((M, N), device=device, dtype=c_dtype) 

123 dot_out_dtype = tl.float32 

124 # launch kernel 

125 grid = lambda META: ( 

126 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

127 META.get("SPLIT_K", 1), 

128 ) 

129 with torch_device_fn.device(a.device): 

130 mm_kernel[grid]( 

131 a, 

132 b, 

133 c, 

134 M, 

135 N, 

136 K, 

137 a.stride(0), 

138 a.stride(1), 

139 b.stride(0), 

140 b.stride(1), 

141 c.stride(0), 

142 c.stride(1), 

143 dot_out_dtype=dot_out_dtype, 

144 GROUP_M=8, 

145 ) 

146 return c 

147 

148 

149def mm_out(a, b, *, out): 

150 logger.debug("GEMS_ASCEND MM_OUT") 

151 if a.stride(0) > 1 and a.stride(1) > 1: 

152 a = a.contiguous() 

153 if b.stride(0) > 1 and b.stride(1) > 1: 

154 b = b.contiguous() 

155 assert a.shape[1] == b.shape[0], "incompatible dimensions" 

156 M, K = a.shape 

157 _, N = b.shape 

158 dot_out_dtype = tl.float32 

159 grid = lambda META: ( 

160 triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]), 

161 META.get("SPLIT_K", 1), 

162 ) 

163 with torch_device_fn.device(a.device): 

164 mm_kernel[grid]( 

165 a, 

166 b, 

167 out, 

168 M, 

169 N, 

170 K, 

171 a.stride(0), 

172 a.stride(1), 

173 b.stride(0), 

174 b.stride(1), 

175 out.stride(0), 

176 out.stride(1), 

177 dot_out_dtype=dot_out_dtype, 

178 GROUP_M=8, 

179 ) 

180 return out