Coverage for src/flag_gems/__init__.py: 91%

80 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1# ruff: noqa: F405 

2import warnings 

3 

4import torch 

5from packaging import version 

6 

7from flag_gems import testing # noqa: F401 

8from flag_gems import runtime 

9from flag_gems.config import aten_patch_list, resolve_user_setting 

10from flag_gems.experimental_ops import * # noqa: F403 

11from flag_gems.fused import * # noqa: F403 

12from flag_gems.logging_utils import setup_flaggems_logging, teardown_flaggems_logging 

13from flag_gems.modules import * # noqa: F403 

14from flag_gems.ops import * # noqa: F403 

15from flag_gems.patches import * # noqa: F403 

16from flag_gems.runtime import flagtune 

17from flag_gems.runtime.backend import SpecOpRegistrar 

18from flag_gems.runtime.op_registrar import GeneralOpRegistrar 

19 

20__version__ = "5.0.2" 

21device = runtime.device.name 

22vendor_name = runtime.device.vendor_name 

23backend_info = runtime.device 

24aten_lib = torch.library.Library("aten", "IMPL") 

25 

26# Register all ops in the current backend with SpecOpRegistrar to support architecture-specialized implementations 

27SpecOpRegistrar(registry=globals(), vendor=vendor_name).apply() 

28 

29registrar = GeneralOpRegistrar 

30current_work_registrar = None 

31AUTOGRAD_DISPATCH_KEY = torch._C.DispatchKey.Autograd.name 

32 

33 

34def torch_ge(v): 

35 return version.parse(torch.__version__) >= version.parse(v) 

36 

37 

38_FULL_CONFIG = ( 

39 ("__ior__.Scalar", bitwise_or_scalar_), 

40 ("__ior__.Tensor", bitwise_or_tensor_), 

41 ("__or__.Scalar", bitwise_or_scalar), 

42 ("__or__.Tensor", bitwise_or_tensor), 

43 ("_assert_async", _assert_async), 

44 ("_conv_depthwise2d", _conv_depthwise2d), 

45 ("_euclidean_dist", _euclidean_dist), 

46 ("_flash_attention_forward", flash_attention_forward), 

47 ( 

48 "_functional_sym_constrain_range_for_size", 

49 _functional_sym_constrain_range_for_size, 

50 ), 

51 ("_grouped_mm", group_mm), 

52 ("_index_put_impl_", _index_put_impl_), 

53 ("_is_all_true", _is_all_true), 

54 ("_log_softmax", log_softmax), 

55 ("_log_softmax.out", log_softmax_out), 

56 ("_log_softmax_backward_data", log_softmax_backward), 

57 ("_log_softmax_backward_data.out", log_softmax_backward_out), 

58 ("_safe_softmax", _safe_softmax), 

59 ("_scaled_mm", scaled_mm, lambda: torch_ge("2.5")), 

60 ("_scaled_mm.out", scaled_mm_out, lambda: torch_ge("2.5")), 

61 ("_segment_reduce_backward", _segment_reduce_backward), 

62 ("_segment_reduce_backward.out", _segment_reduce_backward_out), 

63 ("_softmax", softmax), 

64 ("_softmax.out", softmax_out), 

65 ("_softmax_backward_data", softmax_backward), 

66 ("_softmax_backward_data.out", softmax_backward_out), 

67 ( 

68 "_to_copy", 

69 to_copy, 

70 lambda: version.parse(torch.__version__) >= version.parse("2.4"), 

71 ), 

72 ("_unique2", _unique2), 

73 ("_upsample_bicubic2d_aa", _upsample_bicubic2d_aa), 

74 ("_upsample_bicubic2d_aa_backward", _upsample_bicubic2d_aa_backward), 

75 ("_upsample_nearest_exact1d", _upsample_nearest_exact1d), 

76 ("_weight_norm_interface", weight_norm_interface), 

77 ("_weight_norm_interface_backward", weight_norm_interface_backward), 

78 ("abs", abs), 

79 ("abs_", abs_), 

80 ("absolute", absolute), 

81 ("acos", acos), 

82 ("add.Tensor", add), 

83 ("add_.Tensor", add_), 

84 ("add_rms_norm", add_rms_norm), 

85 ("addcdiv", addcdiv), 

86 ("addcdiv.out", addcdiv_out), 

87 ("addcmul", addcmul), 

88 ("addcmul.out", addcmul_out), 

89 ("addmm", addmm), 

90 ("addmm.dtype", addmm_dtype), 

91 ("addmm.dtype_out", addmm_dtype_out), 

92 ("addmm.out", addmm_out), 

93 ("addmv", addmv), 

94 ("addmv.out", addmv_out), 

95 ("addr", addr), 

96 ("affine_grid_generator", affine_grid_generator), 

97 ("alias_copy", alias_copy), 

98 ("all", all), 

99 ("all.dim", all_dim), 

100 ("all.dims", all_dims), 

101 ("allclose", allclose), 

102 ("amax", amax), 

103 ("aminmax", aminmax), 

104 ("angle", angle), 

105 ("any", any), 

106 ("any.dim", any_dim), 

107 ("any.dims", any_dims), 

108 ("arange", arange), 

109 ("arange.start", arange_start), 

110 ("arange.start_step", arange_start), 

111 ("arcsinh", arcsinh), 

112 ("arcsinh.out", arcsinh_out), 

113 ("arcsinh_", arcsinh_), 

114 ("arctanh_", arctanh_), 

115 ("argmax", argmax), 

116 ("argmin", argmin), 

117 ("argsort", argsort), 

118 ("as_strided_copy", as_strided_copy), 

119 ("as_strided_copy.out", as_strided_copy_out), 

120 ("asinh", asinh), 

121 ("asinh.out", asinh_out), 

122 ("asinh_", asinh_), 

123 ("atan", atan), 

124 ("atan2", atan2), 

125 ("atan2.out", atan2_out), 

126 ("atan_", atan_), 

127 ("atanh", atanh), 

128 ("avg_pool2d", avg_pool2d), 

129 ("avg_pool2d_backward", avg_pool2d_backward), 

130 ("avg_pool3d", avg_pool3d), 

131 ("avg_pool3d_backward", avg_pool3d_backward), 

132 ("baddbmm", baddbmm), 

133 ("bernoulli_.float", bernoulli_), 

134 ("bincount", bincount), 

135 ("bitwise_and.Scalar", bitwise_and_scalar), 

136 ("bitwise_and.Scalar_Tensor", bitwise_and_scalar_tensor), 

137 ("bitwise_and.Tensor", bitwise_and_tensor), 

138 ("bitwise_and_.Scalar", bitwise_and_scalar_), 

139 ("bitwise_and_.Tensor", bitwise_and_tensor_), 

140 ("bitwise_left_shift", bitwise_left_shift), 

141 ("bitwise_not", bitwise_not), 

142 ("bitwise_not_", bitwise_not_), 

143 ("bitwise_or.Scalar", bitwise_or_scalar), 

144 ("bitwise_or.Scalar_Tensor", bitwise_or_scalar_tensor), 

145 ("bitwise_or.Tensor", bitwise_or_tensor), 

146 ("bitwise_or_.Scalar", bitwise_or_scalar_), 

147 ("bitwise_or_.Tensor", bitwise_or_tensor_), 

148 ("bitwise_right_shift", bitwise_right_shift), 

149 ("bmm", bmm), 

150 ("bmm.out", bmm_out), 

151 ("cat", cat), 

152 ("cat.out", cat_out), 

153 ("cauchy", cauchy), 

154 ("cauchy_", cauchy_), 

155 ("ceil", ceil), 

156 ("ceil.out", ceil_out), 

157 ("ceil_", ceil_), 

158 ("celu", celu), 

159 ("celu_", celu_), 

160 ("clamp", clamp), 

161 ("clamp.Tensor", clamp_tensor), 

162 ("clamp_", clamp_), 

163 ("clamp_.Tensor", clamp_tensor_), 

164 ("clamp_max", clamp_max), 

165 ("clamp_max_", clamp_max_), 

166 ("clamp_min", clamp_min), 

167 ("clamp_min_", clamp_min_), 

168 ("clip", clip), 

169 ("clip_", clip_), 

170 ("col2im", col2im), 

171 ("concatenate", concatenate), 

172 ("conj_physical", conj_physical), 

173 ("constant_pad_nd", constant_pad_nd), 

174 # ("contiguous", contiguous), 

175 ("conv1d", conv1d), 

176 ("conv1d.padding", conv1d), 

177 ("conv2d", conv2d), 

178 ("conv2d.padding", conv2d), 

179 ("conv3d", conv3d), 

180 ("conv3d.padding", conv3d), 

181 ("conv_transpose1d", conv_transpose1d), 

182 ("conv_transpose2d", conv_transpose2d), 

183 ( 

184 "copy_", 

185 copy_, 

186 lambda: version.parse(torch.__version__) >= version.parse("2.4"), 

187 ), 

188 ("copysign", copysign), 

189 ("copysign.out", copysign_out), 

190 ("cos", cos), 

191 ("cos_", cos_), 

192 ("cosh", cosh), 

193 ("cosh.out", cosh_out), 

194 ("cosh_", cosh_), 

195 ("count_nonzero", count_nonzero), 

196 ("ctc_loss.IntList", ctc_loss, None, (AUTOGRAD_DISPATCH_KEY,)), 

197 ("ctc_loss.Tensor", ctc_loss, None, (AUTOGRAD_DISPATCH_KEY,)), 

198 ("cudnn_convolution", cudnn_convolution), 

199 ("cummax", cummax), 

200 ("cummin", cummin), 

201 ("cumprod", cumprod), 

202 ("cumprod_", cumprod_), 

203 ("cumsum", cumsum), 

204 ("cumsum.out", cumsum_out), 

205 ("diag", diag), 

206 ("diag_embed", diag_embed), 

207 ("diagonal_backward", diagonal_backward), 

208 ("diff", diff), 

209 ("digamma_", digamma_), 

210 ("div.Scalar", true_divide), 

211 ("div.Scalar_mode", div_mode), 

212 ("div.Tensor", true_divide), 

213 ("div.Tensor_mode", div_mode), 

214 ("div.out", true_divide_out), 

215 ("div_.Scalar", true_divide_), 

216 ("div_.Scalar_mode", div_mode_), 

217 ("div_.Tensor", true_divide_), 

218 ("div_.Tensor_mode", div_mode_), 

219 ("divide.Scalar", true_divide), 

220 ("divide.Scalar_mode", div_mode), 

221 ("divide.Tensor", true_divide), 

222 ("divide.Tensor_mode", div_mode), 

223 ("divide_.Scalar", true_divide_), 

224 ("divide_.Scalar_mode", div_mode_), 

225 ("divide_.Tensor", true_divide_), 

226 ("divide_.Tensor_mode", div_mode_), 

227 ("dot", dot), 

228 ("elu", elu), 

229 ("elu_", elu_), 

230 ("elu_backward", elu_backward), 

231 ("embedding", embedding), 

232 ("embedding_backward", embedding_backward), 

233 ("embedding_dense_backward", embedding_dense_backward), 

234 ("eq.Scalar", eq_scalar), 

235 ("eq.Tensor", eq), 

236 ("equal", equal), 

237 ("erf", erf), 

238 ("erf_", erf_), 

239 ("exp", exp), 

240 ("exp.out", exp_out), 

241 ("exp2", exp2), 

242 ("exp2_", exp2_), 

243 ("exp_", exp_), 

244 ("expm1", expm1), 

245 ("expm1.out", expm1_out), 

246 ("expm1_", expm1_), 

247 ("exponential_", exponential_), 

248 ("eye", eye), 

249 ("eye.m", eye_m), 

250 ("feature_dropout", feature_dropout), 

251 ("feature_dropout_", feature_dropout_), 

252 ("fill.Scalar", fill_scalar), 

253 ("fill.Scalar_out", fill_scalar_out), 

254 ("fill.Tensor", fill_tensor), 

255 ("fill.Tensor_out", fill_tensor_out), 

256 ("fill_.Scalar", fill_scalar_), 

257 ("fill_.Tensor", fill_tensor_), 

258 ("flip", flip), 

259 ("floor", floor), 

260 ("floor.out", floor_out), 

261 ("floor_", floor_), 

262 ("floor_divide", floor_divide), 

263 ("floor_divide.Scalar", floor_divide), 

264 ("floor_divide_.Scalar", floor_divide_), 

265 ("floor_divide_.Tensor", floor_divide_), 

266 ("fmin", fmin), 

267 ("fmin.out", fmin_out), 

268 ("fmod.Scalar", fmod_scalar), 

269 ("fmod.Tensor", fmod_tensor), 

270 ("fmod_", fmod_), 

271 ("fmod_.Scalar", fmod_scalar_), 

272 ("fmod_.Tensor", fmod_tensor_), 

273 ("full", full), 

274 ("full_like", full_like), 

275 ("gather", gather), 

276 ("gather_backward", gather_backward), 

277 ("gcd", gcd), 

278 ("gcd.out", gcd_out), 

279 ("ge.Scalar", ge_scalar), 

280 ("ge.Tensor", ge), 

281 ("gelu", gelu), 

282 ("gelu_", gelu_), 

283 ("gelu_backward", gelu_backward), 

284 ("glu", glu), 

285 ("glu_backward", glu_backward), 

286 ("greater.Scalar", greater_scalar), 

287 ("greater.Scalar_out", greater_scalar_out), 

288 ("greater.Tensor", greater), 

289 ("greater.out", greater_out), 

290 ("grid_sample", grid_sample), 

291 ("gt.Scalar", gt_scalar), 

292 ("gt.Tensor", gt), 

293 ("hardsigmoid", hardsigmoid), 

294 ("hardsigmoid.out", hardsigmoid_out), 

295 ("hardswish_", hardswish_), 

296 ("histc", histc), 

297 ("hstack", hstack), 

298 ("hypot", hypot), 

299 ("i0", i0), 

300 ("i0.out", i0_out), 

301 ("i0_", i0_), 

302 ("index.Tensor", index), 

303 ("index_add", index_add), 

304 ("index_add_", index_add_), 

305 ("index_copy", index_copy), 

306 ("index_copy_", index_copy_), 

307 ("index_put", index_put), 

308 ("index_put_", index_put_), 

309 ("index_reduce_", index_reduce_), 

310 ("index_select", index_select), 

311 ("isclose", isclose), 

312 ("isfinite", isfinite), 

313 ("isin.Scalar_Tensor", isin), 

314 ("isin.Tensor_Scalar", isin), 

315 ("isin.Tensor_Tensor", isin), 

316 ("isinf", isinf), 

317 ("isnan", isnan), 

318 ("isneginf", isneginf), 

319 ("isneginf.out", isneginf_out), 

320 ("kron", kron), 

321 ("le.Scalar", le_scalar), 

322 ("le.Tensor", le), 

323 ("leaky_relu", leaky_relu), 

324 ("leaky_relu.out", leaky_relu_out), 

325 ("leaky_relu_", leaky_relu_), 

326 ("lerp.Scalar", lerp_scalar), 

327 ("lerp.Tensor", lerp_tensor), 

328 ("lerp_.Scalar", lerp_scalar_), 

329 ("lerp_.Tensor", lerp_tensor_), 

330 ("lift_fresh_copy", lift_fresh_copy), 

331 ("linalg_vector_norm", vector_norm), 

332 ("linspace", linspace), 

333 ("log", log), 

334 ("log1p", log1p), 

335 ("log1p.out", log1p_out), 

336 ("log10", log10), 

337 ("log10.out", log10_out), 

338 ("log10_", log10_), 

339 ("log1p", log1p), 

340 ("log1p_", log1p_), 

341 ("log_sigmoid", log_sigmoid), 

342 ("logaddexp", logaddexp), 

343 ("logaddexp.out", logaddexp_out), 

344 ("logical_and", logical_and), 

345 ("logical_and_", logical_and_), 

346 ("logical_not", logical_not), 

347 ("logical_or", logical_or), 

348 ("logical_or_", logical_or_), 

349 ("logical_xor", logical_xor), 

350 ("logit", logit), 

351 ("logit.out", logit_out), 

352 ("logit_", logit_), 

353 ("logspace", logspace), 

354 ("logsumexp", logsumexp), 

355 ("lt.Scalar", lt_scalar), 

356 ("lt.Tensor", lt), 

357 ("margin_ranking_loss", margin_ranking_loss), 

358 ("masked_fill.Scalar", masked_fill), 

359 ("masked_fill.Tensor", masked_fill), 

360 ("masked_fill_.Scalar", masked_fill_), 

361 ("masked_fill_.Tensor", masked_fill_), 

362 ("masked_scatter", masked_scatter), 

363 ("masked_scatter_", masked_scatter_), 

364 ("masked_select", masked_select), 

365 ("max", max), 

366 ("max.dim", max_dim), 

367 ("max_pool2d_backward", max_pool2d_backward), 

368 ("max_pool2d_with_indices", max_pool2d_with_indices), 

369 ("max_pool3d_backward", max_pool3d_backward), 

370 ("max_pool3d_with_indices", max_pool3d_with_indices), 

371 ("maximum", maximum), 

372 ("mean", mean), 

373 ("mean.dim", mean_dim), 

374 ("median", median), 

375 ("median.dim", median_dim), 

376 ("median.dim_values", median_dim_values), 

377 ("median.out", median_out), 

378 ("min", min), 

379 ("min.dim", min_dim), 

380 ("minimum", minimum), 

381 ("mm", mm), 

382 ("mm.out", mm_out), 

383 ("mse_loss", mse_loss), 

384 ("mul.Tensor", mul), 

385 ("mul_.Tensor", mul_), 

386 ("multinomial", multinomial), 

387 ("mv", mv), 

388 ("nan_to_num", nan_to_num), 

389 ("nanmedian", nanmedian), 

390 ("nanmedian.dim", nanmedian_dim), 

391 ("nanmedian.dim_values", nanmedian_dim_values), 

392 ("nanmedian.out", nanmedian_out), 

393 ("native_batch_norm", batch_norm), 

394 ("native_batch_norm_backward", batch_norm_backward), 

395 ("native_dropout", dropout), 

396 ("native_dropout_backward", dropout_backward), 

397 ("native_group_norm", group_norm), 

398 ("native_group_norm_backward", group_norm_backward), 

399 ("native_layer_norm", layer_norm), 

400 ("native_layer_norm_backward", layer_norm_backward), 

401 ("ne.Scalar", ne_scalar), 

402 ("ne.Tensor", ne), 

403 ("neg", neg), 

404 ("neg_", neg_), 

405 ("new_full.Tensor", new_full), 

406 ("nll_loss2d_backward", nll_loss2d_backward), 

407 ("nll_loss2d_forward", nll_loss2d_forward), 

408 ("nll_loss_backward", nll_loss_backward), 

409 ("nll_loss_forward", nll_loss_forward), 

410 ("nll_loss_nd_backward", nll_loss_nd_backward), 

411 ("nll_loss_nd_forward", nll_loss_nd_forward), 

412 ("nonzero", nonzero), 

413 ("nonzero_numpy", nonzero_numpy), 

414 ("normal.Tensor_Tensor", normal_tensor_tensor), 

415 ("normal.Tensor_float", normal_tensor_float), 

416 ("normal.float_Tensor", normal_float_tensor), 

417 ("normal_", normal_), 

418 ("normed_cumsum", normed_cumsum), 

419 ("one_hot", one_hot), 

420 ("ones", ones), 

421 ("ones_like", ones_like), 

422 ("pad", pad), 

423 ("pixel_shuffle", pixel_shuffle), 

424 ("pixel_unshuffle", pixel_unshuffle), 

425 ("pixel_unshuffle.out", pixel_unshuffle_out), 

426 ("poisson", poisson), 

427 ("polar", polar), 

428 ("pow.Scalar", pow_scalar), 

429 ("pow.Tensor_Scalar", pow_tensor_scalar), 

430 ("pow.Tensor_Tensor", pow_tensor_tensor), 

431 ("pow_.Scalar", pow_tensor_scalar_), 

432 ("pow_.Tensor", pow_tensor_tensor_), 

433 ("prelu", prelu), 

434 ("prod", prod), 

435 ("prod.dim_int", prod_dim), 

436 ("quantile", quantile), 

437 ("rad2deg", rad2deg), 

438 ("rad2deg_", rad2deg_), 

439 ("rand", rand), 

440 ("rand_like", rand_like), 

441 ("randint", randint), 

442 ("randint_like", randint_like), 

443 ("randn", randn), 

444 ("randn_like", randn_like), 

445 ("randperm", randperm), 

446 ("reciprocal", reciprocal), 

447 ("reciprocal_", reciprocal_), 

448 ("reflection_pad1d", reflection_pad1d), 

449 ("reflection_pad1d.out", reflection_pad1d_out), 

450 ("reflection_pad1d_backward", reflection_pad1d_backward), 

451 ("reflection_pad2d", reflection_pad2d), 

452 ("reflection_pad2d.out", reflection_pad2d_out), 

453 ("relu", relu), 

454 ("relu6", relu6), 

455 ("relu_", relu_), 

456 ("remainder", remainder), 

457 ("remainder.Scalar", remainder), 

458 ("remainder.Scalar_Tensor", remainder), 

459 ("remainder.Tensor", remainder), 

460 ("remainder_.Scalar", remainder_), 

461 ("remainder_.Tensor", remainder_), 

462 ("renorm", renorm), 

463 ("renorm_", renorm_), 

464 ("repeat", repeat), 

465 ("repeat_interleave.Tensor", repeat_interleave_tensor), 

466 ("repeat_interleave.self_Tensor", repeat_interleave_self_tensor), 

467 ("repeat_interleave.self_int", repeat_interleave_self_int), 

468 ("replication_pad1d", replication_pad1d), 

469 ("replication_pad1d.out", replication_pad1d_out), 

470 ("replication_pad3d", replication_pad3d), 

471 ("resolve_conj", resolve_conj), 

472 ("resolve_neg", resolve_neg), 

473 ("rms_norm", rms_norm), 

474 ("roll", roll), 

475 ("rot90", rot90), 

476 ("round", round), 

477 ("round.out", round_out), 

478 ("round_", round_), 

479 ("rrelu_with_noise_backward", rrelu_with_noise_backward), 

480 ("rsqrt", rsqrt), 

481 ("rsqrt_", rsqrt_), 

482 ("rsub.Scalar", rsub_scalar), 

483 ("rsub.Tensor", rsub_tensor), 

484 ("scaled_softmax_backward", scaled_softmax_backward), 

485 ("scaled_softmax_forward", scaled_softmax_forward), 

486 ("scatter.reduce", scatter), 

487 ("scatter.src", scatter), 

488 ("scatter_.reduce", scatter_), 

489 ("scatter_.src", scatter_), 

490 ("scatter_add_", scatter_add_), 

491 ("scatter_reduce.two", scatter_reduce), 

492 ("scatter_reduce.two_out", scatter_reduce_out), 

493 ("scatter_reduce_.two", scatter_reduce_), 

494 ("searchsorted.Scalar", searchsorted_scalar), 

495 ("searchsorted.Scalar_out", searchsorted_scalar_out), 

496 ("searchsorted.Tensor", searchsorted), 

497 ("searchsorted.Tensor_out", searchsorted_out), 

498 ("segment_reduce", segment_reduce), 

499 ("segment_reduce.out", segment_reduce_out), 

500 ("select_backward", select_backward), 

501 ("select_scatter", select_scatter), 

502 ("selu", selu), 

503 ("selu_", selu_), 

504 ("sgn_", sgn_), 

505 ("sigmoid", sigmoid), 

506 ("sigmoid_", sigmoid_), 

507 ("sigmoid_backward", sigmoid_backward), 

508 ("signbit", signbit), 

509 ("signbit.out", signbit_out), 

510 ("silu", silu), 

511 ("silu_", silu_), 

512 ("silu_backward", silu_backward), 

513 ("sin", sin), 

514 ("sin_", sin_), 

515 ("sinh_", sinh_), 

516 ("slice_backward", slice_backward), 

517 ("slice_scatter", slice_scatter), 

518 ("smooth_l1_loss", smooth_l1_loss), 

519 ("smooth_l1_loss.out", smooth_l1_loss_out), 

520 ("smooth_l1_loss_backward", smooth_l1_loss_backward), 

521 ("soft_margin_loss", soft_margin_loss), 

522 ("softplus", softplus), 

523 ("softshrink", softshrink), 

524 ("softshrink.out", softshrink_out), 

525 ("sort", sort), 

526 ("sort.stable", sort_stable), 

527 ("special_i0e", special_i0e), 

528 ("special_i0e.out", special_i0e_out), 

529 ("special_i1", special_i1), 

530 ("special_i1.out", special_i1_out), 

531 ("split_with_sizes_copy", split_with_sizes_copy), 

532 ("sqrt", sqrt), 

533 ("sqrt_", sqrt_), 

534 ("square", square), 

535 ("square.out", square_out), 

536 ("square_", square_), 

537 ("stack", stack), 

538 ("std.correction", std), 

539 ("sub.Tensor", sub), 

540 ("sub_.Tensor", sub_), 

541 ("sum", sum), 

542 ("sum.IntList_out", sum_dim_out), 

543 ("sum.dim_IntList", sum_dim), 

544 ("sum.out", sum_out), 

545 ("svd", svd), 

546 ("t_copy", t_copy), 

547 ("t_copy.out", t_copy_out), 

548 ("tan", tan), 

549 ("tan_", tan_), 

550 ("tanh", tanh), 

551 ("tanh_", tanh_), 

552 ("tanh_backward", tanh_backward), 

553 ("tensor_split", tensor_split), 

554 ("threshold", threshold), 

555 ("threshold_backward", threshold_backward), 

556 ("tile", tile), 

557 ("topk", topk), 

558 ("trace", trace), 

559 ("tril", tril), 

560 ("tril.out", tril_out), 

561 ("tril_", tril_), 

562 ("triu", triu), 

563 ("triu_", triu_), 

564 ("true_divide.Scalar", true_divide), 

565 ("true_divide.Tensor", true_divide), 

566 ("true_divide_.Scalar", true_divide_), 

567 ("true_divide_.Tensor", true_divide_), 

568 ("unfold_backward", unfold_backward), 

569 ("uniform_", uniform_), 

570 ("unique_consecutive", unique_consecutive), 

571 ("unique_dim", unique_dim), 

572 ("upsample_bicubic2d", upsample_bicubic2d), 

573 ("upsample_linear1d", upsample_linear1d), 

574 ("upsample_linear1d_backward", upsample_linear1d_backward), 

575 ("upsample_nearest1d", upsample_nearest1d), 

576 ("upsample_nearest2d", upsample_nearest2d), 

577 ("upsample_nearest3d", upsample_nearest3d), 

578 ("upsample_trilinear3d", upsample_trilinear3d), 

579 ("var", var), 

580 ("var.correction", var_correction), 

581 ("var.dim", var_dim), 

582 ("var_mean.correction", var_mean), 

583 ("vdot", vdot), 

584 ("view_copy", view_copy), 

585 ("vstack", vstack), 

586 ("where.self", where_self), 

587 ("where.self_out", where_self_out), 

588 ("zero", zero), 

589 ("zero.out", zero_out), 

590 ("zero_", zero_), 

591 ("zeros", zeros), 

592 ("zeros_like", zeros_like), 

593) 

594 

595# Cache mapping from function name -> list of _FULL_CONFIG entries for quick lookup 

596FULL_CONFIG_BY_FUNC = {} 

597for _item in _FULL_CONFIG: 

598 if not _item or len(_item) < 2: 

599 continue 

600 fn = _item[1] 

601 func_name = fn.__name__ if hasattr(fn, "__name__") else str(fn) 

602 FULL_CONFIG_BY_FUNC.setdefault(func_name, []).append(_item) 

603 

604# Friendly names for only_enable(include=[...]) when the registered impl is *.out 

605for _alias, _target in ( 

606 ("softmax", "softmax_out"), 

607 ("softmax_backward", "softmax_backward_out"), 

608 ("log_softmax", "log_softmax_out"), 

609 ("log_softmax_backward", "log_softmax_backward_out"), 

610): 

611 if _target in FULL_CONFIG_BY_FUNC: 

612 FULL_CONFIG_BY_FUNC.setdefault(_alias, []).extend(FULL_CONFIG_BY_FUNC[_target]) 

613 

614 

615def enable( 

616 lib=aten_lib, 

617 unused=None, 

618 registrar=registrar, 

619 record=False, 

620 once=False, 

621 path=None, 

622): 

623 """Register all FlagGems ops except those explicitly excluded. 

624 

625 Args: 

626 lib: torch.library.Library instance to register into. Defaults to the 

627 global `aten_lib` (IMPL mode). 

628 unused: Which ops to skip. Supported forms: 

629 - list/tuple/set of function names (e.g., ["masked_fill", "mul"]). 

630 - str path to a YAML file ending with .yml/.yaml containing an 

631 `exclude:` list. 

632 - "default" or None: auto-load vendor/arch-specific 

633 runtime/backend/_<vendor>/[<arch>/]enable_configs.yaml if present. 

634 registrar: Registrar class; defaults to `Register`. 

635 record: Whether to enable FlagGems logging. 

636 once: When True, log only once. 

637 path: Optional log output path when recording. 

638 

639 Notes: 

640 - If the exclude list/YAML resolves to empty, all ops are registered. 

641 """ 

642 global current_work_registrar 

643 exclude_ops = resolve_user_setting(unused, "exclude") 

644 current_work_registrar = registrar( 

645 _FULL_CONFIG, 

646 user_include_ops=[], 

647 user_exclude_ops=exclude_ops, 

648 cpp_patched_ops=list(set(aten_patch_list)), 

649 lib=lib, 

650 ) 

651 setup_flaggems_logging(path=path, record=record, once=once) 

652 

653 

654def only_enable( 

655 lib=aten_lib, 

656 include=None, 

657 registrar=registrar, 

658 record=False, 

659 once=False, 

660 path=None, 

661): 

662 """Register only the specified FlagGems ops and skip the rest. 

663 

664 Args: 

665 lib: torch.library.Library instance to register into. Defaults to the 

666 global `aten_lib` (IMPL mode). 

667 include: Which ops to register. Supported forms: 

668 - list/tuple/set of function names (e.g., ["rms_norm", "softmax"]). 

669 - str path to a YAML file ending with .yml/.yaml (expects a list or 

670 an `include:` key). 

671 - "default" or None: auto-load vendor/arch-specific 

672 runtime/backend/_<vendor>/[<arch>/]only_enable_configs.yaml if present. 

673 registrar: Registrar class; defaults to `Register`. 

674 record: Whether to enable FlagGems logging. 

675 once: When True, log only once. 

676 path: Optional log output path when recording. 

677 

678 Classic usage: 

679 - Only register a few ops: 

680 only_enable(include=["rms_norm", "softmax"]) 

681 - Use vendor default YAML: 

682 only_enable(include="default") # or include=None 

683 - Use a custom YAML: 

684 only_enable(include="/path/to/only_enable.yaml") 

685 

686 Notes: 

687 - If the include list/YAML resolves to empty or none of the names match 

688 known ops, the function warns and returns without registering. 

689 """ 

690 include_ops = resolve_user_setting(include, "include") 

691 if not include_ops: 

692 warnings.warn( 

693 "only_enable failed: No include entries resolved from list or yaml." 

694 ) 

695 return 

696 

697 global current_work_registrar 

698 current_work_registrar = registrar( 

699 _FULL_CONFIG, 

700 user_include_ops=include_ops, 

701 user_exclude_ops=[], 

702 cpp_patched_ops=list(set(aten_patch_list)), 

703 full_config_by_func=FULL_CONFIG_BY_FUNC, 

704 lib=lib, 

705 ) 

706 setup_flaggems_logging(path=path, record=record, once=once) 

707 

708 

709class use_gems: 

710 """ 

711 The 'include' parameter has higher priority than 'exclude'. 

712 When 'include' is not None, use_gems will not process 'exclude'. 

713 """ 

714 

715 def __init__(self, exclude=None, include=None, record=False, once=False, path=None): 

716 self.lib = torch.library.Library("aten", "IMPL") 

717 self.exclude = exclude if isinstance(exclude, (list, tuple, set, str)) else [] 

718 self.include = include if isinstance(include, (list, tuple, set, str)) else [] 

719 self.registrar = GeneralOpRegistrar 

720 self.record = record 

721 self.once = once 

722 self.path = path 

723 

724 def __enter__(self): 

725 if self.include: 

726 only_enable( 

727 lib=self.lib, 

728 include=self.include, 

729 registrar=self.registrar, 

730 record=self.record, 

731 once=self.once, 

732 path=self.path, 

733 ) 

734 else: 

735 enable( 

736 lib=self.lib, 

737 unused=self.exclude, 

738 registrar=self.registrar, 

739 record=self.record, 

740 once=self.once, 

741 path=self.path, 

742 ) 

743 

744 def __exit__(self, exc_type, exc_val, exc_tb): 

745 global current_work_registrar 

746 if torch.__version__ >= "2.5": 

747 self.lib._destroy() 

748 del self.lib 

749 del self.exclude 

750 del self.include 

751 del self.registrar 

752 del current_work_registrar 

753 if self.record: 

754 teardown_flaggems_logging() 

755 

756 @property 

757 def experimental_ops(self): 

758 import flag_gems.experimental_ops 

759 

760 return flag_gems.experimental_ops 

761 

762 

763def all_registered_ops(): 

764 return current_work_registrar.get_all_ops() 

765 

766 

767def all_registered_keys(): 

768 return current_work_registrar.get_all_keys() 

769 

770 

771__all__ = [ 

772 "all_registered_keys", 

773 "all_registered_ops", 

774 "enable", 

775 "flagtune", 

776 "only_enable", 

777 "use_gems", 

778]