Coverage for src/flag_gems/__init__.py: 89%
76 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import warnings
3import torch
4from packaging import version
6from flag_gems import testing # noqa: F401
7from flag_gems import runtime
8from flag_gems.config import aten_patch_list, resolve_user_setting
9from flag_gems.experimental_ops import * # noqa: F403
10from flag_gems.fused import * # noqa: F403
11from flag_gems.logging_utils import setup_flaggems_logging, teardown_flaggems_logging
12from flag_gems.modules import * # noqa: F403
13from flag_gems.ops import * # noqa: F403
14from flag_gems.patches import * # noqa: F403
15from flag_gems.runtime.register import Register
17__version__ = "5.0.2"
18device = runtime.device.name
19vendor_name = runtime.device.vendor_name
20aten_lib = torch.library.Library("aten", "IMPL")
21registrar = Register
22current_work_registrar = None
23runtime.replace_customized_ops(globals())
26def torch_ge(v):
27 return version.parse(torch.__version__) >= version.parse(v)
30_FULL_CONFIG = (
31 ("__ior__.Scalar", bitwise_or_scalar_),
32 ("__ior__.Tensor", bitwise_or_tensor_),
33 ("__or__.Scalar", bitwise_or_scalar),
34 ("__or__.Tensor", bitwise_or_tensor),
35 ("_assert_async", _assert_async),
36 ("_conv_depthwise2d", _conv_depthwise2d),
37 ("_flash_attention_forward", flash_attention_forward),
38 (
39 "_functional_sym_constrain_range_for_size",
40 _functional_sym_constrain_range_for_size,
41 ),
42 ("_grouped_mm", group_mm),
43 ("_index_put_impl_", _index_put_impl_),
44 ("_is_all_true", _is_all_true),
45 ("_log_softmax", log_softmax),
46 ("_log_softmax.out", log_softmax_out),
47 ("_log_softmax_backward_data", log_softmax_backward),
48 ("_log_softmax_backward_data.out", log_softmax_backward_out),
49 ("_safe_softmax", _safe_softmax),
50 ("_softmax", softmax),
51 ("_softmax.out", softmax_out),
52 ("_softmax_backward_data", softmax_backward),
53 ("_softmax_backward_data.out", softmax_backward_out),
54 (
55 "_to_copy",
56 to_copy,
57 lambda: version.parse(torch.__version__) >= version.parse("2.4"),
58 ),
59 ("_unique2", _unique2),
60 ("_upsample_bicubic2d_aa", _upsample_bicubic2d_aa),
61 ("_upsample_bicubic2d_aa_backward", _upsample_bicubic2d_aa_backward),
62 ("_upsample_nearest_exact1d", _upsample_nearest_exact1d),
63 ("_weight_norm_interface", weight_norm_interface),
64 ("_weight_norm_interface_backward", weight_norm_interface_backward),
65 ("abs", abs),
66 ("abs_", abs_),
67 ("absolute", absolute),
68 ("acos", acos),
69 ("add.Tensor", add),
70 ("add_.Tensor", add_),
71 ("addcdiv", addcdiv),
72 ("addcdiv.out", addcdiv_out),
73 ("addcmul", addcmul),
74 ("addcmul.out", addcmul_out),
75 ("addmv", addmv),
76 ("addmv.out", addmv_out),
77 ("addmm", addmm),
78 ("addmm.out", addmm_out),
79 ("addmm.dtype", addmm_dtype),
80 ("addmm.dtype_out", addmm_dtype_out),
81 ("addr", addr),
82 ("alias_copy", alias_copy),
83 ("all", all),
84 ("all.dim", all_dim),
85 ("all.dims", all_dims),
86 ("allclose", allclose),
87 ("amax", amax),
88 ("aminmax", aminmax),
89 ("angle", angle),
90 ("any", any),
91 ("any.dim", any_dim),
92 ("any.dims", any_dims),
93 ("arange", arange),
94 ("arange.start", arange_start),
95 ("arange.start_step", arange_start),
96 ("arcsinh", arcsinh),
97 ("arcsinh.out", arcsinh_out),
98 ("arcsinh_", arcsinh_),
99 ("argmax", argmax),
100 ("argmin", argmin),
101 ("asinh", asinh),
102 ("asinh.out", asinh_out),
103 ("asinh_", asinh_),
104 ("atan", atan),
105 ("atan_", atan_),
106 ("atan2", atan2),
107 ("atan2.out", atan2_out),
108 ("arctanh_", arctanh_),
109 ("avg_pool2d", avg_pool2d),
110 ("avg_pool2d_backward", avg_pool2d_backward),
111 ("avg_pool3d", avg_pool3d),
112 ("avg_pool3d_backward", avg_pool3d_backward),
113 ("baddbmm", baddbmm),
114 ("bernoulli_.float", bernoulli_),
115 ("bincount", bincount),
116 ("bitwise_and.Scalar", bitwise_and_scalar),
117 ("bitwise_and.Scalar_Tensor", bitwise_and_scalar_tensor),
118 ("bitwise_and.Tensor", bitwise_and_tensor),
119 ("bitwise_and_.Scalar", bitwise_and_scalar_),
120 ("bitwise_and_.Tensor", bitwise_and_tensor_),
121 ("bitwise_left_shift", bitwise_left_shift),
122 ("bitwise_not", bitwise_not),
123 ("bitwise_not_", bitwise_not_),
124 ("bitwise_or.Scalar", bitwise_or_scalar),
125 ("bitwise_or.Scalar_Tensor", bitwise_or_scalar_tensor),
126 ("bitwise_or.Tensor", bitwise_or_tensor),
127 ("bitwise_or_.Scalar", bitwise_or_scalar_),
128 ("bitwise_or_.Tensor", bitwise_or_tensor_),
129 ("bitwise_right_shift", bitwise_right_shift),
130 ("bmm", bmm),
131 ("bmm.out", bmm_out),
132 ("cat", cat),
133 ("cat.out", cat_out),
134 ("celu", celu),
135 ("celu_", celu_),
136 ("ceil", ceil),
137 ("ceil_", ceil_),
138 ("ceil.out", ceil_out),
139 ("clamp", clamp),
140 ("clamp.Tensor", clamp_tensor),
141 ("clamp_min", clamp_min),
142 ("clamp_", clamp_),
143 ("clamp_.Tensor", clamp_tensor_),
144 ("clamp_min_", clamp_min_),
145 ("clip", clip),
146 ("clip_", clip_),
147 ("conj_physical", conj_physical),
148 ("constant_pad_nd", constant_pad_nd),
149 # ("contiguous", contiguous),
150 ("conv1d", conv1d),
151 ("conv1d.padding", conv1d),
152 ("conv2d", conv2d),
153 ("conv2d.padding", conv2d),
154 ("conv3d", conv3d),
155 ("conv3d.padding", conv3d),
156 (
157 "copy_",
158 copy_,
159 lambda: version.parse(torch.__version__) >= version.parse("2.4"),
160 ),
161 ("cos", cos),
162 ("cos_", cos_),
163 ("cosh", cosh),
164 ("cosh_", cosh_),
165 ("cosh.out", cosh_out),
166 ("copysign", copysign),
167 ("copysign.out", copysign_out),
168 ("count_nonzero", count_nonzero),
169 ("cummax", cummax),
170 ("cummin", cummin),
171 ("cumsum", cumsum),
172 ("cumsum.out", cumsum_out),
173 ("diag", diag),
174 ("diag_embed", diag_embed),
175 ("diagonal_backward", diagonal_backward),
176 ("digamma_", digamma_),
177 ("div.Scalar", true_divide),
178 ("div.Scalar_mode", div_mode),
179 ("div.Tensor", true_divide),
180 ("div.Tensor_mode", div_mode),
181 ("div.out", true_divide_out),
182 ("div_.Scalar", true_divide_),
183 ("div_.Scalar_mode", div_mode_),
184 ("div_.Tensor", true_divide_),
185 ("div_.Tensor_mode", div_mode_),
186 ("divide.Scalar", true_divide),
187 ("divide.Scalar_mode", div_mode),
188 ("divide.Tensor", true_divide),
189 ("divide.Tensor_mode", div_mode),
190 ("divide_.Scalar", true_divide_),
191 ("divide_.Scalar_mode", div_mode_),
192 ("divide_.Tensor", true_divide_),
193 ("divide_.Tensor_mode", div_mode_),
194 ("dot", dot),
195 ("elu", elu),
196 ("elu_", elu_),
197 ("elu_backward", elu_backward),
198 ("embedding", embedding),
199 ("embedding_backward", embedding_backward),
200 ("embedding_dense_backward", embedding_dense_backward),
201 ("eq.Scalar", eq_scalar),
202 ("eq.Tensor", eq),
203 ("equal", equal),
204 ("erf", erf),
205 ("erf_", erf_),
206 ("exp", exp),
207 ("exp_", exp_),
208 ("exp.out", exp_out),
209 ("exp2", exp2),
210 ("exp2_", exp2_),
211 ("expm1", expm1),
212 ("expm1_", expm1_),
213 ("expm1.out", expm1_out),
214 ("exponential_", exponential_),
215 ("eye", eye),
216 ("eye.m", eye_m),
217 ("fill.Scalar", fill_scalar),
218 ("fill.Scalar_out", fill_scalar_out),
219 ("fill.Tensor", fill_tensor),
220 ("fill.Tensor_out", fill_tensor_out),
221 ("fill_.Scalar", fill_scalar_),
222 ("fill_.Tensor", fill_tensor_),
223 ("flip", flip),
224 ("floor_", floor_),
225 ("floor_divide", floor_divide),
226 ("floor_divide.Scalar", floor_divide),
227 ("floor_divide_.Scalar", floor_divide_),
228 ("floor_divide_.Tensor", floor_divide_),
229 ("fmin", fmin),
230 ("fmin.out", fmin_out),
231 ("full", full),
232 ("full_like", full_like),
233 ("gather", gather),
234 ("gather_backward", gather_backward),
235 ("gcd", gcd),
236 ("gcd.out", gcd_out),
237 ("ge.Scalar", ge_scalar),
238 ("ge.Tensor", ge),
239 ("gelu", gelu),
240 ("gelu_", gelu_),
241 ("gelu_backward", gelu_backward),
242 ("glu", glu),
243 ("glu_backward", glu_backward),
244 ("greater.Scalar", greater_scalar),
245 ("greater.Tensor", greater),
246 ("greater.Scalar_out", greater_scalar_out),
247 ("greater.out", greater_out),
248 ("grid_sample", grid_sample),
249 ("gt.Scalar", gt_scalar),
250 ("gt.Tensor", gt),
251 ("hardsigmoid", hardsigmoid),
252 ("hardsigmoid.out", hardsigmoid_out),
253 ("hardswish_", hardswish_),
254 ("hstack", hstack),
255 ("hypot", hypot),
256 ("i0", i0),
257 ("i0.out", i0_out),
258 ("i0_", i0_),
259 ("index.Tensor", index),
260 ("index_add", index_add),
261 ("index_add_", index_add_),
262 ("index_put", index_put),
263 ("index_put_", index_put_),
264 ("index_select", index_select),
265 ("isclose", isclose),
266 ("isfinite", isfinite),
267 ("isin.Scalar_Tensor", isin),
268 ("isin.Tensor_Scalar", isin),
269 ("isin.Tensor_Tensor", isin),
270 ("isinf", isinf),
271 ("isnan", isnan),
272 ("isneginf", isneginf),
273 ("isneginf.out", isneginf_out),
274 ("kron", kron),
275 ("le.Scalar", le_scalar),
276 ("le.Tensor", le),
277 ("leaky_relu", leaky_relu),
278 ("leaky_relu_", leaky_relu_),
279 ("leaky_relu.out", leaky_relu_out),
280 ("lerp.Scalar", lerp_scalar),
281 ("lerp.Tensor", lerp_tensor),
282 ("lerp_.Scalar", lerp_scalar_),
283 ("lerp_.Tensor", lerp_tensor_),
284 ("lift_fresh_copy", lift_fresh_copy),
285 ("linalg_vector_norm", vector_norm),
286 ("linspace", linspace),
287 ("log", log),
288 ("log10", log10),
289 ("log10_", log10_),
290 ("log10.out", log10_out),
291 ("log1p_", log1p_),
292 ("log_sigmoid", log_sigmoid),
293 ("logaddexp", logaddexp),
294 ("logaddexp.out", logaddexp_out),
295 ("logical_and", logical_and),
296 ("logical_and_", logical_and_),
297 ("logical_not", logical_not),
298 ("logical_or", logical_or),
299 ("logical_or_", logical_or_),
300 ("logical_xor", logical_xor),
301 ("logit", logit),
302 ("logit.out", logit_out),
303 ("logit_", logit_),
304 ("logspace", logspace),
305 ("lt.Scalar", lt_scalar),
306 ("lt.Tensor", lt),
307 ("margin_ranking_loss", margin_ranking_loss),
308 ("masked_fill.Scalar", masked_fill),
309 ("masked_fill.Tensor", masked_fill),
310 ("masked_fill_.Scalar", masked_fill_),
311 ("masked_fill_.Tensor", masked_fill_),
312 ("masked_scatter", masked_scatter),
313 ("masked_scatter_", masked_scatter_),
314 ("masked_select", masked_select),
315 ("max", max),
316 ("max.dim", max_dim),
317 ("max_pool2d_backward", max_pool2d_backward),
318 ("max_pool2d_with_indices", max_pool2d_with_indices),
319 ("max_pool3d_backward", max_pool3d_backward),
320 ("max_pool3d_with_indices", max_pool3d_with_indices),
321 ("maximum", maximum),
322 ("mean", mean),
323 ("mean.dim", mean_dim),
324 ("min", min),
325 ("min.dim", min_dim),
326 ("minimum", minimum),
327 ("mm", mm),
328 ("mm.out", mm_out),
329 ("mse_loss", mse_loss),
330 ("mul.Tensor", mul),
331 ("mul_.Tensor", mul_),
332 ("multinomial", multinomial),
333 ("mv", mv),
334 ("nan_to_num", nan_to_num),
335 ("native_batch_norm", batch_norm),
336 ("native_batch_norm_backward", batch_norm_backward),
337 ("native_dropout", dropout),
338 ("native_dropout_backward", dropout_backward),
339 ("native_group_norm", group_norm),
340 ("native_group_norm_backward", group_norm_backward),
341 ("native_layer_norm", layer_norm),
342 ("native_layer_norm_backward", layer_norm_backward),
343 ("ne.Scalar", ne_scalar),
344 ("ne.Tensor", ne),
345 ("neg", neg),
346 ("neg_", neg_),
347 ("new_full.Tensor", new_full),
348 ("nll_loss_backward", nll_loss_backward),
349 ("nll_loss_forward", nll_loss_forward),
350 ("nll_loss_nd_forward", nll_loss_nd_forward),
351 ("nll_loss_nd_backward", nll_loss_nd_backward),
352 ("nll_loss2d_backward", nll_loss2d_backward),
353 ("nll_loss2d_forward", nll_loss2d_forward),
354 ("nonzero", nonzero),
355 ("normal.Tensor_float", normal_tensor_float),
356 ("normal.Tensor_Tensor", normal_tensor_tensor),
357 ("normal.float_Tensor", normal_float_tensor),
358 ("normal_", normal_),
359 ("normed_cumsum", normed_cumsum),
360 ("one_hot", one_hot),
361 ("ones", ones),
362 ("ones_like", ones_like),
363 ("pad", pad),
364 ("pixel_shuffle", pixel_shuffle),
365 ("pixel_unshuffle", pixel_unshuffle),
366 ("pixel_unshuffle.out", pixel_unshuffle_out),
367 ("polar", polar),
368 ("pow.Scalar", pow_scalar),
369 ("pow.Tensor_Scalar", pow_tensor_scalar),
370 ("pow.Tensor_Tensor", pow_tensor_tensor),
371 ("pow_.Scalar", pow_tensor_scalar_),
372 ("pow_.Tensor", pow_tensor_tensor_),
373 ("prelu", prelu),
374 ("prod", prod),
375 ("prod.dim_int", prod_dim),
376 ("quantile", quantile),
377 ("rand", rand),
378 ("rand_like", rand_like),
379 ("randn", randn),
380 ("randn_like", randn_like),
381 ("randperm", randperm),
382 ("reciprocal", reciprocal),
383 ("reciprocal_", reciprocal_),
384 ("reflection_pad1d", reflection_pad1d),
385 ("reflection_pad1d.out", reflection_pad1d_out),
386 ("reflection_pad2d", reflection_pad2d),
387 ("reflection_pad2d.out", reflection_pad2d_out),
388 ("relu", relu),
389 ("relu_", relu_),
390 ("relu6", relu6),
391 ("remainder.Scalar", remainder),
392 ("remainder.Scalar_Tensor", remainder),
393 ("remainder.Tensor", remainder),
394 ("remainder_.Scalar", remainder_),
395 ("remainder_.Tensor", remainder_),
396 ("repeat", repeat),
397 ("repeat_interleave.self_int", repeat_interleave_self_int),
398 ("repeat_interleave.self_Tensor", repeat_interleave_self_tensor),
399 ("repeat_interleave.Tensor", repeat_interleave_tensor),
400 ("replication_pad1d", replication_pad1d),
401 ("replication_pad1d.out", replication_pad1d_out),
402 ("replication_pad3d", replication_pad3d),
403 ("resolve_conj", resolve_conj),
404 ("resolve_neg", resolve_neg),
405 ("rms_norm", rms_norm),
406 ("roll", roll),
407 ("round", round),
408 ("round_", round_),
409 ("round.out", round_out),
410 ("rrelu_with_noise_backward", rrelu_with_noise_backward),
411 ("rsqrt", rsqrt),
412 ("rsqrt_", rsqrt_),
413 ("scaled_softmax_backward", scaled_softmax_backward),
414 ("scaled_softmax_forward", scaled_softmax_forward),
415 ("scatter.reduce", scatter),
416 ("scatter.src", scatter),
417 ("scatter_.reduce", scatter_),
418 ("scatter_.src", scatter_),
419 ("scatter_add_", scatter_add_),
420 ("select_backward", select_backward),
421 ("select_scatter", select_scatter),
422 ("selu", selu),
423 ("selu_", selu_),
424 ("sgn_", sgn_),
425 ("sigmoid", sigmoid),
426 ("sigmoid_", sigmoid_),
427 ("sigmoid_backward", sigmoid_backward),
428 ("signbit", signbit),
429 ("signbit.out", signbit_out),
430 ("silu", silu),
431 ("silu_", silu_),
432 ("silu_backward", silu_backward),
433 ("sin", sin),
434 ("sin_", sin_),
435 ("sinh_", sinh_),
436 ("slice_backward", slice_backward),
437 ("slice_scatter", slice_scatter),
438 ("soft_margin_loss", soft_margin_loss),
439 ("softplus", softplus),
440 ("softshrink", softshrink),
441 ("softshrink.out", softshrink_out),
442 ("sort", sort),
443 ("sort.stable", sort_stable),
444 ("special_i0e", special_i0e),
445 ("special_i0e.out", special_i0e_out),
446 ("special_i1", special_i1),
447 ("special_i1.out", special_i1_out),
448 ("sqrt", sqrt),
449 ("sqrt_", sqrt_),
450 ("square", square),
451 ("square_", square_),
452 ("square.out", square_out),
453 ("stack", stack),
454 ("std.correction", std),
455 ("sub.Tensor", sub),
456 ("sub_.Tensor", sub_),
457 ("sum", sum),
458 ("sum.IntList_out", sum_dim_out),
459 ("sum.dim_IntList", sum_dim),
460 ("sum.out", sum_out),
461 ("t_copy", t_copy),
462 ("t_copy.out", t_copy_out),
463 ("tan", tan),
464 ("tan_", tan_),
465 ("tanh", tanh),
466 ("tanh_", tanh_),
467 ("tanh_backward", tanh_backward),
468 ("threshold", threshold),
469 ("threshold_backward", threshold_backward),
470 ("tile", tile),
471 ("topk", topk),
472 ("trace", trace),
473 ("tril", tril),
474 ("triu", triu),
475 ("triu_", triu_),
476 ("true_divide.Scalar", true_divide),
477 ("true_divide.Tensor", true_divide),
478 ("true_divide_.Scalar", true_divide_),
479 ("true_divide_.Tensor", true_divide_),
480 ("unfold_backward", unfold_backward),
481 ("uniform_", uniform_),
482 ("unique_consecutive", unique_consecutive),
483 ("upsample_bicubic2d", upsample_bicubic2d),
484 ("upsample_linear1d", upsample_linear1d),
485 ("upsample_nearest1d", upsample_nearest1d),
486 ("upsample_nearest2d", upsample_nearest2d),
487 ("upsample_nearest3d", upsample_nearest3d),
488 ("var_mean.correction", var_mean),
489 ("var", var),
490 ("var.correction", var_correction),
491 ("var.dim", var_dim),
492 ("vdot", vdot),
493 ("vstack", vstack),
494 ("where.self", where_self),
495 ("where.self_out", where_self_out),
496 ("zero", zero),
497 ("zero_", zero_),
498 ("zero.out", zero_out),
499 ("zeros", zeros),
500 ("zeros_like", zeros_like),
501)
503# Cache mapping from function name -> list of _FULL_CONFIG entries for quick lookup
504FULL_CONFIG_BY_FUNC = {}
505for _item in _FULL_CONFIG:
506 if not _item or len(_item) < 2:
507 continue
508 fn = _item[1]
509 func_name = fn.__name__ if hasattr(fn, "__name__") else str(fn)
510 FULL_CONFIG_BY_FUNC.setdefault(func_name, []).append(_item)
512# Friendly names for only_enable(include=[...]) when the registered impl is *.out
513for _alias, _target in (
514 ("softmax", "softmax_out"),
515 ("softmax_backward", "softmax_backward_out"),
516 ("log_softmax", "log_softmax_out"),
517 ("log_softmax_backward", "log_softmax_backward_out"),
518):
519 if _target in FULL_CONFIG_BY_FUNC:
520 FULL_CONFIG_BY_FUNC.setdefault(_alias, []).extend(FULL_CONFIG_BY_FUNC[_target])
523def enable(
524 lib=aten_lib,
525 unused=None,
526 registrar=registrar,
527 record=False,
528 once=False,
529 path=None,
530):
531 """Register all FlagGems ops except those explicitly excluded.
533 Args:
534 lib: torch.library.Library instance to register into. Defaults to the
535 global `aten_lib` (IMPL mode).
536 unused: Which ops to skip. Supported forms:
537 - list/tuple/set of function names (e.g., ["masked_fill", "mul"]).
538 - str path to a YAML file ending with .yml/.yaml containing an
539 `exclude:` list.
540 - "default" or None: auto-load vendor/arch-specific
541 runtime/backend/_<vendor>/[<arch>/]enable_configs.yaml if present.
542 registrar: Registrar class; defaults to `Register`.
543 record: Whether to enable FlagGems logging.
544 once: When True, log only once.
545 path: Optional log output path when recording.
547 Notes:
548 - If the exclude list/YAML resolves to empty, all ops are registered.
549 """
550 global current_work_registrar
551 exclude_ops = resolve_user_setting(unused, "exclude")
552 current_work_registrar = registrar(
553 _FULL_CONFIG,
554 user_include_ops=[],
555 user_exclude_ops=exclude_ops,
556 cpp_patched_ops=list(set(aten_patch_list)),
557 lib=lib,
558 )
559 setup_flaggems_logging(path=path, record=record, once=once)
562def only_enable(
563 lib=aten_lib,
564 include=None,
565 registrar=registrar,
566 record=False,
567 once=False,
568 path=None,
569):
570 """Register only the specified FlagGems ops and skip the rest.
572 Args:
573 lib: torch.library.Library instance to register into. Defaults to the
574 global `aten_lib` (IMPL mode).
575 include: Which ops to register. Supported forms:
576 - list/tuple/set of function names (e.g., ["rms_norm", "softmax"]).
577 - str path to a YAML file ending with .yml/.yaml (expects a list or
578 an `include:` key).
579 - "default" or None: auto-load vendor/arch-specific
580 runtime/backend/_<vendor>/[<arch>/]only_enable_configs.yaml if present.
581 registrar: Registrar class; defaults to `Register`.
582 record: Whether to enable FlagGems logging.
583 once: When True, log only once.
584 path: Optional log output path when recording.
586 Classic usage:
587 - Only register a few ops:
588 only_enable(include=["rms_norm", "softmax"])
589 - Use vendor default YAML:
590 only_enable(include="default") # or include=None
591 - Use a custom YAML:
592 only_enable(include="/path/to/only_enable.yaml")
594 Notes:
595 - If the include list/YAML resolves to empty or none of the names match
596 known ops, the function warns and returns without registering.
597 """
598 include_ops = resolve_user_setting(include, "include")
599 if not include_ops:
600 warnings.warn(
601 "only_enable failed: No include entries resolved from list or yaml."
602 )
603 return
605 global current_work_registrar
606 current_work_registrar = registrar(
607 _FULL_CONFIG,
608 user_include_ops=include_ops,
609 user_exclude_ops=[],
610 cpp_patched_ops=list(set(aten_patch_list)),
611 full_config_by_func=FULL_CONFIG_BY_FUNC,
612 lib=lib,
613 )
614 setup_flaggems_logging(path=path, record=record, once=once)
617class use_gems:
618 """
619 The 'include' parameter has higher priority than 'exclude'.
620 When 'include' is not None, use_gems will not process 'exclude'.
621 """
623 def __init__(self, exclude=None, include=None, record=False, once=False, path=None):
624 self.lib = torch.library.Library("aten", "IMPL")
625 self.exclude = exclude if isinstance(exclude, (list, tuple, set, str)) else []
626 self.include = include if isinstance(include, (list, tuple, set, str)) else []
627 self.registrar = Register
628 self.record = record
629 self.once = once
630 self.path = path
632 def __enter__(self):
633 if self.include:
634 only_enable(
635 lib=self.lib,
636 include=self.include,
637 registrar=self.registrar,
638 record=self.record,
639 once=self.once,
640 path=self.path,
641 )
642 else:
643 enable(
644 lib=self.lib,
645 unused=self.exclude,
646 registrar=self.registrar,
647 record=self.record,
648 once=self.once,
649 path=self.path,
650 )
652 def __exit__(self, exc_type, exc_val, exc_tb):
653 global current_work_registrar
654 if torch.__version__ >= "2.5":
655 self.lib._destroy()
656 del self.lib
657 del self.exclude
658 del self.include
659 del self.registrar
660 del current_work_registrar
661 if self.record:
662 teardown_flaggems_logging()
664 @property
665 def experimental_ops(self):
666 import flag_gems.experimental_ops
668 return flag_gems.experimental_ops
671def all_registered_ops():
672 return current_work_registrar.get_all_ops()
675def all_registered_keys():
676 return current_work_registrar.get_all_keys()
679__all__ = [
680 "enable",
681 "only_enable",
682 "use_gems",
683 "all_registered_ops",
684 "all_registered_keys",
685]