Coverage for src/flag_gems/__init__.py: 91%
80 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1# ruff: noqa: F405
2import warnings
4import torch
5from packaging import version
7from flag_gems import testing # noqa: F401
8from flag_gems import runtime
9from flag_gems.config import aten_patch_list, resolve_user_setting
10from flag_gems.experimental_ops import * # noqa: F403
11from flag_gems.fused import * # noqa: F403
12from flag_gems.logging_utils import setup_flaggems_logging, teardown_flaggems_logging
13from flag_gems.modules import * # noqa: F403
14from flag_gems.ops import * # noqa: F403
15from flag_gems.patches import * # noqa: F403
16from flag_gems.runtime import flagtune
17from flag_gems.runtime.backend import SpecOpRegistrar
18from flag_gems.runtime.op_registrar import GeneralOpRegistrar
20__version__ = "5.0.2"
21device = runtime.device.name
22vendor_name = runtime.device.vendor_name
23backend_info = runtime.device
24aten_lib = torch.library.Library("aten", "IMPL")
26# Register all ops in the current backend with SpecOpRegistrar to support architecture-specialized implementations
27SpecOpRegistrar(globals()).apply()
29registrar = GeneralOpRegistrar
30current_work_registrar = None
31AUTOGRAD_DISPATCH_KEY = torch._C.DispatchKey.Autograd.name
34def torch_ge(v):
35 return version.parse(torch.__version__) >= version.parse(v)
38_FULL_CONFIG = (
39 ("__ior__.Scalar", bitwise_or_scalar_),
40 ("__ior__.Tensor", bitwise_or_tensor_),
41 ("__or__.Scalar", bitwise_or_scalar),
42 ("__or__.Tensor", bitwise_or_tensor),
43 ("_assert_async", _assert_async),
44 ("_conv_depthwise2d", _conv_depthwise2d),
45 ("_euclidean_dist", _euclidean_dist),
46 ("_flash_attention_forward", flash_attention_forward),
47 (
48 "_functional_sym_constrain_range_for_size",
49 _functional_sym_constrain_range_for_size,
50 ),
51 ("_grouped_mm", group_mm),
52 ("_index_put_impl_", _index_put_impl_),
53 ("_is_all_true", _is_all_true),
54 ("_log_softmax", log_softmax),
55 ("_log_softmax.out", log_softmax_out),
56 ("_log_softmax_backward_data", log_softmax_backward),
57 ("_log_softmax_backward_data.out", log_softmax_backward_out),
58 ("_safe_softmax", _safe_softmax),
59 ("_scaled_mm", scaled_mm, lambda: torch_ge("2.5")),
60 ("_scaled_mm.out", scaled_mm_out, lambda: torch_ge("2.5")),
61 ("_softmax", softmax),
62 ("_softmax.out", softmax_out),
63 ("_softmax_backward_data", softmax_backward),
64 ("_softmax_backward_data.out", softmax_backward_out),
65 (
66 "_to_copy",
67 to_copy,
68 lambda: version.parse(torch.__version__) >= version.parse("2.4"),
69 ),
70 ("_unique2", _unique2),
71 ("_upsample_bicubic2d_aa", _upsample_bicubic2d_aa),
72 ("_upsample_bicubic2d_aa_backward", _upsample_bicubic2d_aa_backward),
73 ("_upsample_nearest_exact1d", _upsample_nearest_exact1d),
74 ("_weight_norm_interface", weight_norm_interface),
75 ("_weight_norm_interface_backward", weight_norm_interface_backward),
76 ("abs", abs),
77 ("abs_", abs_),
78 ("absolute", absolute),
79 ("acos", acos),
80 ("add.Tensor", add),
81 ("add_.Tensor", add_),
82 ("add_rms_norm", add_rms_norm),
83 ("addcdiv", addcdiv),
84 ("addcdiv.out", addcdiv_out),
85 ("addcmul", addcmul),
86 ("addcmul.out", addcmul_out),
87 ("addmv", addmv),
88 ("addmv.out", addmv_out),
89 ("addmm", addmm),
90 ("addmm.out", addmm_out),
91 ("addmm.dtype", addmm_dtype),
92 ("addmm.dtype_out", addmm_dtype_out),
93 ("addr", addr),
94 ("affine_grid_generator", affine_grid_generator),
95 ("alias_copy", alias_copy),
96 ("all", all),
97 ("all.dim", all_dim),
98 ("all.dims", all_dims),
99 ("allclose", allclose),
100 ("amax", amax),
101 ("aminmax", aminmax),
102 ("angle", angle),
103 ("any", any),
104 ("any.dim", any_dim),
105 ("any.dims", any_dims),
106 ("arange", arange),
107 ("arange.start", arange_start),
108 ("arange.start_step", arange_start),
109 ("arcsinh", arcsinh),
110 ("arcsinh.out", arcsinh_out),
111 ("arcsinh_", arcsinh_),
112 ("argmax", argmax),
113 ("argmin", argmin),
114 ("argsort", argsort),
115 ("as_strided_copy", as_strided_copy),
116 ("as_strided_copy.out", as_strided_copy_out),
117 ("asinh", asinh),
118 ("asinh.out", asinh_out),
119 ("asinh_", asinh_),
120 ("atan", atan),
121 ("atan_", atan_),
122 ("atan2", atan2),
123 ("atan2.out", atan2_out),
124 ("arctanh_", arctanh_),
125 ("avg_pool2d", avg_pool2d),
126 ("avg_pool2d_backward", avg_pool2d_backward),
127 ("avg_pool3d", avg_pool3d),
128 ("avg_pool3d_backward", avg_pool3d_backward),
129 ("baddbmm", baddbmm),
130 ("bernoulli_.float", bernoulli_),
131 ("bincount", bincount),
132 ("bitwise_and.Scalar", bitwise_and_scalar),
133 ("bitwise_and.Scalar_Tensor", bitwise_and_scalar_tensor),
134 ("bitwise_and.Tensor", bitwise_and_tensor),
135 ("bitwise_and_.Scalar", bitwise_and_scalar_),
136 ("bitwise_and_.Tensor", bitwise_and_tensor_),
137 ("bitwise_left_shift", bitwise_left_shift),
138 ("bitwise_not", bitwise_not),
139 ("bitwise_not_", bitwise_not_),
140 ("bitwise_or.Scalar", bitwise_or_scalar),
141 ("bitwise_or.Scalar_Tensor", bitwise_or_scalar_tensor),
142 ("bitwise_or.Tensor", bitwise_or_tensor),
143 ("bitwise_or_.Scalar", bitwise_or_scalar_),
144 ("bitwise_or_.Tensor", bitwise_or_tensor_),
145 ("bitwise_right_shift", bitwise_right_shift),
146 ("bmm", bmm),
147 ("bmm.out", bmm_out),
148 ("cat", cat),
149 ("cat.out", cat_out),
150 ("cauchy", cauchy),
151 ("cauchy_", cauchy_),
152 ("celu", celu),
153 ("celu_", celu_),
154 ("ceil", ceil),
155 ("ceil_", ceil_),
156 ("ceil.out", ceil_out),
157 ("clamp", clamp),
158 ("clamp.Tensor", clamp_tensor),
159 ("clamp_max", clamp_max),
160 ("clamp_min", clamp_min),
161 ("clamp_", clamp_),
162 ("clamp_.Tensor", clamp_tensor_),
163 ("clamp_max_", clamp_max_),
164 ("clamp_min_", clamp_min_),
165 ("clip", clip),
166 ("clip_", clip_),
167 ("col2im", col2im),
168 ("concatenate", concatenate),
169 ("conj_physical", conj_physical),
170 ("constant_pad_nd", constant_pad_nd),
171 # ("contiguous", contiguous),
172 ("conv1d", conv1d),
173 ("conv1d.padding", conv1d),
174 ("conv2d", conv2d),
175 ("conv2d.padding", conv2d),
176 ("conv_transpose2d", conv_transpose2d),
177 ("conv3d", conv3d),
178 ("conv3d.padding", conv3d),
179 ("conv_transpose1d", conv_transpose1d),
180 (
181 "copy_",
182 copy_,
183 lambda: version.parse(torch.__version__) >= version.parse("2.4"),
184 ),
185 ("cos", cos),
186 ("cos_", cos_),
187 ("cosh", cosh),
188 ("cosh_", cosh_),
189 ("cosh.out", cosh_out),
190 ("copysign", copysign),
191 ("copysign.out", copysign_out),
192 ("count_nonzero", count_nonzero),
193 ("ctc_loss.IntList", ctc_loss, None, (AUTOGRAD_DISPATCH_KEY,)),
194 ("ctc_loss.Tensor", ctc_loss, None, (AUTOGRAD_DISPATCH_KEY,)),
195 ("cudnn_convolution", cudnn_convolution),
196 ("cummax", cummax),
197 ("cummin", cummin),
198 ("cumprod", cumprod),
199 ("cumprod_", cumprod_),
200 ("cumsum", cumsum),
201 ("cumsum.out", cumsum_out),
202 ("diag", diag),
203 ("diag_embed", diag_embed),
204 ("diagonal_backward", diagonal_backward),
205 ("diff", diff),
206 ("digamma_", digamma_),
207 ("div.Scalar", true_divide),
208 ("div.Scalar_mode", div_mode),
209 ("div.Tensor", true_divide),
210 ("div.Tensor_mode", div_mode),
211 ("div.out", true_divide_out),
212 ("div_.Scalar", true_divide_),
213 ("div_.Scalar_mode", div_mode_),
214 ("div_.Tensor", true_divide_),
215 ("div_.Tensor_mode", div_mode_),
216 ("divide.Scalar", true_divide),
217 ("divide.Scalar_mode", div_mode),
218 ("divide.Tensor", true_divide),
219 ("divide.Tensor_mode", div_mode),
220 ("divide_.Scalar", true_divide_),
221 ("divide_.Scalar_mode", div_mode_),
222 ("divide_.Tensor", true_divide_),
223 ("divide_.Tensor_mode", div_mode_),
224 ("dot", dot),
225 ("elu", elu),
226 ("elu_", elu_),
227 ("elu_backward", elu_backward),
228 ("embedding", embedding),
229 ("embedding_backward", embedding_backward),
230 ("embedding_dense_backward", embedding_dense_backward),
231 ("eq.Scalar", eq_scalar),
232 ("eq.Tensor", eq),
233 ("equal", equal),
234 ("erf", erf),
235 ("erf_", erf_),
236 ("exp", exp),
237 ("exp_", exp_),
238 ("exp.out", exp_out),
239 ("exp2", exp2),
240 ("exp2_", exp2_),
241 ("expm1", expm1),
242 ("expm1_", expm1_),
243 ("expm1.out", expm1_out),
244 ("exponential_", exponential_),
245 ("feature_dropout", feature_dropout),
246 ("feature_dropout_", feature_dropout_),
247 ("eye", eye),
248 ("eye.m", eye_m),
249 ("fill.Scalar", fill_scalar),
250 ("fill.Scalar_out", fill_scalar_out),
251 ("fill.Tensor", fill_tensor),
252 ("fill.Tensor_out", fill_tensor_out),
253 ("fill_.Scalar", fill_scalar_),
254 ("fill_.Tensor", fill_tensor_),
255 ("flip", flip),
256 ("floor_", floor_),
257 ("floor", floor),
258 ("floor.out", floor_out),
259 ("floor_divide", floor_divide),
260 ("floor_divide.Scalar", floor_divide),
261 ("floor_divide_.Scalar", floor_divide_),
262 ("floor_divide_.Tensor", floor_divide_),
263 ("fmin", fmin),
264 ("fmin.out", fmin_out),
265 ("fmod.Scalar", fmod_scalar),
266 ("fmod.Tensor", fmod_tensor),
267 ("fmod_.Scalar", fmod_scalar_),
268 ("fmod_.Tensor", fmod_tensor_),
269 ("full", full),
270 ("full_like", full_like),
271 ("gather", gather),
272 ("gather_backward", gather_backward),
273 ("gcd", gcd),
274 ("gcd.out", gcd_out),
275 ("ge.Scalar", ge_scalar),
276 ("ge.Tensor", ge),
277 ("gelu", gelu),
278 ("gelu_", gelu_),
279 ("gelu_backward", gelu_backward),
280 ("glu", glu),
281 ("glu_backward", glu_backward),
282 ("greater.Scalar", greater_scalar),
283 ("greater.Tensor", greater),
284 ("greater.Scalar_out", greater_scalar_out),
285 ("greater.out", greater_out),
286 ("grid_sample", grid_sample),
287 ("gt.Scalar", gt_scalar),
288 ("gt.Tensor", gt),
289 ("hardsigmoid", hardsigmoid),
290 ("hardsigmoid.out", hardsigmoid_out),
291 ("hardswish_", hardswish_),
292 ("histc", histc),
293 ("hstack", hstack),
294 ("hypot", hypot),
295 ("i0", i0),
296 ("i0.out", i0_out),
297 ("i0_", i0_),
298 ("index.Tensor", index),
299 ("index_add", index_add),
300 ("index_add_", index_add_),
301 ("index_copy", index_copy),
302 ("index_copy_", index_copy_),
303 ("index_put", index_put),
304 ("index_put_", index_put_),
305 ("index_select", index_select),
306 ("isclose", isclose),
307 ("isfinite", isfinite),
308 ("isin.Scalar_Tensor", isin),
309 ("isin.Tensor_Scalar", isin),
310 ("isin.Tensor_Tensor", isin),
311 ("isinf", isinf),
312 ("isnan", isnan),
313 ("isneginf", isneginf),
314 ("isneginf.out", isneginf_out),
315 ("kron", kron),
316 ("le.Scalar", le_scalar),
317 ("le.Tensor", le),
318 ("leaky_relu", leaky_relu),
319 ("leaky_relu_", leaky_relu_),
320 ("leaky_relu.out", leaky_relu_out),
321 ("lerp.Scalar", lerp_scalar),
322 ("lerp.Tensor", lerp_tensor),
323 ("lerp_.Scalar", lerp_scalar_),
324 ("lerp_.Tensor", lerp_tensor_),
325 ("lift_fresh_copy", lift_fresh_copy),
326 ("linalg_vector_norm", vector_norm),
327 ("linspace", linspace),
328 ("log", log),
329 ("log10", log10),
330 ("log10_", log10_),
331 ("log10.out", log10_out),
332 ("log1p", log1p),
333 ("log1p_", log1p_),
334 ("log_sigmoid", log_sigmoid),
335 ("logaddexp", logaddexp),
336 ("logaddexp.out", logaddexp_out),
337 ("logical_and", logical_and),
338 ("logical_and_", logical_and_),
339 ("logical_not", logical_not),
340 ("logical_or", logical_or),
341 ("logical_or_", logical_or_),
342 ("logical_xor", logical_xor),
343 ("logit", logit),
344 ("logit.out", logit_out),
345 ("logit_", logit_),
346 ("logspace", logspace),
347 ("logsumexp", logsumexp),
348 ("lt.Scalar", lt_scalar),
349 ("lt.Tensor", lt),
350 ("margin_ranking_loss", margin_ranking_loss),
351 ("masked_fill.Scalar", masked_fill),
352 ("masked_fill.Tensor", masked_fill),
353 ("masked_fill_.Scalar", masked_fill_),
354 ("masked_fill_.Tensor", masked_fill_),
355 ("masked_scatter", masked_scatter),
356 ("masked_scatter_", masked_scatter_),
357 ("masked_select", masked_select),
358 ("max", max),
359 ("max.dim", max_dim),
360 ("max_pool2d_backward", max_pool2d_backward),
361 ("max_pool2d_with_indices", max_pool2d_with_indices),
362 ("max_pool3d_backward", max_pool3d_backward),
363 ("max_pool3d_with_indices", max_pool3d_with_indices),
364 ("maximum", maximum),
365 ("mean", mean),
366 ("mean.dim", mean_dim),
367 ("median", median),
368 ("median.out", median_out),
369 ("median.dim", median_dim),
370 ("median.dim_values", median_dim_values),
371 ("min", min),
372 ("min.dim", min_dim),
373 ("minimum", minimum),
374 ("mm", mm),
375 ("mm.out", mm_out),
376 ("mse_loss", mse_loss),
377 ("mul.Tensor", mul),
378 ("mul_.Tensor", mul_),
379 ("multinomial", multinomial),
380 ("mv", mv),
381 ("nan_to_num", nan_to_num),
382 ("native_batch_norm", batch_norm),
383 ("native_batch_norm_backward", batch_norm_backward),
384 ("native_dropout", dropout),
385 ("native_dropout_backward", dropout_backward),
386 ("native_group_norm", group_norm),
387 ("native_group_norm_backward", group_norm_backward),
388 ("native_layer_norm", layer_norm),
389 ("native_layer_norm_backward", layer_norm_backward),
390 ("ne.Scalar", ne_scalar),
391 ("ne.Tensor", ne),
392 ("neg", neg),
393 ("neg_", neg_),
394 ("new_full.Tensor", new_full),
395 ("nll_loss_backward", nll_loss_backward),
396 ("nll_loss_forward", nll_loss_forward),
397 ("nll_loss_nd_forward", nll_loss_nd_forward),
398 ("nll_loss_nd_backward", nll_loss_nd_backward),
399 ("nll_loss2d_backward", nll_loss2d_backward),
400 ("nll_loss2d_forward", nll_loss2d_forward),
401 ("nonzero", nonzero),
402 ("nonzero_numpy", nonzero_numpy),
403 ("normal.Tensor_float", normal_tensor_float),
404 ("normal.Tensor_Tensor", normal_tensor_tensor),
405 ("normal.float_Tensor", normal_float_tensor),
406 ("normal_", normal_),
407 ("normed_cumsum", normed_cumsum),
408 ("one_hot", one_hot),
409 ("ones", ones),
410 ("ones_like", ones_like),
411 ("pad", pad),
412 ("pixel_shuffle", pixel_shuffle),
413 ("pixel_unshuffle", pixel_unshuffle),
414 ("pixel_unshuffle.out", pixel_unshuffle_out),
415 ("poisson", poisson),
416 ("polar", polar),
417 ("pow.Scalar", pow_scalar),
418 ("pow.Tensor_Scalar", pow_tensor_scalar),
419 ("pow.Tensor_Tensor", pow_tensor_tensor),
420 ("pow_.Scalar", pow_tensor_scalar_),
421 ("pow_.Tensor", pow_tensor_tensor_),
422 ("prelu", prelu),
423 ("prod", prod),
424 ("prod.dim_int", prod_dim),
425 ("quantile", quantile),
426 ("rad2deg", rad2deg),
427 ("rad2deg_", rad2deg_),
428 ("rand", rand),
429 ("rand_like", rand_like),
430 ("randn", randn),
431 ("randn_like", randn_like),
432 ("randint", randint),
433 ("randperm", randperm),
434 ("reciprocal", reciprocal),
435 ("reciprocal_", reciprocal_),
436 ("reflection_pad1d", reflection_pad1d),
437 ("reflection_pad1d.out", reflection_pad1d_out),
438 ("reflection_pad1d_backward", reflection_pad1d_backward),
439 ("reflection_pad2d", reflection_pad2d),
440 ("reflection_pad2d.out", reflection_pad2d_out),
441 ("relu", relu),
442 ("relu_", relu_),
443 ("relu6", relu6),
444 ("remainder.Scalar", remainder),
445 ("remainder.Scalar_Tensor", remainder),
446 ("remainder.Tensor", remainder),
447 ("remainder_.Scalar", remainder_),
448 ("remainder_.Tensor", remainder_),
449 ("repeat", repeat),
450 ("repeat_interleave.self_int", repeat_interleave_self_int),
451 ("repeat_interleave.self_Tensor", repeat_interleave_self_tensor),
452 ("repeat_interleave.Tensor", repeat_interleave_tensor),
453 ("replication_pad1d", replication_pad1d),
454 ("replication_pad1d.out", replication_pad1d_out),
455 ("replication_pad3d", replication_pad3d),
456 ("resolve_conj", resolve_conj),
457 ("resolve_neg", resolve_neg),
458 ("rms_norm", rms_norm),
459 ("roll", roll),
460 ("round", round),
461 ("round_", round_),
462 ("round.out", round_out),
463 ("rrelu_with_noise_backward", rrelu_with_noise_backward),
464 ("rsqrt", rsqrt),
465 ("rsqrt_", rsqrt_),
466 ("rsub.Scalar", rsub_scalar),
467 ("rsub.Tensor", rsub_tensor),
468 ("scaled_softmax_backward", scaled_softmax_backward),
469 ("scaled_softmax_forward", scaled_softmax_forward),
470 ("scatter.reduce", scatter),
471 ("scatter.src", scatter),
472 ("scatter_.reduce", scatter_),
473 ("scatter_.src", scatter_),
474 ("scatter_add_", scatter_add_),
475 ("scatter_reduce.two", scatter_reduce),
476 ("scatter_reduce_.two", scatter_reduce_),
477 ("scatter_reduce.two_out", scatter_reduce_out),
478 ("select_backward", select_backward),
479 ("select_scatter", select_scatter),
480 ("selu", selu),
481 ("selu_", selu_),
482 ("sgn_", sgn_),
483 ("sigmoid", sigmoid),
484 ("sigmoid_", sigmoid_),
485 ("sigmoid_backward", sigmoid_backward),
486 ("signbit", signbit),
487 ("signbit.out", signbit_out),
488 ("silu", silu),
489 ("silu_", silu_),
490 ("silu_backward", silu_backward),
491 ("sin", sin),
492 ("sin_", sin_),
493 ("sinh_", sinh_),
494 ("slice_backward", slice_backward),
495 ("slice_scatter", slice_scatter),
496 ("smooth_l1_loss", smooth_l1_loss),
497 ("smooth_l1_loss_backward", smooth_l1_loss_backward),
498 ("smooth_l1_loss.out", smooth_l1_loss_out),
499 ("soft_margin_loss", soft_margin_loss),
500 ("softplus", softplus),
501 ("softshrink", softshrink),
502 ("softshrink.out", softshrink_out),
503 ("sort", sort),
504 ("sort.stable", sort_stable),
505 ("special_i0e", special_i0e),
506 ("special_i0e.out", special_i0e_out),
507 ("special_i1", special_i1),
508 ("special_i1.out", special_i1_out),
509 ("sqrt", sqrt),
510 ("sqrt_", sqrt_),
511 ("square", square),
512 ("square_", square_),
513 ("square.out", square_out),
514 ("stack", stack),
515 ("std.correction", std),
516 ("sub.Tensor", sub),
517 ("sub_.Tensor", sub_),
518 ("sum", sum),
519 ("sum.IntList_out", sum_dim_out),
520 ("sum.dim_IntList", sum_dim),
521 ("sum.out", sum_out),
522 ("svd", svd),
523 ("t_copy", t_copy),
524 ("t_copy.out", t_copy_out),
525 ("tan", tan),
526 ("tan_", tan_),
527 ("tanh", tanh),
528 ("tanh_", tanh_),
529 ("tanh_backward", tanh_backward),
530 ("threshold", threshold),
531 ("threshold_backward", threshold_backward),
532 ("tile", tile),
533 ("topk", topk),
534 ("trace", trace),
535 ("tril", tril),
536 ("tril.out", tril_out),
537 ("tril_", tril_),
538 ("triu", triu),
539 ("triu_", triu_),
540 ("true_divide.Scalar", true_divide),
541 ("true_divide.Tensor", true_divide),
542 ("true_divide_.Scalar", true_divide_),
543 ("true_divide_.Tensor", true_divide_),
544 ("unfold_backward", unfold_backward),
545 ("uniform_", uniform_),
546 ("unique_consecutive", unique_consecutive),
547 ("upsample_bicubic2d", upsample_bicubic2d),
548 ("upsample_linear1d", upsample_linear1d),
549 ("upsample_nearest1d", upsample_nearest1d),
550 ("upsample_nearest2d", upsample_nearest2d),
551 ("upsample_nearest3d", upsample_nearest3d),
552 ("var_mean.correction", var_mean),
553 ("var", var),
554 ("var.correction", var_correction),
555 ("var.dim", var_dim),
556 ("vdot", vdot),
557 ("vstack", vstack),
558 ("where.self", where_self),
559 ("where.self_out", where_self_out),
560 ("zero", zero),
561 ("zero_", zero_),
562 ("zero.out", zero_out),
563 ("zeros", zeros),
564 ("zeros_like", zeros_like),
565)
567# Cache mapping from function name -> list of _FULL_CONFIG entries for quick lookup
568FULL_CONFIG_BY_FUNC = {}
569for _item in _FULL_CONFIG:
570 if not _item or len(_item) < 2:
571 continue
572 fn = _item[1]
573 func_name = fn.__name__ if hasattr(fn, "__name__") else str(fn)
574 FULL_CONFIG_BY_FUNC.setdefault(func_name, []).append(_item)
576# Friendly names for only_enable(include=[...]) when the registered impl is *.out
577for _alias, _target in (
578 ("softmax", "softmax_out"),
579 ("softmax_backward", "softmax_backward_out"),
580 ("log_softmax", "log_softmax_out"),
581 ("log_softmax_backward", "log_softmax_backward_out"),
582):
583 if _target in FULL_CONFIG_BY_FUNC:
584 FULL_CONFIG_BY_FUNC.setdefault(_alias, []).extend(FULL_CONFIG_BY_FUNC[_target])
587def enable(
588 lib=aten_lib,
589 unused=None,
590 registrar=registrar,
591 record=False,
592 once=False,
593 path=None,
594):
595 """Register all FlagGems ops except those explicitly excluded.
597 Args:
598 lib: torch.library.Library instance to register into. Defaults to the
599 global `aten_lib` (IMPL mode).
600 unused: Which ops to skip. Supported forms:
601 - list/tuple/set of function names (e.g., ["masked_fill", "mul"]).
602 - str path to a YAML file ending with .yml/.yaml containing an
603 `exclude:` list.
604 - "default" or None: auto-load vendor/arch-specific
605 runtime/backend/_<vendor>/[<arch>/]enable_configs.yaml if present.
606 registrar: Registrar class; defaults to `Register`.
607 record: Whether to enable FlagGems logging.
608 once: When True, log only once.
609 path: Optional log output path when recording.
611 Notes:
612 - If the exclude list/YAML resolves to empty, all ops are registered.
613 """
614 global current_work_registrar
615 exclude_ops = resolve_user_setting(unused, "exclude")
616 current_work_registrar = registrar(
617 _FULL_CONFIG,
618 user_include_ops=[],
619 user_exclude_ops=exclude_ops,
620 cpp_patched_ops=list(set(aten_patch_list)),
621 lib=lib,
622 )
623 setup_flaggems_logging(path=path, record=record, once=once)
626def only_enable(
627 lib=aten_lib,
628 include=None,
629 registrar=registrar,
630 record=False,
631 once=False,
632 path=None,
633):
634 """Register only the specified FlagGems ops and skip the rest.
636 Args:
637 lib: torch.library.Library instance to register into. Defaults to the
638 global `aten_lib` (IMPL mode).
639 include: Which ops to register. Supported forms:
640 - list/tuple/set of function names (e.g., ["rms_norm", "softmax"]).
641 - str path to a YAML file ending with .yml/.yaml (expects a list or
642 an `include:` key).
643 - "default" or None: auto-load vendor/arch-specific
644 runtime/backend/_<vendor>/[<arch>/]only_enable_configs.yaml if present.
645 registrar: Registrar class; defaults to `Register`.
646 record: Whether to enable FlagGems logging.
647 once: When True, log only once.
648 path: Optional log output path when recording.
650 Classic usage:
651 - Only register a few ops:
652 only_enable(include=["rms_norm", "softmax"])
653 - Use vendor default YAML:
654 only_enable(include="default") # or include=None
655 - Use a custom YAML:
656 only_enable(include="/path/to/only_enable.yaml")
658 Notes:
659 - If the include list/YAML resolves to empty or none of the names match
660 known ops, the function warns and returns without registering.
661 """
662 include_ops = resolve_user_setting(include, "include")
663 if not include_ops:
664 warnings.warn(
665 "only_enable failed: No include entries resolved from list or yaml."
666 )
667 return
669 global current_work_registrar
670 current_work_registrar = registrar(
671 _FULL_CONFIG,
672 user_include_ops=include_ops,
673 user_exclude_ops=[],
674 cpp_patched_ops=list(set(aten_patch_list)),
675 full_config_by_func=FULL_CONFIG_BY_FUNC,
676 lib=lib,
677 )
678 setup_flaggems_logging(path=path, record=record, once=once)
681class use_gems:
682 """
683 The 'include' parameter has higher priority than 'exclude'.
684 When 'include' is not None, use_gems will not process 'exclude'.
685 """
687 def __init__(self, exclude=None, include=None, record=False, once=False, path=None):
688 self.lib = torch.library.Library("aten", "IMPL")
689 self.exclude = exclude if isinstance(exclude, (list, tuple, set, str)) else []
690 self.include = include if isinstance(include, (list, tuple, set, str)) else []
691 self.registrar = GeneralOpRegistrar
692 self.record = record
693 self.once = once
694 self.path = path
696 def __enter__(self):
697 if self.include:
698 only_enable(
699 lib=self.lib,
700 include=self.include,
701 registrar=self.registrar,
702 record=self.record,
703 once=self.once,
704 path=self.path,
705 )
706 else:
707 enable(
708 lib=self.lib,
709 unused=self.exclude,
710 registrar=self.registrar,
711 record=self.record,
712 once=self.once,
713 path=self.path,
714 )
716 def __exit__(self, exc_type, exc_val, exc_tb):
717 global current_work_registrar
718 if torch.__version__ >= "2.5":
719 self.lib._destroy()
720 del self.lib
721 del self.exclude
722 del self.include
723 del self.registrar
724 del current_work_registrar
725 if self.record:
726 teardown_flaggems_logging()
728 @property
729 def experimental_ops(self):
730 import flag_gems.experimental_ops
732 return flag_gems.experimental_ops
735def all_registered_ops():
736 return current_work_registrar.get_all_ops()
739def all_registered_keys():
740 return current_work_registrar.get_all_keys()
743__all__ = [
744 "all_registered_keys",
745 "all_registered_ops",
746 "enable",
747 "flagtune",
748 "only_enable",
749 "use_gems",
750]