Coverage for src/flag_gems/runtime/backend/_sunrise/heuristics_config_utils.py: 0%

260 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import torch # noqa: F401 

2import triton 

3 

4_MIN_TILE_N = 64 

5_MAX_TILE_N_PER_ROW = 4096 

6_MAX_ONE_TILE_N = 2048 

7 

8 

9def simple_elementwise_blocksize_heur(args): 

10 return 1024 

11 

12 

13def argmax_heur_tile_k(args): 

14 MAX_TILE_K = 512 

15 NUM_SMS = torch.ptpu.get_device_properties( 

16 torch.ptpu.current_device() 

17 ).multi_processor_count 

18 

19 K = args["K"] 

20 M = args["M"] 

21 dtype = "fp32" if args["inp"].dtype == torch.float32 else "fp16" 

22 

23 if M == 64 and K == 512: 

24 return 64 if dtype == "fp32" else 128 

25 

26 if K <= 128: 

27 return 1 << (K.bit_length() - 1) if K > 0 else 1 

28 

29 tile_k = 64 

30 upper_bound = min(K, MAX_TILE_K) 

31 

32 while tile_k <= upper_bound: 

33 num_blocks = M * triton.cdiv(K, tile_k) 

34 num_waves = num_blocks / NUM_SMS 

35 

36 if num_waves > 1 and (tile_k * 2 <= upper_bound): 

37 tile_k *= 2 

38 else: 

39 break 

40 

41 return tile_k 

42 

43 

44def argmax_heur_tile_n_non_inner(args): 

45 n = args["N"] 

46 tile_k = args["TILE_K"] 

47 

48 if n <= 128: 

49 return n 

50 

51 target_tile = min(8192, n) 

52 tile_n = triton.next_power_of_2(target_tile) 

53 tile_n = max(64, min(tile_n, 4096)) 

54 

55 if tile_n * tile_k > 32768: 

56 tile_n = max(64, 32768 // tile_k) 

57 

58 return tile_n 

59 

60 

61def argmax_heur_one_tile_per_cta(args): 

62 return args["TILE_N"] >= args["N"] 

63 

64 

65def argmax_heur_num_warps_non_inner(args): 

66 tile_n = args["TILE_N"] 

67 dtype = "fp32" if args["inp"].dtype == torch.float32 else "fp16" 

68 

69 if tile_n <= 32: 

70 num_warps = 2 

71 elif tile_n <= 64: 

72 num_warps = 4 

73 elif tile_n <= 128: 

74 num_warps = 4 

75 else: 

76 num_warps = 8 

77 

78 if dtype == "fp32": 

79 num_warps = min(num_warps, 4) 

80 

81 return num_warps 

82 

83 

84def argmax_heur_tile_n_inner(args): 

85 if args["N"] <= (32 * 1024): 

86 return triton.next_power_of_2(args["N"]) 

87 else: 

88 return 4096 

89 

90 

91def argmax_heur_num_warps_inner(args): 

92 tile_size = args["TILE_N"] 

93 if tile_size < 2048: 

94 return 4 

95 elif tile_size < 4096: 

96 return 8 

97 else: 

98 return 16 

99 

100 

101def argmax_heur_block_m(args): 

102 return 1 if args["M"] < 4096 else 8 

103 

104 

105def argmax_heur_block_n(args): 

106 return min(4096, triton.next_power_of_2(args["N"])) 

107 

108 

109def argmin_heur_block_m(args): 

110 return 4 if args["M"] < 4096 else 8 

111 

112 

113def argmin_heur_block_n(args): 

114 return min(4096, triton.next_power_of_2(args["N"])) 

115 

116 

117def bmm_heur_divisible_m(args): 

118 return args["M"] % args["TILE_M"] == 0 

119 

120 

121def bmm_heur_divisible_n(args): 

122 return args["N"] % args["TILE_N"] == 0 

123 

124 

125def bmm_heur_divisible_k(args): 

126 return args["K"] % args["TILE_K"] == 0 

127 

128 

129def dropout_heur_block(args): 

130 if args["N"] <= 512: 

131 return 256 

132 else: 

133 return 512 

134 

135 

136def dropout_heur_num_warps(args): 

137 if args["N"] <= 512: 

138 return 2 

139 elif args["N"] <= 2048: 

140 return 4 

141 else: 

142 return 8 

143 

144 

145def exponential_heur_block(args): 

146 if args["N"] <= 512: 

147 return 512 

148 else: 

149 return 1024 

150 

151 

152def exponential_heur_num_warps(args): 

153 if args["N"] <= 512: 

154 return 4 

155 elif args["N"] <= 1024: 

156 return 8 

157 else: 

158 return 16 

159 

160 

161def gather_heur_block_m(args): 

162 return min(1, triton.next_power_of_2(triton.cdiv(args["N"], 2048))) 

163 

164 

165def gather_heur_block_n(args): 

166 return min(2048, triton.next_power_of_2(args["N"])) 

167 

168 

169def index_select_heur_block_m(args): 

170 return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"]))) 

171 

172 

173def index_select_heur_block_n(args): 

174 m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512) 

175 return max(m, 16) 

176 

177 

178def mm_heur_even_k(args): 

179 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0 

180 

181 

182def rand_heur_block(args): 

183 if args["N"] <= 512: 

184 return 512 

185 else: 

186 return 1024 

187 

188 

189def rand_heur_num_warps(args): 

190 if args["N"] <= 512: 

191 return 4 

192 elif args["N"] <= 1024: 

193 return 8 

194 else: 

195 return 16 

196 

197 

198def randn_heur_block(args): 

199 if args["N"] <= 512: 

200 return 512 

201 else: 

202 return 1024 

203 

204 

205def randn_heur_num_warps(args): 

206 if args["N"] <= 512: 

207 return 4 

208 elif args["N"] <= 1024: 

209 return 8 

210 else: 

211 return 16 

212 

213 

214def softmax_heur_tile_k(args): 

215 MAX_TILE_K = 512 

216 # NUM_SMS = torch.cuda.get_device_properties( 

217 # torch.cuda.current_device() 

218 # ).multi_processor_count 

219 NUM_SMS = 32 # Not support now. 

220 

221 tile_k = 1 

222 upper_bound = min(args["K"], MAX_TILE_K) 

223 while tile_k <= upper_bound: 

224 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k) 

225 num_waves = num_blocks / NUM_SMS 

226 if (num_waves > 1) and (tile_k * 2 <= upper_bound): 

227 tile_k *= 2 

228 else: 

229 break 

230 return tile_k 

231 

232 

233def softmax_heur_tile_n_non_inner(args): 

234 return triton.cdiv(8192, args["TILE_K"]) 

235 

236 

237def softmax_heur_one_tile_per_cta(args): 

238 return args["TILE_N"] >= args["N"] 

239 

240 

241def softmax_heur_num_warps_non_inner(args): 

242 tile_size = args["TILE_N"] * args["TILE_K"] 

243 if tile_size < 2048: 

244 return 4 

245 elif tile_size < 4096: 

246 return 8 

247 else: 

248 return 16 

249 

250 

251def softmax_heur_tile_n_inner(args): 

252 if args["N"] <= 32: 

253 return triton.next_power_of_2(args["N"]) 

254 if args["N"] <= 1024: 

255 return 256 

256 else: 

257 return 512 

258 

259 

260def softmax_heur_num_warps_inner(args): 

261 tile_size = args["TILE_N"] 

262 if tile_size < 64: 

263 return 2 

264 if tile_size < 2048: 

265 return 4 

266 elif tile_size < 4096: 

267 return 8 

268 else: 

269 return 16 

270 

271 

272def softmax_heur_tile_n_bwd_non_inner(args): 

273 return max(1, 1024 // args["TILE_K"]) 

274 

275 

276def softmax_heur_tile_m(args): 

277 return max(1, 1024 // args["TILE_N"]) 

278 

279 

280def uniform_heur_block(args): 

281 if args["N"] <= 512: 

282 return 512 

283 else: 

284 return 1024 

285 

286 

287def uniform_heur_num_warps(args): 

288 if args["N"] <= 512: 

289 return 4 

290 elif args["N"] <= 1024: 

291 return 8 

292 else: 

293 return 16 

294 

295 

296def var_mean_heur_block_n(args): 

297 return triton.next_power_of_2(args["BLOCK_NUM"]) 

298 

299 

300def upsample_nearest2d_SAME_H(args): 

301 return args["OH"] == args["IH"] 

302 

303 

304def upsample_nearest2d_SAME_W(args): 

305 return args["OW"] == args["IW"] 

306 

307 

308def upsample_nearest2d_USE_INT32_IDX(args): 

309 return args["N"] * args["C"] * args["OH"] * args["OW"] <= (2**31 - 1) # INT32 MAX 

310 

311 

312def batch_norm_heur_block_m(args): 

313 return min(256, triton.next_power_of_2(args["batch_dim"])) 

314 

315 

316def batch_norm_heur_block_n(args): 

317 # A maximum of 16384 elements are loaded at once. 

318 BLOCK_M = batch_norm_heur_block_m(args) 

319 BLOCK_N = triton.next_power_of_2(args["spatial_dim"]) 

320 return min(BLOCK_N, max(1, 2**10 // BLOCK_M)) 

321 

322 

323def vdot_heur_block_size(args): 

324 n = args["n_elements"] 

325 if n < 1024: 

326 return 32 

327 elif n < 8192: 

328 return 256 

329 else: 

330 return 1024 

331 

332 

333def mean_heur_tile_k(args): 

334 MAX_TILE_K = 512 

335 NUM_SMS = torch.ptpu.get_device_properties( 

336 torch.ptpu.current_device() 

337 ).multi_processor_count 

338 tile_k = 1 

339 upper_bound = min(args["K"], MAX_TILE_K) 

340 max_tile_k_allowed_by_tile_n = max(1, _MAX_TILE_N_PER_ROW // _MIN_TILE_N) 

341 upper_bound = min(upper_bound, max_tile_k_allowed_by_tile_n) 

342 while tile_k <= upper_bound: 

343 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k) 

344 num_waves = num_blocks / NUM_SMS 

345 if (num_waves > 1) and (tile_k * 2 <= upper_bound): 

346 tile_k *= 2 

347 else: 

348 break 

349 return tile_k 

350 

351 

352def sum_heur_num_warps_inner(args): 

353 tile_size = args["TILE_N"] 

354 if tile_size < 64: 

355 return 2 

356 if tile_size < 2048: 

357 return 4 

358 elif tile_size < 4096: 

359 return 8 

360 else: 

361 return 16 

362 

363 

364def sum_heur_tile_n_inner(args): 

365 if args["N"] <= 32: 

366 return triton.next_power_of_2(args["N"]) 

367 if args["N"] <= 1024: 

368 return 128 

369 else: 

370 return 256 

371 

372 

373def sum_heur_one_tile_per_cta(args): 

374 return args["TILE_N"] >= args["N"] 

375 

376 

377def sum_heur_tile_k(args): 

378 MAX_TILE_K = 128 

379 # NUM_SMS = torch.cuda.get_device_properties( 

380 # torch.cuda.current_device() 

381 # ).multi_processor_count 

382 NUM_SMS = 32 # Not support now. 

383 

384 tile_k = 1 

385 upper_bound = min(args["K"], MAX_TILE_K) 

386 while tile_k <= upper_bound: 

387 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k) 

388 num_waves = num_blocks / NUM_SMS 

389 if (num_waves > 1) and (tile_k * 2 <= upper_bound): 

390 tile_k *= 2 

391 else: 

392 break 

393 return tile_k 

394 

395 

396def mean_heur_tile_n_non_inner(args): 

397 tile_k = args.get("TILE_K", 1) 

398 limit_by_k = max(1, _MAX_TILE_N_PER_ROW // tile_k) 

399 N = args.get("N", 1) 

400 desired = min(max(N, _MIN_TILE_N), limit_by_k) 

401 desired = min(desired, _MAX_ONE_TILE_N, limit_by_k) 

402 tile_n = triton.next_power_of_2(desired) 

403 if tile_n > limit_by_k: 

404 tile_n = limit_by_k 

405 tile_n = max(tile_n, _MIN_TILE_N) 

406 return tile_n 

407 

408 

409def mean_heur_one_tile_per_cta(args): 

410 return args["TILE_N"] >= args["N"] 

411 

412 

413def sum_heur_tile_n_non_inner(args): 

414 return triton.cdiv(256, args["TILE_K"]) 

415 

416 

417HEURISTICS_CONFIGS = { 

418 "argmax_non_inner": { 

419 "TILE_K": argmax_heur_tile_k, 

420 "TILE_N": argmax_heur_tile_n_non_inner, 

421 "ONE_TILE_PER_CTA": argmax_heur_one_tile_per_cta, 

422 "num_warps": argmax_heur_num_warps_non_inner, 

423 }, 

424 "argmax_inner": { 

425 "TILE_N": argmax_heur_tile_n_inner, 

426 "ONE_TILE_PER_CTA": argmax_heur_one_tile_per_cta, 

427 "num_warps": argmax_heur_num_warps_inner, 

428 }, 

429 "argmin": { 

430 "BLOCK_M": argmin_heur_block_m, 

431 "BLOCK_N": argmin_heur_block_n, 

432 }, 

433 "bmm": { 

434 "DIVISIBLE_M": bmm_heur_divisible_m, 

435 "DIVISIBLE_N": bmm_heur_divisible_n, 

436 "DIVISIBLE_K": bmm_heur_divisible_k, 

437 }, 

438 "dropout": { 

439 "BLOCK": dropout_heur_block, 

440 "num_warps": dropout_heur_num_warps, 

441 }, 

442 "exponential_": { 

443 "BLOCK": exponential_heur_block, 

444 "num_warps": exponential_heur_num_warps, 

445 }, 

446 "gather": { 

447 "BLOCK_M": gather_heur_block_m, 

448 "BLOCK_N": gather_heur_block_n, 

449 }, 

450 "index_select": { 

451 "BLOCK_M": index_select_heur_block_m, 

452 "BLOCK_N": index_select_heur_block_n, 

453 }, 

454 "mm": { 

455 "EVEN_K": mm_heur_even_k, 

456 }, 

457 "rand": { 

458 "BLOCK": rand_heur_block, 

459 "num_warps": rand_heur_num_warps, 

460 }, 

461 "randn": { 

462 "BLOCK": randn_heur_block, 

463 "num_warps": randn_heur_num_warps, 

464 }, 

465 "softmax_non_inner": { 

466 "TILE_K": softmax_heur_tile_k, 

467 "TILE_N": softmax_heur_tile_n_non_inner, 

468 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

469 "num_warps": softmax_heur_num_warps_non_inner, 

470 }, 

471 "mean_non_inner": { 

472 "TILE_K": mean_heur_tile_k, 

473 "TILE_N": mean_heur_tile_n_non_inner, 

474 "ONE_TILE_PER_CTA": mean_heur_one_tile_per_cta, 

475 "num_warps": softmax_heur_num_warps_non_inner, 

476 }, 

477 "softmax_inner": { 

478 "TILE_N": softmax_heur_tile_n_inner, 

479 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

480 "num_warps": softmax_heur_num_warps_inner, 

481 }, 

482 "softmax_backward_non_inner": { 

483 "TILE_N": softmax_heur_tile_n_bwd_non_inner, 

484 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

485 }, 

486 "softmax_backward_inner": { 

487 "TILE_M": softmax_heur_tile_m, 

488 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

489 }, 

490 "uniform": { 

491 "BLOCK": uniform_heur_block, 

492 "num_warps": uniform_heur_num_warps, 

493 }, 

494 "upsample_nearest2d": { 

495 "SAME_H": upsample_nearest2d_SAME_H, 

496 "SAME_W": upsample_nearest2d_SAME_W, 

497 "USE_INT32_IDX": upsample_nearest2d_USE_INT32_IDX, 

498 }, 

499 "var_mean": { 

500 "BLOCK_N": var_mean_heur_block_n, 

501 }, 

502 "batch_norm": { 

503 "BLOCK_M": batch_norm_heur_block_m, 

504 "BLOCK_N": batch_norm_heur_block_n, 

505 }, 

506 "vdot": { 

507 "BLOCK_SIZE": vdot_heur_block_size, 

508 }, 

509 "mha_varlen_prefill": { 

510 "BLOCK_M": lambda args: 128, 

511 "BLOCK_N": lambda args: 32, 

512 "num_warps": lambda args: 4, 

513 "num_stages": lambda args: 3, 

514 }, 

515 "mha_varlen_decode": { 

516 "BLOCK_M": lambda args: 16, 

517 "BLOCK_N": lambda args: 64, 

518 "num_warps": lambda args: 4, 

519 "num_stages": lambda args: 3, 

520 }, 

521 "mha_block_128": { 

522 "BLOCK_M": lambda args: 128, 

523 "BLOCK_N": lambda args: 8, 

524 "num_warps": lambda args: 16, 

525 "num_stages": lambda args: 1, 

526 }, 

527 "mha_block_64": { 

528 "BLOCK_M": lambda args: 64, 

529 "BLOCK_N": lambda args: 64, 

530 "num_warps": lambda args: 4, 

531 "num_stages": lambda args: 3, 

532 }, 

533 "mha_block_32": { 

534 "BLOCK_M": lambda args: 32, 

535 "BLOCK_N": lambda args: 64, 

536 "num_warps": lambda args: 4, 

537 "num_stages": lambda args: 3, 

538 }, 

539 "mha_block_16": { 

540 "BLOCK_M": lambda args: 16, 

541 "BLOCK_N": lambda args: 16, 

542 "num_warps": lambda args: 8, 

543 "num_stages": lambda args: 1, 

544 }, 

545 "elementwise_generic": { 

546 "BLOCK_SIZE": simple_elementwise_blocksize_heur, 

547 "num_warps": lambda args: 8, 

548 }, 

549 "sum_inner": { 

550 "TILE_N": sum_heur_tile_n_inner, 

551 "ONE_TILE_PER_CTA": sum_heur_one_tile_per_cta, 

552 "num_warps": sum_heur_num_warps_inner, 

553 }, 

554 "sum_non_inner": { 

555 "TILE_K": sum_heur_tile_k, 

556 "TILE_N": sum_heur_tile_n_non_inner, 

557 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

558 "num_warps": softmax_heur_num_warps_non_inner, 

559 }, 

560}