Coverage for src/flag_gems/__init__.py: 90%
77 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1# ruff: noqa: F405
2import warnings
4import torch
5from packaging import version
7from flag_gems import testing # noqa: F401
8from flag_gems import runtime
9from flag_gems.config import aten_patch_list, resolve_user_setting
10from flag_gems.experimental_ops import * # noqa: F403
11from flag_gems.fused import * # noqa: F403
12from flag_gems.logging_utils import setup_flaggems_logging, teardown_flaggems_logging
13from flag_gems.modules import * # noqa: F403
14from flag_gems.ops import * # noqa: F403
15from flag_gems.patches import * # noqa: F403
16from flag_gems.runtime.register import Register
18__version__ = "5.0.2"
19device = runtime.device.name
20vendor_name = runtime.device.vendor_name
21aten_lib = torch.library.Library("aten", "IMPL")
22registrar = Register
23current_work_registrar = None
24runtime.replace_customized_ops(globals())
25AUTOGRAD_DISPATCH_KEY = torch._C.DispatchKey.Autograd.name
28def torch_ge(v):
29 return version.parse(torch.__version__) >= version.parse(v)
32_FULL_CONFIG = (
33 ("__ior__.Scalar", bitwise_or_scalar_),
34 ("__ior__.Tensor", bitwise_or_tensor_),
35 ("__or__.Scalar", bitwise_or_scalar),
36 ("__or__.Tensor", bitwise_or_tensor),
37 ("_assert_async", _assert_async),
38 ("_conv_depthwise2d", _conv_depthwise2d),
39 ("_euclidean_dist", _euclidean_dist),
40 ("_flash_attention_forward", flash_attention_forward),
41 (
42 "_functional_sym_constrain_range_for_size",
43 _functional_sym_constrain_range_for_size,
44 ),
45 ("_grouped_mm", group_mm),
46 ("_index_put_impl_", _index_put_impl_),
47 ("_is_all_true", _is_all_true),
48 ("_log_softmax", log_softmax),
49 ("_log_softmax.out", log_softmax_out),
50 ("_log_softmax_backward_data", log_softmax_backward),
51 ("_log_softmax_backward_data.out", log_softmax_backward_out),
52 ("_safe_softmax", _safe_softmax),
53 ("_softmax", softmax),
54 ("_softmax.out", softmax_out),
55 ("_softmax_backward_data", softmax_backward),
56 ("_softmax_backward_data.out", softmax_backward_out),
57 (
58 "_to_copy",
59 to_copy,
60 lambda: version.parse(torch.__version__) >= version.parse("2.4"),
61 ),
62 ("_unique2", _unique2),
63 ("_upsample_bicubic2d_aa", _upsample_bicubic2d_aa),
64 ("_upsample_bicubic2d_aa_backward", _upsample_bicubic2d_aa_backward),
65 ("_upsample_nearest_exact1d", _upsample_nearest_exact1d),
66 ("_weight_norm_interface", weight_norm_interface),
67 ("_weight_norm_interface_backward", weight_norm_interface_backward),
68 ("abs", abs),
69 ("abs_", abs_),
70 ("absolute", absolute),
71 ("acos", acos),
72 ("add.Tensor", add),
73 ("add_.Tensor", add_),
74 ("addcdiv", addcdiv),
75 ("addcdiv.out", addcdiv_out),
76 ("addcmul", addcmul),
77 ("addcmul.out", addcmul_out),
78 ("addmv", addmv),
79 ("addmv.out", addmv_out),
80 ("addmm", addmm),
81 ("addmm.out", addmm_out),
82 ("addmm.dtype", addmm_dtype),
83 ("addmm.dtype_out", addmm_dtype_out),
84 ("addr", addr),
85 ("affine_grid_generator", affine_grid_generator),
86 ("alias_copy", alias_copy),
87 ("all", all),
88 ("all.dim", all_dim),
89 ("all.dims", all_dims),
90 ("allclose", allclose),
91 ("amax", amax),
92 ("aminmax", aminmax),
93 ("angle", angle),
94 ("any", any),
95 ("any.dim", any_dim),
96 ("any.dims", any_dims),
97 ("arange", arange),
98 ("arange.start", arange_start),
99 ("arange.start_step", arange_start),
100 ("arcsinh", arcsinh),
101 ("arcsinh.out", arcsinh_out),
102 ("arcsinh_", arcsinh_),
103 ("argmax", argmax),
104 ("argmin", argmin),
105 ("asinh", asinh),
106 ("asinh.out", asinh_out),
107 ("asinh_", asinh_),
108 ("atan", atan),
109 ("atan_", atan_),
110 ("atan2", atan2),
111 ("atan2.out", atan2_out),
112 ("arctanh_", arctanh_),
113 ("avg_pool2d", avg_pool2d),
114 ("avg_pool2d_backward", avg_pool2d_backward),
115 ("avg_pool3d", avg_pool3d),
116 ("avg_pool3d_backward", avg_pool3d_backward),
117 ("baddbmm", baddbmm),
118 ("bernoulli_.float", bernoulli_),
119 ("bincount", bincount),
120 ("bitwise_and.Scalar", bitwise_and_scalar),
121 ("bitwise_and.Scalar_Tensor", bitwise_and_scalar_tensor),
122 ("bitwise_and.Tensor", bitwise_and_tensor),
123 ("bitwise_and_.Scalar", bitwise_and_scalar_),
124 ("bitwise_and_.Tensor", bitwise_and_tensor_),
125 ("bitwise_left_shift", bitwise_left_shift),
126 ("bitwise_not", bitwise_not),
127 ("bitwise_not_", bitwise_not_),
128 ("bitwise_or.Scalar", bitwise_or_scalar),
129 ("bitwise_or.Scalar_Tensor", bitwise_or_scalar_tensor),
130 ("bitwise_or.Tensor", bitwise_or_tensor),
131 ("bitwise_or_.Scalar", bitwise_or_scalar_),
132 ("bitwise_or_.Tensor", bitwise_or_tensor_),
133 ("bitwise_right_shift", bitwise_right_shift),
134 ("bmm", bmm),
135 ("bmm.out", bmm_out),
136 ("cat", cat),
137 ("cat.out", cat_out),
138 ("celu", celu),
139 ("celu_", celu_),
140 ("ceil", ceil),
141 ("ceil_", ceil_),
142 ("ceil.out", ceil_out),
143 ("clamp", clamp),
144 ("clamp.Tensor", clamp_tensor),
145 ("clamp_min", clamp_min),
146 ("clamp_", clamp_),
147 ("clamp_.Tensor", clamp_tensor_),
148 ("clamp_min_", clamp_min_),
149 ("clip", clip),
150 ("clip_", clip_),
151 ("col2im", col2im),
152 ("concatenate", concatenate),
153 ("conj_physical", conj_physical),
154 ("constant_pad_nd", constant_pad_nd),
155 # ("contiguous", contiguous),
156 ("conv1d", conv1d),
157 ("conv1d.padding", conv1d),
158 ("conv2d", conv2d),
159 ("conv2d.padding", conv2d),
160 ("conv_transpose2d", conv_transpose2d),
161 ("conv3d", conv3d),
162 ("conv3d.padding", conv3d),
163 ("conv_transpose1d", conv_transpose1d),
164 (
165 "copy_",
166 copy_,
167 lambda: version.parse(torch.__version__) >= version.parse("2.4"),
168 ),
169 ("cos", cos),
170 ("cos_", cos_),
171 ("cosh", cosh),
172 ("cosh_", cosh_),
173 ("cosh.out", cosh_out),
174 ("copysign", copysign),
175 ("copysign.out", copysign_out),
176 ("count_nonzero", count_nonzero),
177 ("ctc_loss.IntList", ctc_loss, None, (AUTOGRAD_DISPATCH_KEY,)),
178 ("ctc_loss.Tensor", ctc_loss, None, (AUTOGRAD_DISPATCH_KEY,)),
179 ("cudnn_convolution", cudnn_convolution),
180 ("cummax", cummax),
181 ("cummin", cummin),
182 ("cumprod", cumprod),
183 ("cumprod_", cumprod_),
184 ("cumsum", cumsum),
185 ("cumsum.out", cumsum_out),
186 ("diag", diag),
187 ("diag_embed", diag_embed),
188 ("diagonal_backward", diagonal_backward),
189 ("diff", diff),
190 ("digamma_", digamma_),
191 ("div.Scalar", true_divide),
192 ("div.Scalar_mode", div_mode),
193 ("div.Tensor", true_divide),
194 ("div.Tensor_mode", div_mode),
195 ("div.out", true_divide_out),
196 ("div_.Scalar", true_divide_),
197 ("div_.Scalar_mode", div_mode_),
198 ("div_.Tensor", true_divide_),
199 ("div_.Tensor_mode", div_mode_),
200 ("divide.Scalar", true_divide),
201 ("divide.Scalar_mode", div_mode),
202 ("divide.Tensor", true_divide),
203 ("divide.Tensor_mode", div_mode),
204 ("divide_.Scalar", true_divide_),
205 ("divide_.Scalar_mode", div_mode_),
206 ("divide_.Tensor", true_divide_),
207 ("divide_.Tensor_mode", div_mode_),
208 ("dot", dot),
209 ("elu", elu),
210 ("elu_", elu_),
211 ("elu_backward", elu_backward),
212 ("embedding", embedding),
213 ("embedding_backward", embedding_backward),
214 ("embedding_dense_backward", embedding_dense_backward),
215 ("eq.Scalar", eq_scalar),
216 ("eq.Tensor", eq),
217 ("equal", equal),
218 ("erf", erf),
219 ("erf_", erf_),
220 ("exp", exp),
221 ("exp_", exp_),
222 ("exp.out", exp_out),
223 ("exp2", exp2),
224 ("exp2_", exp2_),
225 ("expm1", expm1),
226 ("expm1_", expm1_),
227 ("expm1.out", expm1_out),
228 ("exponential_", exponential_),
229 ("feature_dropout", feature_dropout),
230 ("feature_dropout_", feature_dropout_),
231 ("eye", eye),
232 ("eye.m", eye_m),
233 ("fill.Scalar", fill_scalar),
234 ("fill.Scalar_out", fill_scalar_out),
235 ("fill.Tensor", fill_tensor),
236 ("fill.Tensor_out", fill_tensor_out),
237 ("fill_.Scalar", fill_scalar_),
238 ("fill_.Tensor", fill_tensor_),
239 ("flip", flip),
240 ("floor_", floor_),
241 ("floor_divide", floor_divide),
242 ("floor_divide.Scalar", floor_divide),
243 ("floor_divide_.Scalar", floor_divide_),
244 ("floor_divide_.Tensor", floor_divide_),
245 ("fmin", fmin),
246 ("fmin.out", fmin_out),
247 ("fmod.Scalar", fmod_scalar),
248 ("fmod.Tensor", fmod_tensor),
249 ("fmod_.Scalar", fmod_scalar_),
250 ("fmod_.Tensor", fmod_tensor_),
251 ("full", full),
252 ("full_like", full_like),
253 ("gather", gather),
254 ("gather_backward", gather_backward),
255 ("gcd", gcd),
256 ("gcd.out", gcd_out),
257 ("ge.Scalar", ge_scalar),
258 ("ge.Tensor", ge),
259 ("gelu", gelu),
260 ("gelu_", gelu_),
261 ("gelu_backward", gelu_backward),
262 ("glu", glu),
263 ("glu_backward", glu_backward),
264 ("greater.Scalar", greater_scalar),
265 ("greater.Tensor", greater),
266 ("greater.Scalar_out", greater_scalar_out),
267 ("greater.out", greater_out),
268 ("grid_sample", grid_sample),
269 ("gt.Scalar", gt_scalar),
270 ("gt.Tensor", gt),
271 ("hardsigmoid", hardsigmoid),
272 ("hardsigmoid.out", hardsigmoid_out),
273 ("hardswish_", hardswish_),
274 ("histc", histc),
275 ("hstack", hstack),
276 ("hypot", hypot),
277 ("i0", i0),
278 ("i0.out", i0_out),
279 ("i0_", i0_),
280 ("index.Tensor", index),
281 ("index_add", index_add),
282 ("index_add_", index_add_),
283 ("index_copy", index_copy),
284 ("index_copy_", index_copy_),
285 ("index_put", index_put),
286 ("index_put_", index_put_),
287 ("index_select", index_select),
288 ("isclose", isclose),
289 ("isfinite", isfinite),
290 ("isin.Scalar_Tensor", isin),
291 ("isin.Tensor_Scalar", isin),
292 ("isin.Tensor_Tensor", isin),
293 ("isinf", isinf),
294 ("isnan", isnan),
295 ("isneginf", isneginf),
296 ("isneginf.out", isneginf_out),
297 ("kron", kron),
298 ("le.Scalar", le_scalar),
299 ("le.Tensor", le),
300 ("leaky_relu", leaky_relu),
301 ("leaky_relu_", leaky_relu_),
302 ("leaky_relu.out", leaky_relu_out),
303 ("lerp.Scalar", lerp_scalar),
304 ("lerp.Tensor", lerp_tensor),
305 ("lerp_.Scalar", lerp_scalar_),
306 ("lerp_.Tensor", lerp_tensor_),
307 ("lift_fresh_copy", lift_fresh_copy),
308 ("linalg_vector_norm", vector_norm),
309 ("linspace", linspace),
310 ("log", log),
311 ("log10", log10),
312 ("log10_", log10_),
313 ("log10.out", log10_out),
314 ("log1p_", log1p_),
315 ("log_sigmoid", log_sigmoid),
316 ("logaddexp", logaddexp),
317 ("logaddexp.out", logaddexp_out),
318 ("logical_and", logical_and),
319 ("logical_and_", logical_and_),
320 ("logical_not", logical_not),
321 ("logical_or", logical_or),
322 ("logical_or_", logical_or_),
323 ("logical_xor", logical_xor),
324 ("logit", logit),
325 ("logit.out", logit_out),
326 ("logit_", logit_),
327 ("logspace", logspace),
328 ("logsumexp", logsumexp),
329 ("lt.Scalar", lt_scalar),
330 ("lt.Tensor", lt),
331 ("margin_ranking_loss", margin_ranking_loss),
332 ("masked_fill.Scalar", masked_fill),
333 ("masked_fill.Tensor", masked_fill),
334 ("masked_fill_.Scalar", masked_fill_),
335 ("masked_fill_.Tensor", masked_fill_),
336 ("masked_scatter", masked_scatter),
337 ("masked_scatter_", masked_scatter_),
338 ("masked_select", masked_select),
339 ("max", max),
340 ("max.dim", max_dim),
341 ("max_pool2d_backward", max_pool2d_backward),
342 ("max_pool2d_with_indices", max_pool2d_with_indices),
343 ("max_pool3d_backward", max_pool3d_backward),
344 ("max_pool3d_with_indices", max_pool3d_with_indices),
345 ("maximum", maximum),
346 ("mean", mean),
347 ("mean.dim", mean_dim),
348 ("median", median),
349 ("median.out", median_out),
350 ("median.dim", median_dim),
351 ("median.dim_values", median_dim_values),
352 ("min", min),
353 ("min.dim", min_dim),
354 ("minimum", minimum),
355 ("mm", mm),
356 ("mm.out", mm_out),
357 ("mse_loss", mse_loss),
358 ("mul.Tensor", mul),
359 ("mul_.Tensor", mul_),
360 ("multinomial", multinomial),
361 ("mv", mv),
362 ("nan_to_num", nan_to_num),
363 ("native_batch_norm", batch_norm),
364 ("native_batch_norm_backward", batch_norm_backward),
365 ("native_dropout", dropout),
366 ("native_dropout_backward", dropout_backward),
367 ("native_group_norm", group_norm),
368 ("native_group_norm_backward", group_norm_backward),
369 ("native_layer_norm", layer_norm),
370 ("native_layer_norm_backward", layer_norm_backward),
371 ("ne.Scalar", ne_scalar),
372 ("ne.Tensor", ne),
373 ("neg", neg),
374 ("neg_", neg_),
375 ("new_full.Tensor", new_full),
376 ("nll_loss_backward", nll_loss_backward),
377 ("nll_loss_forward", nll_loss_forward),
378 ("nll_loss_nd_forward", nll_loss_nd_forward),
379 ("nll_loss_nd_backward", nll_loss_nd_backward),
380 ("nll_loss2d_backward", nll_loss2d_backward),
381 ("nll_loss2d_forward", nll_loss2d_forward),
382 ("nonzero", nonzero),
383 ("nonzero_numpy", nonzero_numpy),
384 ("normal.Tensor_float", normal_tensor_float),
385 ("normal.Tensor_Tensor", normal_tensor_tensor),
386 ("normal.float_Tensor", normal_float_tensor),
387 ("normal_", normal_),
388 ("normed_cumsum", normed_cumsum),
389 ("one_hot", one_hot),
390 ("ones", ones),
391 ("ones_like", ones_like),
392 ("pad", pad),
393 ("pixel_shuffle", pixel_shuffle),
394 ("pixel_unshuffle", pixel_unshuffle),
395 ("pixel_unshuffle.out", pixel_unshuffle_out),
396 ("poisson", poisson),
397 ("polar", polar),
398 ("pow.Scalar", pow_scalar),
399 ("pow.Tensor_Scalar", pow_tensor_scalar),
400 ("pow.Tensor_Tensor", pow_tensor_tensor),
401 ("pow_.Scalar", pow_tensor_scalar_),
402 ("pow_.Tensor", pow_tensor_tensor_),
403 ("prelu", prelu),
404 ("prod", prod),
405 ("prod.dim_int", prod_dim),
406 ("quantile", quantile),
407 ("rand", rand),
408 ("rand_like", rand_like),
409 ("randn", randn),
410 ("randn_like", randn_like),
411 ("randperm", randperm),
412 ("reciprocal", reciprocal),
413 ("reciprocal_", reciprocal_),
414 ("reflection_pad1d", reflection_pad1d),
415 ("reflection_pad1d.out", reflection_pad1d_out),
416 ("reflection_pad2d", reflection_pad2d),
417 ("reflection_pad2d.out", reflection_pad2d_out),
418 ("relu", relu),
419 ("relu_", relu_),
420 ("relu6", relu6),
421 ("remainder.Scalar", remainder),
422 ("remainder.Scalar_Tensor", remainder),
423 ("remainder.Tensor", remainder),
424 ("remainder_.Scalar", remainder_),
425 ("remainder_.Tensor", remainder_),
426 ("repeat", repeat),
427 ("repeat_interleave.self_int", repeat_interleave_self_int),
428 ("repeat_interleave.self_Tensor", repeat_interleave_self_tensor),
429 ("repeat_interleave.Tensor", repeat_interleave_tensor),
430 ("replication_pad1d", replication_pad1d),
431 ("replication_pad1d.out", replication_pad1d_out),
432 ("replication_pad3d", replication_pad3d),
433 ("resolve_conj", resolve_conj),
434 ("resolve_neg", resolve_neg),
435 ("rms_norm", rms_norm),
436 ("roll", roll),
437 ("round", round),
438 ("round_", round_),
439 ("round.out", round_out),
440 ("rrelu_with_noise_backward", rrelu_with_noise_backward),
441 ("rsqrt", rsqrt),
442 ("rsqrt_", rsqrt_),
443 ("rsub.Scalar", rsub_scalar),
444 ("rsub.Tensor", rsub_tensor),
445 ("scaled_softmax_backward", scaled_softmax_backward),
446 ("scaled_softmax_forward", scaled_softmax_forward),
447 ("scatter.reduce", scatter),
448 ("scatter.src", scatter),
449 ("scatter_.reduce", scatter_),
450 ("scatter_.src", scatter_),
451 ("scatter_add_", scatter_add_),
452 ("scatter_reduce.two", scatter_reduce),
453 ("scatter_reduce_.two", scatter_reduce_),
454 ("scatter_reduce.two_out", scatter_reduce_out),
455 ("select_backward", select_backward),
456 ("select_scatter", select_scatter),
457 ("selu", selu),
458 ("selu_", selu_),
459 ("sgn_", sgn_),
460 ("sigmoid", sigmoid),
461 ("sigmoid_", sigmoid_),
462 ("sigmoid_backward", sigmoid_backward),
463 ("signbit", signbit),
464 ("signbit.out", signbit_out),
465 ("silu", silu),
466 ("silu_", silu_),
467 ("silu_backward", silu_backward),
468 ("sin", sin),
469 ("sin_", sin_),
470 ("sinh_", sinh_),
471 ("slice_backward", slice_backward),
472 ("slice_scatter", slice_scatter),
473 ("smooth_l1_loss", smooth_l1_loss),
474 ("smooth_l1_loss_backward", smooth_l1_loss_backward),
475 ("smooth_l1_loss.out", smooth_l1_loss_out),
476 ("soft_margin_loss", soft_margin_loss),
477 ("softplus", softplus),
478 ("softshrink", softshrink),
479 ("softshrink.out", softshrink_out),
480 ("sort", sort),
481 ("sort.stable", sort_stable),
482 ("special_i0e", special_i0e),
483 ("special_i0e.out", special_i0e_out),
484 ("special_i1", special_i1),
485 ("special_i1.out", special_i1_out),
486 ("sqrt", sqrt),
487 ("sqrt_", sqrt_),
488 ("square", square),
489 ("square_", square_),
490 ("square.out", square_out),
491 ("stack", stack),
492 ("std.correction", std),
493 ("sub.Tensor", sub),
494 ("sub_.Tensor", sub_),
495 ("sum", sum),
496 ("sum.IntList_out", sum_dim_out),
497 ("sum.dim_IntList", sum_dim),
498 ("sum.out", sum_out),
499 ("svd", svd),
500 ("t_copy", t_copy),
501 ("t_copy.out", t_copy_out),
502 ("tan", tan),
503 ("tan_", tan_),
504 ("tanh", tanh),
505 ("tanh_", tanh_),
506 ("tanh_backward", tanh_backward),
507 ("threshold", threshold),
508 ("threshold_backward", threshold_backward),
509 ("tile", tile),
510 ("topk", topk),
511 ("trace", trace),
512 ("tril", tril),
513 ("tril.out", tril_out),
514 ("tril_", tril_),
515 ("triu", triu),
516 ("triu_", triu_),
517 ("true_divide.Scalar", true_divide),
518 ("true_divide.Tensor", true_divide),
519 ("true_divide_.Scalar", true_divide_),
520 ("true_divide_.Tensor", true_divide_),
521 ("unfold_backward", unfold_backward),
522 ("uniform_", uniform_),
523 ("unique_consecutive", unique_consecutive),
524 ("upsample_bicubic2d", upsample_bicubic2d),
525 ("upsample_linear1d", upsample_linear1d),
526 ("upsample_nearest1d", upsample_nearest1d),
527 ("upsample_nearest2d", upsample_nearest2d),
528 ("upsample_nearest3d", upsample_nearest3d),
529 ("var_mean.correction", var_mean),
530 ("var", var),
531 ("var.correction", var_correction),
532 ("var.dim", var_dim),
533 ("vdot", vdot),
534 ("vstack", vstack),
535 ("where.self", where_self),
536 ("where.self_out", where_self_out),
537 ("zero", zero),
538 ("zero_", zero_),
539 ("zero.out", zero_out),
540 ("zeros", zeros),
541 ("zeros_like", zeros_like),
542)
544# Cache mapping from function name -> list of _FULL_CONFIG entries for quick lookup
545FULL_CONFIG_BY_FUNC = {}
546for _item in _FULL_CONFIG:
547 if not _item or len(_item) < 2:
548 continue
549 fn = _item[1]
550 func_name = fn.__name__ if hasattr(fn, "__name__") else str(fn)
551 FULL_CONFIG_BY_FUNC.setdefault(func_name, []).append(_item)
553# Friendly names for only_enable(include=[...]) when the registered impl is *.out
554for _alias, _target in (
555 ("softmax", "softmax_out"),
556 ("softmax_backward", "softmax_backward_out"),
557 ("log_softmax", "log_softmax_out"),
558 ("log_softmax_backward", "log_softmax_backward_out"),
559):
560 if _target in FULL_CONFIG_BY_FUNC:
561 FULL_CONFIG_BY_FUNC.setdefault(_alias, []).extend(FULL_CONFIG_BY_FUNC[_target])
564def enable(
565 lib=aten_lib,
566 unused=None,
567 registrar=registrar,
568 record=False,
569 once=False,
570 path=None,
571):
572 """Register all FlagGems ops except those explicitly excluded.
574 Args:
575 lib: torch.library.Library instance to register into. Defaults to the
576 global `aten_lib` (IMPL mode).
577 unused: Which ops to skip. Supported forms:
578 - list/tuple/set of function names (e.g., ["masked_fill", "mul"]).
579 - str path to a YAML file ending with .yml/.yaml containing an
580 `exclude:` list.
581 - "default" or None: auto-load vendor/arch-specific
582 runtime/backend/_<vendor>/[<arch>/]enable_configs.yaml if present.
583 registrar: Registrar class; defaults to `Register`.
584 record: Whether to enable FlagGems logging.
585 once: When True, log only once.
586 path: Optional log output path when recording.
588 Notes:
589 - If the exclude list/YAML resolves to empty, all ops are registered.
590 """
591 global current_work_registrar
592 exclude_ops = resolve_user_setting(unused, "exclude")
593 current_work_registrar = registrar(
594 _FULL_CONFIG,
595 user_include_ops=[],
596 user_exclude_ops=exclude_ops,
597 cpp_patched_ops=list(set(aten_patch_list)),
598 lib=lib,
599 )
600 setup_flaggems_logging(path=path, record=record, once=once)
603def only_enable(
604 lib=aten_lib,
605 include=None,
606 registrar=registrar,
607 record=False,
608 once=False,
609 path=None,
610):
611 """Register only the specified FlagGems ops and skip the rest.
613 Args:
614 lib: torch.library.Library instance to register into. Defaults to the
615 global `aten_lib` (IMPL mode).
616 include: Which ops to register. Supported forms:
617 - list/tuple/set of function names (e.g., ["rms_norm", "softmax"]).
618 - str path to a YAML file ending with .yml/.yaml (expects a list or
619 an `include:` key).
620 - "default" or None: auto-load vendor/arch-specific
621 runtime/backend/_<vendor>/[<arch>/]only_enable_configs.yaml if present.
622 registrar: Registrar class; defaults to `Register`.
623 record: Whether to enable FlagGems logging.
624 once: When True, log only once.
625 path: Optional log output path when recording.
627 Classic usage:
628 - Only register a few ops:
629 only_enable(include=["rms_norm", "softmax"])
630 - Use vendor default YAML:
631 only_enable(include="default") # or include=None
632 - Use a custom YAML:
633 only_enable(include="/path/to/only_enable.yaml")
635 Notes:
636 - If the include list/YAML resolves to empty or none of the names match
637 known ops, the function warns and returns without registering.
638 """
639 include_ops = resolve_user_setting(include, "include")
640 if not include_ops:
641 warnings.warn(
642 "only_enable failed: No include entries resolved from list or yaml."
643 )
644 return
646 global current_work_registrar
647 current_work_registrar = registrar(
648 _FULL_CONFIG,
649 user_include_ops=include_ops,
650 user_exclude_ops=[],
651 cpp_patched_ops=list(set(aten_patch_list)),
652 full_config_by_func=FULL_CONFIG_BY_FUNC,
653 lib=lib,
654 )
655 setup_flaggems_logging(path=path, record=record, once=once)
658class use_gems:
659 """
660 The 'include' parameter has higher priority than 'exclude'.
661 When 'include' is not None, use_gems will not process 'exclude'.
662 """
664 def __init__(self, exclude=None, include=None, record=False, once=False, path=None):
665 self.lib = torch.library.Library("aten", "IMPL")
666 self.exclude = exclude if isinstance(exclude, (list, tuple, set, str)) else []
667 self.include = include if isinstance(include, (list, tuple, set, str)) else []
668 self.registrar = Register
669 self.record = record
670 self.once = once
671 self.path = path
673 def __enter__(self):
674 if self.include:
675 only_enable(
676 lib=self.lib,
677 include=self.include,
678 registrar=self.registrar,
679 record=self.record,
680 once=self.once,
681 path=self.path,
682 )
683 else:
684 enable(
685 lib=self.lib,
686 unused=self.exclude,
687 registrar=self.registrar,
688 record=self.record,
689 once=self.once,
690 path=self.path,
691 )
693 def __exit__(self, exc_type, exc_val, exc_tb):
694 global current_work_registrar
695 if torch.__version__ >= "2.5":
696 self.lib._destroy()
697 del self.lib
698 del self.exclude
699 del self.include
700 del self.registrar
701 del current_work_registrar
702 if self.record:
703 teardown_flaggems_logging()
705 @property
706 def experimental_ops(self):
707 import flag_gems.experimental_ops
709 return flag_gems.experimental_ops
712def all_registered_ops():
713 return current_work_registrar.get_all_ops()
716def all_registered_keys():
717 return current_work_registrar.get_all_keys()
720__all__ = [
721 "enable",
722 "only_enable",
723 "use_gems",
724 "all_registered_ops",
725 "all_registered_keys",
726]