Coverage for src/flag_gems/runtime/configs_loader.py: 65%
236 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import copy
2import inspect
3import os
4import warnings
6import triton
8from . import backend, common
9from .backend.device import DeviceDetector
12class TunedConfigLoader(object):
13 _instance = None
15 def __new__(cls, *args, **kargs):
16 if cls._instance is None:
17 cls._instance = super(TunedConfigLoader, cls).__new__(cls)
18 return cls._instance
20 def __init__(self):
21 if not hasattr(self, "initialized"):
22 self.initialized = True
23 self.device = DeviceDetector()
24 # primitive_yaml_config is simply the dictionary returned by yaml
25 # and is reserved from being an attr for vendor customizability
26 self.arch_specialized_yaml_config = None
27 self.arch_heuristics_config = None
28 self.vendor_primitive_yaml_config = self.get_vendor_tune_config()
29 self.default_primitive_yaml_config = self.get_default_tune_config()
30 self.vendor_heuristics_config = self.get_vendor_heuristics_config()
31 self.default_heuristics_config = self.get_default_heuristics_config()
32 self.update_config_from_arch()
34 if self.vendor_heuristics_config is None:
35 vendorname = self.device.vendor_name
36 warnings.warn(
37 f"The {vendorname} configuration of heuristics_config is None"
38 )
39 # gen_key is an identifier that indicates whether the current config needs to be generated automatically
40 self.gen_key = "gen"
41 # loaded_triton_config is wrapped in triton.Config according to primitive_yaml_config
42 self.loaded_triton_config = {}
43 self.triton_config_default = {
44 "num_stages": 2,
45 "num_warps": 4,
46 "num_ctas": 1,
47 }
48 if self.device.vendor_name == "hygon":
49 self.triton_config_default["num_ldmatrixes"] = 0
50 self.expand_config_registry = self._build_expand_registry()
51 self.load_all()
53 def update_config_from_arch(self):
54 try:
55 archEvent = backend.BackendArchEvent()
56 if archEvent.has_arch:
57 self.arch_specialized_yaml_config = archEvent.autotune_configs
58 self.arch_heuristics_config = archEvent.heuristics_configs
59 except Exception as err:
60 print(f"[INFO] : {err}")
62 def _get_op_configs(self, op_name):
63 """Get config for op_name from available config sources."""
64 for config in (
65 self.arch_specialized_yaml_config,
66 self.vendor_primitive_yaml_config,
67 self.default_primitive_yaml_config,
68 ):
69 if config and op_name in config:
70 return config[op_name]
71 return []
73 def _create_triton_config(self, single_config, current_config):
74 """Create a triton.Config with appropriate parameters."""
75 kwargs = {
76 "num_warps": current_config["num_warps"],
77 "num_stages": current_config["num_stages"],
78 "num_ctas": current_config["num_ctas"],
79 }
80 if (
81 self.device.vendor_name == "hygon"
82 and "num_ldmatrixes" in inspect.signature(triton.Config).parameters
83 ):
84 kwargs["num_ldmatrixes"] = current_config["num_ldmatrixes"]
85 return triton.Config(single_config["META"], **kwargs)
87 def _build_configs_by_op(self, op_name, ranges, pre_hook=None):
88 if op_name == "bmm":
89 return [
90 triton.Config(
91 {
92 "TILE_M": block_m,
93 "TILE_N": block_n,
94 "TILE_K": block_k,
95 "GROUP_M": 1 if block_m == 32 else 2,
96 },
97 num_stages=s,
98 num_warps=w,
99 pre_hook=pre_hook,
100 )
101 for block_m in ranges["BLOCK_M"]
102 for block_n in ranges["BLOCK_N"]
103 for block_k in ranges["BLOCK_K"]
104 for s in ranges["s"]
105 for w in ranges["w"]
106 ]
108 if op_name == "addmm":
109 return [
110 triton.Config(
111 {
112 "BLOCK_SIZE_M": block_m,
113 "BLOCK_SIZE_N": block_n,
114 "BLOCK_SIZE_K": block_k,
115 },
116 num_stages=s,
117 num_warps=w,
118 pre_hook=pre_hook,
119 )
120 for block_m in ranges["BLOCK_M"]
121 for block_n in ranges["BLOCK_N"]
122 for block_k in ranges["BLOCK_K"]
123 for s in ranges["s"]
124 for w in ranges["w"]
125 ]
127 if op_name == "baddbmm":
128 return [
129 triton.Config(
130 {
131 "TILE_M": block_m,
132 "TILE_N": block_n,
133 "TILE_K": block_k,
134 "GROUP_M": 1 if block_m <= 32 else 2,
135 },
136 num_stages=s,
137 num_warps=w,
138 pre_hook=pre_hook,
139 )
140 for block_m in ranges["BLOCK_M"]
141 for block_n in ranges["BLOCK_N"]
142 for block_k in ranges["BLOCK_K"]
143 for s in ranges["s"]
144 for w in ranges["w"]
145 ]
147 if op_name == "mv":
148 return [
149 triton.Config(
150 {
151 "BLOCK_N": block_n,
152 "BLOCK_M": block_m,
153 },
154 num_stages=s,
155 num_warps=w,
156 pre_hook=pre_hook,
157 )
158 for block_n in ranges["BLOCK_N"]
159 for block_m in ranges["BLOCK_M"]
160 for s in ranges["s"]
161 for w in ranges["w"]
162 ]
164 if op_name == "mm_general_tma":
165 return [
166 triton.Config(
167 {
168 "BLOCK_M": block_m,
169 "BLOCK_N": block_n,
170 "BLOCK_K": block_k,
171 },
172 num_stages=s,
173 num_warps=w,
174 pre_hook=pre_hook,
175 )
176 for block_m in ranges["BLOCK_M"]
177 for block_n in ranges["BLOCK_N"]
178 for block_k in ranges["BLOCK_K"]
179 for s in ranges["s"]
180 for w in ranges["w"]
181 ]
183 if op_name in ("mm", "mm_sqmma"):
184 return [
185 triton.Config(
186 {
187 "BLOCK_M": block_m,
188 "BLOCK_N": block_n,
189 "BLOCK_K": block_k,
190 },
191 num_stages=s,
192 num_warps=w,
193 pre_hook=pre_hook,
194 )
195 for block_m in ranges["BLOCK_M"]
196 for block_n in ranges["BLOCK_N"]
197 for block_k in ranges["BLOCK_K"]
198 for s in ranges["s"]
199 for w in ranges["w"]
200 ]
202 if op_name in ("bmm_sqmma", "addmm_sqmma"):
203 return [
204 triton.Config(
205 {
206 "BLOCK_SIZE_M": block_m,
207 "BLOCK_SIZE_N": block_n,
208 "BLOCK_SIZE_K": block_k,
209 },
210 num_stages=s,
211 num_warps=w,
212 pre_hook=pre_hook,
213 )
214 for block_m in ranges["BLOCK_M"]
215 for block_n in ranges["BLOCK_N"]
216 for block_k in ranges["BLOCK_K"]
217 for s in ranges["s"]
218 for w in ranges["w"]
219 ]
221 if op_name == "gemv":
222 return [
223 triton.Config(
224 {"BLOCK_M": block_m, "BLOCK_K": block_k},
225 num_stages=s,
226 num_warps=w,
227 pre_hook=pre_hook,
228 )
229 for block_m in ranges["BLOCK_M"]
230 for block_k in ranges["BLOCK_K"]
231 for s in ranges["s"]
232 for w in ranges["w"]
233 ]
235 if op_name == "sparse_attention":
236 return [
237 triton.Config(
238 {"BLOCK": block},
239 num_stages=s,
240 num_warps=w,
241 pre_hook=pre_hook,
242 )
243 for block in ranges["BLOCK"]
244 for s in ranges["s"]
245 for w in ranges["w"]
246 ]
248 if op_name == "w8a8_block_fp8_general":
249 return [
250 triton.Config(
251 {
252 "BLOCK_M": block_m,
253 "BLOCK_N": block_n,
254 "BLOCK_K": block_k,
255 "GROUP_M": group_m,
256 },
257 num_stages=s,
258 num_warps=w,
259 pre_hook=pre_hook,
260 )
261 for block_m in ranges["BLOCK_M"]
262 for block_n in ranges["BLOCK_N"]
263 for block_k in ranges["BLOCK_K"]
264 for group_m in ranges["GROUP_M"]
265 for s in ranges["s"]
266 for w in ranges["w"]
267 ]
269 if op_name == "w8a8_block_fp8_general_tma":
270 group_m_values = ranges.get("GROUP_M", [None])
271 return [
272 triton.Config(
273 dict(
274 {
275 "BLOCK_M": block_m,
276 "BLOCK_N": block_n,
277 "BLOCK_K": block_k,
278 },
279 **({} if group_m is None else {"GROUP_M": group_m}),
280 ),
281 num_stages=s,
282 num_warps=w,
283 pre_hook=pre_hook,
284 )
285 for block_m in ranges["BLOCK_M"]
286 for block_n in ranges["BLOCK_N"]
287 for block_k in ranges["BLOCK_K"]
288 for group_m in group_m_values
289 for s in ranges["s"]
290 for w in ranges["w"]
291 ]
293 if op_name == "w8a8_block_fp8_general_splitk":
294 return [
295 triton.Config(
296 {
297 "BLOCK_M": block_m,
298 "BLOCK_N": block_n,
299 "BLOCK_K": block_k,
300 "SPLIT_K": split_k,
301 },
302 num_stages=s,
303 num_warps=w,
304 pre_hook=pre_hook,
305 )
306 for block_m in ranges["BLOCK_M"]
307 for block_n in ranges["BLOCK_N"]
308 for block_k in ranges["BLOCK_K"]
309 for split_k in ranges["SPLIT_K"]
310 for s in ranges["s"]
311 for w in ranges["w"]
312 ]
314 if op_name == "mm_splitk":
315 return [
316 triton.Config(
317 {
318 "BLOCK_M": block_m,
319 "BLOCK_N": block_n,
320 "BLOCK_K": block_k,
321 "SPLIT_K": split_k,
322 },
323 num_stages=s,
324 num_warps=w,
325 pre_hook=pre_hook,
326 )
327 for block_m in ranges["BLOCK_M"]
328 for block_n in ranges["BLOCK_N"]
329 for block_k in ranges["BLOCK_K"]
330 for split_k in ranges["SPLIT_K"]
331 for s in ranges["s"]
332 for w in ranges["w"]
333 ]
335 return []
337 def _build_single_expand_spec(
338 self,
339 op_name,
340 expand_yaml_path=None,
341 yaml_op_name=None,
342 ):
343 return {
344 "yaml_op_name": yaml_op_name or op_name,
345 "key": common.OP_KEY_ORDERS[op_name],
346 "default_strategy": common.DEFAULT_STRATEGIES[op_name],
347 "expand_yaml_path": expand_yaml_path,
348 }
350 def _iter_expand_config_candidates(self, op_name):
351 vendor_name = self.device.vendor_name
352 contexts = []
353 try:
354 arch_event = backend.BackendArchEvent()
355 current_arch_path = getattr(arch_event, "current_arch_path", None)
356 arch_name = getattr(arch_event, "arch", None)
357 if arch_event.has_arch and current_arch_path:
358 contexts.append((current_arch_path, arch_name))
359 except Exception:
360 pass
362 backend_dir = os.path.join(os.path.dirname(__file__), "backend")
363 contexts.append((os.path.join(backend_dir, f"_{vendor_name}"), vendor_name))
365 seen = set()
366 for base_dir, backend_name in contexts:
367 filenames = []
368 if op_name:
369 filenames.extend(
370 (
371 f"{op_name}_{backend_name}_expand.yaml",
372 f"{op_name}_{vendor_name}_expand.yaml",
373 f"{op_name}_expand.yaml",
374 )
375 )
376 filenames.extend(
377 (
378 f"general_ops_{backend_name}_configs.yaml",
379 f"general_ops_{vendor_name}_configs.yaml",
380 "general_ops_configs.yaml",
381 )
382 )
384 for filename in filenames:
385 path = os.path.normpath(os.path.join(base_dir, filename))
386 if path in seen:
387 continue
388 seen.add(path)
389 yield path
391 def _get_expand_config_path(self, op_name):
392 for path in self._iter_expand_config_candidates(op_name):
393 if os.path.exists(path):
394 return path
395 return None
397 def _build_expand_registry(self):
398 return {
399 "addmm": self._build_single_expand_spec(
400 "addmm", expand_yaml_path=self._get_expand_config_path("addmm")
401 ),
402 "addmm_sqmma": self._build_single_expand_spec("addmm_sqmma"),
403 "baddbmm": self._build_single_expand_spec(
404 "baddbmm", expand_yaml_path=self._get_expand_config_path("baddbmm")
405 ),
406 "bmm": self._build_single_expand_spec(
407 "bmm", expand_yaml_path=self._get_expand_config_path("bmm")
408 ),
409 "bmm_sqmma": self._build_single_expand_spec("bmm_sqmma"),
410 "gemv": self._build_single_expand_spec("gemv"),
411 "mm": self._build_single_expand_spec(
412 "mm", expand_yaml_path=self._get_expand_config_path("mm")
413 ),
414 "mm_general_tma": self._build_single_expand_spec("mm_general_tma"),
415 "mv": self._build_single_expand_spec(
416 "mv", expand_yaml_path=self._get_expand_config_path("mv")
417 ),
418 "w8a8_block_fp8_general": self._build_single_expand_spec(
419 "w8a8_block_fp8_general"
420 ),
421 "w8a8_block_fp8_general_splitk": self._build_single_expand_spec(
422 "w8a8_block_fp8_general_splitk"
423 ),
424 "w8a8_block_fp8_general_tma": self._build_single_expand_spec(
425 "w8a8_block_fp8_general_tma"
426 ),
427 "mm_splitk": self._build_single_expand_spec("mm_splitk"),
428 "sparse_attention": self._build_single_expand_spec("sparse_attention"),
429 }
431 def load_all(self):
432 for key in self.vendor_primitive_yaml_config:
433 self.loaded_triton_config[key] = self.get_tuned_config(key)
435 def get_vendor_heuristics_config(self):
436 return backend.get_heuristic_config(self.device.vendor_name)
438 def get_default_heuristics_config(self):
439 return backend.get_heuristic_config("nvidia")
441 def get_default_tune_config(self):
442 return backend.get_tune_config("nvidia")
444 def get_vendor_tune_config(self):
445 return backend.get_tune_config(self.device.vendor_name)
447 def get_heuristics_config(self, op_name):
448 if self.arch_heuristics_config and op_name in self.arch_heuristics_config:
449 return self.arch_heuristics_config[op_name]
450 elif op_name in self.vendor_heuristics_config:
451 return self.vendor_heuristics_config[op_name]
452 elif op_name in self.default_heuristics_config:
453 return self.default_heuristics_config[op_name]
454 else:
455 warnings.warn(f"No heuristics config found for {op_name}")
456 return None
458 def _resolve_iteration_values(self, gen_config, config_var_key):
459 if isinstance(config_var_key, (list, tuple)):
460 return config_var_key
461 if isinstance(config_var_key, int):
462 return [config_var_key]
463 return gen_config[config_var_key]
465 def _gen_impl(
466 self,
467 gen_config,
468 iteration_plan,
469 std_config,
470 ):
471 all_configs = []
472 final_step = len(iteration_plan)
473 stack = [{"cur_config": std_config, "current_step": 0}]
475 while stack:
476 cur_state = stack[-1]
477 stack.pop()
478 cur_config = cur_state.get("cur_config")
479 current_step = cur_state.get("current_step")
481 if current_step == final_step:
482 all_configs.append(
483 triton.Config(
484 cur_config["META"],
485 num_warps=cur_config["num_warps"],
486 num_stages=cur_config["num_stages"],
487 num_ctas=cur_config["num_ctas"],
488 )
489 )
490 else:
491 cur_entry = iteration_plan[current_step]
492 cur_key = cur_entry["key"]
493 key_config = self._resolve_iteration_values(
494 gen_config, cur_entry["source"]
495 )
496 for single_value in key_config:
497 new_config = copy.deepcopy(cur_config)
498 if cur_entry["kind"] == "meta_field":
499 new_config["META"][cur_key] = single_value
500 elif cur_entry["kind"] == "meta_block":
501 new_config["META"] = copy.deepcopy(single_value)
502 else:
503 new_config[cur_key] = single_value
504 stack.append(
505 {
506 "cur_config": new_config,
507 "current_step": current_step + 1,
508 }
509 )
510 return all_configs
512 def to_gen_config(self, gen_config):
513 param_config = gen_config["param_map"]
514 meta_config = param_config["META"]
515 iteration_plan = []
517 if isinstance(meta_config, dict):
518 for meta_key, source in meta_config.items():
519 iteration_plan.append(
520 {"key": meta_key, "source": source, "kind": "meta_field"}
521 )
522 else:
523 iteration_plan.append(
524 {"key": "META", "source": meta_config, "kind": "meta_block"}
525 )
527 for key, source in param_config.items():
528 if key == "META":
529 continue
530 iteration_plan.append(
531 {"key": key, "source": source, "kind": "config_field"}
532 )
534 current_config = {"META": {}}
535 current_config.update(self.triton_config_default)
536 return self._gen_impl(
537 gen_config,
538 iteration_plan,
539 current_config,
540 )
542 def get_expand_config(self, op_name, yaml_path=None):
543 op_spec = self.expand_config_registry.get(op_name)
544 if op_spec is None:
545 return -1
547 key = op_spec.get("key", [])
548 default_strategy = op_spec.get("default_strategy")
549 expand_yaml_path = yaml_path or op_spec.get("expand_yaml_path")
550 yaml_op_name = op_spec.get("yaml_op_name", op_name)
551 if not expand_yaml_path:
552 return -1
554 try:
555 expand_configs = backend.get_expand_config(
556 op_name=yaml_op_name,
557 file_path=expand_yaml_path,
558 )
559 if not isinstance(expand_configs, list):
560 return -1
562 gen_config = None
563 strategy_config = None
564 for single_config in expand_configs:
565 if isinstance(single_config, dict) and "param_map" in single_config:
566 gen_config = single_config
568 if isinstance(single_config, dict) and "strategy" in single_config:
569 strategy_config = single_config.get("strategy")
571 param_map = gen_config.get("param_map")
572 meta_map = param_map.get("META")
574 strategy = default_strategy
575 if isinstance(strategy_config, dict):
576 strategy = [
577 strategy_config.get(k, default_strategy[idx])
578 for idx, k in enumerate(key)
579 ]
581 ranges = {}
583 for mapped_key in meta_map.values():
584 ranges[mapped_key.upper()] = gen_config[mapped_key]
585 ranges["s"] = gen_config[param_map.get("num_stages")]
586 ranges["w"] = gen_config[param_map.get("num_warps")]
588 return {
589 "ranges": ranges,
590 "strategy": strategy,
591 }
592 except Exception:
593 return -1
595 def ops_get_configs(self, op_name, yaml_path=None, pre_hook=None):
596 expand_config = self.get_expand_config(op_name, yaml_path=yaml_path)
597 if expand_config == -1:
598 return []
599 ranges = expand_config["ranges"]
600 return self._build_configs_by_op(op_name, ranges, pre_hook=pre_hook)
602 def get_tuned_config(self, op_name):
603 if op_name in self.loaded_triton_config:
604 return self.loaded_triton_config[op_name]
606 current_op_configs = self._get_op_configs(op_name)
607 if not current_op_configs:
608 return []
610 configs = []
612 for single_config in current_op_configs:
613 if self.gen_key in single_config:
614 configs.extend(self.to_gen_config(single_config))
615 continue
617 current_config = copy.deepcopy(self.triton_config_default)
618 for default_param in current_config:
619 if default_param in single_config:
620 current_config[default_param] = single_config[default_param]
622 configs.append(self._create_triton_config(single_config, current_config))
623 return configs