Coverage for src/flag_gems/__init__.py: 91%
80 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1# ruff: noqa: F405
2import warnings
4import torch
5from packaging import version
7from flag_gems import testing # noqa: F401
8from flag_gems import runtime
9from flag_gems.config import aten_patch_list, resolve_user_setting
10from flag_gems.experimental_ops import * # noqa: F403
11from flag_gems.fused import * # noqa: F403
12from flag_gems.logging_utils import setup_flaggems_logging, teardown_flaggems_logging
13from flag_gems.modules import * # noqa: F403
14from flag_gems.ops import * # noqa: F403
15from flag_gems.patches import * # noqa: F403
16from flag_gems.runtime import flagtune
17from flag_gems.runtime.backend import SpecOpRegistrar
18from flag_gems.runtime.op_registrar import GeneralOpRegistrar
20__version__ = "5.0.2"
21device = runtime.device.name
22vendor_name = runtime.device.vendor_name
23backend_info = runtime.device
24aten_lib = torch.library.Library("aten", "IMPL")
26# Register all ops in the current backend with SpecOpRegistrar to support architecture-specialized implementations
27SpecOpRegistrar(globals()).apply()
29registrar = GeneralOpRegistrar
30current_work_registrar = None
31AUTOGRAD_DISPATCH_KEY = torch._C.DispatchKey.Autograd.name
34def torch_ge(v):
35 return version.parse(torch.__version__) >= version.parse(v)
38_FULL_CONFIG = (
39 ("__ior__.Scalar", bitwise_or_scalar_),
40 ("__ior__.Tensor", bitwise_or_tensor_),
41 ("__or__.Scalar", bitwise_or_scalar),
42 ("__or__.Tensor", bitwise_or_tensor),
43 ("_assert_async", _assert_async),
44 ("_conv_depthwise2d", _conv_depthwise2d),
45 ("_euclidean_dist", _euclidean_dist),
46 ("_flash_attention_forward", flash_attention_forward),
47 (
48 "_functional_sym_constrain_range_for_size",
49 _functional_sym_constrain_range_for_size,
50 ),
51 ("_grouped_mm", group_mm),
52 ("_index_put_impl_", _index_put_impl_),
53 ("_is_all_true", _is_all_true),
54 ("_log_softmax", log_softmax),
55 ("_log_softmax.out", log_softmax_out),
56 ("_log_softmax_backward_data", log_softmax_backward),
57 ("_log_softmax_backward_data.out", log_softmax_backward_out),
58 ("_safe_softmax", _safe_softmax),
59 ("_scaled_mm", scaled_mm, lambda: torch_ge("2.5")),
60 ("_scaled_mm.out", scaled_mm_out, lambda: torch_ge("2.5")),
61 ("_softmax", softmax),
62 ("_softmax.out", softmax_out),
63 ("_softmax_backward_data", softmax_backward),
64 ("_softmax_backward_data.out", softmax_backward_out),
65 (
66 "_to_copy",
67 to_copy,
68 lambda: version.parse(torch.__version__) >= version.parse("2.4"),
69 ),
70 ("_unique2", _unique2),
71 ("_upsample_bicubic2d_aa", _upsample_bicubic2d_aa),
72 ("_upsample_bicubic2d_aa_backward", _upsample_bicubic2d_aa_backward),
73 ("_upsample_nearest_exact1d", _upsample_nearest_exact1d),
74 ("_weight_norm_interface", weight_norm_interface),
75 ("_weight_norm_interface_backward", weight_norm_interface_backward),
76 ("abs", abs),
77 ("abs_", abs_),
78 ("absolute", absolute),
79 ("acos", acos),
80 ("add.Tensor", add),
81 ("add_.Tensor", add_),
82 ("add_rms_norm", add_rms_norm),
83 ("addcdiv", addcdiv),
84 ("addcdiv.out", addcdiv_out),
85 ("addcmul", addcmul),
86 ("addcmul.out", addcmul_out),
87 ("addmv", addmv),
88 ("addmv.out", addmv_out),
89 ("addmm", addmm),
90 ("addmm.out", addmm_out),
91 ("addmm.dtype", addmm_dtype),
92 ("addmm.dtype_out", addmm_dtype_out),
93 ("addr", addr),
94 ("affine_grid_generator", affine_grid_generator),
95 ("alias_copy", alias_copy),
96 ("all", all),
97 ("all.dim", all_dim),
98 ("all.dims", all_dims),
99 ("allclose", allclose),
100 ("amax", amax),
101 ("aminmax", aminmax),
102 ("angle", angle),
103 ("any", any),
104 ("any.dim", any_dim),
105 ("any.dims", any_dims),
106 ("arange", arange),
107 ("arange.start", arange_start),
108 ("arange.start_step", arange_start),
109 ("arcsinh", arcsinh),
110 ("arcsinh.out", arcsinh_out),
111 ("arcsinh_", arcsinh_),
112 ("argmax", argmax),
113 ("argmin", argmin),
114 ("argsort", argsort),
115 ("as_strided_copy", as_strided_copy),
116 ("as_strided_copy.out", as_strided_copy_out),
117 ("asinh", asinh),
118 ("asinh.out", asinh_out),
119 ("asinh_", asinh_),
120 ("atan", atan),
121 ("atan_", atan_),
122 ("atan2", atan2),
123 ("atan2.out", atan2_out),
124 ("arctanh_", arctanh_),
125 ("avg_pool2d", avg_pool2d),
126 ("avg_pool2d_backward", avg_pool2d_backward),
127 ("avg_pool3d", avg_pool3d),
128 ("avg_pool3d_backward", avg_pool3d_backward),
129 ("baddbmm", baddbmm),
130 ("bernoulli_.float", bernoulli_),
131 ("bincount", bincount),
132 ("bitwise_and.Scalar", bitwise_and_scalar),
133 ("bitwise_and.Scalar_Tensor", bitwise_and_scalar_tensor),
134 ("bitwise_and.Tensor", bitwise_and_tensor),
135 ("bitwise_and_.Scalar", bitwise_and_scalar_),
136 ("bitwise_and_.Tensor", bitwise_and_tensor_),
137 ("bitwise_left_shift", bitwise_left_shift),
138 ("bitwise_not", bitwise_not),
139 ("bitwise_not_", bitwise_not_),
140 ("bitwise_or.Scalar", bitwise_or_scalar),
141 ("bitwise_or.Scalar_Tensor", bitwise_or_scalar_tensor),
142 ("bitwise_or.Tensor", bitwise_or_tensor),
143 ("bitwise_or_.Scalar", bitwise_or_scalar_),
144 ("bitwise_or_.Tensor", bitwise_or_tensor_),
145 ("bitwise_right_shift", bitwise_right_shift),
146 ("bmm", bmm),
147 ("bmm.out", bmm_out),
148 ("cat", cat),
149 ("cat.out", cat_out),
150 ("cauchy", cauchy),
151 ("cauchy_", cauchy_),
152 ("celu", celu),
153 ("celu_", celu_),
154 ("ceil", ceil),
155 ("ceil_", ceil_),
156 ("ceil.out", ceil_out),
157 ("clamp", clamp),
158 ("clamp.Tensor", clamp_tensor),
159 ("clamp_max", clamp_max),
160 ("clamp_min", clamp_min),
161 ("clamp_", clamp_),
162 ("clamp_.Tensor", clamp_tensor_),
163 ("clamp_max_", clamp_max_),
164 ("clamp_min_", clamp_min_),
165 ("clip", clip),
166 ("clip_", clip_),
167 ("col2im", col2im),
168 ("concatenate", concatenate),
169 ("conj_physical", conj_physical),
170 ("constant_pad_nd", constant_pad_nd),
171 # ("contiguous", contiguous),
172 ("conv1d", conv1d),
173 ("conv1d.padding", conv1d),
174 ("conv2d", conv2d),
175 ("conv2d.padding", conv2d),
176 ("conv_transpose2d", conv_transpose2d),
177 ("conv3d", conv3d),
178 ("conv3d.padding", conv3d),
179 ("conv_transpose1d", conv_transpose1d),
180 (
181 "copy_",
182 copy_,
183 lambda: version.parse(torch.__version__) >= version.parse("2.4"),
184 ),
185 ("cos", cos),
186 ("cos_", cos_),
187 ("cosh", cosh),
188 ("cosh_", cosh_),
189 ("cosh.out", cosh_out),
190 ("copysign", copysign),
191 ("copysign.out", copysign_out),
192 ("count_nonzero", count_nonzero),
193 ("ctc_loss.IntList", ctc_loss, None, (AUTOGRAD_DISPATCH_KEY,)),
194 ("ctc_loss.Tensor", ctc_loss, None, (AUTOGRAD_DISPATCH_KEY,)),
195 ("cudnn_convolution", cudnn_convolution),
196 ("cummax", cummax),
197 ("cummin", cummin),
198 ("cumprod", cumprod),
199 ("cumprod_", cumprod_),
200 ("cumsum", cumsum),
201 ("cumsum.out", cumsum_out),
202 ("diag", diag),
203 ("diag_embed", diag_embed),
204 ("diagonal_backward", diagonal_backward),
205 ("diff", diff),
206 ("digamma_", digamma_),
207 ("div.Scalar", true_divide),
208 ("div.Scalar_mode", div_mode),
209 ("div.Tensor", true_divide),
210 ("div.Tensor_mode", div_mode),
211 ("div.out", true_divide_out),
212 ("div_.Scalar", true_divide_),
213 ("div_.Scalar_mode", div_mode_),
214 ("div_.Tensor", true_divide_),
215 ("div_.Tensor_mode", div_mode_),
216 ("divide.Scalar", true_divide),
217 ("divide.Scalar_mode", div_mode),
218 ("divide.Tensor", true_divide),
219 ("divide.Tensor_mode", div_mode),
220 ("divide_.Scalar", true_divide_),
221 ("divide_.Scalar_mode", div_mode_),
222 ("divide_.Tensor", true_divide_),
223 ("divide_.Tensor_mode", div_mode_),
224 ("dot", dot),
225 ("elu", elu),
226 ("elu_", elu_),
227 ("elu_backward", elu_backward),
228 ("embedding", embedding),
229 ("embedding_backward", embedding_backward),
230 ("embedding_dense_backward", embedding_dense_backward),
231 ("eq.Scalar", eq_scalar),
232 ("eq.Tensor", eq),
233 ("equal", equal),
234 ("erf", erf),
235 ("erf_", erf_),
236 ("exp", exp),
237 ("exp_", exp_),
238 ("exp.out", exp_out),
239 ("exp2", exp2),
240 ("exp2_", exp2_),
241 ("expm1", expm1),
242 ("expm1_", expm1_),
243 ("expm1.out", expm1_out),
244 ("exponential_", exponential_),
245 ("feature_dropout", feature_dropout),
246 ("feature_dropout_", feature_dropout_),
247 ("eye", eye),
248 ("eye.m", eye_m),
249 ("fill.Scalar", fill_scalar),
250 ("fill.Scalar_out", fill_scalar_out),
251 ("fill.Tensor", fill_tensor),
252 ("fill.Tensor_out", fill_tensor_out),
253 ("fill_.Scalar", fill_scalar_),
254 ("fill_.Tensor", fill_tensor_),
255 ("flip", flip),
256 ("floor_", floor_),
257 ("floor", floor),
258 ("floor.out", floor_out),
259 ("floor_divide", floor_divide),
260 ("floor_divide.Scalar", floor_divide),
261 ("floor_divide_.Scalar", floor_divide_),
262 ("floor_divide_.Tensor", floor_divide_),
263 ("fmin", fmin),
264 ("fmin.out", fmin_out),
265 ("fmod.Scalar", fmod_scalar),
266 ("fmod.Tensor", fmod_tensor),
267 ("fmod_.Scalar", fmod_scalar_),
268 ("fmod_.Tensor", fmod_tensor_),
269 ("full", full),
270 ("full_like", full_like),
271 ("gather", gather),
272 ("gather_backward", gather_backward),
273 ("gcd", gcd),
274 ("gcd.out", gcd_out),
275 ("ge.Scalar", ge_scalar),
276 ("ge.Tensor", ge),
277 ("gelu", gelu),
278 ("gelu_", gelu_),
279 ("gelu_backward", gelu_backward),
280 ("glu", glu),
281 ("glu_backward", glu_backward),
282 ("greater.Scalar", greater_scalar),
283 ("greater.Tensor", greater),
284 ("greater.Scalar_out", greater_scalar_out),
285 ("greater.out", greater_out),
286 ("grid_sample", grid_sample),
287 ("gt.Scalar", gt_scalar),
288 ("gt.Tensor", gt),
289 ("hardsigmoid", hardsigmoid),
290 ("hardsigmoid.out", hardsigmoid_out),
291 ("hardswish_", hardswish_),
292 ("histc", histc),
293 ("hstack", hstack),
294 ("hypot", hypot),
295 ("i0", i0),
296 ("i0.out", i0_out),
297 ("i0_", i0_),
298 ("index.Tensor", index),
299 ("index_add", index_add),
300 ("index_add_", index_add_),
301 ("index_copy", index_copy),
302 ("index_copy_", index_copy_),
303 ("index_put", index_put),
304 ("index_put_", index_put_),
305 ("index_select", index_select),
306 ("isclose", isclose),
307 ("isfinite", isfinite),
308 ("isin.Scalar_Tensor", isin),
309 ("isin.Tensor_Scalar", isin),
310 ("isin.Tensor_Tensor", isin),
311 ("isinf", isinf),
312 ("isnan", isnan),
313 ("isneginf", isneginf),
314 ("isneginf.out", isneginf_out),
315 ("kron", kron),
316 ("le.Scalar", le_scalar),
317 ("le.Tensor", le),
318 ("leaky_relu", leaky_relu),
319 ("leaky_relu_", leaky_relu_),
320 ("leaky_relu.out", leaky_relu_out),
321 ("lerp.Scalar", lerp_scalar),
322 ("lerp.Tensor", lerp_tensor),
323 ("lerp_.Scalar", lerp_scalar_),
324 ("lerp_.Tensor", lerp_tensor_),
325 ("lift_fresh_copy", lift_fresh_copy),
326 ("linalg_vector_norm", vector_norm),
327 ("linspace", linspace),
328 ("log", log),
329 ("log10", log10),
330 ("log10_", log10_),
331 ("log10.out", log10_out),
332 ("log1p", log1p),
333 ("log1p_", log1p_),
334 ("log_sigmoid", log_sigmoid),
335 ("logaddexp", logaddexp),
336 ("logaddexp.out", logaddexp_out),
337 ("logical_and", logical_and),
338 ("logical_and_", logical_and_),
339 ("logical_not", logical_not),
340 ("logical_or", logical_or),
341 ("logical_or_", logical_or_),
342 ("logical_xor", logical_xor),
343 ("logit", logit),
344 ("logit.out", logit_out),
345 ("logit_", logit_),
346 ("logspace", logspace),
347 ("logsumexp", logsumexp),
348 ("lt.Scalar", lt_scalar),
349 ("lt.Tensor", lt),
350 ("margin_ranking_loss", margin_ranking_loss),
351 ("masked_fill.Scalar", masked_fill),
352 ("masked_fill.Tensor", masked_fill),
353 ("masked_fill_.Scalar", masked_fill_),
354 ("masked_fill_.Tensor", masked_fill_),
355 ("masked_scatter", masked_scatter),
356 ("masked_scatter_", masked_scatter_),
357 ("masked_select", masked_select),
358 ("max", max),
359 ("max.dim", max_dim),
360 ("max_pool2d_backward", max_pool2d_backward),
361 ("max_pool2d_with_indices", max_pool2d_with_indices),
362 ("max_pool3d_backward", max_pool3d_backward),
363 ("max_pool3d_with_indices", max_pool3d_with_indices),
364 ("maximum", maximum),
365 ("mean", mean),
366 ("mean.dim", mean_dim),
367 ("median", median),
368 ("median.out", median_out),
369 ("median.dim", median_dim),
370 ("median.dim_values", median_dim_values),
371 ("min", min),
372 ("min.dim", min_dim),
373 ("minimum", minimum),
374 ("mm", mm),
375 ("mm.out", mm_out),
376 ("mse_loss", mse_loss),
377 ("mul.Tensor", mul),
378 ("mul_.Tensor", mul_),
379 ("multinomial", multinomial),
380 ("mv", mv),
381 ("nan_to_num", nan_to_num),
382 ("native_batch_norm", batch_norm),
383 ("native_batch_norm_backward", batch_norm_backward),
384 ("native_dropout", dropout),
385 ("native_dropout_backward", dropout_backward),
386 ("native_group_norm", group_norm),
387 ("native_group_norm_backward", group_norm_backward),
388 ("native_layer_norm", layer_norm),
389 ("native_layer_norm_backward", layer_norm_backward),
390 ("ne.Scalar", ne_scalar),
391 ("ne.Tensor", ne),
392 ("neg", neg),
393 ("neg_", neg_),
394 ("new_full.Tensor", new_full),
395 ("nll_loss_backward", nll_loss_backward),
396 ("nll_loss_forward", nll_loss_forward),
397 ("nll_loss_nd_forward", nll_loss_nd_forward),
398 ("nll_loss_nd_backward", nll_loss_nd_backward),
399 ("nll_loss2d_backward", nll_loss2d_backward),
400 ("nll_loss2d_forward", nll_loss2d_forward),
401 ("nonzero", nonzero),
402 ("nonzero_numpy", nonzero_numpy),
403 ("normal.Tensor_float", normal_tensor_float),
404 ("normal.Tensor_Tensor", normal_tensor_tensor),
405 ("normal.float_Tensor", normal_float_tensor),
406 ("normal_", normal_),
407 ("normed_cumsum", normed_cumsum),
408 ("one_hot", one_hot),
409 ("ones", ones),
410 ("ones_like", ones_like),
411 ("pad", pad),
412 ("pixel_shuffle", pixel_shuffle),
413 ("pixel_unshuffle", pixel_unshuffle),
414 ("pixel_unshuffle.out", pixel_unshuffle_out),
415 ("poisson", poisson),
416 ("polar", polar),
417 ("pow.Scalar", pow_scalar),
418 ("pow.Tensor_Scalar", pow_tensor_scalar),
419 ("pow.Tensor_Tensor", pow_tensor_tensor),
420 ("pow_.Scalar", pow_tensor_scalar_),
421 ("pow_.Tensor", pow_tensor_tensor_),
422 ("prelu", prelu),
423 ("prod", prod),
424 ("prod.dim_int", prod_dim),
425 ("quantile", quantile),
426 ("rad2deg", rad2deg),
427 ("rad2deg_", rad2deg_),
428 ("rand", rand),
429 ("rand_like", rand_like),
430 ("randn", randn),
431 ("randn_like", randn_like),
432 ("randint", randint),
433 ("randint_like", randint_like),
434 ("randperm", randperm),
435 ("reciprocal", reciprocal),
436 ("reciprocal_", reciprocal_),
437 ("reflection_pad1d", reflection_pad1d),
438 ("reflection_pad1d.out", reflection_pad1d_out),
439 ("reflection_pad1d_backward", reflection_pad1d_backward),
440 ("reflection_pad2d", reflection_pad2d),
441 ("reflection_pad2d.out", reflection_pad2d_out),
442 ("relu", relu),
443 ("relu_", relu_),
444 ("relu6", relu6),
445 ("remainder", remainder),
446 ("remainder.Scalar", remainder),
447 ("remainder.Scalar_Tensor", remainder),
448 ("remainder.Tensor", remainder),
449 ("remainder_.Scalar", remainder_),
450 ("remainder_.Tensor", remainder_),
451 ("renorm", renorm),
452 ("renorm_", renorm_),
453 ("repeat", repeat),
454 ("repeat_interleave.self_int", repeat_interleave_self_int),
455 ("repeat_interleave.self_Tensor", repeat_interleave_self_tensor),
456 ("repeat_interleave.Tensor", repeat_interleave_tensor),
457 ("replication_pad1d", replication_pad1d),
458 ("replication_pad1d.out", replication_pad1d_out),
459 ("replication_pad3d", replication_pad3d),
460 ("resolve_conj", resolve_conj),
461 ("resolve_neg", resolve_neg),
462 ("rms_norm", rms_norm),
463 ("roll", roll),
464 ("round", round),
465 ("round_", round_),
466 ("round.out", round_out),
467 ("rrelu_with_noise_backward", rrelu_with_noise_backward),
468 ("rsqrt", rsqrt),
469 ("rsqrt_", rsqrt_),
470 ("rsub.Scalar", rsub_scalar),
471 ("rsub.Tensor", rsub_tensor),
472 ("scaled_softmax_backward", scaled_softmax_backward),
473 ("scaled_softmax_forward", scaled_softmax_forward),
474 ("scatter.reduce", scatter),
475 ("scatter.src", scatter),
476 ("scatter_.reduce", scatter_),
477 ("scatter_.src", scatter_),
478 ("scatter_add_", scatter_add_),
479 ("scatter_reduce.two", scatter_reduce),
480 ("scatter_reduce_.two", scatter_reduce_),
481 ("scatter_reduce.two_out", scatter_reduce_out),
482 ("select_backward", select_backward),
483 ("select_scatter", select_scatter),
484 ("selu", selu),
485 ("selu_", selu_),
486 ("sgn_", sgn_),
487 ("sigmoid", sigmoid),
488 ("sigmoid_", sigmoid_),
489 ("sigmoid_backward", sigmoid_backward),
490 ("signbit", signbit),
491 ("signbit.out", signbit_out),
492 ("silu", silu),
493 ("silu_", silu_),
494 ("silu_backward", silu_backward),
495 ("sin", sin),
496 ("sin_", sin_),
497 ("sinh_", sinh_),
498 ("slice_backward", slice_backward),
499 ("slice_scatter", slice_scatter),
500 ("smooth_l1_loss", smooth_l1_loss),
501 ("smooth_l1_loss_backward", smooth_l1_loss_backward),
502 ("smooth_l1_loss.out", smooth_l1_loss_out),
503 ("soft_margin_loss", soft_margin_loss),
504 ("softplus", softplus),
505 ("softshrink", softshrink),
506 ("softshrink.out", softshrink_out),
507 ("sort", sort),
508 ("sort.stable", sort_stable),
509 ("special_i0e", special_i0e),
510 ("special_i0e.out", special_i0e_out),
511 ("special_i1", special_i1),
512 ("special_i1.out", special_i1_out),
513 ("sqrt", sqrt),
514 ("sqrt_", sqrt_),
515 ("square", square),
516 ("square_", square_),
517 ("square.out", square_out),
518 ("stack", stack),
519 ("std.correction", std),
520 ("sub.Tensor", sub),
521 ("sub_.Tensor", sub_),
522 ("sum", sum),
523 ("sum.IntList_out", sum_dim_out),
524 ("sum.dim_IntList", sum_dim),
525 ("sum.out", sum_out),
526 ("svd", svd),
527 ("t_copy", t_copy),
528 ("t_copy.out", t_copy_out),
529 ("tan", tan),
530 ("tan_", tan_),
531 ("tanh", tanh),
532 ("tanh_", tanh_),
533 ("tanh_backward", tanh_backward),
534 ("threshold", threshold),
535 ("threshold_backward", threshold_backward),
536 ("tile", tile),
537 ("topk", topk),
538 ("trace", trace),
539 ("tril", tril),
540 ("tril.out", tril_out),
541 ("tril_", tril_),
542 ("triu", triu),
543 ("triu_", triu_),
544 ("true_divide.Scalar", true_divide),
545 ("true_divide.Tensor", true_divide),
546 ("true_divide_.Scalar", true_divide_),
547 ("true_divide_.Tensor", true_divide_),
548 ("unfold_backward", unfold_backward),
549 ("uniform_", uniform_),
550 ("unique_consecutive", unique_consecutive),
551 ("upsample_bicubic2d", upsample_bicubic2d),
552 ("upsample_linear1d", upsample_linear1d),
553 ("upsample_nearest1d", upsample_nearest1d),
554 ("upsample_nearest2d", upsample_nearest2d),
555 ("upsample_nearest3d", upsample_nearest3d),
556 ("var_mean.correction", var_mean),
557 ("var", var),
558 ("var.correction", var_correction),
559 ("var.dim", var_dim),
560 ("vdot", vdot),
561 ("vstack", vstack),
562 ("where.self", where_self),
563 ("where.self_out", where_self_out),
564 ("zero", zero),
565 ("zero_", zero_),
566 ("zero.out", zero_out),
567 ("zeros", zeros),
568 ("zeros_like", zeros_like),
569)
571# Cache mapping from function name -> list of _FULL_CONFIG entries for quick lookup
572FULL_CONFIG_BY_FUNC = {}
573for _item in _FULL_CONFIG:
574 if not _item or len(_item) < 2:
575 continue
576 fn = _item[1]
577 func_name = fn.__name__ if hasattr(fn, "__name__") else str(fn)
578 FULL_CONFIG_BY_FUNC.setdefault(func_name, []).append(_item)
580# Friendly names for only_enable(include=[...]) when the registered impl is *.out
581for _alias, _target in (
582 ("softmax", "softmax_out"),
583 ("softmax_backward", "softmax_backward_out"),
584 ("log_softmax", "log_softmax_out"),
585 ("log_softmax_backward", "log_softmax_backward_out"),
586):
587 if _target in FULL_CONFIG_BY_FUNC:
588 FULL_CONFIG_BY_FUNC.setdefault(_alias, []).extend(FULL_CONFIG_BY_FUNC[_target])
591def enable(
592 lib=aten_lib,
593 unused=None,
594 registrar=registrar,
595 record=False,
596 once=False,
597 path=None,
598):
599 """Register all FlagGems ops except those explicitly excluded.
601 Args:
602 lib: torch.library.Library instance to register into. Defaults to the
603 global `aten_lib` (IMPL mode).
604 unused: Which ops to skip. Supported forms:
605 - list/tuple/set of function names (e.g., ["masked_fill", "mul"]).
606 - str path to a YAML file ending with .yml/.yaml containing an
607 `exclude:` list.
608 - "default" or None: auto-load vendor/arch-specific
609 runtime/backend/_<vendor>/[<arch>/]enable_configs.yaml if present.
610 registrar: Registrar class; defaults to `Register`.
611 record: Whether to enable FlagGems logging.
612 once: When True, log only once.
613 path: Optional log output path when recording.
615 Notes:
616 - If the exclude list/YAML resolves to empty, all ops are registered.
617 """
618 global current_work_registrar
619 exclude_ops = resolve_user_setting(unused, "exclude")
620 current_work_registrar = registrar(
621 _FULL_CONFIG,
622 user_include_ops=[],
623 user_exclude_ops=exclude_ops,
624 cpp_patched_ops=list(set(aten_patch_list)),
625 lib=lib,
626 )
627 setup_flaggems_logging(path=path, record=record, once=once)
630def only_enable(
631 lib=aten_lib,
632 include=None,
633 registrar=registrar,
634 record=False,
635 once=False,
636 path=None,
637):
638 """Register only the specified FlagGems ops and skip the rest.
640 Args:
641 lib: torch.library.Library instance to register into. Defaults to the
642 global `aten_lib` (IMPL mode).
643 include: Which ops to register. Supported forms:
644 - list/tuple/set of function names (e.g., ["rms_norm", "softmax"]).
645 - str path to a YAML file ending with .yml/.yaml (expects a list or
646 an `include:` key).
647 - "default" or None: auto-load vendor/arch-specific
648 runtime/backend/_<vendor>/[<arch>/]only_enable_configs.yaml if present.
649 registrar: Registrar class; defaults to `Register`.
650 record: Whether to enable FlagGems logging.
651 once: When True, log only once.
652 path: Optional log output path when recording.
654 Classic usage:
655 - Only register a few ops:
656 only_enable(include=["rms_norm", "softmax"])
657 - Use vendor default YAML:
658 only_enable(include="default") # or include=None
659 - Use a custom YAML:
660 only_enable(include="/path/to/only_enable.yaml")
662 Notes:
663 - If the include list/YAML resolves to empty or none of the names match
664 known ops, the function warns and returns without registering.
665 """
666 include_ops = resolve_user_setting(include, "include")
667 if not include_ops:
668 warnings.warn(
669 "only_enable failed: No include entries resolved from list or yaml."
670 )
671 return
673 global current_work_registrar
674 current_work_registrar = registrar(
675 _FULL_CONFIG,
676 user_include_ops=include_ops,
677 user_exclude_ops=[],
678 cpp_patched_ops=list(set(aten_patch_list)),
679 full_config_by_func=FULL_CONFIG_BY_FUNC,
680 lib=lib,
681 )
682 setup_flaggems_logging(path=path, record=record, once=once)
685class use_gems:
686 """
687 The 'include' parameter has higher priority than 'exclude'.
688 When 'include' is not None, use_gems will not process 'exclude'.
689 """
691 def __init__(self, exclude=None, include=None, record=False, once=False, path=None):
692 self.lib = torch.library.Library("aten", "IMPL")
693 self.exclude = exclude if isinstance(exclude, (list, tuple, set, str)) else []
694 self.include = include if isinstance(include, (list, tuple, set, str)) else []
695 self.registrar = GeneralOpRegistrar
696 self.record = record
697 self.once = once
698 self.path = path
700 def __enter__(self):
701 if self.include:
702 only_enable(
703 lib=self.lib,
704 include=self.include,
705 registrar=self.registrar,
706 record=self.record,
707 once=self.once,
708 path=self.path,
709 )
710 else:
711 enable(
712 lib=self.lib,
713 unused=self.exclude,
714 registrar=self.registrar,
715 record=self.record,
716 once=self.once,
717 path=self.path,
718 )
720 def __exit__(self, exc_type, exc_val, exc_tb):
721 global current_work_registrar
722 if torch.__version__ >= "2.5":
723 self.lib._destroy()
724 del self.lib
725 del self.exclude
726 del self.include
727 del self.registrar
728 del current_work_registrar
729 if self.record:
730 teardown_flaggems_logging()
732 @property
733 def experimental_ops(self):
734 import flag_gems.experimental_ops
736 return flag_gems.experimental_ops
739def all_registered_ops():
740 return current_work_registrar.get_all_ops()
743def all_registered_keys():
744 return current_work_registrar.get_all_keys()
747__all__ = [
748 "all_registered_keys",
749 "all_registered_ops",
750 "enable",
751 "flagtune",
752 "only_enable",
753 "use_gems",
754]