Coverage for src/flag_gems/runtime/backend/_mthreads/heuristics_config_utils.py: 0%

243 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import torch 

2import triton 

3 

4_MIN_TILE_N = 64 

5_MAX_TILE_N_PER_ROW = 4096 

6_MAX_ONE_TILE_N = 2048 

7 

8 

9def simple_elementwise_blocksize_heur(args): 

10 return 1024 

11 

12 

13def argmax_heur_tile_k(args): 

14 MAX_TILE_K = 512 

15 NUM_SMS = torch.musa.get_device_properties( 

16 torch.musa.current_device() 

17 ).multi_processor_count 

18 

19 K = args["K"] 

20 M = args["M"] 

21 dtype = "fp32" if args["inp"].dtype == torch.float32 else "fp16" 

22 

23 if M == 64 and K == 512: 

24 return 64 if dtype == "fp32" else 128 

25 

26 if K <= 128: 

27 return 1 << (K.bit_length() - 1) if K > 0 else 1 

28 

29 tile_k = 64 

30 upper_bound = min(K, MAX_TILE_K) 

31 

32 while tile_k <= upper_bound: 

33 num_blocks = M * triton.cdiv(K, tile_k) 

34 num_waves = num_blocks / NUM_SMS 

35 

36 if num_waves > 1 and (tile_k * 2 <= upper_bound): 

37 tile_k *= 2 

38 else: 

39 break 

40 

41 return tile_k 

42 

43 

44def argmax_heur_tile_n_non_inner(args): 

45 n = args["N"] 

46 tile_k = args["TILE_K"] 

47 

48 if n <= 128: 

49 return n 

50 

51 target_tile = min(8192, n) 

52 tile_n = triton.next_power_of_2(target_tile) 

53 tile_n = max(64, min(tile_n, 4096)) 

54 

55 if tile_n * tile_k > 32768: 

56 tile_n = max(64, 32768 // tile_k) 

57 

58 return tile_n 

59 

60 

61def argmax_heur_one_tile_per_cta(args): 

62 return args["TILE_N"] >= args["N"] 

63 

64 

65def argmax_heur_num_warps_non_inner(args): 

66 tile_n = args["TILE_N"] 

67 dtype = "fp32" if args["inp"].dtype == torch.float32 else "fp16" 

68 

69 if tile_n <= 32: 

70 num_warps = 2 

71 elif tile_n <= 64: 

72 num_warps = 4 

73 elif tile_n <= 128: 

74 num_warps = 4 

75 else: 

76 num_warps = 8 

77 

78 if dtype == "fp32": 

79 num_warps = min(num_warps, 4) 

80 

81 return num_warps 

82 

83 

84def argmax_heur_tile_n_inner(args): 

85 if args["N"] <= (32 * 1024): 

86 return triton.next_power_of_2(args["N"]) 

87 else: 

88 return 4096 

89 

90 

91def argmax_heur_num_warps_inner(args): 

92 tile_size = args["TILE_N"] 

93 if tile_size < 2048: 

94 return 4 

95 elif tile_size < 4096: 

96 return 8 

97 else: 

98 return 16 

99 

100 

101def argmin_heur_block_m(args): 

102 return 4 if args["M"] < 4096 else 8 

103 

104 

105def argmin_heur_block_n(args): 

106 return min(4096, triton.next_power_of_2(args["N"])) 

107 

108 

109def bmm_heur_divisible_m(args): 

110 return args["M"] % args["TILE_M"] == 0 

111 

112 

113def bmm_heur_divisible_n(args): 

114 return args["N"] % args["TILE_N"] == 0 

115 

116 

117def bmm_heur_divisible_k(args): 

118 return args["K"] % args["TILE_K"] == 0 

119 

120 

121def baddbmm_heur_divisible_m(args): 

122 return args["M"] % args["TILE_M"] == 0 

123 

124 

125def baddbmm_heur_divisible_n(args): 

126 return args["N"] % args["TILE_N"] == 0 

127 

128 

129def baddbmm_heur_divisible_k(args): 

130 return args["K"] % args["TILE_K"] == 0 

131 

132 

133def dropout_heur_block(args): 

134 if args["N"] <= 512: 

135 return 512 

136 else: 

137 return 1024 

138 

139 

140def dropout_heur_num_warps(args): 

141 if args["N"] <= 512: 

142 return 4 

143 elif args["N"] <= 1024: 

144 return 8 

145 else: 

146 return 16 

147 

148 

149def exponential_heur_block(args): 

150 if args["N"] <= 512: 

151 return 512 

152 else: 

153 return 1024 

154 

155 

156def exponential_heur_num_warps(args): 

157 if args["N"] <= 512: 

158 return 4 

159 elif args["N"] <= 1024: 

160 return 8 

161 else: 

162 return 16 

163 

164 

165def gather_heur_block_m(args): 

166 return min(4, triton.next_power_of_2(triton.cdiv(args["N"], 2048))) 

167 

168 

169def gather_heur_block_n(args): 

170 return min(2048, triton.next_power_of_2(args["N"])) 

171 

172 

173def index_select_heur_block_m(args): 

174 return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"]))) 

175 

176 

177def index_select_heur_block_n(args): 

178 m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512) 

179 return max(m, 16) 

180 

181 

182def index_add_heur_block_m(args): 

183 return 4 if args["M"] < 4096 else 8 

184 

185 

186def index_add_heur_block_n(args): 

187 return min(8192, triton.next_power_of_2(args["N"])) 

188 

189 

190def mm_heur_even_k(args): 

191 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0 

192 

193 

194def rand_heur_block(args): 

195 if args["N"] <= 512: 

196 return 512 

197 else: 

198 return 1024 

199 

200 

201def rand_heur_num_warps(args): 

202 if args["N"] <= 512: 

203 return 4 

204 elif args["N"] <= 1024: 

205 return 8 

206 else: 

207 return 16 

208 

209 

210def randn_heur_block(args): 

211 if args["N"] <= 512: 

212 return 512 

213 else: 

214 return 1024 

215 

216 

217def randn_heur_num_warps(args): 

218 if args["N"] <= 512: 

219 return 4 

220 elif args["N"] <= 1024: 

221 return 8 

222 else: 

223 return 16 

224 

225 

226def softmax_heur_tile_k(args): 

227 MAX_TILE_K = 8192 

228 NUM_SMS = torch.musa.get_device_properties( 

229 torch.musa.current_device() 

230 ).multi_processor_count 

231 tile_k = 1 

232 upper_bound = min(args["K"], MAX_TILE_K) 

233 while tile_k <= upper_bound: 

234 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k) 

235 num_waves = num_blocks / NUM_SMS 

236 if (num_waves > 1) and (tile_k * 2 <= upper_bound): 

237 tile_k *= 2 

238 else: 

239 break 

240 return tile_k 

241 

242 

243def softmax_heur_tile_n_non_inner(args): 

244 return triton.cdiv(8192, args["TILE_K"]) 

245 

246 

247def softmax_heur_one_tile_per_cta(args): 

248 return args["TILE_N"] >= args["N"] 

249 

250 

251def softmax_heur_num_warps_non_inner(args): 

252 tile_size = args["TILE_N"] * args["TILE_K"] 

253 if tile_size < 2048: 

254 return 4 

255 elif tile_size < 4096: 

256 return 8 

257 else: 

258 return 16 

259 

260 

261def softmax_heur_tile_n_inner(args): 

262 if args["N"] <= (32 * 1024): 

263 return triton.next_power_of_2(args["N"]) 

264 else: 

265 return 4096 

266 

267 

268def softmax_heur_num_warps_inner(args): 

269 tile_size = args["TILE_N"] 

270 if tile_size < 2048: 

271 return 4 

272 elif tile_size < 4096: 

273 return 8 

274 else: 

275 return 16 

276 

277 

278def softmax_heur_tile_n_bwd_non_inner(args): 

279 return max(1, 1024 // args["TILE_K"]) 

280 

281 

282def softmax_heur_tile_m(args): 

283 return max(1, 1024 // args["TILE_N"]) 

284 

285 

286def uniform_heur_block(args): 

287 if args["N"] <= 512: 

288 return 512 

289 else: 

290 return 1024 

291 

292 

293def uniform_heur_num_warps(args): 

294 if args["N"] <= 512: 

295 return 4 

296 elif args["N"] <= 1024: 

297 return 8 

298 else: 

299 return 16 

300 

301 

302def var_mean_heur_block_n(args): 

303 return triton.next_power_of_2(args["BLOCK_NUM"]) 

304 

305 

306def upsample_nearest2d_SAME_H(args): 

307 return args["OH"] == args["IH"] 

308 

309 

310def upsample_nearest2d_SAME_W(args): 

311 return args["OW"] == args["IW"] 

312 

313 

314def upsample_nearest2d_USE_INT32_IDX(args): 

315 return args["N"] * args["C"] * args["OH"] * args["OW"] <= (2**31 - 1) # INT32 MAX 

316 

317 

318def upsample_nearest3d_SAME_D(args): 

319 return args["OD"] == args["ID"] 

320 

321 

322def upsample_nearest3d_SAME_H(args): 

323 return args["OH"] == args["IH"] 

324 

325 

326def upsample_nearest3d_SAME_W(args): 

327 return args["OW"] == args["IW"] 

328 

329 

330def upsample_nearest3d_USE_INT32_IDX(args): 

331 return args["N"] * args["C"] * args["OD"] * args["OH"] * args["OW"] <= (2**31 - 1) 

332 

333 

334def batch_norm_heur_block_m(args): 

335 return min(2048, triton.next_power_of_2(args["batch_dim"])) 

336 

337 

338def batch_norm_heur_block_n(args): 

339 # A maximum of 16384 elements are loaded at once. 

340 BLOCK_M = batch_norm_heur_block_m(args) 

341 BLOCK_N = triton.next_power_of_2(args["spatial_dim"]) 

342 return min(BLOCK_N, max(1, 2**14 // BLOCK_M)) 

343 

344 

345def vdot_heur_block_size(args): 

346 n = args["n_elements"] 

347 if n < 1024: 

348 return 32 

349 elif n < 8192: 

350 return 256 

351 else: 

352 return 1024 

353 

354 

355def mean_heur_tile_k(args): 

356 MAX_TILE_K = 512 

357 MAX_GRID_Y = 65535 

358 NUM_SMS = torch.musa.get_device_properties( 

359 torch.musa.current_device() 

360 ).multi_processor_count 

361 tile_k = 1 

362 upper_bound = min(args["K"], MAX_TILE_K) 

363 max_tile_k_allowed_by_tile_n = max(1, _MAX_TILE_N_PER_ROW // _MIN_TILE_N) 

364 upper_bound = min(upper_bound, max_tile_k_allowed_by_tile_n) 

365 while tile_k <= upper_bound: 

366 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k) 

367 num_waves = num_blocks / NUM_SMS 

368 if (num_waves > 1) and (tile_k * 2 <= upper_bound): 

369 tile_k *= 2 

370 else: 

371 break 

372 # Ensure grid Y dimension does not exceed CUDA limit 

373 min_tile_k = triton.cdiv(args["K"], MAX_GRID_Y) 

374 if min_tile_k > tile_k: 

375 tile_k = triton.next_power_of_2(min_tile_k) 

376 return tile_k 

377 

378 

379def mean_heur_tile_n_non_inner(args): 

380 tile_k = args.get("TILE_K", 1) 

381 limit_by_k = max(1, _MAX_TILE_N_PER_ROW // tile_k) 

382 N = args.get("N", 1) 

383 desired = min(max(N, _MIN_TILE_N), limit_by_k) 

384 desired = min(desired, _MAX_ONE_TILE_N, limit_by_k) 

385 tile_n = triton.next_power_of_2(desired) 

386 if tile_n > limit_by_k: 

387 tile_n = limit_by_k 

388 tile_n = max(tile_n, _MIN_TILE_N) 

389 return tile_n 

390 

391 

392def mean_heur_one_tile_per_cta(args): 

393 return args["TILE_N"] >= args["N"] 

394 

395 

396HEURISTICS_CONFIGS = { 

397 "argmax_non_inner": { 

398 "TILE_K": argmax_heur_tile_k, 

399 "TILE_N": argmax_heur_tile_n_non_inner, 

400 "ONE_TILE_PER_CTA": argmax_heur_one_tile_per_cta, 

401 "num_warps": argmax_heur_num_warps_non_inner, 

402 }, 

403 "argmax_inner": { 

404 "TILE_N": argmax_heur_tile_n_inner, 

405 "ONE_TILE_PER_CTA": argmax_heur_one_tile_per_cta, 

406 "num_warps": argmax_heur_num_warps_inner, 

407 }, 

408 "argmin": { 

409 "BLOCK_M": argmin_heur_block_m, 

410 "BLOCK_N": argmin_heur_block_n, 

411 }, 

412 "bmm": { 

413 "DIVISIBLE_M": bmm_heur_divisible_m, 

414 "DIVISIBLE_N": bmm_heur_divisible_n, 

415 "DIVISIBLE_K": bmm_heur_divisible_k, 

416 }, 

417 "baddbmm": { 

418 "DIVISIBLE_M": baddbmm_heur_divisible_m, 

419 "DIVISIBLE_N": baddbmm_heur_divisible_n, 

420 "DIVISIBLE_K": baddbmm_heur_divisible_k, 

421 }, 

422 "dropout": { 

423 "BLOCK": dropout_heur_block, 

424 "num_warps": dropout_heur_num_warps, 

425 }, 

426 "exponential_": { 

427 "BLOCK": exponential_heur_block, 

428 "num_warps": exponential_heur_num_warps, 

429 }, 

430 "gather": { 

431 "BLOCK_M": gather_heur_block_m, 

432 "BLOCK_N": gather_heur_block_n, 

433 }, 

434 "index_select": { 

435 "BLOCK_M": index_select_heur_block_m, 

436 "BLOCK_N": index_select_heur_block_n, 

437 }, 

438 "index_add": { 

439 "BLOCK_M": index_add_heur_block_m, 

440 "BLOCK_N": index_add_heur_block_n, 

441 }, 

442 "mm": { 

443 "EVEN_K": mm_heur_even_k, 

444 }, 

445 "rand": { 

446 "BLOCK": rand_heur_block, 

447 "num_warps": rand_heur_num_warps, 

448 }, 

449 "randn": { 

450 "BLOCK": randn_heur_block, 

451 "num_warps": randn_heur_num_warps, 

452 }, 

453 "softmax_non_inner": { 

454 "TILE_K": softmax_heur_tile_k, 

455 "TILE_N": softmax_heur_tile_n_non_inner, 

456 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

457 "num_warps": softmax_heur_num_warps_non_inner, 

458 }, 

459 "softmax_inner": { 

460 "TILE_N": softmax_heur_tile_n_inner, 

461 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

462 "num_warps": softmax_heur_num_warps_inner, 

463 }, 

464 "softmax_backward_non_inner": { 

465 "TILE_N": softmax_heur_tile_n_bwd_non_inner, 

466 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

467 }, 

468 "softmax_backward_inner": { 

469 "TILE_M": softmax_heur_tile_m, 

470 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

471 }, 

472 "mean_non_inner": { 

473 "TILE_K": mean_heur_tile_k, 

474 "TILE_N": mean_heur_tile_n_non_inner, 

475 "ONE_TILE_PER_CTA": mean_heur_one_tile_per_cta, 

476 "num_warps": softmax_heur_num_warps_non_inner, 

477 }, 

478 "uniform": { 

479 "BLOCK": uniform_heur_block, 

480 "num_warps": uniform_heur_num_warps, 

481 }, 

482 "upsample_nearest2d": { 

483 "SAME_H": upsample_nearest2d_SAME_H, 

484 "SAME_W": upsample_nearest2d_SAME_W, 

485 "USE_INT32_IDX": upsample_nearest2d_USE_INT32_IDX, 

486 }, 

487 "upsample_nearest3d": { 

488 "SAME_D": upsample_nearest3d_SAME_D, 

489 "SAME_H": upsample_nearest3d_SAME_H, 

490 "SAME_W": upsample_nearest3d_SAME_W, 

491 "USE_INT32_IDX": upsample_nearest3d_USE_INT32_IDX, 

492 }, 

493 "var_mean": { 

494 "BLOCK_N": var_mean_heur_block_n, 

495 }, 

496 "batch_norm": { 

497 "BLOCK_M": batch_norm_heur_block_m, 

498 "BLOCK_N": batch_norm_heur_block_n, 

499 }, 

500 "vdot": { 

501 "BLOCK_SIZE": vdot_heur_block_size, 

502 }, 

503 "mha_block_128": { 

504 "BLOCK_M": lambda args: 128, 

505 "BLOCK_N": lambda args: 32, 

506 "num_warps": lambda args: 4, 

507 "num_stages": lambda args: 1, 

508 }, 

509 "mha_block_64": { 

510 "BLOCK_M": lambda args: 64, 

511 "BLOCK_N": lambda args: 64, 

512 "num_warps": lambda args: 4, 

513 "num_stages": lambda args: 1, 

514 }, 

515 "mha_block_32": { 

516 "BLOCK_M": lambda args: 32, 

517 "BLOCK_N": lambda args: 32, 

518 "num_warps": lambda args: 4, 

519 "num_stages": lambda args: 1, 

520 }, 

521 "mha_block_16": { 

522 "BLOCK_M": lambda args: 32, 

523 "BLOCK_N": lambda args: 32, 

524 "num_warps": lambda args: 4, 

525 "num_stages": lambda args: 1, 

526 }, 

527 "elementwise_generic": { 

528 "BLOCK_SIZE": simple_elementwise_blocksize_heur, 

529 "num_warps": lambda args: 8, 

530 }, 

531}