Coverage for src/flag_gems/runtime/configs_loader.py: 65%

236 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import copy 

2import inspect 

3import os 

4import warnings 

5 

6import triton 

7 

8from . import backend, common 

9from .backend.device import DeviceDetector 

10 

11 

12class TunedConfigLoader(object): 

13 _instance = None 

14 

15 def __new__(cls, *args, **kargs): 

16 if cls._instance is None: 

17 cls._instance = super(TunedConfigLoader, cls).__new__(cls) 

18 return cls._instance 

19 

20 def __init__(self): 

21 if not hasattr(self, "initialized"): 

22 self.initialized = True 

23 self.device = DeviceDetector() 

24 # primitive_yaml_config is simply the dictionary returned by yaml 

25 # and is reserved from being an attr for vendor customizability 

26 self.arch_specialized_yaml_config = None 

27 self.arch_heuristics_config = None 

28 self.vendor_primitive_yaml_config = self.get_vendor_tune_config() 

29 self.default_primitive_yaml_config = self.get_default_tune_config() 

30 self.vendor_heuristics_config = self.get_vendor_heuristics_config() 

31 self.default_heuristics_config = self.get_default_heuristics_config() 

32 self.update_config_from_arch() 

33 

34 if self.vendor_heuristics_config is None: 

35 vendorname = self.device.vendor_name 

36 warnings.warn( 

37 f"The {vendorname} configuration of heuristics_config is None" 

38 ) 

39 # gen_key is an identifier that indicates whether the current config needs to be generated automatically 

40 self.gen_key = "gen" 

41 # loaded_triton_config is wrapped in triton.Config according to primitive_yaml_config 

42 self.loaded_triton_config = {} 

43 self.triton_config_default = { 

44 "num_stages": 2, 

45 "num_warps": 4, 

46 "num_ctas": 1, 

47 } 

48 if self.device.vendor_name == "hygon": 

49 self.triton_config_default["num_ldmatrixes"] = 0 

50 self.expand_config_registry = self._build_expand_registry() 

51 self.load_all() 

52 

53 def update_config_from_arch(self): 

54 try: 

55 archEvent = backend.BackendArchEvent() 

56 if archEvent.has_arch: 

57 self.arch_specialized_yaml_config = archEvent.autotune_configs 

58 self.arch_heuristics_config = archEvent.heuristics_configs 

59 except Exception as err: 

60 print(f"[INFO] : {err}") 

61 

62 def _get_op_configs(self, op_name): 

63 """Get config for op_name from available config sources.""" 

64 for config in ( 

65 self.arch_specialized_yaml_config, 

66 self.vendor_primitive_yaml_config, 

67 self.default_primitive_yaml_config, 

68 ): 

69 if config and op_name in config: 

70 return config[op_name] 

71 return [] 

72 

73 def _create_triton_config(self, single_config, current_config): 

74 """Create a triton.Config with appropriate parameters.""" 

75 kwargs = { 

76 "num_warps": current_config["num_warps"], 

77 "num_stages": current_config["num_stages"], 

78 "num_ctas": current_config["num_ctas"], 

79 } 

80 if ( 

81 self.device.vendor_name == "hygon" 

82 and "num_ldmatrixes" in inspect.signature(triton.Config).parameters 

83 ): 

84 kwargs["num_ldmatrixes"] = current_config["num_ldmatrixes"] 

85 return triton.Config(single_config["META"], **kwargs) 

86 

87 def _build_configs_by_op(self, op_name, ranges, pre_hook=None): 

88 if op_name == "bmm": 

89 return [ 

90 triton.Config( 

91 { 

92 "TILE_M": block_m, 

93 "TILE_N": block_n, 

94 "TILE_K": block_k, 

95 "GROUP_M": 1 if block_m == 32 else 2, 

96 }, 

97 num_stages=s, 

98 num_warps=w, 

99 pre_hook=pre_hook, 

100 ) 

101 for block_m in ranges["BLOCK_M"] 

102 for block_n in ranges["BLOCK_N"] 

103 for block_k in ranges["BLOCK_K"] 

104 for s in ranges["s"] 

105 for w in ranges["w"] 

106 ] 

107 

108 if op_name == "addmm": 

109 return [ 

110 triton.Config( 

111 { 

112 "BLOCK_SIZE_M": block_m, 

113 "BLOCK_SIZE_N": block_n, 

114 "BLOCK_SIZE_K": block_k, 

115 }, 

116 num_stages=s, 

117 num_warps=w, 

118 pre_hook=pre_hook, 

119 ) 

120 for block_m in ranges["BLOCK_M"] 

121 for block_n in ranges["BLOCK_N"] 

122 for block_k in ranges["BLOCK_K"] 

123 for s in ranges["s"] 

124 for w in ranges["w"] 

125 ] 

126 

127 if op_name == "baddbmm": 

128 return [ 

129 triton.Config( 

130 { 

131 "TILE_M": block_m, 

132 "TILE_N": block_n, 

133 "TILE_K": block_k, 

134 "GROUP_M": 1 if block_m <= 32 else 2, 

135 }, 

136 num_stages=s, 

137 num_warps=w, 

138 pre_hook=pre_hook, 

139 ) 

140 for block_m in ranges["BLOCK_M"] 

141 for block_n in ranges["BLOCK_N"] 

142 for block_k in ranges["BLOCK_K"] 

143 for s in ranges["s"] 

144 for w in ranges["w"] 

145 ] 

146 

147 if op_name == "mv": 

148 return [ 

149 triton.Config( 

150 { 

151 "BLOCK_N": block_n, 

152 "BLOCK_M": block_m, 

153 }, 

154 num_stages=s, 

155 num_warps=w, 

156 pre_hook=pre_hook, 

157 ) 

158 for block_n in ranges["BLOCK_N"] 

159 for block_m in ranges["BLOCK_M"] 

160 for s in ranges["s"] 

161 for w in ranges["w"] 

162 ] 

163 

164 if op_name == "mm_general_tma": 

165 return [ 

166 triton.Config( 

167 { 

168 "BLOCK_M": block_m, 

169 "BLOCK_N": block_n, 

170 "BLOCK_K": block_k, 

171 }, 

172 num_stages=s, 

173 num_warps=w, 

174 pre_hook=pre_hook, 

175 ) 

176 for block_m in ranges["BLOCK_M"] 

177 for block_n in ranges["BLOCK_N"] 

178 for block_k in ranges["BLOCK_K"] 

179 for s in ranges["s"] 

180 for w in ranges["w"] 

181 ] 

182 

183 if op_name in ("mm", "mm_sqmma"): 

184 return [ 

185 triton.Config( 

186 { 

187 "BLOCK_M": block_m, 

188 "BLOCK_N": block_n, 

189 "BLOCK_K": block_k, 

190 }, 

191 num_stages=s, 

192 num_warps=w, 

193 pre_hook=pre_hook, 

194 ) 

195 for block_m in ranges["BLOCK_M"] 

196 for block_n in ranges["BLOCK_N"] 

197 for block_k in ranges["BLOCK_K"] 

198 for s in ranges["s"] 

199 for w in ranges["w"] 

200 ] 

201 

202 if op_name in ("bmm_sqmma", "addmm_sqmma"): 

203 return [ 

204 triton.Config( 

205 { 

206 "BLOCK_SIZE_M": block_m, 

207 "BLOCK_SIZE_N": block_n, 

208 "BLOCK_SIZE_K": block_k, 

209 }, 

210 num_stages=s, 

211 num_warps=w, 

212 pre_hook=pre_hook, 

213 ) 

214 for block_m in ranges["BLOCK_M"] 

215 for block_n in ranges["BLOCK_N"] 

216 for block_k in ranges["BLOCK_K"] 

217 for s in ranges["s"] 

218 for w in ranges["w"] 

219 ] 

220 

221 if op_name == "gemv": 

222 return [ 

223 triton.Config( 

224 {"BLOCK_M": block_m, "BLOCK_K": block_k}, 

225 num_stages=s, 

226 num_warps=w, 

227 pre_hook=pre_hook, 

228 ) 

229 for block_m in ranges["BLOCK_M"] 

230 for block_k in ranges["BLOCK_K"] 

231 for s in ranges["s"] 

232 for w in ranges["w"] 

233 ] 

234 

235 if op_name == "sparse_attention": 

236 return [ 

237 triton.Config( 

238 {"BLOCK": block}, 

239 num_stages=s, 

240 num_warps=w, 

241 pre_hook=pre_hook, 

242 ) 

243 for block in ranges["BLOCK"] 

244 for s in ranges["s"] 

245 for w in ranges["w"] 

246 ] 

247 

248 if op_name == "w8a8_block_fp8_general": 

249 return [ 

250 triton.Config( 

251 { 

252 "BLOCK_M": block_m, 

253 "BLOCK_N": block_n, 

254 "BLOCK_K": block_k, 

255 "GROUP_M": group_m, 

256 }, 

257 num_stages=s, 

258 num_warps=w, 

259 pre_hook=pre_hook, 

260 ) 

261 for block_m in ranges["BLOCK_M"] 

262 for block_n in ranges["BLOCK_N"] 

263 for block_k in ranges["BLOCK_K"] 

264 for group_m in ranges["GROUP_M"] 

265 for s in ranges["s"] 

266 for w in ranges["w"] 

267 ] 

268 

269 if op_name == "w8a8_block_fp8_general_tma": 

270 group_m_values = ranges.get("GROUP_M", [None]) 

271 return [ 

272 triton.Config( 

273 dict( 

274 { 

275 "BLOCK_M": block_m, 

276 "BLOCK_N": block_n, 

277 "BLOCK_K": block_k, 

278 }, 

279 **({} if group_m is None else {"GROUP_M": group_m}), 

280 ), 

281 num_stages=s, 

282 num_warps=w, 

283 pre_hook=pre_hook, 

284 ) 

285 for block_m in ranges["BLOCK_M"] 

286 for block_n in ranges["BLOCK_N"] 

287 for block_k in ranges["BLOCK_K"] 

288 for group_m in group_m_values 

289 for s in ranges["s"] 

290 for w in ranges["w"] 

291 ] 

292 

293 if op_name == "w8a8_block_fp8_general_splitk": 

294 return [ 

295 triton.Config( 

296 { 

297 "BLOCK_M": block_m, 

298 "BLOCK_N": block_n, 

299 "BLOCK_K": block_k, 

300 "SPLIT_K": split_k, 

301 }, 

302 num_stages=s, 

303 num_warps=w, 

304 pre_hook=pre_hook, 

305 ) 

306 for block_m in ranges["BLOCK_M"] 

307 for block_n in ranges["BLOCK_N"] 

308 for block_k in ranges["BLOCK_K"] 

309 for split_k in ranges["SPLIT_K"] 

310 for s in ranges["s"] 

311 for w in ranges["w"] 

312 ] 

313 

314 if op_name == "mm_splitk": 

315 return [ 

316 triton.Config( 

317 { 

318 "BLOCK_M": block_m, 

319 "BLOCK_N": block_n, 

320 "BLOCK_K": block_k, 

321 "SPLIT_K": split_k, 

322 }, 

323 num_stages=s, 

324 num_warps=w, 

325 pre_hook=pre_hook, 

326 ) 

327 for block_m in ranges["BLOCK_M"] 

328 for block_n in ranges["BLOCK_N"] 

329 for block_k in ranges["BLOCK_K"] 

330 for split_k in ranges["SPLIT_K"] 

331 for s in ranges["s"] 

332 for w in ranges["w"] 

333 ] 

334 

335 return [] 

336 

337 def _build_single_expand_spec( 

338 self, 

339 op_name, 

340 expand_yaml_path=None, 

341 yaml_op_name=None, 

342 ): 

343 return { 

344 "yaml_op_name": yaml_op_name or op_name, 

345 "key": common.OP_KEY_ORDERS[op_name], 

346 "default_strategy": common.DEFAULT_STRATEGIES[op_name], 

347 "expand_yaml_path": expand_yaml_path, 

348 } 

349 

350 def _iter_expand_config_candidates(self, op_name): 

351 vendor_name = self.device.vendor_name 

352 contexts = [] 

353 try: 

354 arch_event = backend.BackendArchEvent() 

355 current_arch_path = getattr(arch_event, "current_arch_path", None) 

356 arch_name = getattr(arch_event, "arch", None) 

357 if arch_event.has_arch and current_arch_path: 

358 contexts.append((current_arch_path, arch_name)) 

359 except Exception: 

360 pass 

361 

362 backend_dir = os.path.join(os.path.dirname(__file__), "backend") 

363 contexts.append((os.path.join(backend_dir, f"_{vendor_name}"), vendor_name)) 

364 

365 seen = set() 

366 for base_dir, backend_name in contexts: 

367 filenames = [] 

368 if op_name: 

369 filenames.extend( 

370 ( 

371 f"{op_name}_{backend_name}_expand.yaml", 

372 f"{op_name}_{vendor_name}_expand.yaml", 

373 f"{op_name}_expand.yaml", 

374 ) 

375 ) 

376 filenames.extend( 

377 ( 

378 f"general_ops_{backend_name}_configs.yaml", 

379 f"general_ops_{vendor_name}_configs.yaml", 

380 "general_ops_configs.yaml", 

381 ) 

382 ) 

383 

384 for filename in filenames: 

385 path = os.path.normpath(os.path.join(base_dir, filename)) 

386 if path in seen: 

387 continue 

388 seen.add(path) 

389 yield path 

390 

391 def _get_expand_config_path(self, op_name): 

392 for path in self._iter_expand_config_candidates(op_name): 

393 if os.path.exists(path): 

394 return path 

395 return None 

396 

397 def _build_expand_registry(self): 

398 return { 

399 "addmm": self._build_single_expand_spec( 

400 "addmm", expand_yaml_path=self._get_expand_config_path("addmm") 

401 ), 

402 "addmm_sqmma": self._build_single_expand_spec("addmm_sqmma"), 

403 "baddbmm": self._build_single_expand_spec( 

404 "baddbmm", expand_yaml_path=self._get_expand_config_path("baddbmm") 

405 ), 

406 "bmm": self._build_single_expand_spec( 

407 "bmm", expand_yaml_path=self._get_expand_config_path("bmm") 

408 ), 

409 "bmm_sqmma": self._build_single_expand_spec("bmm_sqmma"), 

410 "gemv": self._build_single_expand_spec("gemv"), 

411 "mm": self._build_single_expand_spec( 

412 "mm", expand_yaml_path=self._get_expand_config_path("mm") 

413 ), 

414 "mm_general_tma": self._build_single_expand_spec("mm_general_tma"), 

415 "mv": self._build_single_expand_spec( 

416 "mv", expand_yaml_path=self._get_expand_config_path("mv") 

417 ), 

418 "w8a8_block_fp8_general": self._build_single_expand_spec( 

419 "w8a8_block_fp8_general" 

420 ), 

421 "w8a8_block_fp8_general_splitk": self._build_single_expand_spec( 

422 "w8a8_block_fp8_general_splitk" 

423 ), 

424 "w8a8_block_fp8_general_tma": self._build_single_expand_spec( 

425 "w8a8_block_fp8_general_tma" 

426 ), 

427 "mm_splitk": self._build_single_expand_spec("mm_splitk"), 

428 "sparse_attention": self._build_single_expand_spec("sparse_attention"), 

429 } 

430 

431 def load_all(self): 

432 for key in self.vendor_primitive_yaml_config: 

433 self.loaded_triton_config[key] = self.get_tuned_config(key) 

434 

435 def get_vendor_heuristics_config(self): 

436 return backend.get_heuristic_config(self.device.vendor_name) 

437 

438 def get_default_heuristics_config(self): 

439 return backend.get_heuristic_config("nvidia") 

440 

441 def get_default_tune_config(self): 

442 return backend.get_tune_config("nvidia") 

443 

444 def get_vendor_tune_config(self): 

445 return backend.get_tune_config(self.device.vendor_name) 

446 

447 def get_heuristics_config(self, op_name): 

448 if self.arch_heuristics_config and op_name in self.arch_heuristics_config: 

449 return self.arch_heuristics_config[op_name] 

450 elif op_name in self.vendor_heuristics_config: 

451 return self.vendor_heuristics_config[op_name] 

452 elif op_name in self.default_heuristics_config: 

453 return self.default_heuristics_config[op_name] 

454 else: 

455 warnings.warn(f"No heuristics config found for {op_name}") 

456 return None 

457 

458 def _resolve_iteration_values(self, gen_config, config_var_key): 

459 if isinstance(config_var_key, (list, tuple)): 

460 return config_var_key 

461 if isinstance(config_var_key, int): 

462 return [config_var_key] 

463 return gen_config[config_var_key] 

464 

465 def _gen_impl( 

466 self, 

467 gen_config, 

468 iteration_plan, 

469 std_config, 

470 ): 

471 all_configs = [] 

472 final_step = len(iteration_plan) 

473 stack = [{"cur_config": std_config, "current_step": 0}] 

474 

475 while stack: 

476 cur_state = stack[-1] 

477 stack.pop() 

478 cur_config = cur_state.get("cur_config") 

479 current_step = cur_state.get("current_step") 

480 

481 if current_step == final_step: 

482 all_configs.append( 

483 triton.Config( 

484 cur_config["META"], 

485 num_warps=cur_config["num_warps"], 

486 num_stages=cur_config["num_stages"], 

487 num_ctas=cur_config["num_ctas"], 

488 ) 

489 ) 

490 else: 

491 cur_entry = iteration_plan[current_step] 

492 cur_key = cur_entry["key"] 

493 key_config = self._resolve_iteration_values( 

494 gen_config, cur_entry["source"] 

495 ) 

496 for single_value in key_config: 

497 new_config = copy.deepcopy(cur_config) 

498 if cur_entry["kind"] == "meta_field": 

499 new_config["META"][cur_key] = single_value 

500 elif cur_entry["kind"] == "meta_block": 

501 new_config["META"] = copy.deepcopy(single_value) 

502 else: 

503 new_config[cur_key] = single_value 

504 stack.append( 

505 { 

506 "cur_config": new_config, 

507 "current_step": current_step + 1, 

508 } 

509 ) 

510 return all_configs 

511 

512 def to_gen_config(self, gen_config): 

513 param_config = gen_config["param_map"] 

514 meta_config = param_config["META"] 

515 iteration_plan = [] 

516 

517 if isinstance(meta_config, dict): 

518 for meta_key, source in meta_config.items(): 

519 iteration_plan.append( 

520 {"key": meta_key, "source": source, "kind": "meta_field"} 

521 ) 

522 else: 

523 iteration_plan.append( 

524 {"key": "META", "source": meta_config, "kind": "meta_block"} 

525 ) 

526 

527 for key, source in param_config.items(): 

528 if key == "META": 

529 continue 

530 iteration_plan.append( 

531 {"key": key, "source": source, "kind": "config_field"} 

532 ) 

533 

534 current_config = {"META": {}} 

535 current_config.update(self.triton_config_default) 

536 return self._gen_impl( 

537 gen_config, 

538 iteration_plan, 

539 current_config, 

540 ) 

541 

542 def get_expand_config(self, op_name, yaml_path=None): 

543 op_spec = self.expand_config_registry.get(op_name) 

544 if op_spec is None: 

545 return -1 

546 

547 key = op_spec.get("key", []) 

548 default_strategy = op_spec.get("default_strategy") 

549 expand_yaml_path = yaml_path or op_spec.get("expand_yaml_path") 

550 yaml_op_name = op_spec.get("yaml_op_name", op_name) 

551 if not expand_yaml_path: 

552 return -1 

553 

554 try: 

555 expand_configs = backend.get_expand_config( 

556 op_name=yaml_op_name, 

557 file_path=expand_yaml_path, 

558 ) 

559 if not isinstance(expand_configs, list): 

560 return -1 

561 

562 gen_config = None 

563 strategy_config = None 

564 for single_config in expand_configs: 

565 if isinstance(single_config, dict) and "param_map" in single_config: 

566 gen_config = single_config 

567 

568 if isinstance(single_config, dict) and "strategy" in single_config: 

569 strategy_config = single_config.get("strategy") 

570 

571 param_map = gen_config.get("param_map") 

572 meta_map = param_map.get("META") 

573 

574 strategy = default_strategy 

575 if isinstance(strategy_config, dict): 

576 strategy = [ 

577 strategy_config.get(k, default_strategy[idx]) 

578 for idx, k in enumerate(key) 

579 ] 

580 

581 ranges = {} 

582 

583 for mapped_key in meta_map.values(): 

584 ranges[mapped_key.upper()] = gen_config[mapped_key] 

585 ranges["s"] = gen_config[param_map.get("num_stages")] 

586 ranges["w"] = gen_config[param_map.get("num_warps")] 

587 

588 return { 

589 "ranges": ranges, 

590 "strategy": strategy, 

591 } 

592 except Exception: 

593 return -1 

594 

595 def ops_get_configs(self, op_name, yaml_path=None, pre_hook=None): 

596 expand_config = self.get_expand_config(op_name, yaml_path=yaml_path) 

597 if expand_config == -1: 

598 return [] 

599 ranges = expand_config["ranges"] 

600 return self._build_configs_by_op(op_name, ranges, pre_hook=pre_hook) 

601 

602 def get_tuned_config(self, op_name): 

603 if op_name in self.loaded_triton_config: 

604 return self.loaded_triton_config[op_name] 

605 

606 current_op_configs = self._get_op_configs(op_name) 

607 if not current_op_configs: 

608 return [] 

609 

610 configs = [] 

611 

612 for single_config in current_op_configs: 

613 if self.gen_key in single_config: 

614 configs.extend(self.to_gen_config(single_config)) 

615 continue 

616 

617 current_config = copy.deepcopy(self.triton_config_default) 

618 for default_param in current_config: 

619 if default_param in single_config: 

620 current_config[default_param] = single_config[default_param] 

621 

622 configs.append(self._create_triton_config(single_config, current_config)) 

623 return configs