Coverage for src/flag_gems/ops/baddbmm.py: 31%

160 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2import os 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from .. import runtime 

9from ..runtime import torch_device_fn 

10from ..utils import libentry, libtuner 

11from ..utils import triton_lang_extension as ext 

12from .bmm import bmm 

13from .mul import mul 

14 

15logger = logging.getLogger(__name__) 

16 

17 

18@libentry() 

19@libtuner( 

20 configs=runtime.ops_get_configs("baddbmm", pre_hook=None) 

21 if os.environ.get("USE_FLAGTUNE") == "1" 

22 else runtime.get_tuned_config("baddbmm"), 

23 key=["M", "N", "K"], 

24 strategy=runtime.get_expand_config("baddbmm")["strategy"] 

25 if os.environ.get("USE_FLAGTUNE") == "1" 

26 else ["align32", "align32", "align32"], 

27 warmup=5, 

28 rep=10, 

29) 

30@triton.heuristics(runtime.get_heuristic_config("baddbmm")) 

31@triton.jit(do_not_specialize=["alpha", "beta"]) 

32def baddbmm_kernel( 

33 A, 

34 B, 

35 O, 

36 bias, 

37 alpha, 

38 beta, 

39 M, 

40 N, 

41 K, 

42 TILE_M: tl.constexpr, 

43 TILE_N: tl.constexpr, 

44 TILE_K: tl.constexpr, 

45 GROUP_M: tl.constexpr, 

46 DIVISIBLE_M: tl.constexpr, 

47 DIVISIBLE_N: tl.constexpr, 

48 DIVISIBLE_K: tl.constexpr, 

49 bias_batch_stride: tl.constexpr, 

50 bias_M_stride: tl.constexpr, 

51 bias_N_stride: tl.constexpr, 

52 IS_FP64: tl.constexpr = False, 

53): 

54 # batch offsets 

55 pid_b = ext.program_id(2) 

56 A += pid_b * M * K 

57 B += pid_b * K * N 

58 O += pid_b * M * N 

59 bias += pid_b * bias_batch_stride 

60 

61 pidx = ext.program_id(0) 

62 pidy = ext.program_id(1) 

63 

64 if GROUP_M == 1: 

65 pid_m, pid_n = pidx, pidy 

66 else: 

67 gridx = ext.num_programs(0) 

68 gridy = ext.num_programs(1) 

69 pid = pidx + pidy * gridx 

70 num_CTA_per_group = gridy * GROUP_M 

71 group_id = pid // num_CTA_per_group 

72 inner_group_id = pid % num_CTA_per_group 

73 GROUP_SIZE = tl.where( 

74 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M 

75 ) 

76 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE 

77 pid_n = inner_group_id // GROUP_SIZE 

78 

79 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M) 

80 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N) 

81 offs_k = tl.arange(0, TILE_K) 

82 

83 if not DIVISIBLE_M: 

84 mask_m = offs_m < M 

85 if not DIVISIBLE_N: 

86 mask_n = offs_n < N 

87 

88 a_ptrs = A + offs_m[:, None] * K + offs_k[None, :] 

89 b_ptrs = B + offs_k[:, None] * N + offs_n[None, :] 

90 o_ptrs = O + offs_m[:, None] * N + offs_n[None, :] 

91 

92 num_iters = tl.cdiv(K, TILE_K) 

93 if IS_FP64: 

94 accumulator = tl.zeros((TILE_M, TILE_N), dtype=tl.float64) 

95 else: 

96 accumulator = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) 

97 for _ in range(num_iters): 

98 if DIVISIBLE_K: 

99 if DIVISIBLE_M: 

100 mask_a = None 

101 else: 

102 mask_a = mask_m[:, None] 

103 if DIVISIBLE_N: 

104 mask_b = None 

105 else: 

106 mask_b = mask_n[None, :] 

107 else: 

108 mask_k = offs_k < K 

109 if DIVISIBLE_M: 

110 mask_a = mask_k[None, :] 

111 else: 

112 mask_a = mask_m[:, None] & mask_k[None, :] 

113 if DIVISIBLE_N: 

114 mask_b = mask_k[:, None] 

115 else: 

116 mask_b = mask_k[:, None] & mask_n[None, :] 

117 a = tl.load(a_ptrs, mask=mask_a) 

118 b = tl.load(b_ptrs, mask=mask_b) 

119 accumulator += tl.dot(a, b, allow_tf32=False) 

120 offs_k += TILE_K 

121 a_ptrs += TILE_K 

122 b_ptrs += TILE_K * N 

123 

124 bias_ptrs = bias + offs_m[:, None] * bias_M_stride + offs_n[None, :] * bias_N_stride 

125 

126 if DIVISIBLE_M and DIVISIBLE_N: 

127 mask_c = None 

128 else: 

129 mask_c = True 

130 if not DIVISIBLE_M: 

131 mask_c &= offs_m[:, None] < M 

132 if not DIVISIBLE_N: 

133 mask_c &= offs_n[None, :] < N 

134 

135 bi = tl.load(bias_ptrs, mask=mask_c) 

136 out = accumulator * alpha + bi * beta 

137 o = out.to(bi.dtype) 

138 tl.store(o_ptrs, o, mask=mask_c) 

139 

140 

141def _baddbmm_launch(bias, A, B, beta, alpha, out): 

142 batch, M, K = A.shape 

143 _, _, N = B.shape 

144 A = A.contiguous() 

145 B = B.contiguous() 

146 bbias = torch.broadcast_to(bias, (batch, M, N)).contiguous() 

147 bias_batch_stride = bbias.stride(0) 

148 bias_M_stride = bbias.stride(1) 

149 bias_N_stride = bbias.stride(-1) 

150 

151 grid = lambda meta: ( 

152 triton.cdiv(meta["M"], meta["TILE_M"]), 

153 triton.cdiv(meta["N"], meta["TILE_N"]), 

154 batch, 

155 ) 

156 with torch_device_fn.device(A.device): 

157 baddbmm_kernel[grid]( 

158 A, 

159 B, 

160 out, 

161 bbias, 

162 alpha, 

163 beta, 

164 M, 

165 N, 

166 K, 

167 bias_batch_stride=bias_batch_stride, 

168 bias_M_stride=bias_M_stride, 

169 bias_N_stride=bias_N_stride, 

170 ) 

171 

172 

173class BaddbmmFunction(torch.autograd.Function): 

174 @staticmethod 

175 def forward(ctx, bias, A, B, beta, alpha): 

176 logger.debug("GEMS BADDBMM FORWARD") 

177 

178 ctx.save_for_backward(A, B, bias) 

179 ctx.alpha = alpha 

180 ctx.beta = beta 

181 

182 batch, M, K = A.shape 

183 _, _, N = B.shape 

184 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device) 

185 _baddbmm_launch(bias, A, B, beta, alpha, out) 

186 return out 

187 

188 @staticmethod 

189 def backward(ctx, grad_output): 

190 logger.debug("GEMS BADDBMM BACKWARD") 

191 A, B, bias = ctx.saved_tensors 

192 

193 grad_A = None 

194 grad_B = None 

195 grad_bias = None 

196 if ctx.needs_input_grad[0]: 

197 grad_bias = compute_bias_grad(grad_output, ctx.beta, bias) 

198 if ctx.needs_input_grad[1]: 

199 grad_A = compute_A_grad(grad_output, B, ctx.alpha) 

200 if ctx.needs_input_grad[2]: 

201 grad_B = compute_B_grad(A, grad_output, ctx.alpha) 

202 

203 return grad_bias, grad_A, grad_B, None, None 

204 

205 

206def compute_bias_grad(d_output, beta, bias): 

207 grad_bias = mul(d_output, beta) 

208 if grad_bias.shape != bias.shape: 

209 # Sum over broadcasted dimensions 

210 while grad_bias.dim() > bias.dim(): 

211 grad_bias = grad_bias.sum(dim=0) 

212 for i in range(bias.dim()): 

213 if bias.shape[i] == 1 and grad_bias.shape[i] > 1: 

214 grad_bias = grad_bias.sum(dim=i, keepdim=True) 

215 return grad_bias.view(bias.shape) 

216 

217 

218def compute_A_grad(d_output, B, alpha): 

219 B_T = B.transpose(1, 2) 

220 if B.dtype == torch.float16: 

221 Bcopy = B_T.to(torch.float32) 

222 dcopye = d_output.to(torch.float32) 

223 mul1 = bmm(dcopye, Bcopy) 

224 grad_A = mul(mul1, alpha) 

225 grad_A = grad_A.to(torch.float16) 

226 else: 

227 mul1 = bmm(d_output, B_T) 

228 grad_A = mul(mul1, alpha) 

229 return grad_A 

230 

231 

232def compute_B_grad(A, d_output, alpha): 

233 A_T = A.transpose(1, 2) 

234 if A.dtype == torch.float16: 

235 Acopy = A_T.to(torch.float32) 

236 dcopye = d_output.to(torch.float32) 

237 mul2 = bmm(Acopy, dcopye) 

238 grad_B = mul(mul2, alpha) 

239 grad_B = grad_B.to(torch.float16) 

240 else: 

241 mul2 = bmm(A_T, d_output) 

242 grad_B = mul(mul2, alpha) 

243 return grad_B 

244 

245 

246def baddbmm_out(bias, A, B, *, beta=1.0, alpha=1.0, out): 

247 logger.debug("GEMS BADDBMM_OUT") 

248 batch, M, K = A.shape 

249 _, _, N = B.shape 

250 assert ( 

251 out.shape == (batch, M, N) and out.dtype == A.dtype 

252 ), "Incompatible output shape or dtype for baddbmm.out" 

253 _baddbmm_launch( 

254 bias.contiguous(), 

255 A.contiguous(), 

256 B.contiguous(), 

257 beta, 

258 alpha, 

259 out, 

260 ) 

261 return out 

262 

263 

264def baddbmm(bias, A, B, beta=1.0, alpha=1.0): 

265 return BaddbmmFunction.apply( 

266 bias.contiguous(), 

267 A.contiguous(), 

268 B.contiguous(), 

269 beta, 

270 alpha, 

271 )