Coverage for src/flag_gems/runtime/backend/_sunrise/heuristics_config_utils.py: 0%
260 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import torch # noqa: F401
2import triton
4_MIN_TILE_N = 64
5_MAX_TILE_N_PER_ROW = 4096
6_MAX_ONE_TILE_N = 2048
9def simple_elementwise_blocksize_heur(args):
10 return 1024
13def argmax_heur_tile_k(args):
14 MAX_TILE_K = 512
15 NUM_SMS = torch.ptpu.get_device_properties(
16 torch.ptpu.current_device()
17 ).multi_processor_count
19 K = args["K"]
20 M = args["M"]
21 dtype = "fp32" if args["inp"].dtype == torch.float32 else "fp16"
23 if M == 64 and K == 512:
24 return 64 if dtype == "fp32" else 128
26 if K <= 128:
27 return 1 << (K.bit_length() - 1) if K > 0 else 1
29 tile_k = 64
30 upper_bound = min(K, MAX_TILE_K)
32 while tile_k <= upper_bound:
33 num_blocks = M * triton.cdiv(K, tile_k)
34 num_waves = num_blocks / NUM_SMS
36 if num_waves > 1 and (tile_k * 2 <= upper_bound):
37 tile_k *= 2
38 else:
39 break
41 return tile_k
44def argmax_heur_tile_n_non_inner(args):
45 n = args["N"]
46 tile_k = args["TILE_K"]
48 if n <= 128:
49 return n
51 target_tile = min(8192, n)
52 tile_n = triton.next_power_of_2(target_tile)
53 tile_n = max(64, min(tile_n, 4096))
55 if tile_n * tile_k > 32768:
56 tile_n = max(64, 32768 // tile_k)
58 return tile_n
61def argmax_heur_one_tile_per_cta(args):
62 return args["TILE_N"] >= args["N"]
65def argmax_heur_num_warps_non_inner(args):
66 tile_n = args["TILE_N"]
67 dtype = "fp32" if args["inp"].dtype == torch.float32 else "fp16"
69 if tile_n <= 32:
70 num_warps = 2
71 elif tile_n <= 64:
72 num_warps = 4
73 elif tile_n <= 128:
74 num_warps = 4
75 else:
76 num_warps = 8
78 if dtype == "fp32":
79 num_warps = min(num_warps, 4)
81 return num_warps
84def argmax_heur_tile_n_inner(args):
85 if args["N"] <= (32 * 1024):
86 return triton.next_power_of_2(args["N"])
87 else:
88 return 4096
91def argmax_heur_num_warps_inner(args):
92 tile_size = args["TILE_N"]
93 if tile_size < 2048:
94 return 4
95 elif tile_size < 4096:
96 return 8
97 else:
98 return 16
101def argmax_heur_block_m(args):
102 return 1 if args["M"] < 4096 else 8
105def argmax_heur_block_n(args):
106 return min(4096, triton.next_power_of_2(args["N"]))
109def argmin_heur_block_m(args):
110 return 4 if args["M"] < 4096 else 8
113def argmin_heur_block_n(args):
114 return min(4096, triton.next_power_of_2(args["N"]))
117def bmm_heur_divisible_m(args):
118 return args["M"] % args["TILE_M"] == 0
121def bmm_heur_divisible_n(args):
122 return args["N"] % args["TILE_N"] == 0
125def bmm_heur_divisible_k(args):
126 return args["K"] % args["TILE_K"] == 0
129def dropout_heur_block(args):
130 if args["N"] <= 512:
131 return 256
132 else:
133 return 512
136def dropout_heur_num_warps(args):
137 if args["N"] <= 512:
138 return 2
139 elif args["N"] <= 2048:
140 return 4
141 else:
142 return 8
145def exponential_heur_block(args):
146 if args["N"] <= 512:
147 return 512
148 else:
149 return 1024
152def exponential_heur_num_warps(args):
153 if args["N"] <= 512:
154 return 4
155 elif args["N"] <= 1024:
156 return 8
157 else:
158 return 16
161def gather_heur_block_m(args):
162 return min(1, triton.next_power_of_2(triton.cdiv(args["N"], 2048)))
165def gather_heur_block_n(args):
166 return min(2048, triton.next_power_of_2(args["N"]))
169def index_select_heur_block_m(args):
170 return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"])))
173def index_select_heur_block_n(args):
174 m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512)
175 return max(m, 16)
178def mm_heur_even_k(args):
179 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0
182def rand_heur_block(args):
183 if args["N"] <= 512:
184 return 512
185 else:
186 return 1024
189def rand_heur_num_warps(args):
190 if args["N"] <= 512:
191 return 4
192 elif args["N"] <= 1024:
193 return 8
194 else:
195 return 16
198def randn_heur_block(args):
199 if args["N"] <= 512:
200 return 512
201 else:
202 return 1024
205def randn_heur_num_warps(args):
206 if args["N"] <= 512:
207 return 4
208 elif args["N"] <= 1024:
209 return 8
210 else:
211 return 16
214def softmax_heur_tile_k(args):
215 MAX_TILE_K = 512
216 # NUM_SMS = torch.cuda.get_device_properties(
217 # torch.cuda.current_device()
218 # ).multi_processor_count
219 NUM_SMS = 32 # Not support now.
221 tile_k = 1
222 upper_bound = min(args["K"], MAX_TILE_K)
223 while tile_k <= upper_bound:
224 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k)
225 num_waves = num_blocks / NUM_SMS
226 if (num_waves > 1) and (tile_k * 2 <= upper_bound):
227 tile_k *= 2
228 else:
229 break
230 return tile_k
233def softmax_heur_tile_n_non_inner(args):
234 return triton.cdiv(8192, args["TILE_K"])
237def softmax_heur_one_tile_per_cta(args):
238 return args["TILE_N"] >= args["N"]
241def softmax_heur_num_warps_non_inner(args):
242 tile_size = args["TILE_N"] * args["TILE_K"]
243 if tile_size < 2048:
244 return 4
245 elif tile_size < 4096:
246 return 8
247 else:
248 return 16
251def softmax_heur_tile_n_inner(args):
252 if args["N"] <= 32:
253 return triton.next_power_of_2(args["N"])
254 if args["N"] <= 1024:
255 return 256
256 else:
257 return 512
260def softmax_heur_num_warps_inner(args):
261 tile_size = args["TILE_N"]
262 if tile_size < 64:
263 return 2
264 if tile_size < 2048:
265 return 4
266 elif tile_size < 4096:
267 return 8
268 else:
269 return 16
272def softmax_heur_tile_n_bwd_non_inner(args):
273 return max(1, 1024 // args["TILE_K"])
276def softmax_heur_tile_m(args):
277 return max(1, 1024 // args["TILE_N"])
280def uniform_heur_block(args):
281 if args["N"] <= 512:
282 return 512
283 else:
284 return 1024
287def uniform_heur_num_warps(args):
288 if args["N"] <= 512:
289 return 4
290 elif args["N"] <= 1024:
291 return 8
292 else:
293 return 16
296def var_mean_heur_block_n(args):
297 return triton.next_power_of_2(args["BLOCK_NUM"])
300def upsample_nearest2d_SAME_H(args):
301 return args["OH"] == args["IH"]
304def upsample_nearest2d_SAME_W(args):
305 return args["OW"] == args["IW"]
308def upsample_nearest2d_USE_INT32_IDX(args):
309 return args["N"] * args["C"] * args["OH"] * args["OW"] <= (2**31 - 1) # INT32 MAX
312def batch_norm_heur_block_m(args):
313 return min(256, triton.next_power_of_2(args["batch_dim"]))
316def batch_norm_heur_block_n(args):
317 # A maximum of 16384 elements are loaded at once.
318 BLOCK_M = batch_norm_heur_block_m(args)
319 BLOCK_N = triton.next_power_of_2(args["spatial_dim"])
320 return min(BLOCK_N, max(1, 2**10 // BLOCK_M))
323def vdot_heur_block_size(args):
324 n = args["n_elements"]
325 if n < 1024:
326 return 32
327 elif n < 8192:
328 return 256
329 else:
330 return 1024
333def mean_heur_tile_k(args):
334 MAX_TILE_K = 512
335 NUM_SMS = torch.ptpu.get_device_properties(
336 torch.ptpu.current_device()
337 ).multi_processor_count
338 tile_k = 1
339 upper_bound = min(args["K"], MAX_TILE_K)
340 max_tile_k_allowed_by_tile_n = max(1, _MAX_TILE_N_PER_ROW // _MIN_TILE_N)
341 upper_bound = min(upper_bound, max_tile_k_allowed_by_tile_n)
342 while tile_k <= upper_bound:
343 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k)
344 num_waves = num_blocks / NUM_SMS
345 if (num_waves > 1) and (tile_k * 2 <= upper_bound):
346 tile_k *= 2
347 else:
348 break
349 return tile_k
352def sum_heur_num_warps_inner(args):
353 tile_size = args["TILE_N"]
354 if tile_size < 64:
355 return 2
356 if tile_size < 2048:
357 return 4
358 elif tile_size < 4096:
359 return 8
360 else:
361 return 16
364def sum_heur_tile_n_inner(args):
365 if args["N"] <= 32:
366 return triton.next_power_of_2(args["N"])
367 if args["N"] <= 1024:
368 return 128
369 else:
370 return 256
373def sum_heur_one_tile_per_cta(args):
374 return args["TILE_N"] >= args["N"]
377def sum_heur_tile_k(args):
378 MAX_TILE_K = 128
379 # NUM_SMS = torch.cuda.get_device_properties(
380 # torch.cuda.current_device()
381 # ).multi_processor_count
382 NUM_SMS = 32 # Not support now.
384 tile_k = 1
385 upper_bound = min(args["K"], MAX_TILE_K)
386 while tile_k <= upper_bound:
387 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k)
388 num_waves = num_blocks / NUM_SMS
389 if (num_waves > 1) and (tile_k * 2 <= upper_bound):
390 tile_k *= 2
391 else:
392 break
393 return tile_k
396def mean_heur_tile_n_non_inner(args):
397 tile_k = args.get("TILE_K", 1)
398 limit_by_k = max(1, _MAX_TILE_N_PER_ROW // tile_k)
399 N = args.get("N", 1)
400 desired = min(max(N, _MIN_TILE_N), limit_by_k)
401 desired = min(desired, _MAX_ONE_TILE_N, limit_by_k)
402 tile_n = triton.next_power_of_2(desired)
403 if tile_n > limit_by_k:
404 tile_n = limit_by_k
405 tile_n = max(tile_n, _MIN_TILE_N)
406 return tile_n
409def mean_heur_one_tile_per_cta(args):
410 return args["TILE_N"] >= args["N"]
413def sum_heur_tile_n_non_inner(args):
414 return triton.cdiv(256, args["TILE_K"])
417HEURISTICS_CONFIGS = {
418 "argmax_non_inner": {
419 "TILE_K": argmax_heur_tile_k,
420 "TILE_N": argmax_heur_tile_n_non_inner,
421 "ONE_TILE_PER_CTA": argmax_heur_one_tile_per_cta,
422 "num_warps": argmax_heur_num_warps_non_inner,
423 },
424 "argmax_inner": {
425 "TILE_N": argmax_heur_tile_n_inner,
426 "ONE_TILE_PER_CTA": argmax_heur_one_tile_per_cta,
427 "num_warps": argmax_heur_num_warps_inner,
428 },
429 "argmin": {
430 "BLOCK_M": argmin_heur_block_m,
431 "BLOCK_N": argmin_heur_block_n,
432 },
433 "bmm": {
434 "DIVISIBLE_M": bmm_heur_divisible_m,
435 "DIVISIBLE_N": bmm_heur_divisible_n,
436 "DIVISIBLE_K": bmm_heur_divisible_k,
437 },
438 "dropout": {
439 "BLOCK": dropout_heur_block,
440 "num_warps": dropout_heur_num_warps,
441 },
442 "exponential_": {
443 "BLOCK": exponential_heur_block,
444 "num_warps": exponential_heur_num_warps,
445 },
446 "gather": {
447 "BLOCK_M": gather_heur_block_m,
448 "BLOCK_N": gather_heur_block_n,
449 },
450 "index_select": {
451 "BLOCK_M": index_select_heur_block_m,
452 "BLOCK_N": index_select_heur_block_n,
453 },
454 "mm": {
455 "EVEN_K": mm_heur_even_k,
456 },
457 "rand": {
458 "BLOCK": rand_heur_block,
459 "num_warps": rand_heur_num_warps,
460 },
461 "randn": {
462 "BLOCK": randn_heur_block,
463 "num_warps": randn_heur_num_warps,
464 },
465 "softmax_non_inner": {
466 "TILE_K": softmax_heur_tile_k,
467 "TILE_N": softmax_heur_tile_n_non_inner,
468 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
469 "num_warps": softmax_heur_num_warps_non_inner,
470 },
471 "mean_non_inner": {
472 "TILE_K": mean_heur_tile_k,
473 "TILE_N": mean_heur_tile_n_non_inner,
474 "ONE_TILE_PER_CTA": mean_heur_one_tile_per_cta,
475 "num_warps": softmax_heur_num_warps_non_inner,
476 },
477 "softmax_inner": {
478 "TILE_N": softmax_heur_tile_n_inner,
479 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
480 "num_warps": softmax_heur_num_warps_inner,
481 },
482 "softmax_backward_non_inner": {
483 "TILE_N": softmax_heur_tile_n_bwd_non_inner,
484 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
485 },
486 "softmax_backward_inner": {
487 "TILE_M": softmax_heur_tile_m,
488 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
489 },
490 "uniform": {
491 "BLOCK": uniform_heur_block,
492 "num_warps": uniform_heur_num_warps,
493 },
494 "upsample_nearest2d": {
495 "SAME_H": upsample_nearest2d_SAME_H,
496 "SAME_W": upsample_nearest2d_SAME_W,
497 "USE_INT32_IDX": upsample_nearest2d_USE_INT32_IDX,
498 },
499 "var_mean": {
500 "BLOCK_N": var_mean_heur_block_n,
501 },
502 "batch_norm": {
503 "BLOCK_M": batch_norm_heur_block_m,
504 "BLOCK_N": batch_norm_heur_block_n,
505 },
506 "vdot": {
507 "BLOCK_SIZE": vdot_heur_block_size,
508 },
509 "mha_varlen_prefill": {
510 "BLOCK_M": lambda args: 128,
511 "BLOCK_N": lambda args: 32,
512 "num_warps": lambda args: 4,
513 "num_stages": lambda args: 3,
514 },
515 "mha_varlen_decode": {
516 "BLOCK_M": lambda args: 16,
517 "BLOCK_N": lambda args: 64,
518 "num_warps": lambda args: 4,
519 "num_stages": lambda args: 3,
520 },
521 "mha_block_128": {
522 "BLOCK_M": lambda args: 128,
523 "BLOCK_N": lambda args: 8,
524 "num_warps": lambda args: 16,
525 "num_stages": lambda args: 1,
526 },
527 "mha_block_64": {
528 "BLOCK_M": lambda args: 64,
529 "BLOCK_N": lambda args: 64,
530 "num_warps": lambda args: 4,
531 "num_stages": lambda args: 3,
532 },
533 "mha_block_32": {
534 "BLOCK_M": lambda args: 32,
535 "BLOCK_N": lambda args: 64,
536 "num_warps": lambda args: 4,
537 "num_stages": lambda args: 3,
538 },
539 "mha_block_16": {
540 "BLOCK_M": lambda args: 16,
541 "BLOCK_N": lambda args: 16,
542 "num_warps": lambda args: 8,
543 "num_stages": lambda args: 1,
544 },
545 "elementwise_generic": {
546 "BLOCK_SIZE": simple_elementwise_blocksize_heur,
547 "num_warps": lambda args: 8,
548 },
549 "sum_inner": {
550 "TILE_N": sum_heur_tile_n_inner,
551 "ONE_TILE_PER_CTA": sum_heur_one_tile_per_cta,
552 "num_warps": sum_heur_num_warps_inner,
553 },
554 "sum_non_inner": {
555 "TILE_K": sum_heur_tile_k,
556 "TILE_N": sum_heur_tile_n_non_inner,
557 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
558 "num_warps": softmax_heur_num_warps_non_inner,
559 },
560}