Coverage for src/flag_gems/config.py: 47%

104 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import os 

2import warnings 

3from pathlib import Path 

4 

5import yaml 

6 

7# Optional imports used inside helper functions to avoid hard dependencies at 

8# module import time. 

9try: # pragma: no cover - best effort fallback 

10 from flag_gems import runtime as _runtime 

11except Exception: # noqa: BLE001 

12 _runtime = None 

13 

14has_c_extension = False 

15use_c_extension = False 

16aten_patch_list = [] 

17 

18# set FLAGGEMS_SOURCE_DIR for cpp extension to find 

19os.environ["FLAGGEMS_SOURCE_DIR"] = str(Path(__file__).parent.resolve()) 

20 

21try: 

22 from flag_gems import c_operators 

23 

24 has_c_extension = True 

25except ImportError: 

26 c_operators = None 

27 has_c_extension = False 

28 

29 

30use_env_c_extension = os.environ.get("USE_C_EXTENSION", "0") == "1" 

31if use_env_c_extension and not has_c_extension: 

32 warnings.warn( 

33 "[FlagGems] USE_C_EXTENSION is set, but C extension is not available. " 

34 "Falling back to pure Python implementation.", 

35 RuntimeWarning, 

36 ) 

37 

38if has_c_extension and use_env_c_extension: 

39 try: 

40 from flag_gems import aten_patch 

41 

42 aten_patch_list = aten_patch.get_registered_ops() 

43 use_c_extension = True 

44 except (ImportError, AttributeError): 

45 aten_patch_list = [] 

46 use_c_extension = False 

47 

48 

49def load_enable_config_from_yaml(yaml_path, key="include"): 

50 """ 

51 Load include/exclude operator lists from a YAML file. 

52 

53 Expected YAML structure: 

54 include: # operators to explicitly enable 

55 - op_a 

56 - op_b 

57 exclude: # operators to skip 

58 - op_c 

59 

60 Both keys are optional; missing keys default to empty lists. 

61 Returns two lists `include` and `exclude`. 

62 """ 

63 yaml_path = Path(yaml_path) 

64 if not yaml_path.is_file(): 

65 warnings.warn(f"load_enable_config_from_yaml: yaml not found: {yaml_path}") 

66 return [] 

67 

68 try: 

69 data = yaml.safe_load(yaml_path.read_text()) 

70 except Exception as err: 

71 warnings.warn( 

72 f"load_enable_config_from_yaml: unexpected error reading {yaml_path}: {err}" 

73 ) 

74 return [] 

75 

76 if key not in ("include", "exclude"): 

77 warnings.warn( 

78 f"load_enable_config_from_yaml: key must be 'include' or 'exclude', got: {key}" 

79 ) 

80 return [] 

81 

82 if data is None: 

83 return [] 

84 

85 if isinstance(data, dict): 

86 operator_list = list(set(data.get(key, []))) 

87 return operator_list 

88 

89 warnings.warn( 

90 f"load_enable_config_from_yaml: yaml {yaml_path} must be a mapping with 'include'/'exclude' lists" 

91 ) 

92 return [] 

93 

94 

95def get_default_enable_config(vendor_name=None, arch_name=None): 

96 base_dir = Path(__file__).resolve().parent / "runtime" / "backend" 

97 vendor_dir = base_dir / f"_{vendor_name}" if vendor_name else base_dir 

98 

99 candidates = [] 

100 if vendor_dir.is_dir(): 

101 if arch_name: 

102 candidates.append(vendor_dir / arch_name / "enable_configs.yaml") 

103 candidates.append(vendor_dir / "enable_configs.yaml") 

104 candidates.append( 

105 base_dir / "_nvidia" / "enable_configs.yaml" 

106 ) # use nvidia as default 

107 return candidates 

108 

109 

110def resolve_user_setting(user_setting_info, user_setting_type="include"): 

111 """ 

112 Resolve user setting for include/exclude operator lists. 

113 

114 Args: 

115 user_setting_info: Can be a list/tuple/set of operators, "default", None, or a path to a YAML file. 

116 user_setting_type: Either "include" or "exclude". 

117 

118 Returns: 

119 List of operators based on the user setting. 

120 """ 

121 # If user_setting_info is a list, tuple, or set, use it directly as the operator list (deduplicated) 

122 if isinstance(user_setting_info, (list, tuple, set)): 

123 return list(set(user_setting_info)) 

124 

125 yaml_candidates = [] 

126 # If set to "default" or None (for include type), 

127 # load from default YAML config files based on vendor and architecture 

128 if user_setting_info == "default" or ( 

129 user_setting_type == "include" and user_setting_info is None 

130 ): 

131 # Lazily infer vendor/arch if not provided. 

132 vendor_name = _runtime.device.vendor_name 

133 arch_event = _runtime.backend.BackendArchEvent() 

134 arch_name = None 

135 if arch_event.has_arch: 

136 arch_name = getattr(arch_event, "arch", None) 

137 yaml_candidates = get_default_enable_config(vendor_name, arch_name) 

138 

139 # If user_setting_info is a string, treat it as a YAML file path 

140 elif isinstance(user_setting_info, str): 

141 yaml_candidates.append(user_setting_info) 

142 

143 # Iterate through candidate YAML paths and try to load the operator list 

144 for yaml_path in yaml_candidates: 

145 operator_list = load_enable_config_from_yaml(yaml_path, user_setting_type) 

146 if operator_list: 

147 return operator_list 

148 else: 

149 warnings.warn( 

150 f"resolve_user_setting: {user_setting_type} yaml not found: {yaml_path}" 

151 ) 

152 

153 # If no operators found in any YAML, warn and return empty list 

154 warnings.warn( 

155 f"resolve_user_setting: no {user_setting_type} ops found; returning empty list" 

156 ) 

157 return [] 

158 

159 

160# Precision-check skip set – derived from conf/operators.yaml 

161 

162_CONF_DIR = Path(__file__).resolve().parent.parent.parent / "conf" 

163_OPERATORS_YAML = _CONF_DIR / "operators.yaml" 

164 

165_skip_precision_check_ops: "frozenset[str] | None" = None 

166 

167 

168def get_skip_precision_check_ops() -> "frozenset[str]": 

169 """Return the frozenset of operator base-names that carry the 

170 ``skip_precision_check`` label in ``conf/operators.yaml``. 

171 

172 The set is built by scanning each entry under the top-level ``ops`` key; 

173 if its ``labels`` list contains ``"skip_precision_check"``, every name in 

174 its ``for`` field is included (with any overload suffix like ``.Tensor`` 

175 stripped to yield the base name). 

176 

177 The result is cached after the first call so subsequent imports are free. 

178 """ 

179 global _skip_precision_check_ops 

180 if _skip_precision_check_ops is not None: 

181 return _skip_precision_check_ops 

182 

183 ops: set = set() 

184 if _OPERATORS_YAML.is_file(): 

185 try: 

186 data = yaml.safe_load(_OPERATORS_YAML.read_text()) 

187 except Exception as err: # noqa: BLE001 

188 warnings.warn( 

189 f"get_skip_precision_check_ops: failed to read " 

190 f"{_OPERATORS_YAML}: {err}" 

191 ) 

192 _skip_precision_check_ops = frozenset() 

193 return _skip_precision_check_ops 

194 

195 if isinstance(data, dict): 

196 for entry in data.get("ops", []): 

197 labels = entry.get("labels", []) 

198 if "skip_precision_check" in labels: 

199 for op_name in entry.get("for", []): 

200 if op_name is None: 

201 continue 

202 # Extract base name (strip overload suffix) 

203 base = str(op_name).split(".")[0] 

204 ops.add(base) 

205 else: 

206 warnings.warn( 

207 f"get_skip_precision_check_ops: operators.yaml not found at " 

208 f"{_OPERATORS_YAML}" 

209 ) 

210 

211 _skip_precision_check_ops = frozenset(ops) 

212 return _skip_precision_check_ops 

213 

214 

215__all__ = [ 

216 "aten_patch_list", 

217 "has_c_extension", 

218 "use_c_extension", 

219 "resolve_user_setting", 

220 "get_skip_precision_check_ops", 

221]