Coverage for src/flag_gems/runtime/backend/_ascend/heuristics_config_utils.py: 0%
136 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import triton
4def argmax_heur_block_m(args):
5 return 16
8def argmax_heur_block_n(args):
9 return 100
12def argmin_heur_block_m(args):
13 return 16
16def argmin_heur_block_n(args):
17 return 100
20def bmm_heur_divisible_m(args):
21 return args["M"] % args["TILE_M"] == 0
24def bmm_heur_divisible_n(args):
25 return args["N"] % args["TILE_N"] == 0
28def bmm_heur_divisible_k(args):
29 return args["K"] % args["TILE_K"] == 0
32def dropout_heur_block(args):
33 if args["N"] <= 512:
34 return 512
35 else:
36 return 4096
39def dropout_heur_num_warps(args):
40 if args["N"] <= 512:
41 return 4
42 elif args["N"] <= 1024:
43 return 8
44 else:
45 return 16
48def exponential_heur_block(args):
49 if args["N"] <= 512:
50 return 512
51 else:
52 return 1024
55def exponential_heur_num_warps(args):
56 if args["N"] <= 512:
57 return 4
58 elif args["N"] <= 1024:
59 return 8
60 else:
61 return 16
64def gather_heur_block_m(args):
65 return min(4, triton.next_power_of_2(triton.cdiv(args["N"], 2048)))
68def gather_heur_block_n(args):
69 return min(2048, triton.next_power_of_2(args["N"]))
72def index_select_heur_block_m(args):
73 return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"])))
76def index_select_heur_block_n(args):
77 m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512)
78 return max(m, 16)
81def rand_heur_block(args):
82 if args["N"] <= 512:
83 return 2048
84 else:
85 return 4097
88def rand_heur_num_warps(args):
89 if args["N"] <= 512:
90 return 4
91 elif args["N"] <= 1024:
92 return 8
93 else:
94 return 16
97def randn_heur_block(args):
98 if args["N"] <= 512:
99 return 2048
100 else:
101 return 4097
104def randn_heur_num_warps(args):
105 if args["N"] <= 512:
106 return 4
107 elif args["N"] <= 1024:
108 return 8
109 else:
110 return 16
113def softmax_heur_tile_k(args):
114 MAX_TILE_K = 4096
115 # FIXME:
116 # NUM_SMS should be obtained by API.
117 # It is actually the number of AIV cores which depends on the Ascend version.
118 NUM_SMS = 40
119 tile_k = 1
120 upper_bound = min(args["K"], MAX_TILE_K)
121 while tile_k <= upper_bound:
122 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k)
123 num_waves = num_blocks / NUM_SMS
124 if (num_waves > 1) and (tile_k * 2 <= upper_bound):
125 tile_k *= 2
126 else:
127 break
128 return tile_k
131def softmax_heur_tile_n_non_inner(args):
132 return triton.cdiv(1024, args["TILE_K"])
135def softmax_heur_one_tile_per_cta(args):
136 return args["TILE_N"] >= args["N"]
139def softmax_heur_num_warps_non_inner(args):
140 tile_size = args["TILE_N"] * args["TILE_K"]
141 if tile_size < 2048:
142 return 4
143 elif tile_size < 4096:
144 return 8
145 else:
146 return 16
149def softmax_heur_tile_n_inner(args):
150 if args["N"] <= (32 * 1024):
151 return triton.next_power_of_2(args["N"])
152 else:
153 return 4096
156def softmax_heur_num_warps_inner(args):
157 tile_size = args["TILE_N"]
158 if tile_size < 2048:
159 return 4
160 elif tile_size < 4096:
161 return 8
162 else:
163 return 16
166def softmax_heur_tile_n_bwd_non_inner(args):
167 return max(1, 1024 // args["TILE_K"])
170def softmax_heur_tile_m(args):
171 return max(1, 1024 // args["TILE_N"])
174def uniform_heur_block(args):
175 if args["N"] <= 512:
176 return 512
177 elif args["N"] >= 1073741824:
178 return 4097
179 else:
180 return 1024
183def uniform_heur_num_warps(args):
184 if args["N"] <= 512:
185 return 4
186 elif args["N"] <= 1024:
187 return 8
188 else:
189 return 16
192def var_mean_heur_block_n(args):
193 return triton.next_power_of_2(args["BLOCK_NUM"])
196def upsample_nearest2d_SAME_H(args):
197 return args["OH"] == args["IH"]
200def upsample_nearest2d_SAME_W(args):
201 return args["OW"] == args["IW"]
204def batch_norm_heur_block_m(args):
205 return min(64, triton.next_power_of_2(args["batch_dim"]))
208def batch_norm_heur_block_n(args):
209 BLOCK_M = batch_norm_heur_block_m(args)
210 BLOCK_N = triton.next_power_of_2(args["spatial_dim"])
211 return min(BLOCK_N, max(1, 2**10 // BLOCK_M))
214def vdot_heur_block_size(args):
215 n = args["n_elements"]
216 if n < 1024:
217 return 32
218 elif n < 8192:
219 return 256
220 else:
221 return 1024
224def mm_heur_even_k(args):
225 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0
228HEURISTICS_CONFIGS = {
229 "argmax": {
230 "BLOCK_M": argmax_heur_block_m,
231 "BLOCK_N": argmax_heur_block_n,
232 },
233 "argmin": {
234 "BLOCK_M": argmin_heur_block_m,
235 "BLOCK_N": argmin_heur_block_n,
236 },
237 "baddbmm": {
238 "DIVISIBLE_M": bmm_heur_divisible_m,
239 "DIVISIBLE_N": bmm_heur_divisible_n,
240 "DIVISIBLE_K": bmm_heur_divisible_k,
241 },
242 "bmm": {
243 "DIVISIBLE_M": bmm_heur_divisible_m,
244 "DIVISIBLE_N": bmm_heur_divisible_n,
245 "DIVISIBLE_K": bmm_heur_divisible_k,
246 },
247 "dropout": {
248 "BLOCK": dropout_heur_block,
249 "num_warps": dropout_heur_num_warps,
250 },
251 "exponential_": {
252 "BLOCK": exponential_heur_block,
253 "num_warps": exponential_heur_num_warps,
254 },
255 "gather": {
256 "BLOCK_M": gather_heur_block_m,
257 "BLOCK_N": gather_heur_block_n,
258 },
259 "index_select": {
260 "BLOCK_M": index_select_heur_block_m,
261 "BLOCK_N": index_select_heur_block_n,
262 },
263 "mm": {
264 "EVEN_K": mm_heur_even_k,
265 },
266 "rand": {
267 "BLOCK": rand_heur_block,
268 "num_warps": rand_heur_num_warps,
269 },
270 "randn": {
271 "BLOCK": randn_heur_block,
272 "num_warps": randn_heur_num_warps,
273 },
274 "softmax_non_inner": {
275 "TILE_K": softmax_heur_tile_k,
276 "TILE_N": softmax_heur_tile_n_non_inner,
277 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
278 "num_warps": softmax_heur_num_warps_non_inner,
279 },
280 "softmax_inner": {
281 "TILE_N": softmax_heur_tile_n_inner,
282 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
283 "num_warps": softmax_heur_num_warps_inner,
284 },
285 "softmax_backward_non_inner": {
286 "TILE_N": softmax_heur_tile_n_bwd_non_inner,
287 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
288 },
289 "softmax_backward_inner": {
290 "TILE_M": softmax_heur_tile_m,
291 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
292 },
293 "uniform": {
294 "BLOCK": uniform_heur_block,
295 "num_warps": uniform_heur_num_warps,
296 },
297 "upsample_nearest2d": {
298 "SAME_H": upsample_nearest2d_SAME_H,
299 "SAME_W": upsample_nearest2d_SAME_W,
300 },
301 "var_mean": {
302 "BLOCK_N": var_mean_heur_block_n,
303 },
304 "batch_norm": {
305 "BLOCK_M": batch_norm_heur_block_m,
306 "BLOCK_N": batch_norm_heur_block_n,
307 },
308 "vdot": {
309 "BLOCK_SIZE": vdot_heur_block_size,
310 },
311 "mha_block_16": {
312 "BLOCK_M": lambda args: 16, # 16
313 "BLOCK_N": lambda args: 16, # 64
314 "num_warps": lambda args: 4, # 4
315 "num_stages": lambda args: 3, # 3
316 },
317}