Coverage for src/flag_gems/runtime/flagtune.py: 60%

91 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import os 

2import warnings 

3from dataclasses import dataclass 

4from types import MappingProxyType 

5 

6USE_FLAGTUNE_ENV = "USE_FLAGTUNE" 

7FLAGTUNE_INCLUDE_ENV = "FLAGTUNE_INCLUDE" 

8 

9_flagtune_op_registry = {} 

10_include_ops = None 

11 

12 

13@dataclass(frozen=True) 

14class FlagTuneOpSpec: 

15 name: str 

16 default_enabled: bool = False 

17 description: str = "" 

18 

19 

20def _normalize_op_name(op_name): 

21 if not isinstance(op_name, str): 

22 raise TypeError("op_name must be a string") 

23 op_name = op_name.strip() 

24 if not op_name: 

25 raise ValueError("op_name must not be empty") 

26 return op_name 

27 

28 

29def register_flagtune_op( 

30 op_name, 

31 *, 

32 default=False, 

33 description="", 

34 replace=False, 

35): 

36 """Register an operator name that can be selected by flag_gems.flagtune.""" 

37 name = _normalize_op_name(op_name) 

38 spec = FlagTuneOpSpec( 

39 name=name, 

40 default_enabled=bool(default), 

41 description=str(description or ""), 

42 ) 

43 

44 existing = _flagtune_op_registry.get(name) 

45 if existing is not None and not replace: 

46 if existing == spec: 

47 return existing 

48 raise ValueError(f"FlagTune op {name!r} is already registered") 

49 

50 _flagtune_op_registry[name] = spec 

51 return spec 

52 

53 

54def get_flagtune_registry(): 

55 return MappingProxyType(dict(_flagtune_op_registry)) 

56 

57 

58def get_supported_flagtune_ops(): 

59 return frozenset(_flagtune_op_registry) 

60 

61 

62def get_default_flagtune_include(): 

63 return frozenset( 

64 name for name, spec in _flagtune_op_registry.items() if spec.default_enabled 

65 ) 

66 

67 

68def _split_include(include): 

69 if include is None: 

70 return get_default_flagtune_include() 

71 if isinstance(include, str): 

72 include = include.replace(";", ",").split(",") 

73 

74 try: 

75 ops = [str(op).strip() for op in include] 

76 except TypeError as err: 

77 raise TypeError( 

78 "include must be a comma-separated string or an iterable" 

79 ) from err 

80 

81 return frozenset(op for op in ops if op) 

82 

83 

84def _normalize_include(include): 

85 ops = _split_include(include) 

86 supported_ops = get_supported_flagtune_ops() 

87 unsupported = sorted(ops - supported_ops) 

88 if unsupported: 

89 supported = ", ".join(sorted(supported_ops)) or "<none>" 

90 raise ValueError( 

91 f"Unsupported flagtune op(s): {', '.join(unsupported)}. " 

92 f"Supported ops: {supported}" 

93 ) 

94 return ops 

95 

96 

97def flagtune(include=None): 

98 """Enable runtime FlagTune for selected operators. 

99 

100 Passing include=None enables the registry's default operators. Passing a 

101 string or iterable selects the registered operators that should use 

102 expanded tuning spaces when their LibTuner runs. This API only updates the 

103 explicit include list; setting USE_FLAGTUNE=1 remains the legacy opt-in for 

104 enabling every registered FlagTune operator. 

105 """ 

106 global _include_ops 

107 _include_ops = _normalize_include(include) 

108 os.environ[FLAGTUNE_INCLUDE_ENV] = ",".join(sorted(_include_ops)) 

109 

110 

111def _include_from_env(): 

112 include = os.environ.get(FLAGTUNE_INCLUDE_ENV) 

113 if include is None: 

114 return frozenset() 

115 try: 

116 return _normalize_include(include) 

117 except (TypeError, ValueError) as err: 

118 warnings.warn(f"Invalid {FLAGTUNE_INCLUDE_ENV}: {err}") 

119 return frozenset() 

120 

121 

122def get_flagtune_include(): 

123 if _include_ops is not None: 

124 return _include_ops 

125 return _include_from_env() 

126 

127 

128def flagtune_enabled(op_name): 

129 try: 

130 op_name = _normalize_op_name(op_name) 

131 except (TypeError, ValueError): 

132 return False 

133 if op_name not in get_supported_flagtune_ops(): 

134 return False 

135 return os.environ.get(USE_FLAGTUNE_ENV) == "1" or op_name in get_flagtune_include() 

136 

137 

138def __getattr__(name): 

139 if name == "SUPPORTED_FLAGTUNE_OPS": 

140 return get_supported_flagtune_ops() 

141 if name == "DEFAULT_FLAGTUNE_INCLUDE": 

142 return get_default_flagtune_include() 

143 raise AttributeError(f"module {__name__!r} has no attribute {name!r}") 

144 

145 

146register_flagtune_op("mm", default=False, description="matrix multiplication") 

147register_flagtune_op("bmm", default=False, description="batched matrix multiplication") 

148register_flagtune_op( 

149 "addmm", 

150 default=False, 

151 description="matrix multiplication with bias", 

152) 

153register_flagtune_op( 

154 "baddbmm", 

155 default=False, 

156 description="batched matrix multiplication with bias", 

157) 

158register_flagtune_op( 

159 "mv", 

160 default=False, 

161 description="matrix-vector multiplication", 

162) 

163register_flagtune_op( 

164 "w8a8_block_fp8_matmul", 

165 default=False, 

166 description="W8A8 block FP8 matrix multiplication", 

167) 

168 

169# DEFAULT_FLAGTUNE_INCLUDE and SUPPORTED_FLAGTUNE_OPS are provided by __getattr__. 

170__all__ = [ # noqa: F822 

171 "DEFAULT_FLAGTUNE_INCLUDE", 

172 "FLAGTUNE_INCLUDE_ENV", 

173 "FlagTuneOpSpec", 

174 "SUPPORTED_FLAGTUNE_OPS", 

175 "USE_FLAGTUNE_ENV", 

176 "flagtune", 

177 "flagtune_enabled", 

178 "get_default_flagtune_include", 

179 "get_flagtune_include", 

180 "get_flagtune_registry", 

181 "get_supported_flagtune_ops", 

182 "register_flagtune_op", 

183]