Coverage for src/flag_gems/runtime/backend/_tsingmicro/heuristics_config_utils.py: 0%
230 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import torch
2import triton
4_MIN_TILE_N = 64
5_MAX_TILE_N_PER_ROW = 4096
6_MAX_ONE_TILE_N = 2048
9def simple_elementwise_blocksize_heur(args):
10 return 1024
13def argmax_heur_tile_k(args):
14 MAX_TILE_K = 512
15 NUM_SMS = torch.txda.get_device_properties(
16 torch.txda.current_device()
17 ).multi_processor_count
19 K = args["K"]
20 M = args["M"]
21 dtype = "fp32" if args["inp"].dtype == torch.float32 else "fp16"
23 if M == 64 and K == 512:
24 return 64 if dtype == "fp32" else 128
26 if K <= 128:
27 return 1 << (K.bit_length() - 1) if K > 0 else 1
29 tile_k = 64
30 upper_bound = min(K, MAX_TILE_K)
32 while tile_k <= upper_bound:
33 num_blocks = M * triton.cdiv(K, tile_k)
34 num_waves = num_blocks / NUM_SMS
36 if num_waves > 1 and (tile_k * 2 <= upper_bound):
37 tile_k *= 2
38 else:
39 break
41 return tile_k
44def argmax_heur_tile_n_non_inner(args):
45 n = args["N"]
46 tile_k = args["TILE_K"]
48 if n <= 128:
49 return n
51 target_tile = min(8192, n)
52 tile_n = triton.next_power_of_2(target_tile)
53 tile_n = max(64, min(tile_n, 4096))
55 if tile_n * tile_k > 32768:
56 tile_n = max(64, 32768 // tile_k)
58 return tile_n
61def argmax_heur_one_tile_per_cta(args):
62 return args["TILE_N"] >= args["N"]
65def argmax_heur_num_warps_non_inner(args):
66 # tile_n = args["TILE_N"]
67 # dtype = "fp32" if args["inp"].dtype == torch.float32 else "fp16"
69 # if tile_n <= 32:
70 # num_warps = 2
71 # elif tile_n <= 64:
72 # num_warps = 4
73 # elif tile_n <= 128:
74 # num_warps = 4
75 # else:
76 # num_warps = 8
78 # if dtype == "fp32":
79 # num_warps = min(num_warps, 4)
81 # return num_warps
83 return 1
86def argmax_heur_tile_n_inner(args):
87 if args["N"] <= (32 * 1024):
88 return triton.next_power_of_2(args["N"])
89 else:
90 return 4096
93def argmax_heur_num_warps_inner(args):
94 # tile_size = args["TILE_N"]
95 # if tile_size < 2048:
96 # return 4
97 # elif tile_size < 4096:
98 # return 8
99 # else:
100 # return 16
102 return 1
105def argmin_heur_block_m(args):
106 return 16 if args["M"] < 4096 else 32
109def argmin_heur_block_n(args):
110 return min(16384, triton.next_power_of_2(args["N"]))
113def bmm_heur_divisible_m(args):
114 return args["M"] % args["TILE_M"] == 0
117def bmm_heur_divisible_n(args):
118 return args["N"] % args["TILE_N"] == 0
121def bmm_heur_divisible_k(args):
122 return args["K"] % args["TILE_K"] == 0
125def baddbmm_heur_divisible_m(args):
126 return args["M"] % args["TILE_M"] == 0
129def baddbmm_heur_divisible_n(args):
130 return args["N"] % args["TILE_N"] == 0
133def baddbmm_heur_divisible_k(args):
134 return args["K"] % args["TILE_K"] == 0
137def dropout_heur_block(args):
138 if args["N"] <= 512:
139 return 512
140 else:
141 return 1024
144def dropout_heur_num_warps(args):
145 if args["N"] <= 512:
146 return 4
147 elif args["N"] <= 1024:
148 return 8
149 else:
150 return 16
153def exponential_heur_block(args):
154 if args["N"] <= 512:
155 return 512
156 else:
157 return 1024
160def exponential_heur_num_warps(args):
161 if args["N"] <= 512:
162 return 4
163 elif args["N"] <= 1024:
164 return 8
165 else:
166 return 16
169def gather_heur_block_m(args):
170 return min(4, triton.next_power_of_2(triton.cdiv(args["N"], 2048)))
173def gather_heur_block_n(args):
174 return min(2048, triton.next_power_of_2(args["N"]))
177def index_select_heur_block_m(args):
178 return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"])))
181def index_select_heur_block_n(args):
182 m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512)
183 return max(m, 16)
186def mm_heur_even_k(args):
187 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0
190def rand_heur_block(args):
191 if args["N"] <= 512:
192 return 512
193 else:
194 return 1024
197def rand_heur_num_warps(args):
198 if args["N"] <= 512:
199 return 4
200 elif args["N"] <= 1024:
201 return 8
202 else:
203 return 16
206def randn_heur_block(args):
207 bs = triton.next_power_of_2(args["N"] // (16 * 4))
208 if bs > 32768:
209 bs = 32768
210 elif bs < 512:
211 bs = 512
212 return bs
215def randn_heur_num_warps(args):
216 return 1
219def softmax_heur_tile_k(args):
220 MAX_TILE_K = 8192
221 NUM_SMS = torch.txda.get_device_properties(
222 torch.txda.current_device()
223 ).multi_processor_count
224 tile_k = 1
225 upper_bound = min(args["K"], MAX_TILE_K)
226 while tile_k <= upper_bound:
227 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k)
228 num_waves = num_blocks / NUM_SMS
229 if (num_waves > 1) and (tile_k * 2 <= upper_bound):
230 tile_k *= 2
231 else:
232 break
233 return tile_k
236def softmax_heur_tile_n_non_inner(args):
237 return triton.cdiv(8192, args["TILE_K"])
240def softmax_heur_one_tile_per_cta(args):
241 return args["TILE_N"] >= args["N"]
244def softmax_heur_num_warps_non_inner(args):
245 tile_size = args["TILE_N"] * args["TILE_K"]
246 if tile_size < 2048:
247 return 4
248 elif tile_size < 4096:
249 return 8
250 else:
251 return 16
254def softmax_heur_tile_n_inner(args):
255 if args["N"] <= (32 * 1024):
256 return triton.next_power_of_2(args["N"])
257 else:
258 return 4096
261def softmax_heur_num_warps_inner(args):
262 tile_size = args["TILE_N"]
263 if tile_size < 2048:
264 return 4
265 elif tile_size < 4096:
266 return 8
267 else:
268 return 16
271def softmax_heur_tile_n_bwd_non_inner(args):
272 return max(1, 1024 // args["TILE_K"])
275def softmax_heur_tile_m(args):
276 return max(1, 1024 // args["TILE_N"])
279def uniform_heur_block(args):
280 if args["N"] <= 512:
281 return 512
282 else:
283 return 1024
286def uniform_heur_num_warps(args):
287 if args["N"] <= 512:
288 return 4
289 elif args["N"] <= 1024:
290 return 8
291 else:
292 return 16
295def var_mean_heur_block_n(args):
296 return triton.next_power_of_2(args["BLOCK_NUM"])
299def upsample_nearest1d_SAME_L(args):
300 return args["OL"] == args["IL"]
303def upsample_nearest1d_USE_INT32_IDX(args):
304 return args["N"] * args["C"] * args["OL"] <= (2**31 - 1) # INT32 MAX
307def upsample_nearest2d_SAME_H(args):
308 return args["OH"] == args["IH"]
311def upsample_nearest2d_SAME_W(args):
312 return args["OW"] == args["IW"]
315def upsample_nearest2d_USE_INT32_IDX(args):
316 return args["N"] * args["C"] * args["OH"] * args["OW"] <= (2**31 - 1) # INT32 MAX
319def batch_norm_heur_block_m(args):
320 return min(2048, triton.next_power_of_2(args["batch_dim"]))
323def batch_norm_heur_block_n(args):
324 # A maximum of 16384 elements are loaded at once.
325 BLOCK_M = batch_norm_heur_block_m(args)
326 BLOCK_N = triton.next_power_of_2(args["spatial_dim"])
327 return min(BLOCK_N, max(1, 2**14 // BLOCK_M))
330def vdot_heur_block_size(args):
331 n = args["n_elements"]
332 if n < 1024:
333 return 32
334 elif n < 8192:
335 return 256
336 else:
337 return 1024
340def mean_heur_tile_k(args):
341 MAX_TILE_K = 512
342 NUM_SMS = torch.txda.get_device_properties(
343 torch.txda.current_device()
344 ).multi_processor_count
345 tile_k = 1
346 upper_bound = min(args["K"], MAX_TILE_K)
347 max_tile_k_allowed_by_tile_n = max(1, _MAX_TILE_N_PER_ROW // _MIN_TILE_N)
348 upper_bound = min(upper_bound, max_tile_k_allowed_by_tile_n)
349 while tile_k <= upper_bound:
350 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k)
351 num_waves = num_blocks / NUM_SMS
352 if (num_waves > 1) and (tile_k * 2 <= upper_bound):
353 tile_k *= 2
354 else:
355 break
356 return tile_k
359def mean_heur_tile_n_non_inner(args):
360 tile_k = args.get("TILE_K", 1)
361 limit_by_k = max(1, _MAX_TILE_N_PER_ROW // tile_k)
362 N = args.get("N", 1)
363 desired = min(max(N, _MIN_TILE_N), limit_by_k)
364 desired = min(desired, _MAX_ONE_TILE_N, limit_by_k)
365 tile_n = triton.next_power_of_2(desired)
366 if tile_n > limit_by_k:
367 tile_n = limit_by_k
368 tile_n = max(tile_n, _MIN_TILE_N)
369 return tile_n
372def mean_heur_one_tile_per_cta(args):
373 return args["TILE_N"] >= args["N"]
376def mha_varlen_heur_block_m(params):
377 if params.seqlen_q == 1:
378 return 1
379 elif params.seqlen_q >= 1024:
380 return 512
381 elif params.seqlen_q >= 512:
382 return 256
383 elif params.seqlen_q >= 256:
384 return 128
385 elif params.seqlen_q >= 128:
386 return 64
387 elif params.seqlen_q >= 64:
388 return 32
389 else:
390 return 16
393def mha_varlen_heur_block_n(params):
394 return 16
397HEURISTICS_CONFIGS = {
398 "argmax_non_inner": {
399 "TILE_K": argmax_heur_tile_k,
400 "TILE_N": argmax_heur_tile_n_non_inner,
401 "ONE_TILE_PER_CTA": argmax_heur_one_tile_per_cta,
402 "num_warps": argmax_heur_num_warps_non_inner,
403 },
404 "argmax_inner": {
405 "TILE_N": argmax_heur_tile_n_inner,
406 "ONE_TILE_PER_CTA": argmax_heur_one_tile_per_cta,
407 "num_warps": argmax_heur_num_warps_inner,
408 },
409 "argmin": {
410 "BLOCK_M": argmin_heur_block_m,
411 "BLOCK_N": argmin_heur_block_n,
412 },
413 "bmm": {
414 "DIVISIBLE_M": bmm_heur_divisible_m,
415 "DIVISIBLE_N": bmm_heur_divisible_n,
416 "DIVISIBLE_K": bmm_heur_divisible_k,
417 },
418 "baddbmm": {
419 "DIVISIBLE_M": baddbmm_heur_divisible_m,
420 "DIVISIBLE_N": baddbmm_heur_divisible_n,
421 "DIVISIBLE_K": baddbmm_heur_divisible_k,
422 },
423 "dropout": {
424 "BLOCK": dropout_heur_block,
425 "num_warps": dropout_heur_num_warps,
426 },
427 "exponential_": {
428 "BLOCK": exponential_heur_block,
429 "num_warps": exponential_heur_num_warps,
430 },
431 "gather": {
432 "BLOCK_M": gather_heur_block_m,
433 "BLOCK_N": gather_heur_block_n,
434 },
435 "index_select": {
436 "BLOCK_M": index_select_heur_block_m,
437 "BLOCK_N": index_select_heur_block_n,
438 },
439 "mm": {
440 "EVEN_K": mm_heur_even_k,
441 },
442 "rand": {
443 "BLOCK": rand_heur_block,
444 "num_warps": rand_heur_num_warps,
445 },
446 "randn": {
447 "BLOCK": randn_heur_block,
448 "num_warps": randn_heur_num_warps,
449 },
450 "softmax_non_inner": {
451 "TILE_K": softmax_heur_tile_k,
452 "TILE_N": softmax_heur_tile_n_non_inner,
453 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
454 "num_warps": softmax_heur_num_warps_non_inner,
455 },
456 "mean_non_inner": {
457 "TILE_K": mean_heur_tile_k,
458 "TILE_N": mean_heur_tile_n_non_inner,
459 "ONE_TILE_PER_CTA": mean_heur_one_tile_per_cta,
460 "num_warps": softmax_heur_num_warps_non_inner,
461 },
462 "softmax_inner": {
463 "TILE_N": softmax_heur_tile_n_inner,
464 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
465 "num_warps": softmax_heur_num_warps_inner,
466 },
467 "softmax_backward_non_inner": {
468 "TILE_N": softmax_heur_tile_n_bwd_non_inner,
469 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
470 },
471 "softmax_backward_inner": {
472 "TILE_M": softmax_heur_tile_m,
473 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
474 },
475 "uniform": {
476 "BLOCK": uniform_heur_block,
477 "num_warps": uniform_heur_num_warps,
478 },
479 "upsample_nearest1d": {
480 "SAME_L": upsample_nearest1d_SAME_L,
481 "USE_INT32_IDX": upsample_nearest1d_USE_INT32_IDX,
482 },
483 "upsample_nearest2d": {
484 "SAME_H": upsample_nearest2d_SAME_H,
485 "SAME_W": upsample_nearest2d_SAME_W,
486 "USE_INT32_IDX": upsample_nearest2d_USE_INT32_IDX,
487 },
488 "var_mean": {
489 "BLOCK_N": var_mean_heur_block_n,
490 },
491 "batch_norm": {
492 "BLOCK_M": batch_norm_heur_block_m,
493 "BLOCK_N": batch_norm_heur_block_n,
494 },
495 "vdot": {
496 "BLOCK_SIZE": vdot_heur_block_size,
497 },
498 "mha_varlen_fwd": {
499 "BLOCK_M": mha_varlen_heur_block_m,
500 "BLOCK_N": mha_varlen_heur_block_n,
501 "num_warps": lambda args: 1,
502 "num_stages": lambda args: 1,
503 },
504 "elementwise_generic": {
505 "BLOCK_SIZE": simple_elementwise_blocksize_heur,
506 "num_warps": lambda args: 8,
507 },
508}