Coverage for src/flag_gems/runtime/backend/_enflame/heuristics_config_utils.py: 0%
113 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import triton
4def argmax_heur_block_m(args):
5 return 4 if args["M"] < 4096 else 8
8def argmax_heur_block_n(args):
9 return min(4096, triton.next_power_of_2(args["N"]))
12def argmin_heur_block_m(args):
13 return 4 if args["M"] < 4096 else 8
16def argmin_heur_block_n(args):
17 return min(4096, triton.next_power_of_2(args["N"]))
20# def bmm_heur_divisible_m(args):
21# return args["M"] % args["BLOCK_M"] == 0
24# def bmm_heur_divisible_n(args):
25# return args["N"] % args["BLOCK_N"] == 0
28# def bmm_heur_divisible_k(args):
29# return args["K"] % args["BLOCK_K"] == 0
32def dropout_heur_block(args):
33 if args["N"] <= 512:
34 return 512
35 else:
36 return 4096
39def dropout_heur_num_warps(args):
40 return 4
43def exponential_heur_block(args):
44 if args["N"] <= 512:
45 return 512
46 else:
47 return 16384
50def exponential_heur_num_warps(args):
51 return 4
54def gather_heur_block_m(args):
55 return min(4, triton.next_power_of_2(triton.cdiv(args["N"], 2048)))
58def gather_heur_block_n(args):
59 return min(2048, triton.next_power_of_2(args["N"]))
62def index_select_heur_block_m(args):
63 return min(16, triton.next_power_of_2(triton.cdiv(32768, args["N"])))
66def index_select_heur_block_n(args):
67 m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512)
68 return max(m, 16)
71def mm_heur_even_k(args):
72 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0
75def rand_heur_block(args):
76 if args["N"] <= 512:
77 return 512
78 else:
79 return 16384
82def rand_heur_num_warps(args):
83 return 4
86def randn_heur_block(args):
87 if args["N"] <= 512:
88 return 512
89 else:
90 return 16384
93def randn_heur_num_warps(args):
94 return 4
97def softmax_heur_tile_k(args):
98 MAX_TILE_K = 8192
99 # NUM_SMS = torch.cuda.get_device_properties(
100 # torch.cuda.current_device()
101 # ).multi_processor_count
102 NUM_SMS = 64
103 tile_k = 1
104 upper_bound = min(args["K"], MAX_TILE_K)
105 while tile_k <= upper_bound:
106 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k)
107 num_waves = num_blocks / NUM_SMS
108 if (num_waves > 1) and (tile_k * 2 <= upper_bound):
109 tile_k *= 2
110 else:
111 break
112 return tile_k
115def softmax_heur_tile_n_non_inner(args):
116 return triton.cdiv(8192, args["TILE_K"])
119def softmax_heur_one_tile_per_cta(args):
120 return args["TILE_N"] >= args["N"]
123def softmax_heur_num_warps_non_inner(args):
124 return 4
127def softmax_heur_tile_n_inner(args):
128 if args["N"] <= (32 * 1024):
129 return triton.next_power_of_2(args["N"])
130 else:
131 return 4096
134def softmax_heur_num_warps_inner(args):
135 return 4
138def softmax_heur_tile_n_bwd_non_inner(args):
139 return max(1, 1024 // args["TILE_K"])
142def softmax_heru_tile_m(args):
143 return max(1, 1024 // args["TILE_N"])
146def uniform_heur_block(args):
147 if args["N"] <= 512:
148 return 512
149 else:
150 return 16384
153def uniform_heur_num_warps(args):
154 return 4
157def var_mean_heur_block_n(args):
158 return triton.next_power_of_2(args["BLOCK_NUM"])
161def upsample_nearest2d_NUM_TILE(args):
162 grid_y = triton.cdiv(args["N"] * args["C"], 4)
163 if grid_y <= 128:
164 num_tile = 1
165 else:
166 num_tile = triton.cdiv(grid_y, 128)
167 return num_tile
170def upsample_nearest2d_TOTAL_TILE(args):
171 return triton.cdiv(args["N"] * args["C"], 4)
174def upsample_nearest2d_SAME_H(args):
175 return args["OH"] == args["IH"]
178def upsample_nearest2d_SAME_W(args):
179 return args["OW"] == args["IW"]
182def upsample_nearest2d_USE_INT32_IDX(args):
183 return args["N"] * args["C"] * args["OH"] * args["OW"] <= (2**31 - 1) # INT32 MAX
186def batch_norm_heur_block_m(args):
187 return min(2048, triton.next_power_of_2(args["batch_dim"]))
190def batch_norm_heur_block_n(args):
191 # A maximum of 16384 elements are loaded at once.
192 BLOCK_M = batch_norm_heur_block_m(args)
193 BLOCK_N = triton.next_power_of_2(args["spatial_dim"])
194 return min(BLOCK_N, max(1, 2**14 // BLOCK_M))
197def vdot_heur_block_size(args):
198 n = args["n_elements"]
199 if n < 1024:
200 return 32
201 elif n < 8192:
202 return 256
203 else:
204 return 1024
207def simple_elementwise_blocksize_heur(args):
208 n = args["n_elements"]
209 if n < 65535:
210 return 1024
211 else:
212 return 16384
215HEURISTICS_CONFIGS = {
216 "argmax": {
217 "BLOCK_M": argmax_heur_block_m,
218 "BLOCK_N": argmax_heur_block_n,
219 },
220 "argmin": {
221 "BLOCK_M": argmin_heur_block_m,
222 "BLOCK_N": argmin_heur_block_n,
223 },
224 "bmm": {
225 # "DIVISIBLE_M": bmm_heur_divisible_m,
226 # "DIVISIBLE_N": bmm_heur_divisible_n,
227 # "DIVISIBLE_K": bmm_heur_divisible_k,
228 },
229 "dropout": {
230 "BLOCK": dropout_heur_block,
231 "num_warps": dropout_heur_num_warps,
232 },
233 "exponential_": {
234 "BLOCK": exponential_heur_block,
235 "num_warps": exponential_heur_num_warps,
236 },
237 "gather": {
238 "BLOCK_M": gather_heur_block_m,
239 "BLOCK_N": gather_heur_block_n,
240 },
241 "index_select": {
242 "BLOCK_M": index_select_heur_block_m,
243 "BLOCK_N": index_select_heur_block_n,
244 },
245 "mha_block_128": {
246 "BLOCK_M": lambda args: 64,
247 "BLOCK_N": lambda args: 32,
248 "num_warps": lambda args: 4,
249 "num_stages": lambda args: 1,
250 },
251 "mha_block_64": {
252 "BLOCK_M": lambda args: 32,
253 "BLOCK_N": lambda args: 64,
254 "num_warps": lambda args: 4,
255 "num_stages": lambda args: 1,
256 },
257 "mha_block_32": {
258 "BLOCK_M": lambda args: 32,
259 "BLOCK_N": lambda args: 64,
260 "num_warps": lambda args: 4,
261 "num_stages": lambda args: 1,
262 },
263 "mha_block_16": {
264 "BLOCK_M": lambda args: 16,
265 "BLOCK_N": lambda args: 64,
266 "num_warps": lambda args: 4,
267 "num_stages": lambda args: 1,
268 },
269 "mm": {
270 "EVEN_K": mm_heur_even_k,
271 },
272 "rand": {
273 "BLOCK": rand_heur_block,
274 "num_warps": rand_heur_num_warps,
275 },
276 "randn": {
277 "BLOCK": randn_heur_block,
278 "num_warps": randn_heur_num_warps,
279 },
280 "softmax_non_inner": {
281 "TILE_K": softmax_heur_tile_k,
282 "TILE_N": softmax_heur_tile_n_non_inner,
283 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
284 "num_warps": softmax_heur_num_warps_non_inner,
285 },
286 "softmax_inner": {
287 "TILE_N": softmax_heur_tile_n_inner,
288 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
289 "num_warps": softmax_heur_num_warps_inner,
290 },
291 "softmax_backward_non_inner": {
292 "TILE_N": softmax_heur_tile_n_bwd_non_inner,
293 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
294 },
295 "softmax_backward_inner": {
296 "TILE_M": softmax_heru_tile_m,
297 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
298 },
299 "uniform": {
300 "BLOCK": uniform_heur_block,
301 "num_warps": uniform_heur_num_warps,
302 },
303 "upsample_nearest2d": {
304 "NUM_TILE": upsample_nearest2d_NUM_TILE,
305 "TOTAL_TILE": upsample_nearest2d_TOTAL_TILE,
306 "SAME_H": upsample_nearest2d_SAME_H,
307 "SAME_W": upsample_nearest2d_SAME_W,
308 "USE_INT32_IDX": upsample_nearest2d_USE_INT32_IDX,
309 },
310 "var_mean": {
311 "BLOCK_N": var_mean_heur_block_n,
312 },
313 "batch_norm": {
314 "BLOCK_M": batch_norm_heur_block_m,
315 "BLOCK_N": batch_norm_heur_block_n,
316 },
317 "vdot": {
318 "BLOCK_SIZE": vdot_heur_block_size,
319 },
320 "elementwise_generic": {
321 "BLOCK_SIZE": simple_elementwise_blocksize_heur,
322 "num_warps": lambda args: 4,
323 },
324 "mha_varlen_fwd": {
325 "BLOCK_M": lambda args: 128,
326 "BLOCK_N": lambda args: 32,
327 "num_warps": lambda args: 4,
328 "num_stages": lambda args: 3,
329 },
330}