Coverage for src/flag_gems/runtime/backend/_tsingmicro/heuristics_config_utils.py: 0%

230 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import torch 

2import triton 

3 

4_MIN_TILE_N = 64 

5_MAX_TILE_N_PER_ROW = 4096 

6_MAX_ONE_TILE_N = 2048 

7 

8 

9def simple_elementwise_blocksize_heur(args): 

10 return 1024 

11 

12 

13def argmax_heur_tile_k(args): 

14 MAX_TILE_K = 512 

15 NUM_SMS = torch.txda.get_device_properties( 

16 torch.txda.current_device() 

17 ).multi_processor_count 

18 

19 K = args["K"] 

20 M = args["M"] 

21 dtype = "fp32" if args["inp"].dtype == torch.float32 else "fp16" 

22 

23 if M == 64 and K == 512: 

24 return 64 if dtype == "fp32" else 128 

25 

26 if K <= 128: 

27 return 1 << (K.bit_length() - 1) if K > 0 else 1 

28 

29 tile_k = 64 

30 upper_bound = min(K, MAX_TILE_K) 

31 

32 while tile_k <= upper_bound: 

33 num_blocks = M * triton.cdiv(K, tile_k) 

34 num_waves = num_blocks / NUM_SMS 

35 

36 if num_waves > 1 and (tile_k * 2 <= upper_bound): 

37 tile_k *= 2 

38 else: 

39 break 

40 

41 return tile_k 

42 

43 

44def argmax_heur_tile_n_non_inner(args): 

45 n = args["N"] 

46 tile_k = args["TILE_K"] 

47 

48 if n <= 128: 

49 return n 

50 

51 target_tile = min(8192, n) 

52 tile_n = triton.next_power_of_2(target_tile) 

53 tile_n = max(64, min(tile_n, 4096)) 

54 

55 if tile_n * tile_k > 32768: 

56 tile_n = max(64, 32768 // tile_k) 

57 

58 return tile_n 

59 

60 

61def argmax_heur_one_tile_per_cta(args): 

62 return args["TILE_N"] >= args["N"] 

63 

64 

65def argmax_heur_num_warps_non_inner(args): 

66 # tile_n = args["TILE_N"] 

67 # dtype = "fp32" if args["inp"].dtype == torch.float32 else "fp16" 

68 

69 # if tile_n <= 32: 

70 # num_warps = 2 

71 # elif tile_n <= 64: 

72 # num_warps = 4 

73 # elif tile_n <= 128: 

74 # num_warps = 4 

75 # else: 

76 # num_warps = 8 

77 

78 # if dtype == "fp32": 

79 # num_warps = min(num_warps, 4) 

80 

81 # return num_warps 

82 

83 return 1 

84 

85 

86def argmax_heur_tile_n_inner(args): 

87 if args["N"] <= (32 * 1024): 

88 return triton.next_power_of_2(args["N"]) 

89 else: 

90 return 4096 

91 

92 

93def argmax_heur_num_warps_inner(args): 

94 # tile_size = args["TILE_N"] 

95 # if tile_size < 2048: 

96 # return 4 

97 # elif tile_size < 4096: 

98 # return 8 

99 # else: 

100 # return 16 

101 

102 return 1 

103 

104 

105def argmin_heur_block_m(args): 

106 return 16 if args["M"] < 4096 else 32 

107 

108 

109def argmin_heur_block_n(args): 

110 return min(16384, triton.next_power_of_2(args["N"])) 

111 

112 

113def bmm_heur_divisible_m(args): 

114 return args["M"] % args["TILE_M"] == 0 

115 

116 

117def bmm_heur_divisible_n(args): 

118 return args["N"] % args["TILE_N"] == 0 

119 

120 

121def bmm_heur_divisible_k(args): 

122 return args["K"] % args["TILE_K"] == 0 

123 

124 

125def baddbmm_heur_divisible_m(args): 

126 return args["M"] % args["TILE_M"] == 0 

127 

128 

129def baddbmm_heur_divisible_n(args): 

130 return args["N"] % args["TILE_N"] == 0 

131 

132 

133def baddbmm_heur_divisible_k(args): 

134 return args["K"] % args["TILE_K"] == 0 

135 

136 

137def dropout_heur_block(args): 

138 if args["N"] <= 512: 

139 return 512 

140 else: 

141 return 1024 

142 

143 

144def dropout_heur_num_warps(args): 

145 if args["N"] <= 512: 

146 return 4 

147 elif args["N"] <= 1024: 

148 return 8 

149 else: 

150 return 16 

151 

152 

153def exponential_heur_block(args): 

154 if args["N"] <= 512: 

155 return 512 

156 else: 

157 return 1024 

158 

159 

160def exponential_heur_num_warps(args): 

161 if args["N"] <= 512: 

162 return 4 

163 elif args["N"] <= 1024: 

164 return 8 

165 else: 

166 return 16 

167 

168 

169def gather_heur_block_m(args): 

170 return min(4, triton.next_power_of_2(triton.cdiv(args["N"], 2048))) 

171 

172 

173def gather_heur_block_n(args): 

174 return min(2048, triton.next_power_of_2(args["N"])) 

175 

176 

177def index_select_heur_block_m(args): 

178 return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"]))) 

179 

180 

181def index_select_heur_block_n(args): 

182 m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512) 

183 return max(m, 16) 

184 

185 

186def mm_heur_even_k(args): 

187 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0 

188 

189 

190def rand_heur_block(args): 

191 if args["N"] <= 512: 

192 return 512 

193 else: 

194 return 1024 

195 

196 

197def rand_heur_num_warps(args): 

198 if args["N"] <= 512: 

199 return 4 

200 elif args["N"] <= 1024: 

201 return 8 

202 else: 

203 return 16 

204 

205 

206def randn_heur_block(args): 

207 bs = triton.next_power_of_2(args["N"] // (16 * 4)) 

208 if bs > 32768: 

209 bs = 32768 

210 elif bs < 512: 

211 bs = 512 

212 return bs 

213 

214 

215def randn_heur_num_warps(args): 

216 return 1 

217 

218 

219def softmax_heur_tile_k(args): 

220 MAX_TILE_K = 8192 

221 NUM_SMS = torch.txda.get_device_properties( 

222 torch.txda.current_device() 

223 ).multi_processor_count 

224 tile_k = 1 

225 upper_bound = min(args["K"], MAX_TILE_K) 

226 while tile_k <= upper_bound: 

227 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k) 

228 num_waves = num_blocks / NUM_SMS 

229 if (num_waves > 1) and (tile_k * 2 <= upper_bound): 

230 tile_k *= 2 

231 else: 

232 break 

233 return tile_k 

234 

235 

236def softmax_heur_tile_n_non_inner(args): 

237 return triton.cdiv(8192, args["TILE_K"]) 

238 

239 

240def softmax_heur_one_tile_per_cta(args): 

241 return args["TILE_N"] >= args["N"] 

242 

243 

244def softmax_heur_num_warps_non_inner(args): 

245 tile_size = args["TILE_N"] * args["TILE_K"] 

246 if tile_size < 2048: 

247 return 4 

248 elif tile_size < 4096: 

249 return 8 

250 else: 

251 return 16 

252 

253 

254def softmax_heur_tile_n_inner(args): 

255 if args["N"] <= (32 * 1024): 

256 return triton.next_power_of_2(args["N"]) 

257 else: 

258 return 4096 

259 

260 

261def softmax_heur_num_warps_inner(args): 

262 tile_size = args["TILE_N"] 

263 if tile_size < 2048: 

264 return 4 

265 elif tile_size < 4096: 

266 return 8 

267 else: 

268 return 16 

269 

270 

271def softmax_heur_tile_n_bwd_non_inner(args): 

272 return max(1, 1024 // args["TILE_K"]) 

273 

274 

275def softmax_heur_tile_m(args): 

276 return max(1, 1024 // args["TILE_N"]) 

277 

278 

279def uniform_heur_block(args): 

280 if args["N"] <= 512: 

281 return 512 

282 else: 

283 return 1024 

284 

285 

286def uniform_heur_num_warps(args): 

287 if args["N"] <= 512: 

288 return 4 

289 elif args["N"] <= 1024: 

290 return 8 

291 else: 

292 return 16 

293 

294 

295def var_mean_heur_block_n(args): 

296 return triton.next_power_of_2(args["BLOCK_NUM"]) 

297 

298 

299def upsample_nearest1d_SAME_L(args): 

300 return args["OL"] == args["IL"] 

301 

302 

303def upsample_nearest1d_USE_INT32_IDX(args): 

304 return args["N"] * args["C"] * args["OL"] <= (2**31 - 1) # INT32 MAX 

305 

306 

307def upsample_nearest2d_SAME_H(args): 

308 return args["OH"] == args["IH"] 

309 

310 

311def upsample_nearest2d_SAME_W(args): 

312 return args["OW"] == args["IW"] 

313 

314 

315def upsample_nearest2d_USE_INT32_IDX(args): 

316 return args["N"] * args["C"] * args["OH"] * args["OW"] <= (2**31 - 1) # INT32 MAX 

317 

318 

319def batch_norm_heur_block_m(args): 

320 return min(2048, triton.next_power_of_2(args["batch_dim"])) 

321 

322 

323def batch_norm_heur_block_n(args): 

324 # A maximum of 16384 elements are loaded at once. 

325 BLOCK_M = batch_norm_heur_block_m(args) 

326 BLOCK_N = triton.next_power_of_2(args["spatial_dim"]) 

327 return min(BLOCK_N, max(1, 2**14 // BLOCK_M)) 

328 

329 

330def vdot_heur_block_size(args): 

331 n = args["n_elements"] 

332 if n < 1024: 

333 return 32 

334 elif n < 8192: 

335 return 256 

336 else: 

337 return 1024 

338 

339 

340def mean_heur_tile_k(args): 

341 MAX_TILE_K = 512 

342 NUM_SMS = torch.txda.get_device_properties( 

343 torch.txda.current_device() 

344 ).multi_processor_count 

345 tile_k = 1 

346 upper_bound = min(args["K"], MAX_TILE_K) 

347 max_tile_k_allowed_by_tile_n = max(1, _MAX_TILE_N_PER_ROW // _MIN_TILE_N) 

348 upper_bound = min(upper_bound, max_tile_k_allowed_by_tile_n) 

349 while tile_k <= upper_bound: 

350 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k) 

351 num_waves = num_blocks / NUM_SMS 

352 if (num_waves > 1) and (tile_k * 2 <= upper_bound): 

353 tile_k *= 2 

354 else: 

355 break 

356 return tile_k 

357 

358 

359def mean_heur_tile_n_non_inner(args): 

360 tile_k = args.get("TILE_K", 1) 

361 limit_by_k = max(1, _MAX_TILE_N_PER_ROW // tile_k) 

362 N = args.get("N", 1) 

363 desired = min(max(N, _MIN_TILE_N), limit_by_k) 

364 desired = min(desired, _MAX_ONE_TILE_N, limit_by_k) 

365 tile_n = triton.next_power_of_2(desired) 

366 if tile_n > limit_by_k: 

367 tile_n = limit_by_k 

368 tile_n = max(tile_n, _MIN_TILE_N) 

369 return tile_n 

370 

371 

372def mean_heur_one_tile_per_cta(args): 

373 return args["TILE_N"] >= args["N"] 

374 

375 

376def mha_varlen_heur_block_m(params): 

377 if params.seqlen_q == 1: 

378 return 1 

379 elif params.seqlen_q >= 1024: 

380 return 512 

381 elif params.seqlen_q >= 512: 

382 return 256 

383 elif params.seqlen_q >= 256: 

384 return 128 

385 elif params.seqlen_q >= 128: 

386 return 64 

387 elif params.seqlen_q >= 64: 

388 return 32 

389 else: 

390 return 16 

391 

392 

393def mha_varlen_heur_block_n(params): 

394 return 16 

395 

396 

397HEURISTICS_CONFIGS = { 

398 "argmax_non_inner": { 

399 "TILE_K": argmax_heur_tile_k, 

400 "TILE_N": argmax_heur_tile_n_non_inner, 

401 "ONE_TILE_PER_CTA": argmax_heur_one_tile_per_cta, 

402 "num_warps": argmax_heur_num_warps_non_inner, 

403 }, 

404 "argmax_inner": { 

405 "TILE_N": argmax_heur_tile_n_inner, 

406 "ONE_TILE_PER_CTA": argmax_heur_one_tile_per_cta, 

407 "num_warps": argmax_heur_num_warps_inner, 

408 }, 

409 "argmin": { 

410 "BLOCK_M": argmin_heur_block_m, 

411 "BLOCK_N": argmin_heur_block_n, 

412 }, 

413 "bmm": { 

414 "DIVISIBLE_M": bmm_heur_divisible_m, 

415 "DIVISIBLE_N": bmm_heur_divisible_n, 

416 "DIVISIBLE_K": bmm_heur_divisible_k, 

417 }, 

418 "baddbmm": { 

419 "DIVISIBLE_M": baddbmm_heur_divisible_m, 

420 "DIVISIBLE_N": baddbmm_heur_divisible_n, 

421 "DIVISIBLE_K": baddbmm_heur_divisible_k, 

422 }, 

423 "dropout": { 

424 "BLOCK": dropout_heur_block, 

425 "num_warps": dropout_heur_num_warps, 

426 }, 

427 "exponential_": { 

428 "BLOCK": exponential_heur_block, 

429 "num_warps": exponential_heur_num_warps, 

430 }, 

431 "gather": { 

432 "BLOCK_M": gather_heur_block_m, 

433 "BLOCK_N": gather_heur_block_n, 

434 }, 

435 "index_select": { 

436 "BLOCK_M": index_select_heur_block_m, 

437 "BLOCK_N": index_select_heur_block_n, 

438 }, 

439 "mm": { 

440 "EVEN_K": mm_heur_even_k, 

441 }, 

442 "rand": { 

443 "BLOCK": rand_heur_block, 

444 "num_warps": rand_heur_num_warps, 

445 }, 

446 "randn": { 

447 "BLOCK": randn_heur_block, 

448 "num_warps": randn_heur_num_warps, 

449 }, 

450 "softmax_non_inner": { 

451 "TILE_K": softmax_heur_tile_k, 

452 "TILE_N": softmax_heur_tile_n_non_inner, 

453 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

454 "num_warps": softmax_heur_num_warps_non_inner, 

455 }, 

456 "mean_non_inner": { 

457 "TILE_K": mean_heur_tile_k, 

458 "TILE_N": mean_heur_tile_n_non_inner, 

459 "ONE_TILE_PER_CTA": mean_heur_one_tile_per_cta, 

460 "num_warps": softmax_heur_num_warps_non_inner, 

461 }, 

462 "softmax_inner": { 

463 "TILE_N": softmax_heur_tile_n_inner, 

464 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

465 "num_warps": softmax_heur_num_warps_inner, 

466 }, 

467 "softmax_backward_non_inner": { 

468 "TILE_N": softmax_heur_tile_n_bwd_non_inner, 

469 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

470 }, 

471 "softmax_backward_inner": { 

472 "TILE_M": softmax_heur_tile_m, 

473 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

474 }, 

475 "uniform": { 

476 "BLOCK": uniform_heur_block, 

477 "num_warps": uniform_heur_num_warps, 

478 }, 

479 "upsample_nearest1d": { 

480 "SAME_L": upsample_nearest1d_SAME_L, 

481 "USE_INT32_IDX": upsample_nearest1d_USE_INT32_IDX, 

482 }, 

483 "upsample_nearest2d": { 

484 "SAME_H": upsample_nearest2d_SAME_H, 

485 "SAME_W": upsample_nearest2d_SAME_W, 

486 "USE_INT32_IDX": upsample_nearest2d_USE_INT32_IDX, 

487 }, 

488 "var_mean": { 

489 "BLOCK_N": var_mean_heur_block_n, 

490 }, 

491 "batch_norm": { 

492 "BLOCK_M": batch_norm_heur_block_m, 

493 "BLOCK_N": batch_norm_heur_block_n, 

494 }, 

495 "vdot": { 

496 "BLOCK_SIZE": vdot_heur_block_size, 

497 }, 

498 "mha_varlen_fwd": { 

499 "BLOCK_M": mha_varlen_heur_block_m, 

500 "BLOCK_N": mha_varlen_heur_block_n, 

501 "num_warps": lambda args: 1, 

502 "num_stages": lambda args: 1, 

503 }, 

504 "elementwise_generic": { 

505 "BLOCK_SIZE": simple_elementwise_blocksize_heur, 

506 "num_warps": lambda args: 8, 

507 }, 

508}