Coverage for src/flag_gems/runtime/backend/_ascend/ops/baddbmm.py: 0%

147 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.ops.mul import mul 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.runtime.backend._ascend import heuristics_config_utils as _hcu 

11from flag_gems.utils import libentry, libtuner 

12from flag_gems.utils import triton_lang_extension as tle 

13 

14from .bmm import bmm 

15 

16logger = logging.getLogger(__name__) 

17 

18 

19@libentry() 

20@libtuner( 

21 configs=runtime.get_tuned_config("baddbmm"), 

22 key=["M", "N", "K"], 

23 strategy=["align32", "align32", "align32"], 

24 warmup=5, 

25 rep=10, 

26) 

27@triton.heuristics(_hcu.HEURISTICS_CONFIGS["baddbmm"]) 

28@triton.jit(do_not_specialize=["alpha", "beta"]) 

29def baddbmm_kernel( 

30 A, 

31 B, 

32 O, 

33 bias, 

34 alpha, 

35 beta, 

36 M, 

37 N, 

38 K, 

39 TILE_M: tl.constexpr, 

40 TILE_N: tl.constexpr, 

41 TILE_K: tl.constexpr, 

42 GROUP_M: tl.constexpr, 

43 DIVISIBLE_M: tl.constexpr, 

44 DIVISIBLE_N: tl.constexpr, 

45 DIVISIBLE_K: tl.constexpr, 

46 bias_batch_stride: tl.constexpr, 

47 bias_M_stride: tl.constexpr, 

48 bias_N_stride: tl.constexpr, 

49): 

50 # batch offsets 

51 pid_b = tle.program_id(2) 

52 A += pid_b * M * K 

53 B += pid_b * K * N 

54 O += pid_b * M * N 

55 bias += pid_b * bias_batch_stride 

56 

57 pidx = tle.program_id(0) 

58 pidy = tle.program_id(1) 

59 

60 if GROUP_M == 1: 

61 pid_m, pid_n = pidx, pidy 

62 else: 

63 gridx = tle.num_programs(0) 

64 gridy = tle.num_programs(1) 

65 pid = pidx + pidy * gridx 

66 num_CTA_per_group = gridy * GROUP_M 

67 group_id = pid // num_CTA_per_group 

68 inner_group_id = pid % num_CTA_per_group 

69 GROUP_SIZE = tl.where( 

70 (group_id * GROUP_M + GROUP_M) > gridx, gridx % GROUP_M, GROUP_M 

71 ) 

72 pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE 

73 pid_n = inner_group_id // GROUP_SIZE 

74 

75 offs_m = pid_m * TILE_M + tl.arange(0, TILE_M) 

76 offs_n = pid_n * TILE_N + tl.arange(0, TILE_N) 

77 offs_k = tl.arange(0, TILE_K) 

78 

79 if not DIVISIBLE_M: 

80 mask_m = offs_m < M 

81 if not DIVISIBLE_N: 

82 mask_n = offs_n < N 

83 

84 a_ptrs = A + offs_m[:, None] * K + offs_k[None, :] 

85 b_ptrs = B + offs_k[:, None] * N + offs_n[None, :] 

86 o_ptrs = O + offs_m[:, None] * N + offs_n[None, :] 

87 

88 num_iters = tl.cdiv(K, TILE_K) 

89 accumulator = tl.zeros((TILE_M, TILE_N), dtype=tl.float32) 

90 for _ in range(num_iters): 

91 if DIVISIBLE_K: 

92 if DIVISIBLE_M: 

93 mask_a = None 

94 else: 

95 mask_a = mask_m[:, None] 

96 if DIVISIBLE_N: 

97 mask_b = None 

98 else: 

99 mask_b = mask_n[None, :] 

100 else: 

101 mask_k = offs_k < K 

102 if DIVISIBLE_M: 

103 mask_a = mask_k[None, :] 

104 else: 

105 mask_a = mask_m[:, None] & mask_k[None, :] 

106 if DIVISIBLE_N: 

107 mask_b = mask_k[:, None] 

108 else: 

109 mask_b = mask_k[:, None] & mask_n[None, :] 

110 a = tl.load(a_ptrs, mask=mask_a) 

111 b = tl.load(b_ptrs, mask=mask_b) 

112 accumulator += tl.dot(a, b, allow_tf32=False) 

113 offs_k += TILE_K 

114 a_ptrs += TILE_K 

115 b_ptrs += TILE_K * N 

116 

117 bias_ptrs = bias + offs_m[:, None] * bias_M_stride + offs_n[None, :] * bias_N_stride 

118 

119 if DIVISIBLE_M and DIVISIBLE_N: 

120 mask_c = None 

121 else: 

122 mask_c = True 

123 if not DIVISIBLE_M: 

124 mask_c &= offs_m[:, None] < M 

125 if not DIVISIBLE_N: 

126 mask_c &= offs_n[None, :] < N 

127 

128 bi = tl.load(bias_ptrs, mask=mask_c) 

129 out = accumulator * alpha + bi * beta 

130 o = out.to(bi.dtype) 

131 tl.store(o_ptrs, o, mask=mask_c) 

132 

133 

134class BaddbmmFunction(torch.autograd.Function): 

135 @staticmethod 

136 def forward(ctx, bias, A, B, beta, alpha): 

137 logger.debug("GEMS_ASCEND BADDBMM FORWARD") 

138 

139 ctx.save_for_backward(A, B, bias) 

140 ctx.alpha = alpha 

141 ctx.beta = beta 

142 

143 batch, M, K = A.shape 

144 _, _, N = B.shape 

145 A = A.contiguous() 

146 B = B.contiguous() 

147 out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device) 

148 

149 bbias = torch.broadcast_to(bias, (batch, M, N)).contiguous() 

150 bias_batch_stride = bbias.stride(0) 

151 bias_M_stride = bbias.stride(1) 

152 bias_N_stride = bbias.stride(-1) 

153 

154 grid = lambda meta: ( 

155 triton.cdiv(meta["M"], meta["TILE_M"]), 

156 triton.cdiv(meta["N"], meta["TILE_N"]), 

157 batch, 

158 ) 

159 with torch_device_fn.device(A.device): 

160 baddbmm_kernel[grid]( 

161 A, 

162 B, 

163 out, 

164 bbias, 

165 alpha, 

166 beta, 

167 M, 

168 N, 

169 K, 

170 bias_batch_stride=bias_batch_stride, 

171 bias_M_stride=bias_M_stride, 

172 bias_N_stride=bias_N_stride, 

173 ) 

174 return out 

175 

176 @staticmethod 

177 def backward(ctx, grad_output): 

178 logger.debug("GEMS_ASCEND BADDBMM BACKWARD") 

179 A, B, bias = ctx.saved_tensors 

180 

181 grad_A = None 

182 grad_B = None 

183 grad_bias = None 

184 if ctx.needs_input_grad[0]: 

185 grad_bias = compute_bias_grad(grad_output, ctx.beta, bias) 

186 if ctx.needs_input_grad[1]: 

187 grad_A = compute_A_grad(grad_output, B, ctx.alpha) 

188 if ctx.needs_input_grad[2]: 

189 grad_B = compute_B_grad(A, grad_output, ctx.alpha) 

190 

191 return grad_bias, grad_A, grad_B, None, None 

192 

193 

194def compute_bias_grad(d_output, beta, bias): 

195 grad_bias = mul(d_output, beta) 

196 if grad_bias.shape != bias.shape: 

197 # Sum over broadcasted dimensions 

198 while grad_bias.dim() > bias.dim(): 

199 grad_bias = grad_bias.sum(dim=0) 

200 for i in range(bias.dim()): 

201 if bias.shape[i] == 1 and grad_bias.shape[i] > 1: 

202 grad_bias = grad_bias.sum(dim=i, keepdim=True) 

203 return grad_bias.view(bias.shape) 

204 

205 

206def compute_A_grad(d_output, B, alpha): 

207 B_T = B.transpose(1, 2).contiguous() 

208 if B.dtype == torch.float16: 

209 Bcopy = B_T.to(torch.float32) 

210 dcopye = d_output.to(torch.float32) 

211 mul1 = bmm(dcopye, Bcopy) 

212 grad_A = mul(mul1, alpha) 

213 grad_A = grad_A.to(torch.float16) 

214 else: 

215 mul1 = bmm(d_output, B_T) 

216 grad_A = mul(mul1, alpha) 

217 return grad_A 

218 

219 

220def compute_B_grad(A, d_output, alpha): 

221 A_T = A.transpose(1, 2).contiguous() 

222 if A.dtype == torch.float16: 

223 Acopy = A_T.to(torch.float32) 

224 dcopye = d_output.to(torch.float32) 

225 mul2 = bmm(Acopy, dcopye) 

226 grad_B = mul(mul2, alpha) 

227 grad_B = grad_B.to(torch.float16) 

228 else: 

229 mul2 = bmm(A_T, d_output) 

230 grad_B = mul(mul2, alpha) 

231 return grad_B 

232 

233 

234def baddbmm(bias, A, B, beta=1.0, alpha=1.0): 

235 return BaddbmmFunction.apply( 

236 bias.contiguous(), 

237 A.contiguous(), 

238 B.contiguous(), 

239 beta, 

240 alpha, 

241 )