Coverage for src/flag_gems/runtime/configloader.py: 62%

204 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import copy 

2import inspect 

3import warnings 

4 

5import triton 

6 

7from . import backend, common 

8from .backend.device import DeviceDetector 

9 

10 

11class ConfigLoader(object): 

12 _instance = None 

13 

14 def __new__(cls, *args, **kargs): 

15 if cls._instance is None: 

16 cls._instance = super(ConfigLoader, cls).__new__(cls) 

17 return cls._instance 

18 

19 def __init__(self): 

20 if not hasattr(self, "initialized"): 

21 self.initialized = True 

22 self.device = DeviceDetector() 

23 # primitive_yaml_config is simply the dictionary returned by yaml 

24 # and is reserved from being an attr for vendor customizability 

25 self.arch_specialized_yaml_config = None 

26 self.arch_heuristics_config = None 

27 self.vendor_primitive_yaml_config = self.get_vendor_tune_config() 

28 self.default_primitive_yaml_config = self.get_default_tune_config() 

29 self.vendor_heuristics_config = self.get_vendor_heuristics_config() 

30 self.default_heuristics_config = self.get_default_heuristics_config() 

31 self.update_config_from_arch() 

32 

33 if self.vendor_heuristics_config is None: 

34 vendorname = self.device.vendor_name 

35 warnings.warn( 

36 f"The {vendorname} configuration of heuristics_config is None" 

37 ) 

38 # gen_key is an identifier that indicates whether the current config needs to be generated automatically 

39 self.gen_key = "gen" 

40 # loaded_triton_config is wrapped in triton.Config according to primitive_yaml_config 

41 self.loaded_triton_config = {} 

42 self.triton_config_default = { 

43 "num_stages": 2, 

44 "num_warps": 4, 

45 "num_ctas": 1, 

46 } 

47 if self.device.vendor_name == "hygon": 

48 self.triton_config_default["num_ldmatrixes"] = 0 

49 self.expand_config_registry = self._build_expand_registry() 

50 self.load_all() 

51 

52 def update_config_from_arch(self): 

53 try: 

54 archEvent = backend.BackendArchEvent() 

55 if archEvent.has_arch: 

56 self.arch_specialized_yaml_config = archEvent.autotune_configs 

57 self.arch_heuristics_config = archEvent.heuristics_configs 

58 except Exception as err: 

59 print(f"[INFO] : {err}") 

60 

61 def _get_op_configs(self, op_name): 

62 """Get config for op_name from available config sources.""" 

63 for config in ( 

64 self.arch_specialized_yaml_config, 

65 self.vendor_primitive_yaml_config, 

66 self.default_primitive_yaml_config, 

67 ): 

68 if config and op_name in config: 

69 return config[op_name] 

70 return [] 

71 

72 def _create_triton_config(self, single_config, current_config): 

73 """Create a triton.Config with appropriate parameters.""" 

74 kwargs = { 

75 "num_warps": current_config["num_warps"], 

76 "num_stages": current_config["num_stages"], 

77 "num_ctas": current_config["num_ctas"], 

78 } 

79 if ( 

80 self.device.vendor_name == "hygon" 

81 and "num_ldmatrixes" in inspect.signature(triton.Config).parameters 

82 ): 

83 kwargs["num_ldmatrixes"] = current_config["num_ldmatrixes"] 

84 return triton.Config(single_config["META"], **kwargs) 

85 

86 def _build_configs_by_op(self, op_name, ranges, pre_hook=None): 

87 if op_name == "bmm": 

88 return [ 

89 triton.Config( 

90 { 

91 "TILE_M": block_m, 

92 "TILE_N": block_n, 

93 "TILE_K": block_k, 

94 "GROUP_M": 1 if block_m == 32 else 2, 

95 }, 

96 num_stages=s, 

97 num_warps=w, 

98 pre_hook=pre_hook, 

99 ) 

100 for block_m in ranges["BLOCK_M"] 

101 for block_n in ranges["BLOCK_N"] 

102 for block_k in ranges["BLOCK_K"] 

103 for s in ranges["s"] 

104 for w in ranges["w"] 

105 ] 

106 

107 if op_name == "addmm": 

108 return [ 

109 triton.Config( 

110 { 

111 "BLOCK_SIZE_M": block_m, 

112 "BLOCK_SIZE_N": block_n, 

113 "BLOCK_SIZE_K": block_k, 

114 }, 

115 num_stages=s, 

116 num_warps=w, 

117 pre_hook=pre_hook, 

118 ) 

119 for block_m in ranges["BLOCK_M"] 

120 for block_n in ranges["BLOCK_N"] 

121 for block_k in ranges["BLOCK_K"] 

122 for s in ranges["s"] 

123 for w in ranges["w"] 

124 ] 

125 

126 if op_name == "baddbmm": 

127 return [ 

128 triton.Config( 

129 { 

130 "TILE_M": block_m, 

131 "TILE_N": block_n, 

132 "TILE_K": block_k, 

133 "GROUP_M": 1 if block_m <= 32 else 2, 

134 }, 

135 num_stages=s, 

136 num_warps=w, 

137 pre_hook=pre_hook, 

138 ) 

139 for block_m in ranges["BLOCK_M"] 

140 for block_n in ranges["BLOCK_N"] 

141 for block_k in ranges["BLOCK_K"] 

142 for s in ranges["s"] 

143 for w in ranges["w"] 

144 ] 

145 

146 if op_name == "mv": 

147 return [ 

148 triton.Config( 

149 { 

150 "BLOCK_N": block_n, 

151 "BLOCK_M": block_m, 

152 }, 

153 num_stages=s, 

154 num_warps=w, 

155 pre_hook=pre_hook, 

156 ) 

157 for block_n in ranges["BLOCK_N"] 

158 for block_m in ranges["BLOCK_M"] 

159 for s in ranges["s"] 

160 for w in ranges["w"] 

161 ] 

162 

163 if op_name == "mm_general_tma": 

164 return [ 

165 triton.Config( 

166 { 

167 "BLOCK_M": block_m, 

168 "BLOCK_N": block_n, 

169 "BLOCK_K": block_k, 

170 }, 

171 num_stages=s, 

172 num_warps=w, 

173 pre_hook=pre_hook, 

174 ) 

175 for block_m in ranges["BLOCK_M"] 

176 for block_n in ranges["BLOCK_N"] 

177 for block_k in ranges["BLOCK_K"] 

178 for s in ranges["s"] 

179 for w in ranges["w"] 

180 ] 

181 

182 if op_name in ("mm", "mm_sqmma"): 

183 return [ 

184 triton.Config( 

185 { 

186 "BLOCK_M": block_m, 

187 "BLOCK_N": block_n, 

188 "BLOCK_K": block_k, 

189 }, 

190 num_stages=s, 

191 num_warps=w, 

192 pre_hook=pre_hook, 

193 ) 

194 for block_m in ranges["BLOCK_M"] 

195 for block_n in ranges["BLOCK_N"] 

196 for block_k in ranges["BLOCK_K"] 

197 for s in ranges["s"] 

198 for w in ranges["w"] 

199 ] 

200 

201 if op_name in ("bmm_sqmma", "addmm_sqmma"): 

202 return [ 

203 triton.Config( 

204 { 

205 "BLOCK_SIZE_M": block_m, 

206 "BLOCK_SIZE_N": block_n, 

207 "BLOCK_SIZE_K": block_k, 

208 }, 

209 num_stages=s, 

210 num_warps=w, 

211 pre_hook=pre_hook, 

212 ) 

213 for block_m in ranges["BLOCK_M"] 

214 for block_n in ranges["BLOCK_N"] 

215 for block_k in ranges["BLOCK_K"] 

216 for s in ranges["s"] 

217 for w in ranges["w"] 

218 ] 

219 

220 if op_name == "gemv": 

221 return [ 

222 triton.Config( 

223 {"BLOCK_M": block_m, "BLOCK_K": block_k}, 

224 num_stages=s, 

225 num_warps=w, 

226 pre_hook=pre_hook, 

227 ) 

228 for block_m in ranges["BLOCK_M"] 

229 for block_k in ranges["BLOCK_K"] 

230 for s in ranges["s"] 

231 for w in ranges["w"] 

232 ] 

233 

234 if op_name == "sparse_attention": 

235 return [ 

236 triton.Config( 

237 {"BLOCK": block}, 

238 num_stages=s, 

239 num_warps=w, 

240 pre_hook=pre_hook, 

241 ) 

242 for block in ranges["BLOCK"] 

243 for s in ranges["s"] 

244 for w in ranges["w"] 

245 ] 

246 

247 if op_name == "w8a8_block_fp8_general": 

248 return [ 

249 triton.Config( 

250 { 

251 "BLOCK_M": block_m, 

252 "BLOCK_N": block_n, 

253 "BLOCK_K": block_k, 

254 "GROUP_M": group_m, 

255 }, 

256 num_stages=s, 

257 num_warps=w, 

258 pre_hook=pre_hook, 

259 ) 

260 for block_m in ranges["BLOCK_M"] 

261 for block_n in ranges["BLOCK_N"] 

262 for block_k in ranges["BLOCK_K"] 

263 for group_m in ranges["GROUP_M"] 

264 for s in ranges["s"] 

265 for w in ranges["w"] 

266 ] 

267 

268 if op_name == "w8a8_block_fp8_general_tma": 

269 group_m_values = ranges.get("GROUP_M", [None]) 

270 return [ 

271 triton.Config( 

272 dict( 

273 { 

274 "BLOCK_M": block_m, 

275 "BLOCK_N": block_n, 

276 "BLOCK_K": block_k, 

277 }, 

278 **({} if group_m is None else {"GROUP_M": group_m}), 

279 ), 

280 num_stages=s, 

281 num_warps=w, 

282 pre_hook=pre_hook, 

283 ) 

284 for block_m in ranges["BLOCK_M"] 

285 for block_n in ranges["BLOCK_N"] 

286 for block_k in ranges["BLOCK_K"] 

287 for group_m in group_m_values 

288 for s in ranges["s"] 

289 for w in ranges["w"] 

290 ] 

291 

292 if op_name == "w8a8_block_fp8_general_splitk": 

293 return [ 

294 triton.Config( 

295 { 

296 "BLOCK_M": block_m, 

297 "BLOCK_N": block_n, 

298 "BLOCK_K": block_k, 

299 "SPLIT_K": split_k, 

300 }, 

301 num_stages=s, 

302 num_warps=w, 

303 pre_hook=pre_hook, 

304 ) 

305 for block_m in ranges["BLOCK_M"] 

306 for block_n in ranges["BLOCK_N"] 

307 for block_k in ranges["BLOCK_K"] 

308 for split_k in ranges["SPLIT_K"] 

309 for s in ranges["s"] 

310 for w in ranges["w"] 

311 ] 

312 

313 if op_name == "mm_splitk": 

314 return [ 

315 triton.Config( 

316 { 

317 "BLOCK_M": block_m, 

318 "BLOCK_N": block_n, 

319 "BLOCK_K": block_k, 

320 "SPLIT_K": split_k, 

321 }, 

322 num_stages=s, 

323 num_warps=w, 

324 pre_hook=pre_hook, 

325 ) 

326 for block_m in ranges["BLOCK_M"] 

327 for block_n in ranges["BLOCK_N"] 

328 for block_k in ranges["BLOCK_K"] 

329 for split_k in ranges["SPLIT_K"] 

330 for s in ranges["s"] 

331 for w in ranges["w"] 

332 ] 

333 

334 return [] 

335 

336 def _build_single_expand_spec( 

337 self, 

338 op_name, 

339 expand_yaml_path=None, 

340 yaml_op_name=None, 

341 ): 

342 return { 

343 "yaml_op_name": yaml_op_name or op_name, 

344 "key": common.OP_KEY_ORDERS[op_name], 

345 "default_strategy": common.DEFAULT_STRATEGIES[op_name], 

346 "expand_yaml_path": expand_yaml_path, 

347 } 

348 

349 def _build_expand_registry(self): 

350 DEFAULT_EXPAND_CONFIG_PATH = common.DEFAULT_EXPAND_CONFIG_PATH 

351 return { 

352 "addmm": self._build_single_expand_spec( 

353 "addmm", expand_yaml_path=DEFAULT_EXPAND_CONFIG_PATH 

354 ), 

355 "addmm_sqmma": self._build_single_expand_spec("addmm_sqmma"), 

356 "baddbmm": self._build_single_expand_spec( 

357 "baddbmm", expand_yaml_path=DEFAULT_EXPAND_CONFIG_PATH 

358 ), 

359 "bmm": self._build_single_expand_spec( 

360 "bmm", expand_yaml_path=DEFAULT_EXPAND_CONFIG_PATH 

361 ), 

362 "bmm_sqmma": self._build_single_expand_spec("bmm_sqmma"), 

363 "gemv": self._build_single_expand_spec("gemv"), 

364 "mm": self._build_single_expand_spec("mm"), 

365 "mm_general_tma": self._build_single_expand_spec("mm_general_tma"), 

366 "mv": self._build_single_expand_spec( 

367 "mv", expand_yaml_path=DEFAULT_EXPAND_CONFIG_PATH 

368 ), 

369 "w8a8_block_fp8_general": self._build_single_expand_spec( 

370 "w8a8_block_fp8_general" 

371 ), 

372 "w8a8_block_fp8_general_splitk": self._build_single_expand_spec( 

373 "w8a8_block_fp8_general_splitk" 

374 ), 

375 "w8a8_block_fp8_general_tma": self._build_single_expand_spec( 

376 "w8a8_block_fp8_general_tma" 

377 ), 

378 "mm_splitk": self._build_single_expand_spec("mm_splitk"), 

379 "sparse_attention": self._build_single_expand_spec("sparse_attention"), 

380 } 

381 

382 def load_all(self): 

383 for key in self.vendor_primitive_yaml_config: 

384 self.loaded_triton_config[key] = self.get_tuned_config(key) 

385 

386 def get_vendor_heuristics_config(self): 

387 return backend.get_heuristic_config(self.device.vendor_name) 

388 

389 def get_default_heuristics_config(self): 

390 return backend.get_heuristic_config("nvidia") 

391 

392 def get_default_tune_config(self): 

393 return backend.get_tune_config("nvidia") 

394 

395 def get_vendor_tune_config(self): 

396 return backend.get_tune_config(self.device.vendor_name) 

397 

398 def get_heuristics_config(self, op_name): 

399 if self.arch_heuristics_config and op_name in self.arch_heuristics_config: 

400 return self.arch_heuristics_config[op_name] 

401 elif op_name in self.vendor_heuristics_config: 

402 return self.vendor_heuristics_config[op_name] 

403 elif op_name in self.default_heuristics_config: 

404 return self.default_heuristics_config[op_name] 

405 else: 

406 warnings.warn(f"No heuristics config found for {op_name}") 

407 return None 

408 

409 def _resolve_iteration_values(self, gen_config, config_var_key): 

410 if isinstance(config_var_key, (list, tuple)): 

411 return config_var_key 

412 if isinstance(config_var_key, int): 

413 return [config_var_key] 

414 return gen_config[config_var_key] 

415 

416 def _gen_impl( 

417 self, 

418 gen_config, 

419 iteration_plan, 

420 std_config, 

421 ): 

422 all_configs = [] 

423 final_step = len(iteration_plan) 

424 stack = [{"cur_config": std_config, "current_step": 0}] 

425 

426 while stack: 

427 cur_state = stack[-1] 

428 stack.pop() 

429 cur_config = cur_state.get("cur_config") 

430 current_step = cur_state.get("current_step") 

431 

432 if current_step == final_step: 

433 all_configs.append( 

434 triton.Config( 

435 cur_config["META"], 

436 num_warps=cur_config["num_warps"], 

437 num_stages=cur_config["num_stages"], 

438 num_ctas=cur_config["num_ctas"], 

439 ) 

440 ) 

441 else: 

442 cur_entry = iteration_plan[current_step] 

443 cur_key = cur_entry["key"] 

444 key_config = self._resolve_iteration_values( 

445 gen_config, cur_entry["source"] 

446 ) 

447 for single_value in key_config: 

448 new_config = copy.deepcopy(cur_config) 

449 if cur_entry["kind"] == "meta_field": 

450 new_config["META"][cur_key] = single_value 

451 elif cur_entry["kind"] == "meta_block": 

452 new_config["META"] = copy.deepcopy(single_value) 

453 else: 

454 new_config[cur_key] = single_value 

455 stack.append( 

456 { 

457 "cur_config": new_config, 

458 "current_step": current_step + 1, 

459 } 

460 ) 

461 return all_configs 

462 

463 def to_gen_config(self, gen_config): 

464 param_config = gen_config["param_map"] 

465 meta_config = param_config["META"] 

466 iteration_plan = [] 

467 

468 if isinstance(meta_config, dict): 

469 for meta_key, source in meta_config.items(): 

470 iteration_plan.append( 

471 {"key": meta_key, "source": source, "kind": "meta_field"} 

472 ) 

473 else: 

474 iteration_plan.append( 

475 {"key": "META", "source": meta_config, "kind": "meta_block"} 

476 ) 

477 

478 for key, source in param_config.items(): 

479 if key == "META": 

480 continue 

481 iteration_plan.append( 

482 {"key": key, "source": source, "kind": "config_field"} 

483 ) 

484 

485 current_config = {"META": {}} 

486 current_config.update(self.triton_config_default) 

487 return self._gen_impl( 

488 gen_config, 

489 iteration_plan, 

490 current_config, 

491 ) 

492 

493 def get_expand_config(self, op_name, yaml_path=None): 

494 op_spec = self.expand_config_registry.get(op_name) 

495 if op_spec is None: 

496 return -1 

497 

498 key = op_spec.get("key", []) 

499 default_strategy = op_spec.get("default_strategy") 

500 expand_yaml_path = op_spec.get("expand_yaml_path") or yaml_path 

501 yaml_op_name = op_spec.get("yaml_op_name", op_name) 

502 

503 try: 

504 expand_configs = backend.get_expand_config( 

505 op_name=yaml_op_name, 

506 file_path=expand_yaml_path, 

507 ) 

508 if not isinstance(expand_configs, list): 

509 return -1 

510 

511 gen_config = None 

512 strategy_config = None 

513 for single_config in expand_configs: 

514 if isinstance(single_config, dict) and "param_map" in single_config: 

515 gen_config = single_config 

516 

517 if isinstance(single_config, dict) and "strategy" in single_config: 

518 strategy_config = single_config.get("strategy") 

519 

520 param_map = gen_config.get("param_map") 

521 meta_map = param_map.get("META") 

522 

523 strategy = default_strategy 

524 if isinstance(strategy_config, dict): 

525 strategy = [ 

526 strategy_config.get(k, default_strategy[idx]) 

527 for idx, k in enumerate(key) 

528 ] 

529 

530 ranges = {} 

531 

532 for mapped_key in meta_map.values(): 

533 ranges[mapped_key.upper()] = gen_config[mapped_key] 

534 ranges["s"] = gen_config[param_map.get("num_stages")] 

535 ranges["w"] = gen_config[param_map.get("num_warps")] 

536 

537 return { 

538 "ranges": ranges, 

539 "strategy": strategy, 

540 } 

541 except Exception: 

542 return -1 

543 

544 def ops_get_configs(self, op_name, yaml_path=None, pre_hook=None): 

545 expand_config = self.get_expand_config(op_name, yaml_path=yaml_path) 

546 if expand_config == -1: 

547 return [] 

548 ranges = expand_config["ranges"] 

549 return self._build_configs_by_op(op_name, ranges, pre_hook=pre_hook) 

550 

551 def get_tuned_config(self, op_name): 

552 if op_name in self.loaded_triton_config: 

553 return self.loaded_triton_config[op_name] 

554 

555 current_op_configs = self._get_op_configs(op_name) 

556 if not current_op_configs: 

557 return [] 

558 

559 configs = [] 

560 

561 for single_config in current_op_configs: 

562 if self.gen_key in single_config: 

563 configs.extend(self.to_gen_config(single_config)) 

564 continue 

565 

566 current_config = copy.deepcopy(self.triton_config_default) 

567 for default_param in current_config: 

568 if default_param in single_config: 

569 current_config[default_param] = single_config[default_param] 

570 

571 configs.append(self._create_triton_config(single_config, current_config)) 

572 return configs