Coverage for src/flag_gems/__init__.py: 90%
77 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1# ruff: noqa: F405
2import warnings
4import torch
5from packaging import version
7from flag_gems import testing # noqa: F401
8from flag_gems import runtime
9from flag_gems.config import aten_patch_list, resolve_user_setting
10from flag_gems.experimental_ops import * # noqa: F403
11from flag_gems.fused import * # noqa: F403
12from flag_gems.logging_utils import setup_flaggems_logging, teardown_flaggems_logging
13from flag_gems.modules import * # noqa: F403
14from flag_gems.ops import * # noqa: F403
15from flag_gems.patches import * # noqa: F403
16from flag_gems.runtime.register import Register
18__version__ = "5.0.2"
19device = runtime.device.name
20vendor_name = runtime.device.vendor_name
21aten_lib = torch.library.Library("aten", "IMPL")
22registrar = Register
23current_work_registrar = None
24runtime.replace_customized_ops(globals())
25AUTOGRAD_DISPATCH_KEY = torch._C.DispatchKey.Autograd.name
28def torch_ge(v):
29 return version.parse(torch.__version__) >= version.parse(v)
32_FULL_CONFIG = (
33 ("__ior__.Scalar", bitwise_or_scalar_),
34 ("__ior__.Tensor", bitwise_or_tensor_),
35 ("__or__.Scalar", bitwise_or_scalar),
36 ("__or__.Tensor", bitwise_or_tensor),
37 ("_assert_async", _assert_async),
38 ("_conv_depthwise2d", _conv_depthwise2d),
39 ("_euclidean_dist", _euclidean_dist),
40 ("_flash_attention_forward", flash_attention_forward),
41 (
42 "_functional_sym_constrain_range_for_size",
43 _functional_sym_constrain_range_for_size,
44 ),
45 ("_grouped_mm", group_mm),
46 ("_index_put_impl_", _index_put_impl_),
47 ("_is_all_true", _is_all_true),
48 ("_log_softmax", log_softmax),
49 ("_log_softmax.out", log_softmax_out),
50 ("_log_softmax_backward_data", log_softmax_backward),
51 ("_log_softmax_backward_data.out", log_softmax_backward_out),
52 ("_safe_softmax", _safe_softmax),
53 ("_softmax", softmax),
54 ("_softmax.out", softmax_out),
55 ("_softmax_backward_data", softmax_backward),
56 ("_softmax_backward_data.out", softmax_backward_out),
57 (
58 "_to_copy",
59 to_copy,
60 lambda: version.parse(torch.__version__) >= version.parse("2.4"),
61 ),
62 ("_unique2", _unique2),
63 ("_upsample_bicubic2d_aa", _upsample_bicubic2d_aa),
64 ("_upsample_bicubic2d_aa_backward", _upsample_bicubic2d_aa_backward),
65 ("_upsample_nearest_exact1d", _upsample_nearest_exact1d),
66 ("_weight_norm_interface", weight_norm_interface),
67 ("_weight_norm_interface_backward", weight_norm_interface_backward),
68 ("abs", abs),
69 ("abs_", abs_),
70 ("absolute", absolute),
71 ("acos", acos),
72 ("add.Tensor", add),
73 ("add_.Tensor", add_),
74 ("addcdiv", addcdiv),
75 ("addcdiv.out", addcdiv_out),
76 ("addcmul", addcmul),
77 ("addcmul.out", addcmul_out),
78 ("addmv", addmv),
79 ("addmv.out", addmv_out),
80 ("addmm", addmm),
81 ("addmm.out", addmm_out),
82 ("addmm.dtype", addmm_dtype),
83 ("addmm.dtype_out", addmm_dtype_out),
84 ("addr", addr),
85 ("affine_grid_generator", affine_grid_generator),
86 ("alias_copy", alias_copy),
87 ("all", all),
88 ("all.dim", all_dim),
89 ("all.dims", all_dims),
90 ("allclose", allclose),
91 ("amax", amax),
92 ("aminmax", aminmax),
93 ("angle", angle),
94 ("any", any),
95 ("any.dim", any_dim),
96 ("any.dims", any_dims),
97 ("arange", arange),
98 ("arange.start", arange_start),
99 ("arange.start_step", arange_start),
100 ("arcsinh", arcsinh),
101 ("arcsinh.out", arcsinh_out),
102 ("arcsinh_", arcsinh_),
103 ("argmax", argmax),
104 ("argmin", argmin),
105 ("asinh", asinh),
106 ("asinh.out", asinh_out),
107 ("asinh_", asinh_),
108 ("atan", atan),
109 ("atan_", atan_),
110 ("atan2", atan2),
111 ("atan2.out", atan2_out),
112 ("arctanh_", arctanh_),
113 ("avg_pool2d", avg_pool2d),
114 ("avg_pool2d_backward", avg_pool2d_backward),
115 ("avg_pool3d", avg_pool3d),
116 ("avg_pool3d_backward", avg_pool3d_backward),
117 ("baddbmm", baddbmm),
118 ("bernoulli_.float", bernoulli_),
119 ("bincount", bincount),
120 ("bitwise_and.Scalar", bitwise_and_scalar),
121 ("bitwise_and.Scalar_Tensor", bitwise_and_scalar_tensor),
122 ("bitwise_and.Tensor", bitwise_and_tensor),
123 ("bitwise_and_.Scalar", bitwise_and_scalar_),
124 ("bitwise_and_.Tensor", bitwise_and_tensor_),
125 ("bitwise_left_shift", bitwise_left_shift),
126 ("bitwise_not", bitwise_not),
127 ("bitwise_not_", bitwise_not_),
128 ("bitwise_or.Scalar", bitwise_or_scalar),
129 ("bitwise_or.Scalar_Tensor", bitwise_or_scalar_tensor),
130 ("bitwise_or.Tensor", bitwise_or_tensor),
131 ("bitwise_or_.Scalar", bitwise_or_scalar_),
132 ("bitwise_or_.Tensor", bitwise_or_tensor_),
133 ("bitwise_right_shift", bitwise_right_shift),
134 ("bmm", bmm),
135 ("bmm.out", bmm_out),
136 ("cat", cat),
137 ("cat.out", cat_out),
138 ("cauchy", cauchy),
139 ("cauchy_", cauchy_),
140 ("celu", celu),
141 ("celu_", celu_),
142 ("ceil", ceil),
143 ("ceil_", ceil_),
144 ("ceil.out", ceil_out),
145 ("clamp", clamp),
146 ("clamp.Tensor", clamp_tensor),
147 ("clamp_min", clamp_min),
148 ("clamp_", clamp_),
149 ("clamp_.Tensor", clamp_tensor_),
150 ("clamp_min_", clamp_min_),
151 ("clip", clip),
152 ("clip_", clip_),
153 ("col2im", col2im),
154 ("concatenate", concatenate),
155 ("conj_physical", conj_physical),
156 ("constant_pad_nd", constant_pad_nd),
157 # ("contiguous", contiguous),
158 ("conv1d", conv1d),
159 ("conv1d.padding", conv1d),
160 ("conv2d", conv2d),
161 ("conv2d.padding", conv2d),
162 ("conv_transpose2d", conv_transpose2d),
163 ("conv3d", conv3d),
164 ("conv3d.padding", conv3d),
165 ("conv_transpose1d", conv_transpose1d),
166 (
167 "copy_",
168 copy_,
169 lambda: version.parse(torch.__version__) >= version.parse("2.4"),
170 ),
171 ("cos", cos),
172 ("cos_", cos_),
173 ("cosh", cosh),
174 ("cosh_", cosh_),
175 ("cosh.out", cosh_out),
176 ("copysign", copysign),
177 ("copysign.out", copysign_out),
178 ("count_nonzero", count_nonzero),
179 ("ctc_loss.IntList", ctc_loss, None, (AUTOGRAD_DISPATCH_KEY,)),
180 ("ctc_loss.Tensor", ctc_loss, None, (AUTOGRAD_DISPATCH_KEY,)),
181 ("cudnn_convolution", cudnn_convolution),
182 ("cummax", cummax),
183 ("cummin", cummin),
184 ("cumprod", cumprod),
185 ("cumprod_", cumprod_),
186 ("cumsum", cumsum),
187 ("cumsum.out", cumsum_out),
188 ("diag", diag),
189 ("diag_embed", diag_embed),
190 ("diagonal_backward", diagonal_backward),
191 ("diff", diff),
192 ("digamma_", digamma_),
193 ("div.Scalar", true_divide),
194 ("div.Scalar_mode", div_mode),
195 ("div.Tensor", true_divide),
196 ("div.Tensor_mode", div_mode),
197 ("div.out", true_divide_out),
198 ("div_.Scalar", true_divide_),
199 ("div_.Scalar_mode", div_mode_),
200 ("div_.Tensor", true_divide_),
201 ("div_.Tensor_mode", div_mode_),
202 ("divide.Scalar", true_divide),
203 ("divide.Scalar_mode", div_mode),
204 ("divide.Tensor", true_divide),
205 ("divide.Tensor_mode", div_mode),
206 ("divide_.Scalar", true_divide_),
207 ("divide_.Scalar_mode", div_mode_),
208 ("divide_.Tensor", true_divide_),
209 ("divide_.Tensor_mode", div_mode_),
210 ("dot", dot),
211 ("elu", elu),
212 ("elu_", elu_),
213 ("elu_backward", elu_backward),
214 ("embedding", embedding),
215 ("embedding_backward", embedding_backward),
216 ("embedding_dense_backward", embedding_dense_backward),
217 ("eq.Scalar", eq_scalar),
218 ("eq.Tensor", eq),
219 ("equal", equal),
220 ("erf", erf),
221 ("erf_", erf_),
222 ("exp", exp),
223 ("exp_", exp_),
224 ("exp.out", exp_out),
225 ("exp2", exp2),
226 ("exp2_", exp2_),
227 ("expm1", expm1),
228 ("expm1_", expm1_),
229 ("expm1.out", expm1_out),
230 ("exponential_", exponential_),
231 ("feature_dropout", feature_dropout),
232 ("feature_dropout_", feature_dropout_),
233 ("eye", eye),
234 ("eye.m", eye_m),
235 ("fill.Scalar", fill_scalar),
236 ("fill.Scalar_out", fill_scalar_out),
237 ("fill.Tensor", fill_tensor),
238 ("fill.Tensor_out", fill_tensor_out),
239 ("fill_.Scalar", fill_scalar_),
240 ("fill_.Tensor", fill_tensor_),
241 ("flip", flip),
242 ("floor_", floor_),
243 ("floor_divide", floor_divide),
244 ("floor_divide.Scalar", floor_divide),
245 ("floor_divide_.Scalar", floor_divide_),
246 ("floor_divide_.Tensor", floor_divide_),
247 ("fmin", fmin),
248 ("fmin.out", fmin_out),
249 ("fmod.Scalar", fmod_scalar),
250 ("fmod.Tensor", fmod_tensor),
251 ("fmod_.Scalar", fmod_scalar_),
252 ("fmod_.Tensor", fmod_tensor_),
253 ("full", full),
254 ("full_like", full_like),
255 ("gather", gather),
256 ("gather_backward", gather_backward),
257 ("gcd", gcd),
258 ("gcd.out", gcd_out),
259 ("ge.Scalar", ge_scalar),
260 ("ge.Tensor", ge),
261 ("gelu", gelu),
262 ("gelu_", gelu_),
263 ("gelu_backward", gelu_backward),
264 ("glu", glu),
265 ("glu_backward", glu_backward),
266 ("greater.Scalar", greater_scalar),
267 ("greater.Tensor", greater),
268 ("greater.Scalar_out", greater_scalar_out),
269 ("greater.out", greater_out),
270 ("grid_sample", grid_sample),
271 ("gt.Scalar", gt_scalar),
272 ("gt.Tensor", gt),
273 ("hardsigmoid", hardsigmoid),
274 ("hardsigmoid.out", hardsigmoid_out),
275 ("hardswish_", hardswish_),
276 ("histc", histc),
277 ("hstack", hstack),
278 ("hypot", hypot),
279 ("i0", i0),
280 ("i0.out", i0_out),
281 ("i0_", i0_),
282 ("index.Tensor", index),
283 ("index_add", index_add),
284 ("index_add_", index_add_),
285 ("index_copy", index_copy),
286 ("index_copy_", index_copy_),
287 ("index_put", index_put),
288 ("index_put_", index_put_),
289 ("index_select", index_select),
290 ("isclose", isclose),
291 ("isfinite", isfinite),
292 ("isin.Scalar_Tensor", isin),
293 ("isin.Tensor_Scalar", isin),
294 ("isin.Tensor_Tensor", isin),
295 ("isinf", isinf),
296 ("isnan", isnan),
297 ("isneginf", isneginf),
298 ("isneginf.out", isneginf_out),
299 ("kron", kron),
300 ("le.Scalar", le_scalar),
301 ("le.Tensor", le),
302 ("leaky_relu", leaky_relu),
303 ("leaky_relu_", leaky_relu_),
304 ("leaky_relu.out", leaky_relu_out),
305 ("lerp.Scalar", lerp_scalar),
306 ("lerp.Tensor", lerp_tensor),
307 ("lerp_.Scalar", lerp_scalar_),
308 ("lerp_.Tensor", lerp_tensor_),
309 ("lift_fresh_copy", lift_fresh_copy),
310 ("linalg_vector_norm", vector_norm),
311 ("linspace", linspace),
312 ("log", log),
313 ("log10", log10),
314 ("log10_", log10_),
315 ("log10.out", log10_out),
316 ("log1p_", log1p_),
317 ("log_sigmoid", log_sigmoid),
318 ("logaddexp", logaddexp),
319 ("logaddexp.out", logaddexp_out),
320 ("logical_and", logical_and),
321 ("logical_and_", logical_and_),
322 ("logical_not", logical_not),
323 ("logical_or", logical_or),
324 ("logical_or_", logical_or_),
325 ("logical_xor", logical_xor),
326 ("logit", logit),
327 ("logit.out", logit_out),
328 ("logit_", logit_),
329 ("logspace", logspace),
330 ("logsumexp", logsumexp),
331 ("lt.Scalar", lt_scalar),
332 ("lt.Tensor", lt),
333 ("margin_ranking_loss", margin_ranking_loss),
334 ("masked_fill.Scalar", masked_fill),
335 ("masked_fill.Tensor", masked_fill),
336 ("masked_fill_.Scalar", masked_fill_),
337 ("masked_fill_.Tensor", masked_fill_),
338 ("masked_scatter", masked_scatter),
339 ("masked_scatter_", masked_scatter_),
340 ("masked_select", masked_select),
341 ("max", max),
342 ("max.dim", max_dim),
343 ("max_pool2d_backward", max_pool2d_backward),
344 ("max_pool2d_with_indices", max_pool2d_with_indices),
345 ("max_pool3d_backward", max_pool3d_backward),
346 ("max_pool3d_with_indices", max_pool3d_with_indices),
347 ("maximum", maximum),
348 ("mean", mean),
349 ("mean.dim", mean_dim),
350 ("median", median),
351 ("median.out", median_out),
352 ("median.dim", median_dim),
353 ("median.dim_values", median_dim_values),
354 ("min", min),
355 ("min.dim", min_dim),
356 ("minimum", minimum),
357 ("mm", mm),
358 ("mm.out", mm_out),
359 ("mse_loss", mse_loss),
360 ("mul.Tensor", mul),
361 ("mul_.Tensor", mul_),
362 ("multinomial", multinomial),
363 ("mv", mv),
364 ("nan_to_num", nan_to_num),
365 ("native_batch_norm", batch_norm),
366 ("native_batch_norm_backward", batch_norm_backward),
367 ("native_dropout", dropout),
368 ("native_dropout_backward", dropout_backward),
369 ("native_group_norm", group_norm),
370 ("native_group_norm_backward", group_norm_backward),
371 ("native_layer_norm", layer_norm),
372 ("native_layer_norm_backward", layer_norm_backward),
373 ("ne.Scalar", ne_scalar),
374 ("ne.Tensor", ne),
375 ("neg", neg),
376 ("neg_", neg_),
377 ("new_full.Tensor", new_full),
378 ("nll_loss_backward", nll_loss_backward),
379 ("nll_loss_forward", nll_loss_forward),
380 ("nll_loss_nd_forward", nll_loss_nd_forward),
381 ("nll_loss_nd_backward", nll_loss_nd_backward),
382 ("nll_loss2d_backward", nll_loss2d_backward),
383 ("nll_loss2d_forward", nll_loss2d_forward),
384 ("nonzero", nonzero),
385 ("nonzero_numpy", nonzero_numpy),
386 ("normal.Tensor_float", normal_tensor_float),
387 ("normal.Tensor_Tensor", normal_tensor_tensor),
388 ("normal.float_Tensor", normal_float_tensor),
389 ("normal_", normal_),
390 ("normed_cumsum", normed_cumsum),
391 ("one_hot", one_hot),
392 ("ones", ones),
393 ("ones_like", ones_like),
394 ("pad", pad),
395 ("pixel_shuffle", pixel_shuffle),
396 ("pixel_unshuffle", pixel_unshuffle),
397 ("pixel_unshuffle.out", pixel_unshuffle_out),
398 ("poisson", poisson),
399 ("polar", polar),
400 ("pow.Scalar", pow_scalar),
401 ("pow.Tensor_Scalar", pow_tensor_scalar),
402 ("pow.Tensor_Tensor", pow_tensor_tensor),
403 ("pow_.Scalar", pow_tensor_scalar_),
404 ("pow_.Tensor", pow_tensor_tensor_),
405 ("prelu", prelu),
406 ("prod", prod),
407 ("prod.dim_int", prod_dim),
408 ("quantile", quantile),
409 ("rad2deg", rad2deg),
410 ("rad2deg_", rad2deg_),
411 ("rand", rand),
412 ("rand_like", rand_like),
413 ("randn", randn),
414 ("randn_like", randn_like),
415 ("randint", randint),
416 ("randperm", randperm),
417 ("reciprocal", reciprocal),
418 ("reciprocal_", reciprocal_),
419 ("reflection_pad1d", reflection_pad1d),
420 ("reflection_pad1d.out", reflection_pad1d_out),
421 ("reflection_pad2d", reflection_pad2d),
422 ("reflection_pad2d.out", reflection_pad2d_out),
423 ("relu", relu),
424 ("relu_", relu_),
425 ("relu6", relu6),
426 ("remainder.Scalar", remainder),
427 ("remainder.Scalar_Tensor", remainder),
428 ("remainder.Tensor", remainder),
429 ("remainder_.Scalar", remainder_),
430 ("remainder_.Tensor", remainder_),
431 ("repeat", repeat),
432 ("repeat_interleave.self_int", repeat_interleave_self_int),
433 ("repeat_interleave.self_Tensor", repeat_interleave_self_tensor),
434 ("repeat_interleave.Tensor", repeat_interleave_tensor),
435 ("replication_pad1d", replication_pad1d),
436 ("replication_pad1d.out", replication_pad1d_out),
437 ("replication_pad3d", replication_pad3d),
438 ("resolve_conj", resolve_conj),
439 ("resolve_neg", resolve_neg),
440 ("rms_norm", rms_norm),
441 ("roll", roll),
442 ("round", round),
443 ("round_", round_),
444 ("round.out", round_out),
445 ("rrelu_with_noise_backward", rrelu_with_noise_backward),
446 ("rsqrt", rsqrt),
447 ("rsqrt_", rsqrt_),
448 ("rsub.Scalar", rsub_scalar),
449 ("rsub.Tensor", rsub_tensor),
450 ("scaled_softmax_backward", scaled_softmax_backward),
451 ("scaled_softmax_forward", scaled_softmax_forward),
452 ("scatter.reduce", scatter),
453 ("scatter.src", scatter),
454 ("scatter_.reduce", scatter_),
455 ("scatter_.src", scatter_),
456 ("scatter_add_", scatter_add_),
457 ("scatter_reduce.two", scatter_reduce),
458 ("scatter_reduce_.two", scatter_reduce_),
459 ("scatter_reduce.two_out", scatter_reduce_out),
460 ("select_backward", select_backward),
461 ("select_scatter", select_scatter),
462 ("selu", selu),
463 ("selu_", selu_),
464 ("sgn_", sgn_),
465 ("sigmoid", sigmoid),
466 ("sigmoid_", sigmoid_),
467 ("sigmoid_backward", sigmoid_backward),
468 ("signbit", signbit),
469 ("signbit.out", signbit_out),
470 ("silu", silu),
471 ("silu_", silu_),
472 ("silu_backward", silu_backward),
473 ("sin", sin),
474 ("sin_", sin_),
475 ("sinh_", sinh_),
476 ("slice_backward", slice_backward),
477 ("slice_scatter", slice_scatter),
478 ("smooth_l1_loss", smooth_l1_loss),
479 ("smooth_l1_loss_backward", smooth_l1_loss_backward),
480 ("smooth_l1_loss.out", smooth_l1_loss_out),
481 ("soft_margin_loss", soft_margin_loss),
482 ("softplus", softplus),
483 ("softshrink", softshrink),
484 ("softshrink.out", softshrink_out),
485 ("sort", sort),
486 ("sort.stable", sort_stable),
487 ("special_i0e", special_i0e),
488 ("special_i0e.out", special_i0e_out),
489 ("special_i1", special_i1),
490 ("special_i1.out", special_i1_out),
491 ("sqrt", sqrt),
492 ("sqrt_", sqrt_),
493 ("square", square),
494 ("square_", square_),
495 ("square.out", square_out),
496 ("stack", stack),
497 ("std.correction", std),
498 ("sub.Tensor", sub),
499 ("sub_.Tensor", sub_),
500 ("sum", sum),
501 ("sum.IntList_out", sum_dim_out),
502 ("sum.dim_IntList", sum_dim),
503 ("sum.out", sum_out),
504 ("svd", svd),
505 ("t_copy", t_copy),
506 ("t_copy.out", t_copy_out),
507 ("tan", tan),
508 ("tan_", tan_),
509 ("tanh", tanh),
510 ("tanh_", tanh_),
511 ("tanh_backward", tanh_backward),
512 ("threshold", threshold),
513 ("threshold_backward", threshold_backward),
514 ("tile", tile),
515 ("topk", topk),
516 ("trace", trace),
517 ("tril", tril),
518 ("tril.out", tril_out),
519 ("tril_", tril_),
520 ("triu", triu),
521 ("triu_", triu_),
522 ("true_divide.Scalar", true_divide),
523 ("true_divide.Tensor", true_divide),
524 ("true_divide_.Scalar", true_divide_),
525 ("true_divide_.Tensor", true_divide_),
526 ("unfold_backward", unfold_backward),
527 ("uniform_", uniform_),
528 ("unique_consecutive", unique_consecutive),
529 ("upsample_bicubic2d", upsample_bicubic2d),
530 ("upsample_linear1d", upsample_linear1d),
531 ("upsample_nearest1d", upsample_nearest1d),
532 ("upsample_nearest2d", upsample_nearest2d),
533 ("upsample_nearest3d", upsample_nearest3d),
534 ("var_mean.correction", var_mean),
535 ("var", var),
536 ("var.correction", var_correction),
537 ("var.dim", var_dim),
538 ("vdot", vdot),
539 ("vstack", vstack),
540 ("where.self", where_self),
541 ("where.self_out", where_self_out),
542 ("zero", zero),
543 ("zero_", zero_),
544 ("zero.out", zero_out),
545 ("zeros", zeros),
546 ("zeros_like", zeros_like),
547)
549# Cache mapping from function name -> list of _FULL_CONFIG entries for quick lookup
550FULL_CONFIG_BY_FUNC = {}
551for _item in _FULL_CONFIG:
552 if not _item or len(_item) < 2:
553 continue
554 fn = _item[1]
555 func_name = fn.__name__ if hasattr(fn, "__name__") else str(fn)
556 FULL_CONFIG_BY_FUNC.setdefault(func_name, []).append(_item)
558# Friendly names for only_enable(include=[...]) when the registered impl is *.out
559for _alias, _target in (
560 ("softmax", "softmax_out"),
561 ("softmax_backward", "softmax_backward_out"),
562 ("log_softmax", "log_softmax_out"),
563 ("log_softmax_backward", "log_softmax_backward_out"),
564):
565 if _target in FULL_CONFIG_BY_FUNC:
566 FULL_CONFIG_BY_FUNC.setdefault(_alias, []).extend(FULL_CONFIG_BY_FUNC[_target])
569def enable(
570 lib=aten_lib,
571 unused=None,
572 registrar=registrar,
573 record=False,
574 once=False,
575 path=None,
576):
577 """Register all FlagGems ops except those explicitly excluded.
579 Args:
580 lib: torch.library.Library instance to register into. Defaults to the
581 global `aten_lib` (IMPL mode).
582 unused: Which ops to skip. Supported forms:
583 - list/tuple/set of function names (e.g., ["masked_fill", "mul"]).
584 - str path to a YAML file ending with .yml/.yaml containing an
585 `exclude:` list.
586 - "default" or None: auto-load vendor/arch-specific
587 runtime/backend/_<vendor>/[<arch>/]enable_configs.yaml if present.
588 registrar: Registrar class; defaults to `Register`.
589 record: Whether to enable FlagGems logging.
590 once: When True, log only once.
591 path: Optional log output path when recording.
593 Notes:
594 - If the exclude list/YAML resolves to empty, all ops are registered.
595 """
596 global current_work_registrar
597 exclude_ops = resolve_user_setting(unused, "exclude")
598 current_work_registrar = registrar(
599 _FULL_CONFIG,
600 user_include_ops=[],
601 user_exclude_ops=exclude_ops,
602 cpp_patched_ops=list(set(aten_patch_list)),
603 lib=lib,
604 )
605 setup_flaggems_logging(path=path, record=record, once=once)
608def only_enable(
609 lib=aten_lib,
610 include=None,
611 registrar=registrar,
612 record=False,
613 once=False,
614 path=None,
615):
616 """Register only the specified FlagGems ops and skip the rest.
618 Args:
619 lib: torch.library.Library instance to register into. Defaults to the
620 global `aten_lib` (IMPL mode).
621 include: Which ops to register. Supported forms:
622 - list/tuple/set of function names (e.g., ["rms_norm", "softmax"]).
623 - str path to a YAML file ending with .yml/.yaml (expects a list or
624 an `include:` key).
625 - "default" or None: auto-load vendor/arch-specific
626 runtime/backend/_<vendor>/[<arch>/]only_enable_configs.yaml if present.
627 registrar: Registrar class; defaults to `Register`.
628 record: Whether to enable FlagGems logging.
629 once: When True, log only once.
630 path: Optional log output path when recording.
632 Classic usage:
633 - Only register a few ops:
634 only_enable(include=["rms_norm", "softmax"])
635 - Use vendor default YAML:
636 only_enable(include="default") # or include=None
637 - Use a custom YAML:
638 only_enable(include="/path/to/only_enable.yaml")
640 Notes:
641 - If the include list/YAML resolves to empty or none of the names match
642 known ops, the function warns and returns without registering.
643 """
644 include_ops = resolve_user_setting(include, "include")
645 if not include_ops:
646 warnings.warn(
647 "only_enable failed: No include entries resolved from list or yaml."
648 )
649 return
651 global current_work_registrar
652 current_work_registrar = registrar(
653 _FULL_CONFIG,
654 user_include_ops=include_ops,
655 user_exclude_ops=[],
656 cpp_patched_ops=list(set(aten_patch_list)),
657 full_config_by_func=FULL_CONFIG_BY_FUNC,
658 lib=lib,
659 )
660 setup_flaggems_logging(path=path, record=record, once=once)
663class use_gems:
664 """
665 The 'include' parameter has higher priority than 'exclude'.
666 When 'include' is not None, use_gems will not process 'exclude'.
667 """
669 def __init__(self, exclude=None, include=None, record=False, once=False, path=None):
670 self.lib = torch.library.Library("aten", "IMPL")
671 self.exclude = exclude if isinstance(exclude, (list, tuple, set, str)) else []
672 self.include = include if isinstance(include, (list, tuple, set, str)) else []
673 self.registrar = Register
674 self.record = record
675 self.once = once
676 self.path = path
678 def __enter__(self):
679 if self.include:
680 only_enable(
681 lib=self.lib,
682 include=self.include,
683 registrar=self.registrar,
684 record=self.record,
685 once=self.once,
686 path=self.path,
687 )
688 else:
689 enable(
690 lib=self.lib,
691 unused=self.exclude,
692 registrar=self.registrar,
693 record=self.record,
694 once=self.once,
695 path=self.path,
696 )
698 def __exit__(self, exc_type, exc_val, exc_tb):
699 global current_work_registrar
700 if torch.__version__ >= "2.5":
701 self.lib._destroy()
702 del self.lib
703 del self.exclude
704 del self.include
705 del self.registrar
706 del current_work_registrar
707 if self.record:
708 teardown_flaggems_logging()
710 @property
711 def experimental_ops(self):
712 import flag_gems.experimental_ops
714 return flag_gems.experimental_ops
717def all_registered_ops():
718 return current_work_registrar.get_all_ops()
721def all_registered_keys():
722 return current_work_registrar.get_all_keys()
725__all__ = [
726 "all_registered_keys",
727 "all_registered_ops",
728 "enable",
729 "only_enable",
730 "use_gems",
731]