Coverage for src/flag_gems/runtime/configloader.py: 62%
204 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import copy
2import inspect
3import warnings
5import triton
7from . import backend, common
8from .backend.device import DeviceDetector
11class ConfigLoader(object):
12 _instance = None
14 def __new__(cls, *args, **kargs):
15 if cls._instance is None:
16 cls._instance = super(ConfigLoader, cls).__new__(cls)
17 return cls._instance
19 def __init__(self):
20 if not hasattr(self, "initialized"):
21 self.initialized = True
22 self.device = DeviceDetector()
23 # primitive_yaml_config is simply the dictionary returned by yaml
24 # and is reserved from being an attr for vendor customizability
25 self.arch_specialized_yaml_config = None
26 self.arch_heuristics_config = None
27 self.vendor_primitive_yaml_config = self.get_vendor_tune_config()
28 self.default_primitive_yaml_config = self.get_default_tune_config()
29 self.vendor_heuristics_config = self.get_vendor_heuristics_config()
30 self.default_heuristics_config = self.get_default_heuristics_config()
31 self.update_config_from_arch()
33 if self.vendor_heuristics_config is None:
34 vendorname = self.device.vendor_name
35 warnings.warn(
36 f"The {vendorname} configuration of heuristics_config is None"
37 )
38 # gen_key is an identifier that indicates whether the current config needs to be generated automatically
39 self.gen_key = "gen"
40 # loaded_triton_config is wrapped in triton.Config according to primitive_yaml_config
41 self.loaded_triton_config = {}
42 self.triton_config_default = {
43 "num_stages": 2,
44 "num_warps": 4,
45 "num_ctas": 1,
46 }
47 if self.device.vendor_name == "hygon":
48 self.triton_config_default["num_ldmatrixes"] = 0
49 self.expand_config_registry = self._build_expand_registry()
50 self.load_all()
52 def update_config_from_arch(self):
53 try:
54 archEvent = backend.BackendArchEvent()
55 if archEvent.has_arch:
56 self.arch_specialized_yaml_config = archEvent.autotune_configs
57 self.arch_heuristics_config = archEvent.heuristics_configs
58 except Exception as err:
59 print(f"[INFO] : {err}")
61 def _get_op_configs(self, op_name):
62 """Get config for op_name from available config sources."""
63 for config in (
64 self.arch_specialized_yaml_config,
65 self.vendor_primitive_yaml_config,
66 self.default_primitive_yaml_config,
67 ):
68 if config and op_name in config:
69 return config[op_name]
70 return []
72 def _create_triton_config(self, single_config, current_config):
73 """Create a triton.Config with appropriate parameters."""
74 kwargs = {
75 "num_warps": current_config["num_warps"],
76 "num_stages": current_config["num_stages"],
77 "num_ctas": current_config["num_ctas"],
78 }
79 if (
80 self.device.vendor_name == "hygon"
81 and "num_ldmatrixes" in inspect.signature(triton.Config).parameters
82 ):
83 kwargs["num_ldmatrixes"] = current_config["num_ldmatrixes"]
84 return triton.Config(single_config["META"], **kwargs)
86 def _build_configs_by_op(self, op_name, ranges, pre_hook=None):
87 if op_name == "bmm":
88 return [
89 triton.Config(
90 {
91 "TILE_M": block_m,
92 "TILE_N": block_n,
93 "TILE_K": block_k,
94 "GROUP_M": 1 if block_m == 32 else 2,
95 },
96 num_stages=s,
97 num_warps=w,
98 pre_hook=pre_hook,
99 )
100 for block_m in ranges["BLOCK_M"]
101 for block_n in ranges["BLOCK_N"]
102 for block_k in ranges["BLOCK_K"]
103 for s in ranges["s"]
104 for w in ranges["w"]
105 ]
107 if op_name == "addmm":
108 return [
109 triton.Config(
110 {
111 "BLOCK_SIZE_M": block_m,
112 "BLOCK_SIZE_N": block_n,
113 "BLOCK_SIZE_K": block_k,
114 },
115 num_stages=s,
116 num_warps=w,
117 pre_hook=pre_hook,
118 )
119 for block_m in ranges["BLOCK_M"]
120 for block_n in ranges["BLOCK_N"]
121 for block_k in ranges["BLOCK_K"]
122 for s in ranges["s"]
123 for w in ranges["w"]
124 ]
126 if op_name == "baddbmm":
127 return [
128 triton.Config(
129 {
130 "TILE_M": block_m,
131 "TILE_N": block_n,
132 "TILE_K": block_k,
133 "GROUP_M": 1 if block_m <= 32 else 2,
134 },
135 num_stages=s,
136 num_warps=w,
137 pre_hook=pre_hook,
138 )
139 for block_m in ranges["BLOCK_M"]
140 for block_n in ranges["BLOCK_N"]
141 for block_k in ranges["BLOCK_K"]
142 for s in ranges["s"]
143 for w in ranges["w"]
144 ]
146 if op_name == "mv":
147 return [
148 triton.Config(
149 {
150 "BLOCK_N": block_n,
151 "BLOCK_M": block_m,
152 },
153 num_stages=s,
154 num_warps=w,
155 pre_hook=pre_hook,
156 )
157 for block_n in ranges["BLOCK_N"]
158 for block_m in ranges["BLOCK_M"]
159 for s in ranges["s"]
160 for w in ranges["w"]
161 ]
163 if op_name == "mm_general_tma":
164 return [
165 triton.Config(
166 {
167 "BLOCK_M": block_m,
168 "BLOCK_N": block_n,
169 "BLOCK_K": block_k,
170 },
171 num_stages=s,
172 num_warps=w,
173 pre_hook=pre_hook,
174 )
175 for block_m in ranges["BLOCK_M"]
176 for block_n in ranges["BLOCK_N"]
177 for block_k in ranges["BLOCK_K"]
178 for s in ranges["s"]
179 for w in ranges["w"]
180 ]
182 if op_name in ("mm", "mm_sqmma"):
183 return [
184 triton.Config(
185 {
186 "BLOCK_M": block_m,
187 "BLOCK_N": block_n,
188 "BLOCK_K": block_k,
189 },
190 num_stages=s,
191 num_warps=w,
192 pre_hook=pre_hook,
193 )
194 for block_m in ranges["BLOCK_M"]
195 for block_n in ranges["BLOCK_N"]
196 for block_k in ranges["BLOCK_K"]
197 for s in ranges["s"]
198 for w in ranges["w"]
199 ]
201 if op_name in ("bmm_sqmma", "addmm_sqmma"):
202 return [
203 triton.Config(
204 {
205 "BLOCK_SIZE_M": block_m,
206 "BLOCK_SIZE_N": block_n,
207 "BLOCK_SIZE_K": block_k,
208 },
209 num_stages=s,
210 num_warps=w,
211 pre_hook=pre_hook,
212 )
213 for block_m in ranges["BLOCK_M"]
214 for block_n in ranges["BLOCK_N"]
215 for block_k in ranges["BLOCK_K"]
216 for s in ranges["s"]
217 for w in ranges["w"]
218 ]
220 if op_name == "gemv":
221 return [
222 triton.Config(
223 {"BLOCK_M": block_m, "BLOCK_K": block_k},
224 num_stages=s,
225 num_warps=w,
226 pre_hook=pre_hook,
227 )
228 for block_m in ranges["BLOCK_M"]
229 for block_k in ranges["BLOCK_K"]
230 for s in ranges["s"]
231 for w in ranges["w"]
232 ]
234 if op_name == "sparse_attention":
235 return [
236 triton.Config(
237 {"BLOCK": block},
238 num_stages=s,
239 num_warps=w,
240 pre_hook=pre_hook,
241 )
242 for block in ranges["BLOCK"]
243 for s in ranges["s"]
244 for w in ranges["w"]
245 ]
247 if op_name == "w8a8_block_fp8_general":
248 return [
249 triton.Config(
250 {
251 "BLOCK_M": block_m,
252 "BLOCK_N": block_n,
253 "BLOCK_K": block_k,
254 "GROUP_M": group_m,
255 },
256 num_stages=s,
257 num_warps=w,
258 pre_hook=pre_hook,
259 )
260 for block_m in ranges["BLOCK_M"]
261 for block_n in ranges["BLOCK_N"]
262 for block_k in ranges["BLOCK_K"]
263 for group_m in ranges["GROUP_M"]
264 for s in ranges["s"]
265 for w in ranges["w"]
266 ]
268 if op_name == "w8a8_block_fp8_general_tma":
269 group_m_values = ranges.get("GROUP_M", [None])
270 return [
271 triton.Config(
272 dict(
273 {
274 "BLOCK_M": block_m,
275 "BLOCK_N": block_n,
276 "BLOCK_K": block_k,
277 },
278 **({} if group_m is None else {"GROUP_M": group_m}),
279 ),
280 num_stages=s,
281 num_warps=w,
282 pre_hook=pre_hook,
283 )
284 for block_m in ranges["BLOCK_M"]
285 for block_n in ranges["BLOCK_N"]
286 for block_k in ranges["BLOCK_K"]
287 for group_m in group_m_values
288 for s in ranges["s"]
289 for w in ranges["w"]
290 ]
292 if op_name == "w8a8_block_fp8_general_splitk":
293 return [
294 triton.Config(
295 {
296 "BLOCK_M": block_m,
297 "BLOCK_N": block_n,
298 "BLOCK_K": block_k,
299 "SPLIT_K": split_k,
300 },
301 num_stages=s,
302 num_warps=w,
303 pre_hook=pre_hook,
304 )
305 for block_m in ranges["BLOCK_M"]
306 for block_n in ranges["BLOCK_N"]
307 for block_k in ranges["BLOCK_K"]
308 for split_k in ranges["SPLIT_K"]
309 for s in ranges["s"]
310 for w in ranges["w"]
311 ]
313 if op_name == "mm_splitk":
314 return [
315 triton.Config(
316 {
317 "BLOCK_M": block_m,
318 "BLOCK_N": block_n,
319 "BLOCK_K": block_k,
320 "SPLIT_K": split_k,
321 },
322 num_stages=s,
323 num_warps=w,
324 pre_hook=pre_hook,
325 )
326 for block_m in ranges["BLOCK_M"]
327 for block_n in ranges["BLOCK_N"]
328 for block_k in ranges["BLOCK_K"]
329 for split_k in ranges["SPLIT_K"]
330 for s in ranges["s"]
331 for w in ranges["w"]
332 ]
334 return []
336 def _build_single_expand_spec(
337 self,
338 op_name,
339 expand_yaml_path=None,
340 yaml_op_name=None,
341 ):
342 return {
343 "yaml_op_name": yaml_op_name or op_name,
344 "key": common.OP_KEY_ORDERS[op_name],
345 "default_strategy": common.DEFAULT_STRATEGIES[op_name],
346 "expand_yaml_path": expand_yaml_path,
347 }
349 def _build_expand_registry(self):
350 DEFAULT_EXPAND_CONFIG_PATH = common.DEFAULT_EXPAND_CONFIG_PATH
351 return {
352 "addmm": self._build_single_expand_spec(
353 "addmm", expand_yaml_path=DEFAULT_EXPAND_CONFIG_PATH
354 ),
355 "addmm_sqmma": self._build_single_expand_spec("addmm_sqmma"),
356 "baddbmm": self._build_single_expand_spec(
357 "baddbmm", expand_yaml_path=DEFAULT_EXPAND_CONFIG_PATH
358 ),
359 "bmm": self._build_single_expand_spec(
360 "bmm", expand_yaml_path=DEFAULT_EXPAND_CONFIG_PATH
361 ),
362 "bmm_sqmma": self._build_single_expand_spec("bmm_sqmma"),
363 "gemv": self._build_single_expand_spec("gemv"),
364 "mm": self._build_single_expand_spec("mm"),
365 "mm_general_tma": self._build_single_expand_spec("mm_general_tma"),
366 "mv": self._build_single_expand_spec(
367 "mv", expand_yaml_path=DEFAULT_EXPAND_CONFIG_PATH
368 ),
369 "w8a8_block_fp8_general": self._build_single_expand_spec(
370 "w8a8_block_fp8_general"
371 ),
372 "w8a8_block_fp8_general_splitk": self._build_single_expand_spec(
373 "w8a8_block_fp8_general_splitk"
374 ),
375 "w8a8_block_fp8_general_tma": self._build_single_expand_spec(
376 "w8a8_block_fp8_general_tma"
377 ),
378 "mm_splitk": self._build_single_expand_spec("mm_splitk"),
379 "sparse_attention": self._build_single_expand_spec("sparse_attention"),
380 }
382 def load_all(self):
383 for key in self.vendor_primitive_yaml_config:
384 self.loaded_triton_config[key] = self.get_tuned_config(key)
386 def get_vendor_heuristics_config(self):
387 return backend.get_heuristic_config(self.device.vendor_name)
389 def get_default_heuristics_config(self):
390 return backend.get_heuristic_config("nvidia")
392 def get_default_tune_config(self):
393 return backend.get_tune_config("nvidia")
395 def get_vendor_tune_config(self):
396 return backend.get_tune_config(self.device.vendor_name)
398 def get_heuristics_config(self, op_name):
399 if self.arch_heuristics_config and op_name in self.arch_heuristics_config:
400 return self.arch_heuristics_config[op_name]
401 elif op_name in self.vendor_heuristics_config:
402 return self.vendor_heuristics_config[op_name]
403 elif op_name in self.default_heuristics_config:
404 return self.default_heuristics_config[op_name]
405 else:
406 warnings.warn(f"No heuristics config found for {op_name}")
407 return None
409 def _resolve_iteration_values(self, gen_config, config_var_key):
410 if isinstance(config_var_key, (list, tuple)):
411 return config_var_key
412 if isinstance(config_var_key, int):
413 return [config_var_key]
414 return gen_config[config_var_key]
416 def _gen_impl(
417 self,
418 gen_config,
419 iteration_plan,
420 std_config,
421 ):
422 all_configs = []
423 final_step = len(iteration_plan)
424 stack = [{"cur_config": std_config, "current_step": 0}]
426 while stack:
427 cur_state = stack[-1]
428 stack.pop()
429 cur_config = cur_state.get("cur_config")
430 current_step = cur_state.get("current_step")
432 if current_step == final_step:
433 all_configs.append(
434 triton.Config(
435 cur_config["META"],
436 num_warps=cur_config["num_warps"],
437 num_stages=cur_config["num_stages"],
438 num_ctas=cur_config["num_ctas"],
439 )
440 )
441 else:
442 cur_entry = iteration_plan[current_step]
443 cur_key = cur_entry["key"]
444 key_config = self._resolve_iteration_values(
445 gen_config, cur_entry["source"]
446 )
447 for single_value in key_config:
448 new_config = copy.deepcopy(cur_config)
449 if cur_entry["kind"] == "meta_field":
450 new_config["META"][cur_key] = single_value
451 elif cur_entry["kind"] == "meta_block":
452 new_config["META"] = copy.deepcopy(single_value)
453 else:
454 new_config[cur_key] = single_value
455 stack.append(
456 {
457 "cur_config": new_config,
458 "current_step": current_step + 1,
459 }
460 )
461 return all_configs
463 def to_gen_config(self, gen_config):
464 param_config = gen_config["param_map"]
465 meta_config = param_config["META"]
466 iteration_plan = []
468 if isinstance(meta_config, dict):
469 for meta_key, source in meta_config.items():
470 iteration_plan.append(
471 {"key": meta_key, "source": source, "kind": "meta_field"}
472 )
473 else:
474 iteration_plan.append(
475 {"key": "META", "source": meta_config, "kind": "meta_block"}
476 )
478 for key, source in param_config.items():
479 if key == "META":
480 continue
481 iteration_plan.append(
482 {"key": key, "source": source, "kind": "config_field"}
483 )
485 current_config = {"META": {}}
486 current_config.update(self.triton_config_default)
487 return self._gen_impl(
488 gen_config,
489 iteration_plan,
490 current_config,
491 )
493 def get_expand_config(self, op_name, yaml_path=None):
494 op_spec = self.expand_config_registry.get(op_name)
495 if op_spec is None:
496 return -1
498 key = op_spec.get("key", [])
499 default_strategy = op_spec.get("default_strategy")
500 expand_yaml_path = op_spec.get("expand_yaml_path") or yaml_path
501 yaml_op_name = op_spec.get("yaml_op_name", op_name)
503 try:
504 expand_configs = backend.get_expand_config(
505 op_name=yaml_op_name,
506 file_path=expand_yaml_path,
507 )
508 if not isinstance(expand_configs, list):
509 return -1
511 gen_config = None
512 strategy_config = None
513 for single_config in expand_configs:
514 if isinstance(single_config, dict) and "param_map" in single_config:
515 gen_config = single_config
517 if isinstance(single_config, dict) and "strategy" in single_config:
518 strategy_config = single_config.get("strategy")
520 param_map = gen_config.get("param_map")
521 meta_map = param_map.get("META")
523 strategy = default_strategy
524 if isinstance(strategy_config, dict):
525 strategy = [
526 strategy_config.get(k, default_strategy[idx])
527 for idx, k in enumerate(key)
528 ]
530 ranges = {}
532 for mapped_key in meta_map.values():
533 ranges[mapped_key.upper()] = gen_config[mapped_key]
534 ranges["s"] = gen_config[param_map.get("num_stages")]
535 ranges["w"] = gen_config[param_map.get("num_warps")]
537 return {
538 "ranges": ranges,
539 "strategy": strategy,
540 }
541 except Exception:
542 return -1
544 def ops_get_configs(self, op_name, yaml_path=None, pre_hook=None):
545 expand_config = self.get_expand_config(op_name, yaml_path=yaml_path)
546 if expand_config == -1:
547 return []
548 ranges = expand_config["ranges"]
549 return self._build_configs_by_op(op_name, ranges, pre_hook=pre_hook)
551 def get_tuned_config(self, op_name):
552 if op_name in self.loaded_triton_config:
553 return self.loaded_triton_config[op_name]
555 current_op_configs = self._get_op_configs(op_name)
556 if not current_op_configs:
557 return []
559 configs = []
561 for single_config in current_op_configs:
562 if self.gen_key in single_config:
563 configs.extend(self.to_gen_config(single_config))
564 continue
566 current_config = copy.deepcopy(self.triton_config_default)
567 for default_param in current_config:
568 if default_param in single_config:
569 current_config[default_param] = single_config[default_param]
571 configs.append(self._create_triton_config(single_config, current_config))
572 return configs