Coverage for src/flag_gems/ops/bmm.py: 38%

96 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import logging 

2import os 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems import runtime 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry, libtuner 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger(__name__) 

14 

15 

16@libentry() 

17@libtuner( 

18 configs=runtime.ops_get_configs("bmm", pre_hook=None) 

19 if os.environ.get("USE_FLAGTUNE") == "1" 

20 else runtime.get_tuned_config("bmm"), 

21 key=["M", "N", "K", "stride_am", "stride_bk"], 

22 strategy=runtime.get_expand_config("bmm")["strategy"] 

23 if os.environ.get("USE_FLAGTUNE") == "1" 

24 else [ 

25 "log", 

26 "log", 

27 "log", 

28 "align32", 

29 "align32", 

30 ], 

31) 

32@triton.heuristics(runtime.get_heuristic_config("bmm")) 

33@triton.jit 

34def bmm_kernel( 

35 A, 

36 B, 

37 O, 

38 M, 

39 N, 

40 K, 

41 stride_ab, 

42 stride_am, 

43 stride_ak, 

44 stride_bb, 

45 stride_bk, 

46 stride_bn, 

47 stride_ob, 

48 stride_om, 

49 stride_on, 

50 TILE_M: tl.constexpr, 

51 TILE_N: tl.constexpr, 

52 TILE_K: tl.constexpr, 

53 GROUP_M: tl.constexpr, 

54 DIVISIBLE_M: tl.constexpr, 

55 DIVISIBLE_N: tl.constexpr, 

56 DIVISIBLE_K: tl.constexpr, 

57 IS_FP64: tl.constexpr = False, 

58): 

59 # batch offsets 

60 pid_b = tle.program_id(2) 

61 A += pid_b * stride_ab 

62 B += pid_b * stride_bb 

63 O += pid_b * stride_ob 

64 

65 pidx = tle.program_id(0) 

66 pidy = tle.program_id(1) 

67 

68 if GROUP_M == 1: 

69 pid_m, pid_n = pidx, pidy 

70 else: 

71 # reorder CTAs 

72 gridx = tle.num_programs(0) 

73 gridy = tle.num_programs(1) 

74 pid = pidx + pidy * gridx 

75 

76 num_CTA_per_group = gridy * GROUP_M 

77 

78 group_id = pid // num_CTA_per_group 

79 inner_group_id = pid % num_CTA_per_group 

80 GROUP_SIZE = tl.where( 

81 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M 

82 ) 

83 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE 

84 pid_n = inner_group_id // GROUP_SIZE 

85 

86 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M) 

87 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N) 

88 offs_k = tl.arange(0, TILE_K) 

89 

90 if not DIVISIBLE_M: 

91 mask_m = offs_m < M 

92 if not DIVISIBLE_N: 

93 mask_n = offs_n < N 

94 

95 a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak 

96 b_ptrs = B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn 

97 o_ptrs = O + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on 

98 

99 num_iters = tl.cdiv(K, TILE_K) 

100 if IS_FP64: 

101 o = tl.zeros((TILE_M, TILE_N), dtype=tl.float64) 

102 else: 

103 o = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) 

104 for _ in range(num_iters): 

105 if DIVISIBLE_K: 

106 if DIVISIBLE_M: 

107 mask_a = None 

108 else: 

109 mask_a = mask_m[:, None] 

110 if DIVISIBLE_N: 

111 mask_b = None 

112 else: 

113 mask_b = mask_n[None, :] 

114 else: 

115 mask_k = offs_k < K 

116 if DIVISIBLE_M: 

117 mask_a = mask_k[None, :] 

118 else: 

119 mask_a = mask_m[:, None] & mask_k[None, :] 

120 if DIVISIBLE_N: 

121 mask_b = mask_k[:, None] 

122 else: 

123 mask_b = mask_k[:, None] & mask_n[None, :] 

124 

125 a = tl.load(a_ptrs, mask_a) 

126 b = tl.load(b_ptrs, mask_b) 

127 

128 offs_k += TILE_K 

129 a_ptrs += TILE_K * stride_ak 

130 b_ptrs += TILE_K * stride_bk 

131 

132 o += tl.dot(a, b, allow_tf32=False) 

133 

134 if DIVISIBLE_M and DIVISIBLE_N: 

135 mask_c = None 

136 elif DIVISIBLE_M and not DIVISIBLE_N: 

137 mask_c = mask_n[None, :] 

138 elif not DIVISIBLE_M and DIVISIBLE_N: 

139 mask_c = mask_m[:, None] 

140 else: 

141 mask_c = mask_m[:, None] & mask_n[None, :] 

142 tl.store(o_ptrs, o, mask_c) 

143 

144 

145def bmm(A, B): 

146 logger.debug("GEMS BMM") 

147 assert A.shape[0] == B.shape[0], "Batch dim mismatch" 

148 assert A.shape[2] == B.shape[1], "K dim mismatch" 

149 batch, M, K = A.shape 

150 _, _, N = B.shape 

151 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device) 

152 

153 grid_fn = lambda meta: ( 

154 triton.cdiv(meta["M"], meta["TILE_M"]), 

155 triton.cdiv(meta["N"], meta["TILE_N"]), 

156 batch, 

157 ) 

158 with torch_device_fn.device(A.device): 

159 bmm_kernel[grid_fn]( 

160 A, 

161 B, 

162 out, 

163 M, 

164 N, 

165 K, 

166 A.stride(0), 

167 A.stride(1), 

168 A.stride(2), 

169 B.stride(0), 

170 B.stride(1), 

171 B.stride(2), 

172 out.stride(0), 

173 out.stride(1), 

174 out.stride(2), 

175 IS_FP64=A.dtype == torch.float64, 

176 ) 

177 return out 

178 

179 

180def bmm_out(A, B, out): 

181 logger.debug("GEMS BMM_OUT") 

182 assert A.shape[0] == B.shape[0] == out.shape[0], "Batch dim mismatch" 

183 assert A.shape[2] == B.shape[1], "K dim mismatch" 

184 batch, M, K = A.shape 

185 _, _, N = B.shape 

186 

187 grid_fn = lambda meta: ( 

188 triton.cdiv(meta["M"], meta["TILE_M"]), 

189 triton.cdiv(meta["N"], meta["TILE_N"]), 

190 batch, 

191 ) 

192 with torch_device_fn.device(A.device): 

193 bmm_kernel[grid_fn]( 

194 A, 

195 B, 

196 out, 

197 M, 

198 N, 

199 K, 

200 A.stride(0), 

201 A.stride(1), 

202 A.stride(2), 

203 B.stride(0), 

204 B.stride(1), 

205 B.stride(2), 

206 out.stride(0), 

207 out.stride(1), 

208 out.stride(2), 

209 IS_FP64=A.dtype == torch.float64, 

210 ) 

211 return out