Coverage for src/flag_gems/ops/bmm.py: 37%

95 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry, libtuner 

10from flag_gems.utils import triton_lang_extension as ext 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15@libentry() 

16@libtuner( 

17 configs=runtime.get_tuned_config("bmm"), 

18 key=["M", "N", "K", "stride_am", "stride_bk"], 

19 strategy=[ 

20 "log", 

21 "log", 

22 "log", 

23 "align32", 

24 "align32", 

25 ], 

26 flagtune_op_name="bmm", 

27 flagtune_expand_op_name="bmm", 

28 flagtune_pre_hook=None, 

29) 

30@triton.heuristics(runtime.get_heuristic_config("bmm")) 

31@triton.jit 

32def bmm_kernel( 

33 A, 

34 B, 

35 O, 

36 M, 

37 N, 

38 K, 

39 stride_ab, 

40 stride_am, 

41 stride_ak, 

42 stride_bb, 

43 stride_bk, 

44 stride_bn, 

45 stride_ob, 

46 stride_om, 

47 stride_on, 

48 TILE_M: tl.constexpr, 

49 TILE_N: tl.constexpr, 

50 TILE_K: tl.constexpr, 

51 GROUP_M: tl.constexpr, 

52 DIVISIBLE_M: tl.constexpr, 

53 DIVISIBLE_N: tl.constexpr, 

54 DIVISIBLE_K: tl.constexpr, 

55 IS_FP64: tl.constexpr = False, 

56): 

57 # batch offsets 

58 pid_b = ext.program_id(2) 

59 A += pid_b * stride_ab 

60 B += pid_b * stride_bb 

61 O += pid_b * stride_ob 

62 

63 pidx = ext.program_id(0) 

64 pidy = ext.program_id(1) 

65 

66 if GROUP_M == 1: 

67 pid_m, pid_n = pidx, pidy 

68 else: 

69 # reorder CTAs 

70 gridx = ext.num_programs(0) 

71 gridy = ext.num_programs(1) 

72 pid = pidx + pidy * gridx 

73 

74 num_CTA_per_group = gridy * GROUP_M 

75 

76 group_id = pid // num_CTA_per_group 

77 inner_group_id = pid % num_CTA_per_group 

78 GROUP_SIZE = tl.where( 

79 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M 

80 ) 

81 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE 

82 pid_n = inner_group_id // GROUP_SIZE 

83 

84 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M) 

85 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N) 

86 offs_k = tl.arange(0, TILE_K) 

87 

88 if not DIVISIBLE_M: 

89 mask_m = offs_m < M 

90 if not DIVISIBLE_N: 

91 mask_n = offs_n < N 

92 

93 a_ptrs = A + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak 

94 b_ptrs = B + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn 

95 o_ptrs = O + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on 

96 

97 num_iters = tl.cdiv(K, TILE_K) 

98 if IS_FP64: 

99 o = tl.zeros((TILE_M, TILE_N), dtype=tl.float64) 

100 else: 

101 o = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) 

102 for _ in range(num_iters): 

103 if DIVISIBLE_K: 

104 if DIVISIBLE_M: 

105 mask_a = None 

106 else: 

107 mask_a = mask_m[:, None] 

108 if DIVISIBLE_N: 

109 mask_b = None 

110 else: 

111 mask_b = mask_n[None, :] 

112 else: 

113 mask_k = offs_k < K 

114 if DIVISIBLE_M: 

115 mask_a = mask_k[None, :] 

116 else: 

117 mask_a = mask_m[:, None] & mask_k[None, :] 

118 if DIVISIBLE_N: 

119 mask_b = mask_k[:, None] 

120 else: 

121 mask_b = mask_k[:, None] & mask_n[None, :] 

122 

123 a = tl.load(a_ptrs, mask_a) 

124 b = tl.load(b_ptrs, mask_b) 

125 

126 offs_k += TILE_K 

127 a_ptrs += TILE_K * stride_ak 

128 b_ptrs += TILE_K * stride_bk 

129 

130 o += tl.dot(a, b, allow_tf32=False) 

131 

132 if DIVISIBLE_M and DIVISIBLE_N: 

133 mask_c = None 

134 elif DIVISIBLE_M and not DIVISIBLE_N: 

135 mask_c = mask_n[None, :] 

136 elif not DIVISIBLE_M and DIVISIBLE_N: 

137 mask_c = mask_m[:, None] 

138 else: 

139 mask_c = mask_m[:, None] & mask_n[None, :] 

140 tl.store(o_ptrs, o, mask_c) 

141 

142 

143def bmm(A, B): 

144 logger.debug("GEMS BMM") 

145 assert A.shape[0] == B.shape[0], "Batch dim mismatch" 

146 assert A.shape[2] == B.shape[1], "K dim mismatch" 

147 batch, M, K = A.shape 

148 _, _, N = B.shape 

149 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device) 

150 

151 grid_fn = lambda meta: ( 

152 triton.cdiv(meta["M"], meta["TILE_M"]), 

153 triton.cdiv(meta["N"], meta["TILE_N"]), 

154 batch, 

155 ) 

156 with torch_device_fn.device(A.device): 

157 bmm_kernel[grid_fn]( 

158 A, 

159 B, 

160 out, 

161 M, 

162 N, 

163 K, 

164 A.stride(0), 

165 A.stride(1), 

166 A.stride(2), 

167 B.stride(0), 

168 B.stride(1), 

169 B.stride(2), 

170 out.stride(0), 

171 out.stride(1), 

172 out.stride(2), 

173 IS_FP64=A.dtype == torch.float64, 

174 ) 

175 return out 

176 

177 

178def bmm_out(A, B, out): 

179 logger.debug("GEMS BMM_OUT") 

180 assert A.shape[0] == B.shape[0] == out.shape[0], "Batch dim mismatch" 

181 assert A.shape[2] == B.shape[1], "K dim mismatch" 

182 batch, M, K = A.shape 

183 _, _, N = B.shape 

184 

185 grid_fn = lambda meta: ( 

186 triton.cdiv(meta["M"], meta["TILE_M"]), 

187 triton.cdiv(meta["N"], meta["TILE_N"]), 

188 batch, 

189 ) 

190 with torch_device_fn.device(A.device): 

191 bmm_kernel[grid_fn]( 

192 A, 

193 B, 

194 out, 

195 M, 

196 N, 

197 K, 

198 A.stride(0), 

199 A.stride(1), 

200 A.stride(2), 

201 B.stride(0), 

202 B.stride(1), 

203 B.stride(2), 

204 out.stride(0), 

205 out.stride(1), 

206 out.stride(2), 

207 IS_FP64=A.dtype == torch.float64, 

208 ) 

209 return out