Coverage for src/flag_gems/runtime/configloader.py: 62%
203 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import copy
2import warnings
4import triton
6from . import backend, common
7from .backend.device import DeviceDetector
10class ConfigLoader(object):
11 _instance = None
13 def __new__(cls, *args, **kargs):
14 if cls._instance is None:
15 cls._instance = super(ConfigLoader, cls).__new__(cls)
16 return cls._instance
18 def __init__(self):
19 if not hasattr(self, "initialized"):
20 self.initialized = True
21 self.device = DeviceDetector()
22 # primitive_yaml_config is simply the dictionary returned by yaml
23 # and is reserved from being an attr for vendor customizability
24 self.arch_specialized_yaml_config = None
25 self.arch_heuristics_config = None
26 self.vendor_primitive_yaml_config = self.get_vendor_tune_config()
27 self.default_primitive_yaml_config = self.get_default_tune_config()
28 self.vendor_heuristics_config = self.get_vendor_heuristics_config()
29 self.default_heuristics_config = self.get_default_heuristics_config()
30 self.update_config_from_arch()
32 if self.vendor_heuristics_config is None:
33 vendorname = self.device.vendor_name
34 warnings.warn(
35 f"The {vendorname} configuration of heuristics_config is None"
36 )
37 # gen_key is an identifier that indicates whether the current config needs to be generated automatically
38 self.gen_key = "gen"
39 # loaded_triton_config is wrapped in triton.Config according to primitive_yaml_config
40 self.loaded_triton_config = {}
41 self.triton_config_default = {
42 "num_stages": 2,
43 "num_warps": 4,
44 "num_ctas": 1,
45 }
46 if self.device.vendor_name == "hygon":
47 self.triton_config_default["num_ldmatrixes"] = 0
48 self.expand_config_registry = self._build_expand_registry()
49 self.load_all()
51 def update_config_from_arch(self):
52 try:
53 archEvent = backend.BackendArchEvent()
54 if archEvent.has_arch:
55 self.arch_specialized_yaml_config = archEvent.autotune_configs
56 self.arch_heuristics_config = archEvent.heuristics_configs
57 except Exception as err:
58 print(f"[INFO] : {err}")
60 def _get_op_configs(self, op_name):
61 """Get config for op_name from available config sources."""
62 for config in (
63 self.arch_specialized_yaml_config,
64 self.vendor_primitive_yaml_config,
65 self.default_primitive_yaml_config,
66 ):
67 if config and op_name in config:
68 return config[op_name]
69 return []
71 def _create_triton_config(self, single_config, current_config):
72 """Create a triton.Config with appropriate parameters."""
73 kwargs = {
74 "num_warps": current_config["num_warps"],
75 "num_stages": current_config["num_stages"],
76 "num_ctas": current_config["num_ctas"],
77 }
78 if self.device.vendor_name == "hygon":
79 kwargs["num_ldmatrixes"] = current_config["num_ldmatrixes"]
80 return triton.Config(single_config["META"], **kwargs)
82 def _build_configs_by_op(self, op_name, ranges, pre_hook=None):
83 if op_name == "bmm":
84 return [
85 triton.Config(
86 {
87 "TILE_M": block_m,
88 "TILE_N": block_n,
89 "TILE_K": block_k,
90 "GROUP_M": 1 if block_m == 32 else 2,
91 },
92 num_stages=s,
93 num_warps=w,
94 pre_hook=pre_hook,
95 )
96 for block_m in ranges["BLOCK_M"]
97 for block_n in ranges["BLOCK_N"]
98 for block_k in ranges["BLOCK_K"]
99 for s in ranges["s"]
100 for w in ranges["w"]
101 ]
103 if op_name == "addmm":
104 return [
105 triton.Config(
106 {
107 "BLOCK_SIZE_M": block_m,
108 "BLOCK_SIZE_N": block_n,
109 "BLOCK_SIZE_K": block_k,
110 },
111 num_stages=s,
112 num_warps=w,
113 pre_hook=pre_hook,
114 )
115 for block_m in ranges["BLOCK_M"]
116 for block_n in ranges["BLOCK_N"]
117 for block_k in ranges["BLOCK_K"]
118 for s in ranges["s"]
119 for w in ranges["w"]
120 ]
122 if op_name == "baddbmm":
123 return [
124 triton.Config(
125 {
126 "TILE_M": block_m,
127 "TILE_N": block_n,
128 "TILE_K": block_k,
129 "GROUP_M": 1 if block_m <= 32 else 2,
130 },
131 num_stages=s,
132 num_warps=w,
133 pre_hook=pre_hook,
134 )
135 for block_m in ranges["BLOCK_M"]
136 for block_n in ranges["BLOCK_N"]
137 for block_k in ranges["BLOCK_K"]
138 for s in ranges["s"]
139 for w in ranges["w"]
140 ]
142 if op_name == "mv":
143 return [
144 triton.Config(
145 {
146 "BLOCK_N": block_n,
147 "BLOCK_M": block_m,
148 },
149 num_stages=s,
150 num_warps=w,
151 pre_hook=pre_hook,
152 )
153 for block_n in ranges["BLOCK_N"]
154 for block_m in ranges["BLOCK_M"]
155 for s in ranges["s"]
156 for w in ranges["w"]
157 ]
159 if op_name == "mm_general_tma":
160 return [
161 triton.Config(
162 {
163 "BLOCK_M": block_m,
164 "BLOCK_N": block_n,
165 "BLOCK_K": block_k,
166 },
167 num_stages=s,
168 num_warps=w,
169 pre_hook=pre_hook,
170 )
171 for block_m in ranges["BLOCK_M"]
172 for block_n in ranges["BLOCK_N"]
173 for block_k in ranges["BLOCK_K"]
174 for s in ranges["s"]
175 for w in ranges["w"]
176 ]
178 if op_name in ("mm", "mm_sqmma"):
179 return [
180 triton.Config(
181 {
182 "BLOCK_M": block_m,
183 "BLOCK_N": block_n,
184 "BLOCK_K": block_k,
185 },
186 num_stages=s,
187 num_warps=w,
188 pre_hook=pre_hook,
189 )
190 for block_m in ranges["BLOCK_M"]
191 for block_n in ranges["BLOCK_N"]
192 for block_k in ranges["BLOCK_K"]
193 for s in ranges["s"]
194 for w in ranges["w"]
195 ]
197 if op_name in ("bmm_sqmma", "addmm_sqmma"):
198 return [
199 triton.Config(
200 {
201 "BLOCK_SIZE_M": block_m,
202 "BLOCK_SIZE_N": block_n,
203 "BLOCK_SIZE_K": block_k,
204 },
205 num_stages=s,
206 num_warps=w,
207 pre_hook=pre_hook,
208 )
209 for block_m in ranges["BLOCK_M"]
210 for block_n in ranges["BLOCK_N"]
211 for block_k in ranges["BLOCK_K"]
212 for s in ranges["s"]
213 for w in ranges["w"]
214 ]
216 if op_name == "gemv":
217 return [
218 triton.Config(
219 {"BLOCK_M": block_m, "BLOCK_K": block_k},
220 num_stages=s,
221 num_warps=w,
222 pre_hook=pre_hook,
223 )
224 for block_m in ranges["BLOCK_M"]
225 for block_k in ranges["BLOCK_K"]
226 for s in ranges["s"]
227 for w in ranges["w"]
228 ]
230 if op_name == "sparse_attention":
231 return [
232 triton.Config(
233 {"BLOCK": block},
234 num_stages=s,
235 num_warps=w,
236 pre_hook=pre_hook,
237 )
238 for block in ranges["BLOCK"]
239 for s in ranges["s"]
240 for w in ranges["w"]
241 ]
243 if op_name == "w8a8_block_fp8_general":
244 return [
245 triton.Config(
246 {
247 "BLOCK_M": block_m,
248 "BLOCK_N": block_n,
249 "BLOCK_K": block_k,
250 "GROUP_M": group_m,
251 },
252 num_stages=s,
253 num_warps=w,
254 pre_hook=pre_hook,
255 )
256 for block_m in ranges["BLOCK_M"]
257 for block_n in ranges["BLOCK_N"]
258 for block_k in ranges["BLOCK_K"]
259 for group_m in ranges["GROUP_M"]
260 for s in ranges["s"]
261 for w in ranges["w"]
262 ]
264 if op_name == "w8a8_block_fp8_general_tma":
265 group_m_values = ranges.get("GROUP_M", [None])
266 return [
267 triton.Config(
268 dict(
269 {
270 "BLOCK_M": block_m,
271 "BLOCK_N": block_n,
272 "BLOCK_K": block_k,
273 },
274 **({} if group_m is None else {"GROUP_M": group_m}),
275 ),
276 num_stages=s,
277 num_warps=w,
278 pre_hook=pre_hook,
279 )
280 for block_m in ranges["BLOCK_M"]
281 for block_n in ranges["BLOCK_N"]
282 for block_k in ranges["BLOCK_K"]
283 for group_m in group_m_values
284 for s in ranges["s"]
285 for w in ranges["w"]
286 ]
288 if op_name == "w8a8_block_fp8_general_splitk":
289 return [
290 triton.Config(
291 {
292 "BLOCK_M": block_m,
293 "BLOCK_N": block_n,
294 "BLOCK_K": block_k,
295 "SPLIT_K": split_k,
296 },
297 num_stages=s,
298 num_warps=w,
299 pre_hook=pre_hook,
300 )
301 for block_m in ranges["BLOCK_M"]
302 for block_n in ranges["BLOCK_N"]
303 for block_k in ranges["BLOCK_K"]
304 for split_k in ranges["SPLIT_K"]
305 for s in ranges["s"]
306 for w in ranges["w"]
307 ]
309 if op_name == "mm_splitk":
310 return [
311 triton.Config(
312 {
313 "BLOCK_M": block_m,
314 "BLOCK_N": block_n,
315 "BLOCK_K": block_k,
316 "SPLIT_K": split_k,
317 },
318 num_stages=s,
319 num_warps=w,
320 pre_hook=pre_hook,
321 )
322 for block_m in ranges["BLOCK_M"]
323 for block_n in ranges["BLOCK_N"]
324 for block_k in ranges["BLOCK_K"]
325 for split_k in ranges["SPLIT_K"]
326 for s in ranges["s"]
327 for w in ranges["w"]
328 ]
330 return []
332 def _build_single_expand_spec(
333 self,
334 op_name,
335 expand_yaml_path=None,
336 yaml_op_name=None,
337 ):
338 return {
339 "yaml_op_name": yaml_op_name or op_name,
340 "key": common.OP_KEY_ORDERS[op_name],
341 "default_strategy": common.DEFAULT_STRATEGIES[op_name],
342 "expand_yaml_path": expand_yaml_path,
343 }
345 def _build_expand_registry(self):
346 DEFAULT_EXPAND_CONFIG_PATH = common.DEFAULT_EXPAND_CONFIG_PATH
347 return {
348 "addmm": self._build_single_expand_spec(
349 "addmm", expand_yaml_path=DEFAULT_EXPAND_CONFIG_PATH
350 ),
351 "addmm_sqmma": self._build_single_expand_spec("addmm_sqmma"),
352 "baddbmm": self._build_single_expand_spec(
353 "baddbmm", expand_yaml_path=DEFAULT_EXPAND_CONFIG_PATH
354 ),
355 "bmm": self._build_single_expand_spec(
356 "bmm", expand_yaml_path=DEFAULT_EXPAND_CONFIG_PATH
357 ),
358 "bmm_sqmma": self._build_single_expand_spec("bmm_sqmma"),
359 "gemv": self._build_single_expand_spec("gemv"),
360 "mm": self._build_single_expand_spec("mm"),
361 "mm_general_tma": self._build_single_expand_spec("mm_general_tma"),
362 "mv": self._build_single_expand_spec(
363 "mv", expand_yaml_path=DEFAULT_EXPAND_CONFIG_PATH
364 ),
365 "w8a8_block_fp8_general": self._build_single_expand_spec(
366 "w8a8_block_fp8_general"
367 ),
368 "w8a8_block_fp8_general_splitk": self._build_single_expand_spec(
369 "w8a8_block_fp8_general_splitk"
370 ),
371 "w8a8_block_fp8_general_tma": self._build_single_expand_spec(
372 "w8a8_block_fp8_general_tma"
373 ),
374 "mm_splitk": self._build_single_expand_spec("mm_splitk"),
375 "sparse_attention": self._build_single_expand_spec("sparse_attention"),
376 }
378 def load_all(self):
379 for key in self.vendor_primitive_yaml_config:
380 self.loaded_triton_config[key] = self.get_tuned_config(key)
382 def get_vendor_heuristics_config(self):
383 return backend.get_heuristic_config(self.device.vendor_name)
385 def get_default_heuristics_config(self):
386 return backend.get_heuristic_config("nvidia")
388 def get_default_tune_config(self):
389 return backend.get_tune_config("nvidia")
391 def get_vendor_tune_config(self):
392 return backend.get_tune_config(self.device.vendor_name)
394 def get_heuristics_config(self, op_name):
395 if self.arch_heuristics_config and op_name in self.arch_heuristics_config:
396 return self.arch_heuristics_config[op_name]
397 elif op_name in self.vendor_heuristics_config:
398 return self.vendor_heuristics_config[op_name]
399 elif op_name in self.default_heuristics_config:
400 return self.default_heuristics_config[op_name]
401 else:
402 warnings.warn(f"No heuristics config found for {op_name}")
403 return None
405 def _resolve_iteration_values(self, gen_config, config_var_key):
406 if isinstance(config_var_key, (list, tuple)):
407 return config_var_key
408 if isinstance(config_var_key, int):
409 return [config_var_key]
410 return gen_config[config_var_key]
412 def _gen_impl(
413 self,
414 gen_config,
415 iteration_plan,
416 std_config,
417 ):
418 all_configs = []
419 final_step = len(iteration_plan)
420 stack = [{"cur_config": std_config, "current_step": 0}]
422 while stack:
423 cur_state = stack[-1]
424 stack.pop()
425 cur_config = cur_state.get("cur_config")
426 current_step = cur_state.get("current_step")
428 if current_step == final_step:
429 all_configs.append(
430 triton.Config(
431 cur_config["META"],
432 num_warps=cur_config["num_warps"],
433 num_stages=cur_config["num_stages"],
434 num_ctas=cur_config["num_ctas"],
435 )
436 )
437 else:
438 cur_entry = iteration_plan[current_step]
439 cur_key = cur_entry["key"]
440 key_config = self._resolve_iteration_values(
441 gen_config, cur_entry["source"]
442 )
443 for single_value in key_config:
444 new_config = copy.deepcopy(cur_config)
445 if cur_entry["kind"] == "meta_field":
446 new_config["META"][cur_key] = single_value
447 elif cur_entry["kind"] == "meta_block":
448 new_config["META"] = copy.deepcopy(single_value)
449 else:
450 new_config[cur_key] = single_value
451 stack.append(
452 {
453 "cur_config": new_config,
454 "current_step": current_step + 1,
455 }
456 )
457 return all_configs
459 def to_gen_config(self, gen_config):
460 param_config = gen_config["param_map"]
461 meta_config = param_config["META"]
462 iteration_plan = []
464 if isinstance(meta_config, dict):
465 for meta_key, source in meta_config.items():
466 iteration_plan.append(
467 {"key": meta_key, "source": source, "kind": "meta_field"}
468 )
469 else:
470 iteration_plan.append(
471 {"key": "META", "source": meta_config, "kind": "meta_block"}
472 )
474 for key, source in param_config.items():
475 if key == "META":
476 continue
477 iteration_plan.append(
478 {"key": key, "source": source, "kind": "config_field"}
479 )
481 current_config = {"META": {}}
482 current_config.update(self.triton_config_default)
483 return self._gen_impl(
484 gen_config,
485 iteration_plan,
486 current_config,
487 )
489 def get_expand_config(self, op_name, yaml_path=None):
490 op_spec = self.expand_config_registry.get(op_name)
491 if op_spec is None:
492 return -1
494 key = op_spec.get("key", [])
495 default_strategy = op_spec.get("default_strategy")
496 expand_yaml_path = op_spec.get("expand_yaml_path") or yaml_path
497 yaml_op_name = op_spec.get("yaml_op_name", op_name)
499 try:
500 expand_configs = backend.get_expand_config(
501 op_name=yaml_op_name,
502 file_path=expand_yaml_path,
503 )
504 if not isinstance(expand_configs, list):
505 return -1
507 gen_config = None
508 strategy_config = None
509 for single_config in expand_configs:
510 if isinstance(single_config, dict) and "param_map" in single_config:
511 gen_config = single_config
513 if isinstance(single_config, dict) and "strategy" in single_config:
514 strategy_config = single_config.get("strategy")
516 param_map = gen_config.get("param_map")
517 meta_map = param_map.get("META")
519 strategy = default_strategy
520 if isinstance(strategy_config, dict):
521 strategy = [
522 strategy_config.get(k, default_strategy[idx])
523 for idx, k in enumerate(key)
524 ]
526 ranges = {}
528 for mapped_key in meta_map.values():
529 ranges[mapped_key.upper()] = gen_config[mapped_key]
530 ranges["s"] = gen_config[param_map.get("num_stages")]
531 ranges["w"] = gen_config[param_map.get("num_warps")]
533 return {
534 "ranges": ranges,
535 "strategy": strategy,
536 }
537 except Exception:
538 return -1
540 def ops_get_configs(self, op_name, yaml_path=None, pre_hook=None):
541 expand_config = self.get_expand_config(op_name, yaml_path=yaml_path)
542 if expand_config == -1:
543 return []
544 ranges = expand_config["ranges"]
545 return self._build_configs_by_op(op_name, ranges, pre_hook=pre_hook)
547 def get_tuned_config(self, op_name):
548 if op_name in self.loaded_triton_config:
549 return self.loaded_triton_config[op_name]
551 current_op_configs = self._get_op_configs(op_name)
552 if not current_op_configs:
553 return []
555 configs = []
557 for single_config in current_op_configs:
558 if self.gen_key in single_config:
559 configs.extend(self.to_gen_config(single_config))
560 continue
562 current_config = copy.deepcopy(self.triton_config_default)
563 for default_param in current_config:
564 if default_param in single_config:
565 current_config[default_param] = single_config[default_param]
567 configs.append(self._create_triton_config(single_config, current_config))
568 return configs