Coverage for src/flag_gems/runtime/backend/_ascend/ops/softmax.py: 0%

256 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.runtime.backend._ascend import heuristics_config_utils as _hcu 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as ext 

12 

13logger = logging.getLogger(__name__) 

14 

15 

16@libentry() 

17@triton.heuristics(_hcu.HEURISTICS_CONFIGS["softmax_non_inner"]) 

18@triton.jit 

19def softmax_kernel_non_inner( 

20 output_ptr, 

21 input_ptr, 

22 M, 

23 N, 

24 K, 

25 TILE_N: tl.constexpr, 

26 TILE_K: tl.constexpr, 

27 ONE_TILE_PER_CTA: tl.constexpr, 

28): 

29 pid_k = ext.program_id(1) 

30 pid_m = ext.program_id(0) 

31 

32 k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K) 

33 

34 if ONE_TILE_PER_CTA: 

35 n_offsets = tl.arange(0, TILE_N) 

36 offset = pid_m * N * K + n_offsets[:, None] * K + k_offsets 

37 mask = (n_offsets[:, None] < N) & (k_offsets < K) 

38 input_ptrs = input_ptr + offset 

39 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")) 

40 m = tl.max(inp, 0) 

41 e = tl.exp(inp - m[None, :]) 

42 z = tl.sum(e, 0) 

43 out = e / z 

44 output_ptrs = output_ptr + offset 

45 tl.store(output_ptrs, out, mask=mask) 

46 else: 

47 m = tl.full([TILE_N, TILE_K], value=float("-inf"), dtype=tl.float32) 

48 z = tl.full([TILE_N, TILE_K], value=0.0, dtype=tl.float32) 

49 

50 # specialization does not improve performance inn this example, as tested 

51 for start_n in range(0, N, TILE_N): 

52 n_offsets = start_n + tl.arange(0, TILE_N) 

53 offsets = pid_m * N * K + n_offsets[:, None] * K + k_offsets 

54 mask = (n_offsets[:, None] < N) & (k_offsets < K) 

55 inp = tl.load(input_ptr + offsets, mask=mask, other=-float("inf")) 

56 m_new = tl.maximum(m, inp) 

57 all_neg_inf = m_new == float("-inf") 

58 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new)) 

59 m = m_new 

60 

61 m_reduced = tl.max(m, 0) # (TILE_K,) 

62 z = tl.sum(z * tl.exp(m - m_reduced[None, :]), 0) # (TILE_K, ) 

63 m = m_reduced 

64 

65 # specialization does not improve performance inn this example, as tested 

66 previous_multiple = prev_multiple_of(N, TILE_N) 

67 for start_n in range(0, N, TILE_N): 

68 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N) 

69 offsets = pid_m * N * K + n_offsets[:, None] * K + k_offsets 

70 mask = (n_offsets[:, None] < N) & (k_offsets[None, :] < K) 

71 inp = tl.load(input_ptr + offsets, mask=mask, other=-float("inf")) 

72 o = tl.exp(inp - m[None, :]) / z[None, :] 

73 tl.store(output_ptr + offsets, o, mask=mask) 

74 

75 

76@triton.jit 

77def next_multiple_of(a, b): 

78 # the smallest x>=a that x%b ==0 

79 return tl.cidv(a, b) * b 

80 

81 

82@triton.jit 

83def prev_multiple_of(a, b): 

84 # the largest x<a that x%b ==0 

85 return tl.cdiv(a, b) * b - b 

86 

87 

88@libentry() 

89@triton.heuristics(_hcu.HEURISTICS_CONFIGS["softmax_inner"]) 

90@triton.jit 

91def softmax_kernel_inner( 

92 output_ptr, 

93 input_ptr, 

94 M, 

95 N, 

96 TILE_N: tl.constexpr, 

97 ONE_TILE_PER_CTA: tl.constexpr, 

98): 

99 pid_m = ext.program_id(0) 

100 if ONE_TILE_PER_CTA: 

101 n_offsets = tl.arange(0, TILE_N) 

102 offset = pid_m * N + n_offsets 

103 input_ptrs = input_ptr + offset 

104 mask = n_offsets < N 

105 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to( 

106 output_ptr.dtype.element_ty 

107 ) 

108 m = tl.max(inp, 0) 

109 e = tl.exp(inp - m) 

110 z = tl.sum(e, 0) 

111 out = e / z 

112 output_ptrs = output_ptr + offset 

113 tl.store(output_ptrs, out, mask=mask) 

114 else: 

115 m = tl.full([TILE_N], value=float("-inf"), dtype=tl.float32) 

116 z = tl.full([TILE_N], value=0.0, dtype=tl.float32) 

117 input_ptr += pid_m * N 

118 output_ptr += pid_m * N 

119 

120 previous_multiple = prev_multiple_of(N, TILE_N) 

121 for start_n in range(0, previous_multiple, TILE_N): 

122 n_offsets = start_n + tl.arange(0, TILE_N) 

123 inp = tl.load(input_ptr + n_offsets) 

124 m_new = tl.maximum(m, inp) 

125 # it is possible that there are -inf's in the input 

126 all_neg_inf = m_new == float("-inf") 

127 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new)) 

128 m = m_new 

129 # specialize the last iteration 

130 for start_n in range(previous_multiple, N, TILE_N): 

131 n_offsets = start_n + tl.arange(0, TILE_N) 

132 mask = n_offsets < N 

133 inp = tl.load(input_ptr + n_offsets, mask=mask, other=-float("inf")) 

134 m_new = tl.maximum(m, inp) 

135 all_neg_inf = m_new == float("-inf") 

136 z = tl.where(all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new)) 

137 m = m_new 

138 

139 m_reduced = tl.max(m, 0) 

140 z = tl.sum(z * tl.exp(m - m_reduced), 0) 

141 m = m_reduced 

142 

143 previous_multiple = prev_multiple_of(N, TILE_N) 

144 # specialize the first iteration 

145 for start_n in range(0, TILE_N, TILE_N): 

146 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N) 

147 mask = n_offsets < N 

148 inp = tl.load( 

149 input_ptr + n_offsets, 

150 mask=mask, 

151 other=-float("inf"), 

152 eviction_policy="evict_first", 

153 ) 

154 o = tl.exp(inp - m) / z 

155 tl.store(output_ptr + n_offsets, o, mask=mask) 

156 for start_n in range(TILE_N, N, TILE_N): 

157 n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N) 

158 inp = tl.load(input_ptr + n_offsets, eviction_policy="evict_first") 

159 o = tl.exp(inp - m) / z 

160 tl.store(output_ptr + n_offsets, o) 

161 

162 

163# ------------------------ backward ------------------------------- 

164@libentry() 

165@triton.autotune( 

166 configs=runtime.get_tuned_config("softmax_non_inner"), 

167 key=[ 

168 "M", 

169 "N", 

170 "K", 

171 ], 

172) 

173@triton.heuristics(_hcu.HEURISTICS_CONFIGS["softmax_backward_non_inner"]) 

174@triton.jit 

175def softmax_backward_kernel_non_inner( 

176 out_ptr, 

177 out_grad_ptr, 

178 in_grad_ptr, 

179 M, 

180 N, 

181 K, 

182 TILE_N: tl.constexpr, 

183 TILE_K: tl.constexpr, 

184 ONE_TILE_PER_CTA: tl.constexpr, 

185): 

186 pid_m = ext.program_id(0) 

187 pid_k = ext.program_id(1) 

188 offsets_k = pid_k * TILE_K + tl.arange(0, TILE_K) 

189 

190 if ONE_TILE_PER_CTA: 

191 offsets_n = tl.arange(0, TILE_N) 

192 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k 

193 mask = (offsets_n < N)[:, None] & (offsets_k < K) 

194 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32) 

195 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32) 

196 scale = tl.sum(out_tile * out_grad_tile, axis=0) 

197 in_grad_tile = out_tile * (out_grad_tile - scale[None, :]) 

198 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask) 

199 else: 

200 scale = tl.zeros([TILE_N, TILE_K], dtype=tl.float32) 

201 for off in range(0, N, TILE_N): 

202 offsets_n = tl.arange(0, TILE_N) + off 

203 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k 

204 mask = (offsets_n < N)[:, None] & (offsets_k < K) 

205 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32) 

206 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32) 

207 scale += out_tile * out_grad_tile 

208 # offsets_n += TILE_N 

209 # offsets += TILE_N * K 

210 scale = tl.sum(scale, axis=0) # (TILE_K) 

211 

212 for off in range(0, N, TILE_N): 

213 offsets_n = tl.arange(0, TILE_N) + off 

214 offsets = pid_m * N * K + offsets_n[:, None] * K + offsets_k 

215 mask = (offsets_n < N)[:, None] & (offsets_k < K) 

216 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32) 

217 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32) 

218 in_grad_tile = out_tile * (out_grad_tile - scale[None, :]) 

219 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask) 

220 # offsets_n += TILE_N 

221 # offsets += TILE_N * K 

222 

223 

224@libentry() 

225@triton.autotune( 

226 configs=runtime.get_tuned_config("softmax_inner"), 

227 key=["M", "N"], 

228) 

229@triton.heuristics( 

230 values=_hcu.HEURISTICS_CONFIGS["softmax_backward_inner"], 

231) 

232@triton.jit 

233def softmax_backward_kernel_inner( 

234 out_ptr, 

235 out_grad_ptr, 

236 in_grad_ptr, 

237 M, 

238 N, 

239 TILE_M: tl.constexpr, 

240 TILE_N: tl.constexpr, 

241 ONE_TILE_PER_CTA: tl.constexpr, 

242): 

243 pid_m = ext.program_id(0) 

244 m_offsets = pid_m * TILE_M + tl.arange(0, TILE_M) 

245 if ONE_TILE_PER_CTA: 

246 n_offsets = tl.arange(0, TILE_N) 

247 offsets = m_offsets[:, None] * N + n_offsets 

248 mask = (m_offsets[:, None] < M) & (n_offsets < N) 

249 out_tile = tl.load(out_ptr + offsets, mask=mask).to(tl.float32) 

250 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32) 

251 scale = tl.sum(out_tile * out_grad_tile, 1) 

252 in_grad_tile = out_tile * (out_grad_tile - scale[:, None]) 

253 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask) 

254 else: 

255 # scale = tl.zeros([TILE_M, TILE_N], dtype=tl.float32) 

256 scale = tl.zeros([TILE_M], dtype=tl.float32) 

257 

258 for off in range(0, N, TILE_N): 

259 n_offsets = tl.arange(0, TILE_N) + off 

260 offsets = m_offsets[:, None] * N + n_offsets 

261 mask = (m_offsets[:, None] < M) & (n_offsets < N) 

262 out_tile = tl.load( 

263 out_ptr + offsets, mask=mask, eviction_policy="evict_last" 

264 ).to(tl.float32) 

265 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32) 

266 # scale += out_tile * out_grad_tile 

267 scale += tl.sum(out_tile * out_grad_tile, axis=1) 

268 # n_offsets += TILE_N 

269 # offsets += TILE_N 

270 # scale = tl.sum(scale, 1) # (TILE_M,) 

271 

272 for off in range(0, N, TILE_N): 

273 n_offsets = tl.arange(0, TILE_N) + off 

274 offsets = m_offsets[:, None] * N + n_offsets 

275 mask = (m_offsets[:, None] < M) & (n_offsets < N) 

276 out_tile = tl.load( 

277 out_ptr + offsets, mask=mask, eviction_policy="evict_first" 

278 ).to(tl.float32) 

279 out_grad_tile = tl.load(out_grad_ptr + offsets, mask=mask).to(tl.float32) 

280 in_grad_tile = out_tile * (out_grad_tile - scale[:, None]) 

281 tl.store(in_grad_ptr + offsets, in_grad_tile, mask=mask) 

282 # n_offsets += TILE_N 

283 # offsets += TILE_N 

284 

285 

286def softmax(self, dim, half_to_float=False): 

287 logger.debug("GEMS_ASCEND SOFTMAX") 

288 

289 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim" 

290 dim = dim % self.ndim 

291 M = 1 

292 N = self.shape[dim] 

293 for i in range(dim): 

294 M *= self.shape[i] 

295 self = self.contiguous() 

296 dtype = torch.float32 if half_to_float else self.dtype 

297 out = torch.empty_like(self, dtype=dtype) 

298 if N == 0 or self.numel() == 0: 

299 return out 

300 K = self.numel() // M // N 

301 

302 with torch_device_fn.device(self.device): 

303 if K > 1: 

304 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1) 

305 softmax_kernel_non_inner[grid](out, self, M, N, K) 

306 else: 

307 grid = (M, 1, 1) 

308 softmax_kernel_inner[grid](out, self, M, N) 

309 return out 

310 

311 

312def softmax_out(self, dim, half_to_float=False, *, out): 

313 logger.debug("GEMS_ASCEND SOFTMAX_OUT") 

314 

315 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim" 

316 

317 if self.numel() == 0: 

318 if tuple(out.shape) != tuple(self.shape): 

319 out.resize_(self.shape) 

320 return out 

321 

322 dim = dim % self.ndim 

323 M = 1 

324 N = self.shape[dim] 

325 for i in range(dim): 

326 M *= self.shape[i] 

327 self = self.contiguous() 

328 if tuple(out.shape) != tuple(self.shape): 

329 out.resize_(self.shape) 

330 K = self.numel() // M // N 

331 

332 with torch_device_fn.device(self.device): 

333 if K > 1: 

334 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1) 

335 softmax_kernel_non_inner[grid](out, self, M, N, K) 

336 else: 

337 grid = (M, 1, 1) 

338 softmax_kernel_inner[grid](out, self, M, N) 

339 return out 

340 

341 

342def softmax_backward(grad_output, output, dim, input_dtype): 

343 logger.debug("GEMS_ASCEND SOFTMAX BACKWARD") 

344 

345 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim" 

346 dim = dim % output.ndim 

347 M = 1 

348 N = output.shape[dim] 

349 for i in range(dim): 

350 M *= output.shape[i] 

351 

352 grad_output = grad_output.contiguous() 

353 in_grad = torch.empty_like(output, dtype=input_dtype) 

354 K = output.numel() // M // N 

355 

356 with torch_device_fn.device(in_grad.device): 

357 if K > 1: 

358 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1) 

359 softmax_backward_kernel_non_inner[grid]( 

360 output, grad_output, in_grad, M, N, K 

361 ) 

362 else: 

363 grid = lambda meta: (triton.cdiv(M, meta["TILE_M"]), 1, 1) 

364 softmax_backward_kernel_inner[grid](output, grad_output, in_grad, M, N) 

365 return in_grad 

366 

367 

368def softmax_backward_out(grad_output, output, dim, input_dtype, *, grad_input): 

369 logger.debug("GEMS_ASCEND SOFTMAX BACKWARD_OUT") 

370 

371 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim" 

372 dim = dim % output.ndim 

373 M = 1 

374 N = output.shape[dim] 

375 for i in range(dim): 

376 M *= output.shape[i] 

377 

378 grad_output = grad_output.contiguous() 

379 if tuple(grad_input.shape) != tuple(output.shape): 

380 grad_input.resize_(output.shape) 

381 K = output.numel() // M // N 

382 

383 with torch_device_fn.device(grad_input.device): 

384 if K > 1: 

385 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1) 

386 softmax_backward_kernel_non_inner[grid]( 

387 output, grad_output, grad_input, M, N, K 

388 ) 

389 else: 

390 grid = lambda meta: (triton.cdiv(M, meta["TILE_M"]), 1, 1) 

391 softmax_backward_kernel_inner[grid](output, grad_output, grad_input, M, N) 

392 return grad_input