Coverage for src/flag_gems/__init__.py: 91%
80 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1# ruff: noqa: F405
2import warnings
4import torch
5from packaging import version
7from flag_gems import testing # noqa: F401
8from flag_gems import runtime
9from flag_gems.config import aten_patch_list, resolve_user_setting
10from flag_gems.experimental_ops import * # noqa: F403
11from flag_gems.fused import * # noqa: F403
12from flag_gems.logging_utils import setup_flaggems_logging, teardown_flaggems_logging
13from flag_gems.modules import * # noqa: F403
14from flag_gems.ops import * # noqa: F403
15from flag_gems.patches import * # noqa: F403
16from flag_gems.runtime import flagtune
17from flag_gems.runtime.backend import SpecOpRegistrar
18from flag_gems.runtime.op_registrar import GeneralOpRegistrar
20__version__ = "5.0.2"
21device = runtime.device.name
22vendor_name = runtime.device.vendor_name
23backend_info = runtime.device
24aten_lib = torch.library.Library("aten", "IMPL")
26# Register all ops in the current backend with SpecOpRegistrar to support architecture-specialized implementations
27SpecOpRegistrar(registry=globals(), vendor=vendor_name).apply()
29registrar = GeneralOpRegistrar
30current_work_registrar = None
31AUTOGRAD_DISPATCH_KEY = torch._C.DispatchKey.Autograd.name
34def torch_ge(v):
35 return version.parse(torch.__version__) >= version.parse(v)
38_FULL_CONFIG = (
39 ("__ior__.Scalar", bitwise_or_scalar_),
40 ("__ior__.Tensor", bitwise_or_tensor_),
41 ("__or__.Scalar", bitwise_or_scalar),
42 ("__or__.Tensor", bitwise_or_tensor),
43 ("_assert_async", _assert_async),
44 ("_conv_depthwise2d", _conv_depthwise2d),
45 ("_euclidean_dist", _euclidean_dist),
46 ("_flash_attention_forward", flash_attention_forward),
47 (
48 "_functional_sym_constrain_range_for_size",
49 _functional_sym_constrain_range_for_size,
50 ),
51 ("_grouped_mm", group_mm),
52 ("_index_put_impl_", _index_put_impl_),
53 ("_is_all_true", _is_all_true),
54 ("_log_softmax", log_softmax),
55 ("_log_softmax.out", log_softmax_out),
56 ("_log_softmax_backward_data", log_softmax_backward),
57 ("_log_softmax_backward_data.out", log_softmax_backward_out),
58 ("_safe_softmax", _safe_softmax),
59 ("_scaled_mm", scaled_mm, lambda: torch_ge("2.5")),
60 ("_scaled_mm.out", scaled_mm_out, lambda: torch_ge("2.5")),
61 ("_segment_reduce_backward", _segment_reduce_backward),
62 ("_segment_reduce_backward.out", _segment_reduce_backward_out),
63 ("_softmax", softmax),
64 ("_softmax.out", softmax_out),
65 ("_softmax_backward_data", softmax_backward),
66 ("_softmax_backward_data.out", softmax_backward_out),
67 (
68 "_to_copy",
69 to_copy,
70 lambda: version.parse(torch.__version__) >= version.parse("2.4"),
71 ),
72 ("_unique2", _unique2),
73 ("_upsample_bicubic2d_aa", _upsample_bicubic2d_aa),
74 ("_upsample_bicubic2d_aa_backward", _upsample_bicubic2d_aa_backward),
75 ("_upsample_nearest_exact1d", _upsample_nearest_exact1d),
76 ("_weight_norm_interface", weight_norm_interface),
77 ("_weight_norm_interface_backward", weight_norm_interface_backward),
78 ("abs", abs),
79 ("abs_", abs_),
80 ("absolute", absolute),
81 ("acos", acos),
82 ("add.Tensor", add),
83 ("add_.Tensor", add_),
84 ("add_rms_norm", add_rms_norm),
85 ("addcdiv", addcdiv),
86 ("addcdiv.out", addcdiv_out),
87 ("addcmul", addcmul),
88 ("addcmul.out", addcmul_out),
89 ("addmm", addmm),
90 ("addmm.dtype", addmm_dtype),
91 ("addmm.dtype_out", addmm_dtype_out),
92 ("addmm.out", addmm_out),
93 ("addmv", addmv),
94 ("addmv.out", addmv_out),
95 ("addr", addr),
96 ("affine_grid_generator", affine_grid_generator),
97 ("alias_copy", alias_copy),
98 ("all", all),
99 ("all.dim", all_dim),
100 ("all.dims", all_dims),
101 ("allclose", allclose),
102 ("amax", amax),
103 ("aminmax", aminmax),
104 ("angle", angle),
105 ("any", any),
106 ("any.dim", any_dim),
107 ("any.dims", any_dims),
108 ("arange", arange),
109 ("arange.start", arange_start),
110 ("arange.start_step", arange_start),
111 ("arcsinh", arcsinh),
112 ("arcsinh.out", arcsinh_out),
113 ("arcsinh_", arcsinh_),
114 ("arctanh_", arctanh_),
115 ("argmax", argmax),
116 ("argmin", argmin),
117 ("argsort", argsort),
118 ("as_strided_copy", as_strided_copy),
119 ("as_strided_copy.out", as_strided_copy_out),
120 ("asinh", asinh),
121 ("asinh.out", asinh_out),
122 ("asinh_", asinh_),
123 ("atan", atan),
124 ("atan2", atan2),
125 ("atan2.out", atan2_out),
126 ("atan_", atan_),
127 ("atanh", atanh),
128 ("avg_pool2d", avg_pool2d),
129 ("avg_pool2d_backward", avg_pool2d_backward),
130 ("avg_pool3d", avg_pool3d),
131 ("avg_pool3d_backward", avg_pool3d_backward),
132 ("baddbmm", baddbmm),
133 ("bernoulli_.float", bernoulli_),
134 ("bincount", bincount),
135 ("bitwise_and.Scalar", bitwise_and_scalar),
136 ("bitwise_and.Scalar_Tensor", bitwise_and_scalar_tensor),
137 ("bitwise_and.Tensor", bitwise_and_tensor),
138 ("bitwise_and_.Scalar", bitwise_and_scalar_),
139 ("bitwise_and_.Tensor", bitwise_and_tensor_),
140 ("bitwise_left_shift", bitwise_left_shift),
141 ("bitwise_not", bitwise_not),
142 ("bitwise_not_", bitwise_not_),
143 ("bitwise_or.Scalar", bitwise_or_scalar),
144 ("bitwise_or.Scalar_Tensor", bitwise_or_scalar_tensor),
145 ("bitwise_or.Tensor", bitwise_or_tensor),
146 ("bitwise_or_.Scalar", bitwise_or_scalar_),
147 ("bitwise_or_.Tensor", bitwise_or_tensor_),
148 ("bitwise_right_shift", bitwise_right_shift),
149 ("bmm", bmm),
150 ("bmm.out", bmm_out),
151 ("cat", cat),
152 ("cat.out", cat_out),
153 ("cauchy", cauchy),
154 ("cauchy_", cauchy_),
155 ("ceil", ceil),
156 ("ceil.out", ceil_out),
157 ("ceil_", ceil_),
158 ("celu", celu),
159 ("celu_", celu_),
160 ("clamp", clamp),
161 ("clamp.Tensor", clamp_tensor),
162 ("clamp_", clamp_),
163 ("clamp_.Tensor", clamp_tensor_),
164 ("clamp_max", clamp_max),
165 ("clamp_max_", clamp_max_),
166 ("clamp_min", clamp_min),
167 ("clamp_min_", clamp_min_),
168 ("clip", clip),
169 ("clip_", clip_),
170 ("col2im", col2im),
171 ("concatenate", concatenate),
172 ("conj_physical", conj_physical),
173 ("constant_pad_nd", constant_pad_nd),
174 # ("contiguous", contiguous),
175 ("conv1d", conv1d),
176 ("conv1d.padding", conv1d),
177 ("conv2d", conv2d),
178 ("conv2d.padding", conv2d),
179 ("conv3d", conv3d),
180 ("conv3d.padding", conv3d),
181 ("conv_transpose1d", conv_transpose1d),
182 ("conv_transpose2d", conv_transpose2d),
183 (
184 "copy_",
185 copy_,
186 lambda: version.parse(torch.__version__) >= version.parse("2.4"),
187 ),
188 ("copysign", copysign),
189 ("copysign.out", copysign_out),
190 ("cos", cos),
191 ("cos_", cos_),
192 ("cosh", cosh),
193 ("cosh.out", cosh_out),
194 ("cosh_", cosh_),
195 ("count_nonzero", count_nonzero),
196 ("ctc_loss.IntList", ctc_loss, None, (AUTOGRAD_DISPATCH_KEY,)),
197 ("ctc_loss.Tensor", ctc_loss, None, (AUTOGRAD_DISPATCH_KEY,)),
198 ("cudnn_convolution", cudnn_convolution),
199 ("cummax", cummax),
200 ("cummin", cummin),
201 ("cumprod", cumprod),
202 ("cumprod_", cumprod_),
203 ("cumsum", cumsum),
204 ("cumsum.out", cumsum_out),
205 ("diag", diag),
206 ("diag_embed", diag_embed),
207 ("diagonal_backward", diagonal_backward),
208 ("diff", diff),
209 ("digamma_", digamma_),
210 ("div.Scalar", true_divide),
211 ("div.Scalar_mode", div_mode),
212 ("div.Tensor", true_divide),
213 ("div.Tensor_mode", div_mode),
214 ("div.out", true_divide_out),
215 ("div_.Scalar", true_divide_),
216 ("div_.Scalar_mode", div_mode_),
217 ("div_.Tensor", true_divide_),
218 ("div_.Tensor_mode", div_mode_),
219 ("divide.Scalar", true_divide),
220 ("divide.Scalar_mode", div_mode),
221 ("divide.Tensor", true_divide),
222 ("divide.Tensor_mode", div_mode),
223 ("divide_.Scalar", true_divide_),
224 ("divide_.Scalar_mode", div_mode_),
225 ("divide_.Tensor", true_divide_),
226 ("divide_.Tensor_mode", div_mode_),
227 ("dot", dot),
228 ("elu", elu),
229 ("elu_", elu_),
230 ("elu_backward", elu_backward),
231 ("embedding", embedding),
232 ("embedding_backward", embedding_backward),
233 ("embedding_dense_backward", embedding_dense_backward),
234 ("eq.Scalar", eq_scalar),
235 ("eq.Tensor", eq),
236 ("equal", equal),
237 ("erf", erf),
238 ("erf_", erf_),
239 ("exp", exp),
240 ("exp.out", exp_out),
241 ("exp2", exp2),
242 ("exp2_", exp2_),
243 ("exp_", exp_),
244 ("expm1", expm1),
245 ("expm1.out", expm1_out),
246 ("expm1_", expm1_),
247 ("exponential_", exponential_),
248 ("eye", eye),
249 ("eye.m", eye_m),
250 ("feature_dropout", feature_dropout),
251 ("feature_dropout_", feature_dropout_),
252 ("fill.Scalar", fill_scalar),
253 ("fill.Scalar_out", fill_scalar_out),
254 ("fill.Tensor", fill_tensor),
255 ("fill.Tensor_out", fill_tensor_out),
256 ("fill_.Scalar", fill_scalar_),
257 ("fill_.Tensor", fill_tensor_),
258 ("flip", flip),
259 ("floor", floor),
260 ("floor.out", floor_out),
261 ("floor_", floor_),
262 ("floor_divide", floor_divide),
263 ("floor_divide.Scalar", floor_divide),
264 ("floor_divide_.Scalar", floor_divide_),
265 ("floor_divide_.Tensor", floor_divide_),
266 ("fmin", fmin),
267 ("fmin.out", fmin_out),
268 ("fmod.Scalar", fmod_scalar),
269 ("fmod.Tensor", fmod_tensor),
270 ("fmod_", fmod_),
271 ("fmod_.Scalar", fmod_scalar_),
272 ("fmod_.Tensor", fmod_tensor_),
273 ("full", full),
274 ("full_like", full_like),
275 ("gather", gather),
276 ("gather_backward", gather_backward),
277 ("gcd", gcd),
278 ("gcd.out", gcd_out),
279 ("ge.Scalar", ge_scalar),
280 ("ge.Tensor", ge),
281 ("gelu", gelu),
282 ("gelu_", gelu_),
283 ("gelu_backward", gelu_backward),
284 ("glu", glu),
285 ("glu_backward", glu_backward),
286 ("greater.Scalar", greater_scalar),
287 ("greater.Scalar_out", greater_scalar_out),
288 ("greater.Tensor", greater),
289 ("greater.out", greater_out),
290 ("grid_sample", grid_sample),
291 ("gt.Scalar", gt_scalar),
292 ("gt.Tensor", gt),
293 ("hardsigmoid", hardsigmoid),
294 ("hardsigmoid.out", hardsigmoid_out),
295 ("hardswish_", hardswish_),
296 ("histc", histc),
297 ("hstack", hstack),
298 ("hypot", hypot),
299 ("i0", i0),
300 ("i0.out", i0_out),
301 ("i0_", i0_),
302 ("index.Tensor", index),
303 ("index_add", index_add),
304 ("index_add_", index_add_),
305 ("index_copy", index_copy),
306 ("index_copy_", index_copy_),
307 ("index_put", index_put),
308 ("index_put_", index_put_),
309 ("index_reduce_", index_reduce_),
310 ("index_select", index_select),
311 ("isclose", isclose),
312 ("isfinite", isfinite),
313 ("isin.Scalar_Tensor", isin),
314 ("isin.Tensor_Scalar", isin),
315 ("isin.Tensor_Tensor", isin),
316 ("isinf", isinf),
317 ("isnan", isnan),
318 ("isneginf", isneginf),
319 ("isneginf.out", isneginf_out),
320 ("kron", kron),
321 ("le.Scalar", le_scalar),
322 ("le.Tensor", le),
323 ("leaky_relu", leaky_relu),
324 ("leaky_relu.out", leaky_relu_out),
325 ("leaky_relu_", leaky_relu_),
326 ("lerp.Scalar", lerp_scalar),
327 ("lerp.Tensor", lerp_tensor),
328 ("lerp_.Scalar", lerp_scalar_),
329 ("lerp_.Tensor", lerp_tensor_),
330 ("lift_fresh_copy", lift_fresh_copy),
331 ("linalg_vector_norm", vector_norm),
332 ("linspace", linspace),
333 ("log", log),
334 ("log1p", log1p),
335 ("log1p.out", log1p_out),
336 ("log10", log10),
337 ("log10.out", log10_out),
338 ("log10_", log10_),
339 ("log1p", log1p),
340 ("log1p_", log1p_),
341 ("log_sigmoid", log_sigmoid),
342 ("logaddexp", logaddexp),
343 ("logaddexp.out", logaddexp_out),
344 ("logical_and", logical_and),
345 ("logical_and_", logical_and_),
346 ("logical_not", logical_not),
347 ("logical_or", logical_or),
348 ("logical_or_", logical_or_),
349 ("logical_xor", logical_xor),
350 ("logit", logit),
351 ("logit.out", logit_out),
352 ("logit_", logit_),
353 ("logspace", logspace),
354 ("logsumexp", logsumexp),
355 ("lt.Scalar", lt_scalar),
356 ("lt.Tensor", lt),
357 ("margin_ranking_loss", margin_ranking_loss),
358 ("masked_fill.Scalar", masked_fill),
359 ("masked_fill.Tensor", masked_fill),
360 ("masked_fill_.Scalar", masked_fill_),
361 ("masked_fill_.Tensor", masked_fill_),
362 ("masked_scatter", masked_scatter),
363 ("masked_scatter_", masked_scatter_),
364 ("masked_select", masked_select),
365 ("max", max),
366 ("max.dim", max_dim),
367 ("max_pool2d_backward", max_pool2d_backward),
368 ("max_pool2d_with_indices", max_pool2d_with_indices),
369 ("max_pool3d_backward", max_pool3d_backward),
370 ("max_pool3d_with_indices", max_pool3d_with_indices),
371 ("maximum", maximum),
372 ("mean", mean),
373 ("mean.dim", mean_dim),
374 ("median", median),
375 ("median.dim", median_dim),
376 ("median.dim_values", median_dim_values),
377 ("median.out", median_out),
378 ("min", min),
379 ("min.dim", min_dim),
380 ("minimum", minimum),
381 ("mm", mm),
382 ("mm.out", mm_out),
383 ("mse_loss", mse_loss),
384 ("mul.Tensor", mul),
385 ("mul_.Tensor", mul_),
386 ("multinomial", multinomial),
387 ("mv", mv),
388 ("nan_to_num", nan_to_num),
389 ("nanmedian", nanmedian),
390 ("nanmedian.dim", nanmedian_dim),
391 ("nanmedian.dim_values", nanmedian_dim_values),
392 ("nanmedian.out", nanmedian_out),
393 ("native_batch_norm", batch_norm),
394 ("native_batch_norm_backward", batch_norm_backward),
395 ("native_dropout", dropout),
396 ("native_dropout_backward", dropout_backward),
397 ("native_group_norm", group_norm),
398 ("native_group_norm_backward", group_norm_backward),
399 ("native_layer_norm", layer_norm),
400 ("native_layer_norm_backward", layer_norm_backward),
401 ("ne.Scalar", ne_scalar),
402 ("ne.Tensor", ne),
403 ("neg", neg),
404 ("neg_", neg_),
405 ("new_full.Tensor", new_full),
406 ("nll_loss2d_backward", nll_loss2d_backward),
407 ("nll_loss2d_forward", nll_loss2d_forward),
408 ("nll_loss_backward", nll_loss_backward),
409 ("nll_loss_forward", nll_loss_forward),
410 ("nll_loss_nd_backward", nll_loss_nd_backward),
411 ("nll_loss_nd_forward", nll_loss_nd_forward),
412 ("nonzero", nonzero),
413 ("nonzero_numpy", nonzero_numpy),
414 ("normal.Tensor_Tensor", normal_tensor_tensor),
415 ("normal.Tensor_float", normal_tensor_float),
416 ("normal.float_Tensor", normal_float_tensor),
417 ("normal_", normal_),
418 ("normed_cumsum", normed_cumsum),
419 ("one_hot", one_hot),
420 ("ones", ones),
421 ("ones_like", ones_like),
422 ("pad", pad),
423 ("pixel_shuffle", pixel_shuffle),
424 ("pixel_unshuffle", pixel_unshuffle),
425 ("pixel_unshuffle.out", pixel_unshuffle_out),
426 ("poisson", poisson),
427 ("polar", polar),
428 ("pow.Scalar", pow_scalar),
429 ("pow.Tensor_Scalar", pow_tensor_scalar),
430 ("pow.Tensor_Tensor", pow_tensor_tensor),
431 ("pow_.Scalar", pow_tensor_scalar_),
432 ("pow_.Tensor", pow_tensor_tensor_),
433 ("prelu", prelu),
434 ("prod", prod),
435 ("prod.dim_int", prod_dim),
436 ("quantile", quantile),
437 ("rad2deg", rad2deg),
438 ("rad2deg_", rad2deg_),
439 ("rand", rand),
440 ("rand_like", rand_like),
441 ("randint", randint),
442 ("randint_like", randint_like),
443 ("randn", randn),
444 ("randn_like", randn_like),
445 ("randperm", randperm),
446 ("reciprocal", reciprocal),
447 ("reciprocal_", reciprocal_),
448 ("reflection_pad1d", reflection_pad1d),
449 ("reflection_pad1d.out", reflection_pad1d_out),
450 ("reflection_pad1d_backward", reflection_pad1d_backward),
451 ("reflection_pad2d", reflection_pad2d),
452 ("reflection_pad2d.out", reflection_pad2d_out),
453 ("relu", relu),
454 ("relu6", relu6),
455 ("relu_", relu_),
456 ("remainder", remainder),
457 ("remainder.Scalar", remainder),
458 ("remainder.Scalar_Tensor", remainder),
459 ("remainder.Tensor", remainder),
460 ("remainder_.Scalar", remainder_),
461 ("remainder_.Tensor", remainder_),
462 ("renorm", renorm),
463 ("renorm_", renorm_),
464 ("repeat", repeat),
465 ("repeat_interleave.Tensor", repeat_interleave_tensor),
466 ("repeat_interleave.self_Tensor", repeat_interleave_self_tensor),
467 ("repeat_interleave.self_int", repeat_interleave_self_int),
468 ("replication_pad1d", replication_pad1d),
469 ("replication_pad1d.out", replication_pad1d_out),
470 ("replication_pad3d", replication_pad3d),
471 ("resolve_conj", resolve_conj),
472 ("resolve_neg", resolve_neg),
473 ("rms_norm", rms_norm),
474 ("roll", roll),
475 ("rot90", rot90),
476 ("round", round),
477 ("round.out", round_out),
478 ("round_", round_),
479 ("rrelu_with_noise_backward", rrelu_with_noise_backward),
480 ("rsqrt", rsqrt),
481 ("rsqrt_", rsqrt_),
482 ("rsub.Scalar", rsub_scalar),
483 ("rsub.Tensor", rsub_tensor),
484 ("scaled_softmax_backward", scaled_softmax_backward),
485 ("scaled_softmax_forward", scaled_softmax_forward),
486 ("scatter.reduce", scatter),
487 ("scatter.src", scatter),
488 ("scatter_.reduce", scatter_),
489 ("scatter_.src", scatter_),
490 ("scatter_add_", scatter_add_),
491 ("scatter_reduce.two", scatter_reduce),
492 ("scatter_reduce.two_out", scatter_reduce_out),
493 ("scatter_reduce_.two", scatter_reduce_),
494 ("searchsorted.Scalar", searchsorted_scalar),
495 ("searchsorted.Scalar_out", searchsorted_scalar_out),
496 ("searchsorted.Tensor", searchsorted),
497 ("searchsorted.Tensor_out", searchsorted_out),
498 ("segment_reduce", segment_reduce),
499 ("segment_reduce.out", segment_reduce_out),
500 ("select_backward", select_backward),
501 ("select_scatter", select_scatter),
502 ("selu", selu),
503 ("selu_", selu_),
504 ("sgn_", sgn_),
505 ("sigmoid", sigmoid),
506 ("sigmoid_", sigmoid_),
507 ("sigmoid_backward", sigmoid_backward),
508 ("signbit", signbit),
509 ("signbit.out", signbit_out),
510 ("silu", silu),
511 ("silu_", silu_),
512 ("silu_backward", silu_backward),
513 ("sin", sin),
514 ("sin_", sin_),
515 ("sinh_", sinh_),
516 ("slice_backward", slice_backward),
517 ("slice_scatter", slice_scatter),
518 ("smooth_l1_loss", smooth_l1_loss),
519 ("smooth_l1_loss.out", smooth_l1_loss_out),
520 ("smooth_l1_loss_backward", smooth_l1_loss_backward),
521 ("soft_margin_loss", soft_margin_loss),
522 ("softplus", softplus),
523 ("softshrink", softshrink),
524 ("softshrink.out", softshrink_out),
525 ("sort", sort),
526 ("sort.stable", sort_stable),
527 ("special_i0e", special_i0e),
528 ("special_i0e.out", special_i0e_out),
529 ("special_i1", special_i1),
530 ("special_i1.out", special_i1_out),
531 ("split_with_sizes_copy", split_with_sizes_copy),
532 ("sqrt", sqrt),
533 ("sqrt_", sqrt_),
534 ("square", square),
535 ("square.out", square_out),
536 ("square_", square_),
537 ("stack", stack),
538 ("std.correction", std),
539 ("sub.Tensor", sub),
540 ("sub_.Tensor", sub_),
541 ("sum", sum),
542 ("sum.IntList_out", sum_dim_out),
543 ("sum.dim_IntList", sum_dim),
544 ("sum.out", sum_out),
545 ("svd", svd),
546 ("t_copy", t_copy),
547 ("t_copy.out", t_copy_out),
548 ("tan", tan),
549 ("tan_", tan_),
550 ("tanh", tanh),
551 ("tanh_", tanh_),
552 ("tanh_backward", tanh_backward),
553 ("tensor_split", tensor_split),
554 ("threshold", threshold),
555 ("threshold_backward", threshold_backward),
556 ("tile", tile),
557 ("topk", topk),
558 ("trace", trace),
559 ("tril", tril),
560 ("tril.out", tril_out),
561 ("tril_", tril_),
562 ("triu", triu),
563 ("triu_", triu_),
564 ("true_divide.Scalar", true_divide),
565 ("true_divide.Tensor", true_divide),
566 ("true_divide_.Scalar", true_divide_),
567 ("true_divide_.Tensor", true_divide_),
568 ("unfold_backward", unfold_backward),
569 ("uniform_", uniform_),
570 ("unique_consecutive", unique_consecutive),
571 ("unique_dim", unique_dim),
572 ("upsample_bicubic2d", upsample_bicubic2d),
573 ("upsample_linear1d", upsample_linear1d),
574 ("upsample_linear1d_backward", upsample_linear1d_backward),
575 ("upsample_nearest1d", upsample_nearest1d),
576 ("upsample_nearest2d", upsample_nearest2d),
577 ("upsample_nearest3d", upsample_nearest3d),
578 ("upsample_trilinear3d", upsample_trilinear3d),
579 ("var", var),
580 ("var.correction", var_correction),
581 ("var.dim", var_dim),
582 ("var_mean.correction", var_mean),
583 ("vdot", vdot),
584 ("view_copy", view_copy),
585 ("vstack", vstack),
586 ("where.self", where_self),
587 ("where.self_out", where_self_out),
588 ("zero", zero),
589 ("zero.out", zero_out),
590 ("zero_", zero_),
591 ("zeros", zeros),
592 ("zeros_like", zeros_like),
593)
595# Cache mapping from function name -> list of _FULL_CONFIG entries for quick lookup
596FULL_CONFIG_BY_FUNC = {}
597for _item in _FULL_CONFIG:
598 if not _item or len(_item) < 2:
599 continue
600 fn = _item[1]
601 func_name = fn.__name__ if hasattr(fn, "__name__") else str(fn)
602 FULL_CONFIG_BY_FUNC.setdefault(func_name, []).append(_item)
604# Friendly names for only_enable(include=[...]) when the registered impl is *.out
605for _alias, _target in (
606 ("softmax", "softmax_out"),
607 ("softmax_backward", "softmax_backward_out"),
608 ("log_softmax", "log_softmax_out"),
609 ("log_softmax_backward", "log_softmax_backward_out"),
610):
611 if _target in FULL_CONFIG_BY_FUNC:
612 FULL_CONFIG_BY_FUNC.setdefault(_alias, []).extend(FULL_CONFIG_BY_FUNC[_target])
615def enable(
616 lib=aten_lib,
617 unused=None,
618 registrar=registrar,
619 record=False,
620 once=False,
621 path=None,
622):
623 """Register all FlagGems ops except those explicitly excluded.
625 Args:
626 lib: torch.library.Library instance to register into. Defaults to the
627 global `aten_lib` (IMPL mode).
628 unused: Which ops to skip. Supported forms:
629 - list/tuple/set of function names (e.g., ["masked_fill", "mul"]).
630 - str path to a YAML file ending with .yml/.yaml containing an
631 `exclude:` list.
632 - "default" or None: auto-load vendor/arch-specific
633 runtime/backend/_<vendor>/[<arch>/]enable_configs.yaml if present.
634 registrar: Registrar class; defaults to `Register`.
635 record: Whether to enable FlagGems logging.
636 once: When True, log only once.
637 path: Optional log output path when recording.
639 Notes:
640 - If the exclude list/YAML resolves to empty, all ops are registered.
641 """
642 global current_work_registrar
643 exclude_ops = resolve_user_setting(unused, "exclude")
644 current_work_registrar = registrar(
645 _FULL_CONFIG,
646 user_include_ops=[],
647 user_exclude_ops=exclude_ops,
648 cpp_patched_ops=list(set(aten_patch_list)),
649 lib=lib,
650 )
651 setup_flaggems_logging(path=path, record=record, once=once)
654def only_enable(
655 lib=aten_lib,
656 include=None,
657 registrar=registrar,
658 record=False,
659 once=False,
660 path=None,
661):
662 """Register only the specified FlagGems ops and skip the rest.
664 Args:
665 lib: torch.library.Library instance to register into. Defaults to the
666 global `aten_lib` (IMPL mode).
667 include: Which ops to register. Supported forms:
668 - list/tuple/set of function names (e.g., ["rms_norm", "softmax"]).
669 - str path to a YAML file ending with .yml/.yaml (expects a list or
670 an `include:` key).
671 - "default" or None: auto-load vendor/arch-specific
672 runtime/backend/_<vendor>/[<arch>/]only_enable_configs.yaml if present.
673 registrar: Registrar class; defaults to `Register`.
674 record: Whether to enable FlagGems logging.
675 once: When True, log only once.
676 path: Optional log output path when recording.
678 Classic usage:
679 - Only register a few ops:
680 only_enable(include=["rms_norm", "softmax"])
681 - Use vendor default YAML:
682 only_enable(include="default") # or include=None
683 - Use a custom YAML:
684 only_enable(include="/path/to/only_enable.yaml")
686 Notes:
687 - If the include list/YAML resolves to empty or none of the names match
688 known ops, the function warns and returns without registering.
689 """
690 include_ops = resolve_user_setting(include, "include")
691 if not include_ops:
692 warnings.warn(
693 "only_enable failed: No include entries resolved from list or yaml."
694 )
695 return
697 global current_work_registrar
698 current_work_registrar = registrar(
699 _FULL_CONFIG,
700 user_include_ops=include_ops,
701 user_exclude_ops=[],
702 cpp_patched_ops=list(set(aten_patch_list)),
703 full_config_by_func=FULL_CONFIG_BY_FUNC,
704 lib=lib,
705 )
706 setup_flaggems_logging(path=path, record=record, once=once)
709class use_gems:
710 """
711 The 'include' parameter has higher priority than 'exclude'.
712 When 'include' is not None, use_gems will not process 'exclude'.
713 """
715 def __init__(self, exclude=None, include=None, record=False, once=False, path=None):
716 self.lib = torch.library.Library("aten", "IMPL")
717 self.exclude = exclude if isinstance(exclude, (list, tuple, set, str)) else []
718 self.include = include if isinstance(include, (list, tuple, set, str)) else []
719 self.registrar = GeneralOpRegistrar
720 self.record = record
721 self.once = once
722 self.path = path
724 def __enter__(self):
725 if self.include:
726 only_enable(
727 lib=self.lib,
728 include=self.include,
729 registrar=self.registrar,
730 record=self.record,
731 once=self.once,
732 path=self.path,
733 )
734 else:
735 enable(
736 lib=self.lib,
737 unused=self.exclude,
738 registrar=self.registrar,
739 record=self.record,
740 once=self.once,
741 path=self.path,
742 )
744 def __exit__(self, exc_type, exc_val, exc_tb):
745 global current_work_registrar
746 if torch.__version__ >= "2.5":
747 self.lib._destroy()
748 del self.lib
749 del self.exclude
750 del self.include
751 del self.registrar
752 del current_work_registrar
753 if self.record:
754 teardown_flaggems_logging()
756 @property
757 def experimental_ops(self):
758 import flag_gems.experimental_ops
760 return flag_gems.experimental_ops
763def all_registered_ops():
764 return current_work_registrar.get_all_ops()
767def all_registered_keys():
768 return current_work_registrar.get_all_keys()
771__all__ = [
772 "all_registered_keys",
773 "all_registered_ops",
774 "enable",
775 "flagtune",
776 "only_enable",
777 "use_gems",
778]