Coverage for src/flag_gems/runtime/configloader.py: 62%

203 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import copy 

2import warnings 

3 

4import triton 

5 

6from . import backend, common 

7from .backend.device import DeviceDetector 

8 

9 

10class ConfigLoader(object): 

11 _instance = None 

12 

13 def __new__(cls, *args, **kargs): 

14 if cls._instance is None: 

15 cls._instance = super(ConfigLoader, cls).__new__(cls) 

16 return cls._instance 

17 

18 def __init__(self): 

19 if not hasattr(self, "initialized"): 

20 self.initialized = True 

21 self.device = DeviceDetector() 

22 # primitive_yaml_config is simply the dictionary returned by yaml 

23 # and is reserved from being an attr for vendor customizability 

24 self.arch_specialized_yaml_config = None 

25 self.arch_heuristics_config = None 

26 self.vendor_primitive_yaml_config = self.get_vendor_tune_config() 

27 self.default_primitive_yaml_config = self.get_default_tune_config() 

28 self.vendor_heuristics_config = self.get_vendor_heuristics_config() 

29 self.default_heuristics_config = self.get_default_heuristics_config() 

30 self.update_config_from_arch() 

31 

32 if self.vendor_heuristics_config is None: 

33 vendorname = self.device.vendor_name 

34 warnings.warn( 

35 f"The {vendorname} configuration of heuristics_config is None" 

36 ) 

37 # gen_key is an identifier that indicates whether the current config needs to be generated automatically 

38 self.gen_key = "gen" 

39 # loaded_triton_config is wrapped in triton.Config according to primitive_yaml_config 

40 self.loaded_triton_config = {} 

41 self.triton_config_default = { 

42 "num_stages": 2, 

43 "num_warps": 4, 

44 "num_ctas": 1, 

45 } 

46 if self.device.vendor_name == "hygon": 

47 self.triton_config_default["num_ldmatrixes"] = 0 

48 self.expand_config_registry = self._build_expand_registry() 

49 self.load_all() 

50 

51 def update_config_from_arch(self): 

52 try: 

53 archEvent = backend.BackendArchEvent() 

54 if archEvent.has_arch: 

55 self.arch_specialized_yaml_config = archEvent.autotune_configs 

56 self.arch_heuristics_config = archEvent.heuristics_configs 

57 except Exception as err: 

58 print(f"[INFO] : {err}") 

59 

60 def _get_op_configs(self, op_name): 

61 """Get config for op_name from available config sources.""" 

62 for config in ( 

63 self.arch_specialized_yaml_config, 

64 self.vendor_primitive_yaml_config, 

65 self.default_primitive_yaml_config, 

66 ): 

67 if config and op_name in config: 

68 return config[op_name] 

69 return [] 

70 

71 def _create_triton_config(self, single_config, current_config): 

72 """Create a triton.Config with appropriate parameters.""" 

73 kwargs = { 

74 "num_warps": current_config["num_warps"], 

75 "num_stages": current_config["num_stages"], 

76 "num_ctas": current_config["num_ctas"], 

77 } 

78 if self.device.vendor_name == "hygon": 

79 kwargs["num_ldmatrixes"] = current_config["num_ldmatrixes"] 

80 return triton.Config(single_config["META"], **kwargs) 

81 

82 def _build_configs_by_op(self, op_name, ranges, pre_hook=None): 

83 if op_name == "bmm": 

84 return [ 

85 triton.Config( 

86 { 

87 "TILE_M": block_m, 

88 "TILE_N": block_n, 

89 "TILE_K": block_k, 

90 "GROUP_M": 1 if block_m == 32 else 2, 

91 }, 

92 num_stages=s, 

93 num_warps=w, 

94 pre_hook=pre_hook, 

95 ) 

96 for block_m in ranges["BLOCK_M"] 

97 for block_n in ranges["BLOCK_N"] 

98 for block_k in ranges["BLOCK_K"] 

99 for s in ranges["s"] 

100 for w in ranges["w"] 

101 ] 

102 

103 if op_name == "addmm": 

104 return [ 

105 triton.Config( 

106 { 

107 "BLOCK_SIZE_M": block_m, 

108 "BLOCK_SIZE_N": block_n, 

109 "BLOCK_SIZE_K": block_k, 

110 }, 

111 num_stages=s, 

112 num_warps=w, 

113 pre_hook=pre_hook, 

114 ) 

115 for block_m in ranges["BLOCK_M"] 

116 for block_n in ranges["BLOCK_N"] 

117 for block_k in ranges["BLOCK_K"] 

118 for s in ranges["s"] 

119 for w in ranges["w"] 

120 ] 

121 

122 if op_name == "baddbmm": 

123 return [ 

124 triton.Config( 

125 { 

126 "TILE_M": block_m, 

127 "TILE_N": block_n, 

128 "TILE_K": block_k, 

129 "GROUP_M": 1 if block_m <= 32 else 2, 

130 }, 

131 num_stages=s, 

132 num_warps=w, 

133 pre_hook=pre_hook, 

134 ) 

135 for block_m in ranges["BLOCK_M"] 

136 for block_n in ranges["BLOCK_N"] 

137 for block_k in ranges["BLOCK_K"] 

138 for s in ranges["s"] 

139 for w in ranges["w"] 

140 ] 

141 

142 if op_name == "mv": 

143 return [ 

144 triton.Config( 

145 { 

146 "BLOCK_N": block_n, 

147 "BLOCK_M": block_m, 

148 }, 

149 num_stages=s, 

150 num_warps=w, 

151 pre_hook=pre_hook, 

152 ) 

153 for block_n in ranges["BLOCK_N"] 

154 for block_m in ranges["BLOCK_M"] 

155 for s in ranges["s"] 

156 for w in ranges["w"] 

157 ] 

158 

159 if op_name == "mm_general_tma": 

160 return [ 

161 triton.Config( 

162 { 

163 "BLOCK_M": block_m, 

164 "BLOCK_N": block_n, 

165 "BLOCK_K": block_k, 

166 }, 

167 num_stages=s, 

168 num_warps=w, 

169 pre_hook=pre_hook, 

170 ) 

171 for block_m in ranges["BLOCK_M"] 

172 for block_n in ranges["BLOCK_N"] 

173 for block_k in ranges["BLOCK_K"] 

174 for s in ranges["s"] 

175 for w in ranges["w"] 

176 ] 

177 

178 if op_name in ("mm", "mm_sqmma"): 

179 return [ 

180 triton.Config( 

181 { 

182 "BLOCK_M": block_m, 

183 "BLOCK_N": block_n, 

184 "BLOCK_K": block_k, 

185 }, 

186 num_stages=s, 

187 num_warps=w, 

188 pre_hook=pre_hook, 

189 ) 

190 for block_m in ranges["BLOCK_M"] 

191 for block_n in ranges["BLOCK_N"] 

192 for block_k in ranges["BLOCK_K"] 

193 for s in ranges["s"] 

194 for w in ranges["w"] 

195 ] 

196 

197 if op_name in ("bmm_sqmma", "addmm_sqmma"): 

198 return [ 

199 triton.Config( 

200 { 

201 "BLOCK_SIZE_M": block_m, 

202 "BLOCK_SIZE_N": block_n, 

203 "BLOCK_SIZE_K": block_k, 

204 }, 

205 num_stages=s, 

206 num_warps=w, 

207 pre_hook=pre_hook, 

208 ) 

209 for block_m in ranges["BLOCK_M"] 

210 for block_n in ranges["BLOCK_N"] 

211 for block_k in ranges["BLOCK_K"] 

212 for s in ranges["s"] 

213 for w in ranges["w"] 

214 ] 

215 

216 if op_name == "gemv": 

217 return [ 

218 triton.Config( 

219 {"BLOCK_M": block_m, "BLOCK_K": block_k}, 

220 num_stages=s, 

221 num_warps=w, 

222 pre_hook=pre_hook, 

223 ) 

224 for block_m in ranges["BLOCK_M"] 

225 for block_k in ranges["BLOCK_K"] 

226 for s in ranges["s"] 

227 for w in ranges["w"] 

228 ] 

229 

230 if op_name == "sparse_attention": 

231 return [ 

232 triton.Config( 

233 {"BLOCK": block}, 

234 num_stages=s, 

235 num_warps=w, 

236 pre_hook=pre_hook, 

237 ) 

238 for block in ranges["BLOCK"] 

239 for s in ranges["s"] 

240 for w in ranges["w"] 

241 ] 

242 

243 if op_name == "w8a8_block_fp8_general": 

244 return [ 

245 triton.Config( 

246 { 

247 "BLOCK_M": block_m, 

248 "BLOCK_N": block_n, 

249 "BLOCK_K": block_k, 

250 "GROUP_M": group_m, 

251 }, 

252 num_stages=s, 

253 num_warps=w, 

254 pre_hook=pre_hook, 

255 ) 

256 for block_m in ranges["BLOCK_M"] 

257 for block_n in ranges["BLOCK_N"] 

258 for block_k in ranges["BLOCK_K"] 

259 for group_m in ranges["GROUP_M"] 

260 for s in ranges["s"] 

261 for w in ranges["w"] 

262 ] 

263 

264 if op_name == "w8a8_block_fp8_general_tma": 

265 group_m_values = ranges.get("GROUP_M", [None]) 

266 return [ 

267 triton.Config( 

268 dict( 

269 { 

270 "BLOCK_M": block_m, 

271 "BLOCK_N": block_n, 

272 "BLOCK_K": block_k, 

273 }, 

274 **({} if group_m is None else {"GROUP_M": group_m}), 

275 ), 

276 num_stages=s, 

277 num_warps=w, 

278 pre_hook=pre_hook, 

279 ) 

280 for block_m in ranges["BLOCK_M"] 

281 for block_n in ranges["BLOCK_N"] 

282 for block_k in ranges["BLOCK_K"] 

283 for group_m in group_m_values 

284 for s in ranges["s"] 

285 for w in ranges["w"] 

286 ] 

287 

288 if op_name == "w8a8_block_fp8_general_splitk": 

289 return [ 

290 triton.Config( 

291 { 

292 "BLOCK_M": block_m, 

293 "BLOCK_N": block_n, 

294 "BLOCK_K": block_k, 

295 "SPLIT_K": split_k, 

296 }, 

297 num_stages=s, 

298 num_warps=w, 

299 pre_hook=pre_hook, 

300 ) 

301 for block_m in ranges["BLOCK_M"] 

302 for block_n in ranges["BLOCK_N"] 

303 for block_k in ranges["BLOCK_K"] 

304 for split_k in ranges["SPLIT_K"] 

305 for s in ranges["s"] 

306 for w in ranges["w"] 

307 ] 

308 

309 if op_name == "mm_splitk": 

310 return [ 

311 triton.Config( 

312 { 

313 "BLOCK_M": block_m, 

314 "BLOCK_N": block_n, 

315 "BLOCK_K": block_k, 

316 "SPLIT_K": split_k, 

317 }, 

318 num_stages=s, 

319 num_warps=w, 

320 pre_hook=pre_hook, 

321 ) 

322 for block_m in ranges["BLOCK_M"] 

323 for block_n in ranges["BLOCK_N"] 

324 for block_k in ranges["BLOCK_K"] 

325 for split_k in ranges["SPLIT_K"] 

326 for s in ranges["s"] 

327 for w in ranges["w"] 

328 ] 

329 

330 return [] 

331 

332 def _build_single_expand_spec( 

333 self, 

334 op_name, 

335 expand_yaml_path=None, 

336 yaml_op_name=None, 

337 ): 

338 return { 

339 "yaml_op_name": yaml_op_name or op_name, 

340 "key": common.OP_KEY_ORDERS[op_name], 

341 "default_strategy": common.DEFAULT_STRATEGIES[op_name], 

342 "expand_yaml_path": expand_yaml_path, 

343 } 

344 

345 def _build_expand_registry(self): 

346 DEFAULT_EXPAND_CONFIG_PATH = common.DEFAULT_EXPAND_CONFIG_PATH 

347 return { 

348 "addmm": self._build_single_expand_spec( 

349 "addmm", expand_yaml_path=DEFAULT_EXPAND_CONFIG_PATH 

350 ), 

351 "addmm_sqmma": self._build_single_expand_spec("addmm_sqmma"), 

352 "baddbmm": self._build_single_expand_spec( 

353 "baddbmm", expand_yaml_path=DEFAULT_EXPAND_CONFIG_PATH 

354 ), 

355 "bmm": self._build_single_expand_spec( 

356 "bmm", expand_yaml_path=DEFAULT_EXPAND_CONFIG_PATH 

357 ), 

358 "bmm_sqmma": self._build_single_expand_spec("bmm_sqmma"), 

359 "gemv": self._build_single_expand_spec("gemv"), 

360 "mm": self._build_single_expand_spec("mm"), 

361 "mm_general_tma": self._build_single_expand_spec("mm_general_tma"), 

362 "mv": self._build_single_expand_spec( 

363 "mv", expand_yaml_path=DEFAULT_EXPAND_CONFIG_PATH 

364 ), 

365 "w8a8_block_fp8_general": self._build_single_expand_spec( 

366 "w8a8_block_fp8_general" 

367 ), 

368 "w8a8_block_fp8_general_splitk": self._build_single_expand_spec( 

369 "w8a8_block_fp8_general_splitk" 

370 ), 

371 "w8a8_block_fp8_general_tma": self._build_single_expand_spec( 

372 "w8a8_block_fp8_general_tma" 

373 ), 

374 "mm_splitk": self._build_single_expand_spec("mm_splitk"), 

375 "sparse_attention": self._build_single_expand_spec("sparse_attention"), 

376 } 

377 

378 def load_all(self): 

379 for key in self.vendor_primitive_yaml_config: 

380 self.loaded_triton_config[key] = self.get_tuned_config(key) 

381 

382 def get_vendor_heuristics_config(self): 

383 return backend.get_heuristic_config(self.device.vendor_name) 

384 

385 def get_default_heuristics_config(self): 

386 return backend.get_heuristic_config("nvidia") 

387 

388 def get_default_tune_config(self): 

389 return backend.get_tune_config("nvidia") 

390 

391 def get_vendor_tune_config(self): 

392 return backend.get_tune_config(self.device.vendor_name) 

393 

394 def get_heuristics_config(self, op_name): 

395 if self.arch_heuristics_config and op_name in self.arch_heuristics_config: 

396 return self.arch_heuristics_config[op_name] 

397 elif op_name in self.vendor_heuristics_config: 

398 return self.vendor_heuristics_config[op_name] 

399 elif op_name in self.default_heuristics_config: 

400 return self.default_heuristics_config[op_name] 

401 else: 

402 warnings.warn(f"No heuristics config found for {op_name}") 

403 return None 

404 

405 def _resolve_iteration_values(self, gen_config, config_var_key): 

406 if isinstance(config_var_key, (list, tuple)): 

407 return config_var_key 

408 if isinstance(config_var_key, int): 

409 return [config_var_key] 

410 return gen_config[config_var_key] 

411 

412 def _gen_impl( 

413 self, 

414 gen_config, 

415 iteration_plan, 

416 std_config, 

417 ): 

418 all_configs = [] 

419 final_step = len(iteration_plan) 

420 stack = [{"cur_config": std_config, "current_step": 0}] 

421 

422 while stack: 

423 cur_state = stack[-1] 

424 stack.pop() 

425 cur_config = cur_state.get("cur_config") 

426 current_step = cur_state.get("current_step") 

427 

428 if current_step == final_step: 

429 all_configs.append( 

430 triton.Config( 

431 cur_config["META"], 

432 num_warps=cur_config["num_warps"], 

433 num_stages=cur_config["num_stages"], 

434 num_ctas=cur_config["num_ctas"], 

435 ) 

436 ) 

437 else: 

438 cur_entry = iteration_plan[current_step] 

439 cur_key = cur_entry["key"] 

440 key_config = self._resolve_iteration_values( 

441 gen_config, cur_entry["source"] 

442 ) 

443 for single_value in key_config: 

444 new_config = copy.deepcopy(cur_config) 

445 if cur_entry["kind"] == "meta_field": 

446 new_config["META"][cur_key] = single_value 

447 elif cur_entry["kind"] == "meta_block": 

448 new_config["META"] = copy.deepcopy(single_value) 

449 else: 

450 new_config[cur_key] = single_value 

451 stack.append( 

452 { 

453 "cur_config": new_config, 

454 "current_step": current_step + 1, 

455 } 

456 ) 

457 return all_configs 

458 

459 def to_gen_config(self, gen_config): 

460 param_config = gen_config["param_map"] 

461 meta_config = param_config["META"] 

462 iteration_plan = [] 

463 

464 if isinstance(meta_config, dict): 

465 for meta_key, source in meta_config.items(): 

466 iteration_plan.append( 

467 {"key": meta_key, "source": source, "kind": "meta_field"} 

468 ) 

469 else: 

470 iteration_plan.append( 

471 {"key": "META", "source": meta_config, "kind": "meta_block"} 

472 ) 

473 

474 for key, source in param_config.items(): 

475 if key == "META": 

476 continue 

477 iteration_plan.append( 

478 {"key": key, "source": source, "kind": "config_field"} 

479 ) 

480 

481 current_config = {"META": {}} 

482 current_config.update(self.triton_config_default) 

483 return self._gen_impl( 

484 gen_config, 

485 iteration_plan, 

486 current_config, 

487 ) 

488 

489 def get_expand_config(self, op_name, yaml_path=None): 

490 op_spec = self.expand_config_registry.get(op_name) 

491 if op_spec is None: 

492 return -1 

493 

494 key = op_spec.get("key", []) 

495 default_strategy = op_spec.get("default_strategy") 

496 expand_yaml_path = op_spec.get("expand_yaml_path") or yaml_path 

497 yaml_op_name = op_spec.get("yaml_op_name", op_name) 

498 

499 try: 

500 expand_configs = backend.get_expand_config( 

501 op_name=yaml_op_name, 

502 file_path=expand_yaml_path, 

503 ) 

504 if not isinstance(expand_configs, list): 

505 return -1 

506 

507 gen_config = None 

508 strategy_config = None 

509 for single_config in expand_configs: 

510 if isinstance(single_config, dict) and "param_map" in single_config: 

511 gen_config = single_config 

512 

513 if isinstance(single_config, dict) and "strategy" in single_config: 

514 strategy_config = single_config.get("strategy") 

515 

516 param_map = gen_config.get("param_map") 

517 meta_map = param_map.get("META") 

518 

519 strategy = default_strategy 

520 if isinstance(strategy_config, dict): 

521 strategy = [ 

522 strategy_config.get(k, default_strategy[idx]) 

523 for idx, k in enumerate(key) 

524 ] 

525 

526 ranges = {} 

527 

528 for mapped_key in meta_map.values(): 

529 ranges[mapped_key.upper()] = gen_config[mapped_key] 

530 ranges["s"] = gen_config[param_map.get("num_stages")] 

531 ranges["w"] = gen_config[param_map.get("num_warps")] 

532 

533 return { 

534 "ranges": ranges, 

535 "strategy": strategy, 

536 } 

537 except Exception: 

538 return -1 

539 

540 def ops_get_configs(self, op_name, yaml_path=None, pre_hook=None): 

541 expand_config = self.get_expand_config(op_name, yaml_path=yaml_path) 

542 if expand_config == -1: 

543 return [] 

544 ranges = expand_config["ranges"] 

545 return self._build_configs_by_op(op_name, ranges, pre_hook=pre_hook) 

546 

547 def get_tuned_config(self, op_name): 

548 if op_name in self.loaded_triton_config: 

549 return self.loaded_triton_config[op_name] 

550 

551 current_op_configs = self._get_op_configs(op_name) 

552 if not current_op_configs: 

553 return [] 

554 

555 configs = [] 

556 

557 for single_config in current_op_configs: 

558 if self.gen_key in single_config: 

559 configs.extend(self.to_gen_config(single_config)) 

560 continue 

561 

562 current_config = copy.deepcopy(self.triton_config_default) 

563 for default_param in current_config: 

564 if default_param in single_config: 

565 current_config[default_param] = single_config[default_param] 

566 

567 configs.append(self._create_triton_config(single_config, current_config)) 

568 return configs