Coverage for src/flag_gems/runtime/backend/backend_utils.py: 59%

49 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import functools 

2import os 

3from dataclasses import dataclass 

4 

5import yaml 

6 

7 

8# Metadata template, Each vendor needs to specialize instances of this template 

9@dataclass 

10class VendorInfoBase: 

11 vendor_name: str 

12 device_name: str 

13 device_query_cmd: str 

14 dispatch_key: str = None 

15 triton_extra_name: str = None 

16 

17 

18def get_tune_config(vendor_name=None, file_mode="r", file_path=None): 

19 BACKEND_EVENT = file_path is not None 

20 config = None 

21 try: 

22 if not file_path: 

23 vendor_name = "_" + vendor_name 

24 script_path = os.path.abspath(__file__) 

25 base_dir = os.path.dirname(script_path) 

26 file_path = os.path.join(base_dir, vendor_name, "tune_configs.yaml") 

27 else: 

28 file_path = os.path.join(file_path, "tune_configs.yaml") 

29 with open(file_path, file_mode) as file: 

30 config = yaml.safe_load(file) 

31 except FileNotFoundError: 

32 if not BACKEND_EVENT: 

33 raise FileNotFoundError(f"Configuration file not found: {file_path}") 

34 except yaml.YAMLError as e: 

35 raise ValueError(f"Failed to parse YAML file: {e}") 

36 except Exception as e: 

37 raise RuntimeError(f"An unexpected error occurred: {e}") 

38 

39 return config 

40 

41 

42@functools.lru_cache(maxsize=None) 

43def _load_expand_config(file_path, file_mode="r"): 

44 with open(file_path, file_mode) as file: 

45 return yaml.safe_load(file) or {} 

46 

47 

48def get_expand_config(op_name=None, file_mode="r", file_path=None): 

49 if not file_path: 

50 raise ValueError("expand config file path is required") 

51 try: 

52 config = _load_expand_config(file_path, file_mode) 

53 except FileNotFoundError: 

54 raise FileNotFoundError(f"Configuration file not found: {file_path}") 

55 except yaml.YAMLError as e: 

56 raise ValueError(f"Failed to parse YAML file: {e}") 

57 except Exception as e: 

58 raise RuntimeError(f"An unexpected error occurred: {e}") 

59 if op_name is None: 

60 return config 

61 return config.get(op_name)