Coverage for src/flag_gems/runtime/backend/_cambricon/ops/softmax.py: 0%

552 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import copy 

2import logging 

3import math 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems import runtime 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils import libentry, libtuner 

12 

13from ..utils import MAX_NRAM_SIZE, TOTAL_CORE_NUM 

14from .zeros import zero_ 

15 

16logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

17MAX_N = 16384 

18 

19 

20def align(max_block): 

21 a = triton.next_power_of_2(max_block) 

22 return max_block if max_block == a else a // 2 

23 

24 

25def config_prune1(configs, named_args, **kwargs): 

26 M = named_args["M"] 

27 N = named_args["N"] 

28 K = named_args["K"] 

29 input = named_args["input_ptr"] 

30 configs_map = {} 

31 for config in configs: 

32 kw = config.kwargs 

33 TILE_K, TILE_N, num_warps, num_stages = ( 

34 kw["TILE_K"], 

35 kw["TILE_N"], 

36 config.num_warps, 

37 config.num_stages, 

38 ) 

39 if N < MAX_N: 

40 config = copy.deepcopy(config) 

41 TILE_N = config.kwargs["TILE_N"] = N 

42 k_per_core = math.ceil(K / max(TOTAL_CORE_NUM // M, 1)) 

43 TILE_K = config.kwargs["TILE_K"] = k_per_core 

44 num_stages = config.num_stages = 1 

45 key = (TILE_K, TILE_N, num_warps, num_stages) 

46 configs_map.setdefault(key, config) 

47 

48 config = copy.deepcopy(config) 

49 max_tile_k_without_pipe = MAX_NRAM_SIZE // 4 // (2 * TILE_N + 1) 

50 TILE_K = config.kwargs["TILE_K"] = align(max_tile_k_without_pipe) 

51 num_stages = config.num_stages = 1 

52 key = (TILE_K, TILE_N, num_warps, num_stages) 

53 configs_map.setdefault(key, config) 

54 

55 config = copy.deepcopy(config) 

56 max_tile_k_without_pipe = MAX_NRAM_SIZE // 4 // (3 * TILE_N + 1) 

57 if input.dtype == torch.float32: 

58 max_tile_k_without_pipe = MAX_NRAM_SIZE // 4 // (4 * TILE_N + 1) 

59 TILE_K = config.kwargs["TILE_K"] = align(max_tile_k_without_pipe) 

60 num_stages = config.num_stages = 3 

61 key = (TILE_K, TILE_N, num_warps, num_stages) 

62 configs_map.setdefault(key, config) 

63 else: 

64 key = (TILE_K, TILE_N, num_warps, num_stages) 

65 configs_map.setdefault(key, config) 

66 pruned_configs = [] 

67 for k, v in configs_map.items(): 

68 pruned_configs.append(v) 

69 extra_config = copy.deepcopy(pruned_configs[0]) 

70 extra_config.kwargs["TILE_K"] = 1 

71 extra_config.kwargs["TILE_N"] = N 

72 extra_config.num_warps = 1 

73 extra_config.num_stages = 3 

74 pruned_configs.append(extra_config) 

75 extra_config2 = copy.deepcopy(extra_config) 

76 extra_config2.num_stages = 1 

77 pruned_configs.append(extra_config2) 

78 return pruned_configs 

79 

80 

81def softmax_tile_mode_for_non_inner(M, N, K, TILE_N, TILE_K): 

82 one_tile_k = TILE_K * max(TOTAL_CORE_NUM // M, 1) >= K 

83 one_tile_n = TILE_N >= N 

84 if one_tile_n and one_tile_k: 

85 return 0 

86 elif one_tile_n and not one_tile_k: 

87 return 1 

88 else: 

89 return 2 

90 

91 

92@libentry() 

93@libtuner( 

94 configs=runtime.get_tuned_config("softmax_non_inner"), 

95 key=[ 

96 "N", 

97 "K", 

98 ], 

99 prune_configs_by={"early_config_prune": config_prune1}, 

100) 

101@triton.heuristics(runtime.get_heuristic_config("softmax_non_inner")) 

102@triton.jit 

103def softmax_kernel_non_inner( 

104 output_ptr, 

105 input_ptr, 

106 M, 

107 N, 

108 K, 

109 TILE_N: tl.constexpr, 

110 TILE_K: tl.constexpr, 

111 TILE_MODE: tl.constexpr, 

112): 

113 pid_m = tl.program_id(0) 

114 pid_k = tl.program_id(1) 

115 

116 p_k_num = tl.num_programs(axis=1) 

117 split_k = tl.cdiv(K, p_k_num) 

118 k_start = pid_k * split_k 

119 

120 if TILE_MODE == 0: 

121 n_offset = tl.arange(0, TILE_N) 

122 k_offset = pid_k * TILE_K + tl.arange(0, TILE_K) 

123 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :] 

124 mask = k_offset[None, :] < K 

125 input_ptrs = input_ptr + offset 

126 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32) 

127 row_minus_max = inp - tl.max(inp, axis=0)[None, :] 

128 numerator = tl.exp(row_minus_max) 

129 denominator = tl.sum(numerator, axis=0)[None, :] 

130 recip = 1.0 / denominator 

131 softmax_output = numerator * recip 

132 output_ptrs = output_ptr + offset 

133 tl.store(output_ptrs, softmax_output, mask=mask) 

134 elif TILE_MODE == 1: 

135 for k_idx in range(0, split_k, TILE_K): 

136 k_offset = k_start + k_idx + tl.arange(0, TILE_K) 

137 n_offset = tl.arange(0, TILE_N) 

138 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :] 

139 mask = k_offset[None, :] < K 

140 input_ptrs = input_ptr + offset 

141 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32) 

142 row_minus_max = inp - tl.max(inp, axis=0)[None, :] 

143 numerator = tl.exp(row_minus_max) 

144 denominator = tl.sum(numerator, axis=0)[None, :] 

145 recip = 1.0 / denominator 

146 softmax_output = numerator * recip 

147 output_ptrs = output_ptr + offset 

148 tl.store(output_ptrs, softmax_output, mask=mask) 

149 else: 

150 for k_idx in range(0, split_k, TILE_K): 

151 k_offset = k_start + k_idx + tl.arange(0, TILE_K) 

152 m = tl.full([TILE_N, TILE_K], value=float("-inf"), dtype=tl.float32) 

153 z = tl.full([TILE_N, TILE_K], value=0.0, dtype=tl.float32) 

154 # specialization does not improve performance inn this example, as tested 

155 for start_n in range(0, N, TILE_N): 

156 n_offset = start_n + tl.arange(0, TILE_N) 

157 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :] 

158 mask = (n_offset[:, None] < N) & (k_offset[None, :] < K) 

159 inp = tl.load(input_ptr + offset, mask=mask, other=-float("inf")).to( 

160 tl.float32 

161 ) 

162 m_new = tl.maximum(m, inp) 

163 all_neg_inf = m_new == float("-inf") 

164 z = tl.where( 

165 all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new) 

166 ) 

167 m = m_new 

168 m_reduced = tl.max(m, 0) # (TILE_K,) 

169 z = tl.sum(z * tl.exp(m - m_reduced[None, :]), 0) # (TILE_K, ) 

170 recip_z = 1.0 / z 

171 m = m_reduced 

172 # specialization does not improve performance inn this example, as tested 

173 for start_n in range(0, N, TILE_N): 

174 n_offset = start_n + tl.arange(0, TILE_N) 

175 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :] 

176 mask = (n_offset[:, None] < N) & (k_offset[None, :] < K) 

177 inp = tl.load(input_ptr + offset, mask=mask, other=-float("inf")).to( 

178 tl.float32 

179 ) 

180 o = tl.exp(inp - m[None, :]) * recip_z[None, :] 

181 tl.store(output_ptr + offset, o, mask=mask) 

182 

183 

184def config_prune2(configs, named_args, **kwargs): 

185 M = named_args["M"] 

186 N = named_args["N"] 

187 input = named_args["input_ptr"] 

188 configs_map = {} 

189 # When N is less than MAX_C_MLU_SOFTMAX_FORWARD, no reduction loops 

190 for config in configs: 

191 kw = config.kwargs 

192 BLOCK_M, BLOCK_N, num_warps, num_stages = ( 

193 kw["BLOCK_M"], 

194 kw["BLOCK_N"], 

195 config.num_warps, 

196 config.num_stages, 

197 ) 

198 if N < MAX_N: 

199 config = copy.deepcopy(config) 

200 BLOCK_N = config.kwargs["BLOCK_N"] = N 

201 m_per_core = math.ceil(M / TOTAL_CORE_NUM) 

202 BLOCK_M = config.kwargs["BLOCK_M"] = m_per_core 

203 num_stages = config.num_stages = 1 

204 key = (BLOCK_M, BLOCK_N, num_warps, num_stages) 

205 configs_map.setdefault(key, config) 

206 

207 config = copy.deepcopy(config) 

208 max_block_m_without_pipe = MAX_NRAM_SIZE // 4 // (2 * BLOCK_N + 1) 

209 BLOCK_M = config.kwargs["BLOCK_M"] = align(max_block_m_without_pipe) 

210 num_stages = config.num_stages = 1 

211 key = (BLOCK_M, BLOCK_N, num_warps, num_stages) 

212 configs_map.setdefault(key, config) 

213 

214 config = copy.deepcopy(config) 

215 max_block_m_without_pipe = MAX_NRAM_SIZE // 4 // (4 * BLOCK_N + 1) 

216 if input.dtype == torch.float32: 

217 max_block_m_without_pipe = MAX_NRAM_SIZE // 4 // (6 * BLOCK_N + 1) 

218 BLOCK_M = config.kwargs["BLOCK_M"] = align(max_block_m_without_pipe) 

219 num_stages = config.num_stages = 3 

220 key = (BLOCK_M, BLOCK_N, num_warps, num_stages) 

221 configs_map.setdefault(key, config) 

222 key = (BLOCK_M, BLOCK_N, num_warps, num_stages) 

223 # Only keep one config for the same key 

224 configs_map.setdefault(key, config) 

225 pruned_configs = [] 

226 for k, v in configs_map.items(): 

227 pruned_configs.append(v) 

228 # Add a heuristic config. 

229 extra_config = copy.deepcopy(pruned_configs[0]) 

230 extra_config.kwargs["BLOCK_M"] = 1 

231 extra_config.kwargs["BLOCK_N"] = N 

232 extra_config.num_warps = 1 

233 extra_config.num_stages = 3 

234 pruned_configs.append(extra_config) 

235 extra_config2 = copy.deepcopy(extra_config) 

236 extra_config2.num_stages = 1 

237 pruned_configs.append(extra_config2) 

238 return pruned_configs 

239 

240 

241def softmax_tile_mode_for_inner(args): 

242 one_tile_m = args["BLOCK_M"] * TOTAL_CORE_NUM >= args["M"] 

243 one_tile_n = args["BLOCK_N"] >= args["N"] 

244 if one_tile_n and one_tile_m: 

245 return 0 

246 elif one_tile_n and not one_tile_m: 

247 return 1 

248 else: 

249 return 2 

250 

251 

252@libentry() 

253@libtuner( 

254 configs=runtime.get_tuned_config("softmax_inner"), 

255 key=[ 

256 "M", 

257 "N", 

258 ], 

259 prune_configs_by={"early_config_prune": config_prune2}, 

260) 

261@triton.heuristics(runtime.get_heuristic_config("softmax_inner")) 

262@triton.jit 

263def softmax_kernel_inner( 

264 output_ptr, 

265 input_ptr, 

266 M, 

267 N, 

268 BLOCK_M: tl.constexpr, 

269 BLOCK_N: tl.constexpr, 

270 TILE_MODE: tl.constexpr, 

271): 

272 pid_m = tl.program_id(0) 

273 pnum = tl.num_programs(axis=0) 

274 split_m = tl.cdiv(M, pnum) 

275 m_start = pid_m * split_m 

276 

277 if TILE_MODE == 0: 

278 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

279 n_offset = tl.arange(0, BLOCK_N) 

280 offset = m_offset[:, None] * N + n_offset[None, :] 

281 mask = m_offset[:, None] < M 

282 input_ptrs = input_ptr + offset 

283 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32) 

284 row_minus_max = inp - tl.max(inp, axis=1)[:, None] 

285 numerator = tl.exp(row_minus_max) 

286 denominator = tl.sum(numerator, axis=1)[:, None] 

287 recip = 1.0 / denominator 

288 softmax_output = numerator * recip 

289 output_ptrs = output_ptr + offset 

290 tl.store(output_ptrs, softmax_output, mask=mask) 

291 elif TILE_MODE == 1: 

292 for m_idx in range(0, split_m, BLOCK_M): 

293 m_offset = m_start + m_idx + tl.arange(0, BLOCK_M) 

294 n_offset = tl.arange(0, BLOCK_N) 

295 offset = m_offset[:, None] * N + n_offset[None, :] 

296 mask = m_offset[:, None] < M 

297 input_ptrs = input_ptr + offset 

298 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32) 

299 trans_inp = tl.trans(inp) 

300 row_minus_max = trans_inp - tl.max(trans_inp, axis=0)[None, :] 

301 numerator = tl.exp(row_minus_max) 

302 denominator = tl.sum(numerator, axis=0)[None, :] 

303 recip = 1.0 / denominator 

304 softmax_output = tl.trans(numerator * recip) 

305 output_ptrs = output_ptr + offset 

306 tl.store(output_ptrs, softmax_output, mask=mask) 

307 else: 

308 for m_idx in range(0, split_m, BLOCK_M): 

309 m_offset = m_start + m_idx + tl.arange(0, BLOCK_M) 

310 block_max = tl.full( 

311 [BLOCK_M, BLOCK_N], value=float("-inf"), dtype=tl.float32 

312 ) 

313 block_sum = tl.full([BLOCK_M, BLOCK_N], value=0.0, dtype=tl.float32) 

314 # specialization does not improve performance inn this example, as tested 

315 for start_n in range(0, N, BLOCK_N): 

316 n_offset = start_n + tl.arange(0, BLOCK_N) 

317 offset = m_offset[:, None] * N + n_offset[None, :] 

318 mask = m_offset[:, None] < M and n_offset[None, :] < N 

319 inp = tl.load(input_ptr + offset, mask=mask, other=-float("inf")).to( 

320 tl.float32 

321 ) 

322 cur_max = tl.maximum(block_max, inp) 

323 all_neg_inf = cur_max == float("-inf") 

324 block_sum = tl.where( 

325 all_neg_inf, 

326 block_sum, 

327 block_sum * tl.exp(block_max - cur_max) + tl.exp(inp - cur_max), 

328 ) 

329 block_max = cur_max 

330 

331 trans_block_max = tl.trans(block_max) 

332 trans_block_sum = tl.trans(block_sum) 

333 max_reduced = tl.max(trans_block_max, 0) 

334 total_sum = tl.sum( 

335 trans_block_sum * tl.exp(trans_block_max - max_reduced[None, :]), 0 

336 ) 

337 recip_total_sum = 1.0 / total_sum 

338 total_max = max_reduced 

339 

340 for start_n in range(0, N, BLOCK_N): 

341 n_offset = start_n + tl.arange(0, BLOCK_N) 

342 offset = m_offset[:, None] * N + n_offset[None, :] 

343 mask = m_offset[:, None] < M and n_offset[None, :] < N 

344 inp = tl.load(input_ptr + offset, mask=mask, other=-float("inf")).to( 

345 tl.float32 

346 ) 

347 o = tl.exp(inp - total_max[:, None]) * recip_total_sum[:, None] 

348 tl.store(output_ptr + offset, o, mask=mask) 

349 

350 

351@triton.jit 

352def softmax_kernel_inner_k_partial_stats( 

353 x_ptr, 

354 max_buf_ptr, 

355 sum_buf_ptr, 

356 M, 

357 N, 

358 T, 

359 BLOCK_M: tl.constexpr, 

360 BLOCK_N: tl.constexpr, 

361): 

362 pnum = tl.num_programs(axis=0) 

363 pid = tl.program_id(0) 

364 total_blocks = (M // BLOCK_M) * T 

365 work_per_core = (total_blocks + pnum - 1) // pnum 

366 start = pid * work_per_core 

367 end = tl.minimum(start + work_per_core, total_blocks) 

368 

369 for task in range(start, end): 

370 row_id = task // T 

371 tile_id = task % T 

372 

373 offs_m = row_id * BLOCK_M + tl.arange(0, BLOCK_M) 

374 offs_n = tile_id * BLOCK_N + tl.arange(0, BLOCK_N) 

375 mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) 

376 

377 tile = tl.load( 

378 x_ptr + offs_m[:, None] * N + offs_n[None, :], 

379 mask=mask, 

380 other=-float("inf"), 

381 ).to(tl.float32) 

382 

383 tile_max = tl.max(tile, axis=1) 

384 all_neg_inf = tile_max == -float("inf") 

385 

386 tile_sum = tl.where( 

387 all_neg_inf, 

388 0.0, 

389 tl.sum(tl.exp(tile - tile_max[:, None]), axis=1), 

390 ) 

391 

392 tl.store(max_buf_ptr + offs_m * T + tile_id, tile_max, mask=(offs_m < M)) 

393 tl.store(sum_buf_ptr + offs_m * T + tile_id, tile_sum, mask=(offs_m < M)) 

394 

395 

396@triton.jit 

397def softmax_kernel_inner_k_merge_stats( 

398 max_buf_ptr, 

399 sum_buf_ptr, 

400 gmax_ptr, 

401 gsum_ptr, 

402 M: tl.constexpr, 

403 T: tl.constexpr, 

404 BLOCK_M: tl.constexpr, 

405): 

406 pid_m = tl.program_id(axis=0) 

407 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # [BM] 

408 mask_m = offs_m < M 

409 tile_max = tl.load( 

410 max_buf_ptr + offs_m[:, None] * T + tl.arange(0, T)[None, :], 

411 mask=(offs_m[:, None] < M), 

412 other=-float("inf"), 

413 ) 

414 tile_sum = tl.load( 

415 sum_buf_ptr + offs_m[:, None] * T + tl.arange(0, T)[None, :], 

416 mask=(offs_m[:, None] < M), 

417 other=0.0, 

418 ).to(tl.float32) 

419 

420 gmax = tl.max(tile_max, axis=1) 

421 scale = tl.exp(tile_max - gmax[:, None]) 

422 scale = tl.where(gmax[:, None] == -float("inf"), 0.0, scale) 

423 gsum = tl.sum(tile_sum * scale, axis=1) 

424 

425 tl.store(gmax_ptr + offs_m, gmax, mask=mask_m) 

426 tl.store(gsum_ptr + offs_m, gsum, mask=mask_m) 

427 

428 

429@triton.jit 

430def softmax_kernel_inner_k_write_softmax( 

431 x_ptr, 

432 y_ptr, 

433 gmax_ptr, 

434 gsum_ptr, 

435 M, 

436 N, 

437 T, 

438 BLOCK_M: tl.constexpr, 

439 BLOCK_N: tl.constexpr, 

440): 

441 pnum = tl.num_programs(axis=0) 

442 pid = tl.program_id(0) 

443 total_blocks = (M // BLOCK_M) * T 

444 work_per_core = (total_blocks + pnum - 1) // pnum 

445 start = pid * work_per_core 

446 end = tl.minimum(start + work_per_core, total_blocks) 

447 

448 for task in range(start, end): 

449 row_id = task // T 

450 tile_id = task % T 

451 

452 offs_m = row_id * BLOCK_M + tl.arange(0, BLOCK_M) 

453 offs_n = tile_id * BLOCK_N + tl.arange(0, BLOCK_N) 

454 mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) 

455 

456 # load global stats 

457 gmax = tl.load(gmax_ptr + offs_m, mask=(offs_m < M), other=-float("inf")).to( 

458 tl.float32 

459 ) 

460 gsum = tl.load(gsum_ptr + offs_m, mask=(offs_m < M), other=0.0).to(tl.float32) 

461 

462 # load tile 

463 tile = tl.load( 

464 x_ptr + offs_m[:, None] * N + offs_n[None, :], 

465 mask=mask, 

466 other=-float("inf"), 

467 ).to(tl.float32) 

468 

469 valid = gsum[:, None] > 0 

470 

471 out = tl.where( 

472 valid, 

473 tl.exp(tile - gmax[:, None]) / gsum[:, None], 

474 0.0, 

475 ) 

476 

477 tl.store(y_ptr + offs_m[:, None] * N + offs_n[None, :], out, mask=mask) 

478 

479 

480# ------------------------ backward ------------------------------- 

481 

482 

483def nram_usage_for_backward_non_inner(bn, bk, tile_mode, num_stages, dtype): 

484 coef = 1 

485 if tile_mode == 0: 

486 coef = 3 

487 elif tile_mode == 1: 

488 if num_stages == 1: 

489 coef = 3 

490 else: 

491 if dtype == torch.float32: 

492 coef = 7 

493 else: 

494 coef = 6 

495 else: 

496 if num_stages == 1: 

497 coef = 5 

498 else: 

499 if dtype == torch.float32: 

500 coef = 13 

501 else: 

502 coef = 10 

503 return (coef * bn + 1) * bk * 4 

504 

505 

506def config_prune3(configs, named_args, **kwargs): 

507 M = named_args["M"] 

508 N = named_args["N"] 

509 K = named_args["K"] 

510 output = named_args["output_ptr"] 

511 dtype = output.dtype 

512 k_per_core = math.ceil(K / max(TOTAL_CORE_NUM // M, 1)) 

513 # No need for any loop. 

514 if nram_usage_for_backward_non_inner(N, k_per_core, 0, 1, dtype) < MAX_NRAM_SIZE: 

515 config = copy.deepcopy(configs[0]) 

516 config.kwargs["TILE_K"] = k_per_core 

517 config.kwargs["TILE_N"] = N 

518 config.num_stages = 1 

519 return [config] 

520 align_num = 256 // 4 if dtype == torch.float32 else 256 // 2 

521 pruned_configs = [] 

522 for config in configs: 

523 kw = config.kwargs 

524 TILE_K, TILE_N, num_stages = ( 

525 kw["TILE_K"], 

526 kw["TILE_N"], 

527 config.num_stages, 

528 ) 

529 # Align the lowest dimension to 256B while loading/storing data. 

530 if TILE_K % align_num != 0: 

531 continue 

532 # nram usage shoule be smaller than MAX_NRAM_SIZE 

533 mode = softmax_tile_mode_for_non_inner(M, N, K, TILE_N, TILE_K) 

534 nram = nram_usage_for_backward_non_inner( 

535 TILE_N, TILE_K, mode, num_stages, dtype 

536 ) 

537 if nram > MAX_NRAM_SIZE or nram < MAX_NRAM_SIZE // 2: 

538 continue 

539 pruned_configs.append(config) 

540 return pruned_configs 

541 

542 

543@libentry() 

544@libtuner( 

545 configs=runtime.get_tuned_config("softmax_non_inner_bw"), 

546 key=[ 

547 "N", 

548 "K", 

549 ], 

550 prune_configs_by={"early_config_prune": config_prune3}, 

551) 

552@triton.heuristics(runtime.get_heuristic_config("softmax_backward_non_inner")) 

553@triton.jit 

554def softmax_backward_kernel_non_inner( 

555 output_ptr, 

556 out_grad_ptr, 

557 in_grad_ptr, 

558 M, 

559 N, 

560 K, 

561 TILE_N: tl.constexpr, 

562 TILE_K: tl.constexpr, 

563 TILE_MODE: tl.constexpr, 

564): 

565 pid_m = tl.program_id(0) 

566 pid_k = tl.program_id(1) 

567 

568 p_k_num = tl.num_programs(axis=1) 

569 split_k = tl.cdiv(K, p_k_num) 

570 k_start = pid_k * split_k 

571 

572 if TILE_MODE == 0: 

573 n_offset = tl.arange(0, TILE_N) 

574 k_offset = pid_k * TILE_K + tl.arange(0, TILE_K) 

575 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :] 

576 mask = k_offset[None, :] < K 

577 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32) 

578 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32) 

579 scale = tl.sum(out_tile * out_grad_tile, axis=0) 

580 in_grad_tile = out_tile * (out_grad_tile - scale[None, :]) 

581 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask) 

582 elif TILE_MODE == 1: 

583 for k_idx in range(0, split_k, TILE_K): 

584 k_offset = k_start + k_idx + tl.arange(0, TILE_K) 

585 n_offset = tl.arange(0, TILE_N) 

586 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :] 

587 mask = k_offset[None, :] < K and n_offset[:, None] < N 

588 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32) 

589 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32) 

590 scale = tl.sum(out_tile * out_grad_tile, axis=0) 

591 in_grad_tile = out_tile * (out_grad_tile - scale[None, :]) 

592 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask) 

593 else: 

594 for k_idx in range(0, split_k, TILE_K): 

595 k_offset = k_start + k_idx + tl.arange(0, TILE_K) 

596 scale = tl.zeros([TILE_N, TILE_K], dtype=tl.float32) 

597 # specialization does not improve performance inn this example, as tested 

598 for start_n in range(0, N, TILE_N): 

599 n_offset = start_n + tl.arange(0, TILE_N) 

600 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :] 

601 mask = (n_offset[:, None] < N) & (k_offset[None, :] < K) 

602 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32) 

603 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32) 

604 scale += out_tile * out_grad_tile 

605 scale = tl.sum(scale, axis=0) 

606 for start_n in range(0, N, TILE_N): 

607 n_offset = start_n + tl.arange(0, TILE_N) 

608 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :] 

609 mask = (n_offset[:, None] < N) & (k_offset[None, :] < K) 

610 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32) 

611 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32) 

612 in_grad_tile = out_tile * (out_grad_tile - scale[None, :]) 

613 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask) 

614 

615 

616def config_prune4(configs, named_args, **kwargs): 

617 M = named_args["M"] 

618 N = named_args["N"] 

619 output = named_args["output_ptr"] 

620 configs_map = {} 

621 # When N is less than MAX_C_MLU_SOFTMAX_FORWARD, no reduction loops 

622 for config in configs: 

623 kw = config.kwargs 

624 BLOCK_M, BLOCK_N, num_warps, num_stages = ( 

625 kw["BLOCK_M"], 

626 kw["BLOCK_N"], 

627 config.num_warps, 

628 config.num_stages, 

629 ) 

630 if N < MAX_N: 

631 config = copy.deepcopy(config) 

632 BLOCK_N = config.kwargs["BLOCK_N"] = N 

633 m_per_core = math.ceil(M / TOTAL_CORE_NUM) 

634 BLOCK_M = config.kwargs["BLOCK_M"] = m_per_core 

635 num_stages = config.num_stages = 1 

636 key = (BLOCK_M, BLOCK_N, num_warps, num_stages) 

637 configs_map.setdefault(key, config) 

638 

639 config = copy.deepcopy(config) 

640 max_block_m_without_pipe = MAX_NRAM_SIZE // 4 // (3 * BLOCK_N + 1) 

641 BLOCK_M = config.kwargs["BLOCK_M"] = align(max_block_m_without_pipe) 

642 num_stages = config.num_stages = 1 

643 key = (BLOCK_M, BLOCK_N, num_warps, num_stages) 

644 configs_map.setdefault(key, config) 

645 

646 config = copy.deepcopy(config) 

647 max_block_m_without_pipe = MAX_NRAM_SIZE // 4 // (6 * BLOCK_N + 1) 

648 if output.dtype == torch.float32: 

649 max_block_m_without_pipe = MAX_NRAM_SIZE // 4 // (7 * BLOCK_N + 1) 

650 BLOCK_M = config.kwargs["BLOCK_M"] = align(max_block_m_without_pipe) 

651 num_stages = config.num_stages = 3 

652 key = (BLOCK_M, BLOCK_N, num_warps, num_stages) 

653 configs_map.setdefault(key, config) 

654 key = (BLOCK_M, BLOCK_N, num_warps, num_stages) 

655 # Only keep one config for the same key 

656 configs_map.setdefault(key, config) 

657 pruned_configs = [] 

658 for k, v in configs_map.items(): 

659 pruned_configs.append(v) 

660 # Add a heuristic config. 

661 extra_config = copy.deepcopy(pruned_configs[0]) 

662 extra_config.kwargs["BLOCK_M"] = 1 

663 extra_config.kwargs["BLOCK_N"] = N 

664 extra_config.num_warps = 1 

665 extra_config.num_stages = 3 

666 pruned_configs.append(extra_config) 

667 extra_config2 = copy.deepcopy(extra_config) 

668 extra_config2.num_stages = 1 

669 pruned_configs.append(extra_config2) 

670 return pruned_configs 

671 

672 

673@libentry() 

674@libtuner( 

675 configs=runtime.get_tuned_config("softmax_inner_bw"), 

676 key=[ 

677 "M", 

678 "N", 

679 ], 

680 prune_configs_by={"early_config_prune": config_prune4}, 

681) 

682@triton.heuristics( 

683 values=runtime.get_heuristic_config("softmax_backward_inner"), 

684) 

685@triton.jit 

686def softmax_backward_kernel_inner( 

687 output_ptr, 

688 out_grad_ptr, 

689 in_grad_ptr, 

690 M, 

691 N, 

692 BLOCK_M: tl.constexpr, 

693 BLOCK_N: tl.constexpr, 

694 TILE_MODE: tl.constexpr, 

695): 

696 pid_m = tl.program_id(0) 

697 pnum = tl.num_programs(axis=0) 

698 split_m = tl.cdiv(M, pnum) 

699 m_start = pid_m * split_m 

700 

701 if TILE_MODE == 0: 

702 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) 

703 n_offset = tl.arange(0, BLOCK_N) 

704 offset = m_offset[:, None] * N + n_offset[None, :] 

705 mask = m_offset[:, None] < M 

706 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32) 

707 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32) 

708 scale = tl.sum(out_tile * out_grad_tile, 1) 

709 in_grad_tile = out_tile * (out_grad_tile - scale[:, None]) 

710 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask) 

711 elif TILE_MODE == 1: 

712 for m_idx in range(0, split_m, BLOCK_M): 

713 m_offset = m_start + m_idx + tl.arange(0, BLOCK_M) 

714 n_offset = tl.arange(0, BLOCK_N) 

715 offset = m_offset[:, None] * N + n_offset[None, :] 

716 mask = m_offset[:, None] < M 

717 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32) 

718 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32) 

719 scale = tl.sum(out_tile * out_grad_tile, 1) 

720 in_grad_tile = out_tile * (out_grad_tile - scale[:, None]) 

721 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask) 

722 else: 

723 for m_idx in range(0, split_m, BLOCK_M): 

724 m_offset = m_start + m_idx + tl.arange(0, BLOCK_M) 

725 scale = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) 

726 for start_n in range(0, N, BLOCK_N): 

727 n_offset = start_n + tl.arange(0, BLOCK_N) 

728 offset = m_offset[:, None] * N + n_offset[None, :] 

729 mask = m_offset[:, None] < M and n_offset[None, :] < N 

730 out_tile = tl.load( 

731 output_ptr + offset, mask=mask, eviction_policy="evict_last" 

732 ).to(tl.float32) 

733 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32) 

734 scale += out_tile * out_grad_tile 

735 scale = tl.sum(scale, 1) 

736 for start_n in range(0, N, BLOCK_N): 

737 n_offset = start_n + tl.arange(0, BLOCK_N) 

738 offset = m_offset[:, None] * N + n_offset[None, :] 

739 mask = m_offset[:, None] < M and n_offset[None, :] < N 

740 out_tile = tl.load( 

741 output_ptr + offset, mask=mask, eviction_policy="evict_first" 

742 ).to(tl.float32) 

743 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32) 

744 in_grad_tile = out_tile * (out_grad_tile - scale[:, None]) 

745 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask) 

746 

747 

748def softmax(self, dim, half_to_float=False): 

749 logger.debug("GEMS_CAMBRICON SOFTMAX") 

750 

751 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim" 

752 

753 # special handling for dim = 0 and empty tensor 

754 if self.numel() == 0: 

755 # empty tensor, return the same shape with 1's 

756 out_shape = list(self.shape) 

757 out = torch.empty(out_shape, dtype=self.dtype, device=self.device) 

758 zero_(out) 

759 return out 

760 

761 dim = dim % self.ndim 

762 M = 1 

763 N = self.shape[dim] 

764 for i in range(dim): 

765 M *= self.shape[i] # pre_dim 

766 self = self.contiguous() 

767 if half_to_float: 

768 dtype = torch.float32 

769 else: 

770 dtype = self.dtype 

771 out = torch.empty_like(self, dtype=dtype) 

772 K = self.numel() // M // N # post_dim 

773 

774 with torch_device_fn.device(self.device): 

775 if K > 1: 

776 logger.debug("GEMS_CAMBRICON SOFTMAX USE NON INNER") 

777 grid = lambda meta: (M, max(TOTAL_CORE_NUM // M, 1), 1) 

778 softmax_kernel_non_inner[grid]( 

779 out, 

780 self, 

781 M, 

782 N, 

783 K, 

784 ) 

785 else: 

786 logger.debug("GEMS_CAMBRICON SOFTMAX USE INNER") 

787 if M > TOTAL_CORE_NUM or N < 1024 * 8 * 8: 

788 softmax_kernel_inner[TOTAL_CORE_NUM, 1, 1]( 

789 out, 

790 self, 

791 M, 

792 N, 

793 ) 

794 else: 

795 block_m = 1 

796 block_n = 8192 * 4 

797 if dtype is torch.float32: 

798 block_n = 8192 * 2 

799 # workspace 

800 T = (N + block_n - 1) // block_n 

801 max_buf = torch.empty((M, T), device=self.device, dtype=torch.float32) 

802 sum_buf = torch.empty((M, T), device=self.device, dtype=torch.float32) 

803 gmax = torch.empty((M,), device=self.device, dtype=torch.float32) 

804 gsum = torch.empty((M,), device=self.device, dtype=torch.float32) 

805 # kernel 1: per-tile stats 

806 softmax_kernel_inner_k_partial_stats[(TOTAL_CORE_NUM,)]( 

807 self, 

808 max_buf, 

809 sum_buf, 

810 M, 

811 N, 

812 T, 

813 BLOCK_M=block_m, 

814 BLOCK_N=block_n, 

815 bottleneck="simd", 

816 num_stages=3, 

817 ) 

818 # kernel 2: merge stats along N-tiles 

819 grid_merge = (triton.cdiv(M, block_m),) 

820 softmax_kernel_inner_k_merge_stats[grid_merge]( 

821 max_buf, sum_buf, gmax, gsum, M, T, BLOCK_M=block_m 

822 ) 

823 block_n = block_n // 2 

824 T = (N + block_n - 1) // block_n 

825 # kernel 3: write normalized outputs 

826 softmax_kernel_inner_k_write_softmax[(TOTAL_CORE_NUM,)]( 

827 self, 

828 out, 

829 gmax, 

830 gsum, 

831 M, 

832 N, 

833 T, 

834 BLOCK_M=block_m, 

835 BLOCK_N=block_n, 

836 bottleneck="simd", 

837 num_stages=3, 

838 ) 

839 return out 

840 

841 

842def softmax_backward(grad_output, output, dim, input_dtype): 

843 logger.debug("GEMS_CAMBRICON SOFTMAX VJP") 

844 

845 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim" 

846 dim = dim % output.ndim 

847 M = 1 

848 N = output.shape[dim] 

849 for i in range(dim): 

850 M *= output.shape[i] 

851 

852 grad_output = grad_output.contiguous() 

853 in_grad = torch.empty_like(output) 

854 K = output.numel() // M // N 

855 

856 with torch_device_fn.device(in_grad.device): 

857 if K > 1: 

858 logger.debug("GEMS_CAMBRICON SOFTMAX VJP USE NON INNER") 

859 grid = lambda meta: (M, max(TOTAL_CORE_NUM // M, 1), 1) 

860 softmax_backward_kernel_non_inner[grid]( 

861 output, 

862 grad_output, 

863 in_grad, 

864 M, 

865 N, 

866 K, 

867 ) 

868 else: 

869 logger.debug("GEMS_CAMBRICON SOFTMAX VJP USE INNER") 

870 softmax_backward_kernel_inner[TOTAL_CORE_NUM, 1, 1]( 

871 output, 

872 grad_output, 

873 in_grad, 

874 M, 

875 N, 

876 ) 

877 return in_grad