Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/__init__.py: 0%

182 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1from .abs import abs, abs_ 

2from .acos import acos 

3from .add import add, add_ 

4from .addcdiv import addcdiv 

5from .addcmul import addcmul 

6from .addmm import addmm, addmm_out 

7from .addmv import addmv, addmv_out 

8from .addr import addr 

9from .all import all, all_dim, all_dims 

10from .amax import amax 

11from .angle import angle 

12from .any import any, any_dim, any_dims 

13from .apply_repetition_penalties import apply_repetition_penalties 

14from .arange import arange, arange_start 

15from .argmax import argmax 

16from .argmin import argmin 

17from .atan import atan, atan_ 

18from .attention import ( 

19 ScaleDotProductAttention, 

20 flash_attention_forward, 

21 flash_attn_varlen_func, 

22 scaled_dot_product_attention, 

23 scaled_dot_product_attention_backward, 

24 scaled_dot_product_attention_forward, 

25) 

26from .avg_pool2d import avg_pool2d, avg_pool2d_backward 

27from .baddbmm import baddbmm 

28from .batch_norm import batch_norm, batch_norm_backward 

29from .bitwise_and import ( 

30 bitwise_and_scalar, 

31 bitwise_and_scalar_, 

32 bitwise_and_scalar_tensor, 

33 bitwise_and_tensor, 

34 bitwise_and_tensor_, 

35) 

36from .bitwise_left_shift import bitwise_left_shift 

37from .bitwise_not import bitwise_not, bitwise_not_ 

38from .bitwise_or import ( 

39 bitwise_or_scalar, 

40 bitwise_or_scalar_, 

41 bitwise_or_scalar_tensor, 

42 bitwise_or_tensor, 

43 bitwise_or_tensor_, 

44) 

45from .bitwise_right_shift import bitwise_right_shift 

46from .bmm import bmm, bmm_out 

47from .cat import cat 

48from .celu import celu, celu_ 

49from .clamp import clamp, clamp_, clamp_min, clamp_min_, clamp_tensor, clamp_tensor_ 

50from .contiguous import contiguous 

51from .conv1d import conv1d 

52from .conv2d import conv2d 

53from .conv3d import conv3d 

54from .conv_depthwise2d import _conv_depthwise2d 

55from .copy import copy, copy_ 

56from .cos import cos, cos_ 

57from .count_nonzero import count_nonzero 

58from .cummax import cummax 

59from .cummin import cummin 

60from .cumsum import cumsum, cumsum_out, normed_cumsum 

61from .diag import diag 

62from .diag_embed import diag_embed 

63from .diagonal import diagonal_backward 

64from .digamma_ import digamma_ 

65from .div import ( 

66 div_mode, 

67 div_mode_, 

68 floor_divide, 

69 floor_divide_, 

70 remainder, 

71 remainder_, 

72 true_divide, 

73 true_divide_, 

74 true_divide_out, 

75) 

76from .dot import dot 

77from .dropout import dropout, dropout_backward 

78from .elu import elu, elu_, elu_backward 

79from .embedding import embedding, embedding_backward 

80from .eq import eq, eq_scalar 

81from .erf import erf, erf_ 

82from .exp import exp, exp_, exp_out 

83from .exp2 import exp2, exp2_ 

84from .exponential_ import exponential_ 

85from .eye import eye 

86from .eye_m import eye_m 

87from .fill import fill_scalar, fill_scalar_, fill_tensor, fill_tensor_ 

88from .flip import flip 

89from .full import full 

90from .full_like import full_like 

91from .gather import gather, gather_backward 

92from .ge import ge, ge_scalar 

93from .gelu import gelu, gelu_, gelu_backward 

94from .get_scheduler_metadata import get_scheduler_metadata 

95from .glu import glu, glu_backward 

96from .groupnorm import group_norm, group_norm_backward 

97from .gt import gt, gt_scalar 

98from .hadamard_transform import hadamard_transform 

99from .hstack import hstack 

100from .index import index 

101from .index_add import index_add, index_add_ 

102from .index_put import index_put, index_put_ 

103from .index_select import index_select 

104from .isclose import allclose, isclose 

105from .isfinite import isfinite 

106from .isin import isin 

107from .isinf import isinf 

108from .isnan import isnan 

109from .kron import kron 

110from .layernorm import layer_norm, layer_norm_backward 

111from .le import le, le_scalar 

112from .lerp import lerp_scalar, lerp_scalar_, lerp_tensor, lerp_tensor_ 

113from .linspace import linspace 

114from .log import log 

115from .log_sigmoid import log_sigmoid 

116from .log_softmax import log_softmax, log_softmax_backward 

117from .logical_and import logical_and 

118from .logical_not import logical_not 

119from .logical_or import logical_or 

120from .logical_xor import logical_xor 

121from .logspace import logspace 

122from .lt import lt, lt_scalar 

123from .masked_fill import masked_fill, masked_fill_ 

124from .masked_scatter import masked_scatter, masked_scatter_ 

125from .masked_select import masked_select 

126from .matmul_bf16 import matmul_bf16 

127from .matmul_int8 import matmul_int8 

128from .max import max, max_dim 

129from .max_pool2d_with_indices import max_pool2d_backward, max_pool2d_with_indices 

130from .maximum import maximum 

131from .mean import mean, mean_dim 

132from .min import min, min_dim 

133from .minimum import minimum 

134from .mm import mm, mm_out 

135from .mse_loss import mse_loss 

136from .mul import mul, mul_ 

137from .multinomial import multinomial 

138from .mv import mv, mv_cluster 

139from .nan_to_num import nan_to_num 

140from .ne import ne, ne_scalar 

141from .neg import neg, neg_ 

142from .nllloss import ( 

143 nll_loss2d_backward, 

144 nll_loss2d_forward, 

145 nll_loss_backward, 

146 nll_loss_forward, 

147) 

148from .nonzero import nonzero 

149from .normal import ( 

150 normal_, 

151 normal_float_tensor, 

152 normal_tensor_float, 

153 normal_tensor_tensor, 

154) 

155from .ones import ones 

156from .ones_like import ones_like 

157from .pad import constant_pad_nd, pad 

158from .per_token_group_quant_fp8 import SUPPORTED_FP8_DTYPE, per_token_group_quant_fp8 

159from .polar import polar 

160from .pow import ( 

161 pow_scalar, 

162 pow_tensor_scalar, 

163 pow_tensor_scalar_, 

164 pow_tensor_tensor, 

165 pow_tensor_tensor_, 

166) 

167from .prod import prod, prod_dim 

168from .quantile import quantile 

169from .rand import rand 

170from .rand_like import rand_like 

171from .randn import randn 

172from .randn_like import randn_like 

173from .randperm import randperm 

174from .reciprocal import reciprocal, reciprocal_ 

175from .reflection_pad1d import reflection_pad1d, reflection_pad1d_out 

176from .reflection_pad2d import reflection_pad2d, reflection_pad2d_out 

177from .relu import relu, relu_ 

178from .repeat import repeat 

179from .repeat_interleave import ( 

180 repeat_interleave_self_int, 

181 repeat_interleave_self_tensor, 

182 repeat_interleave_tensor, 

183) 

184from .resolve_conj import resolve_conj 

185from .resolve_neg import resolve_neg 

186from .rms_norm import rms_norm, rms_norm_backward, rms_norm_forward 

187from .round import round, round_, round_out 

188from .rsqrt import rsqrt, rsqrt_ 

189from .rsub import rsub 

190from .scaled_softmax import scaled_softmax_backward, scaled_softmax_forward 

191from .scatter import scatter, scatter_ 

192from .scatter_add_ import scatter_add_ 

193from .select_scatter import select_scatter 

194from .sigmoid import sigmoid, sigmoid_, sigmoid_backward 

195from .silu import silu, silu_, silu_backward 

196from .sin import sin, sin_ 

197from .slice_scatter import slice_scatter 

198from .soft_margin_loss import soft_margin_loss, soft_margin_loss_out 

199from .softmax import softmax, softmax_backward 

200from .softplus import softplus 

201from .softshrink import softshrink, softshrink_out 

202from .sort import sort, sort_stable 

203from .sqrt import sqrt, sqrt_ 

204from .stack import stack 

205from .std import std 

206from .sub import sub, sub_ 

207from .sum import sum, sum_dim, sum_dim_out, sum_out 

208from .tan import tan, tan_ 

209from .tanh import tanh, tanh_, tanh_backward 

210from .threshold import threshold, threshold_backward 

211from .tile import tile 

212from .to import to_copy 

213from .topk import topk 

214from .trace import trace 

215from .triu import triu 

216from .uniform import uniform_ 

217from .unique import _unique2 

218from .upsample_bicubic2d_aa import _upsample_bicubic2d_aa 

219from .upsample_linear1d import upsample_linear1d 

220from .upsample_nearest1d import upsample_nearest1d 

221from .upsample_nearest2d import upsample_nearest2d 

222from .var_mean import var_mean 

223from .vdot import vdot 

224from .vector_norm import vector_norm 

225from .vstack import vstack 

226from .weightnorm import weight_norm_interface, weight_norm_interface_backward 

227from .where import where_scalar_other, where_scalar_self, where_self, where_self_out 

228from .zero import zero, zero_out 

229from .zeros import zeros 

230from .zeros_like import zeros_like 

231 

232__all__ = [ 

233 "_conv_depthwise2d", 

234 "digamma_", 

235 "soft_margin_loss", 

236 "soft_margin_loss_out", 

237 "softshrink", 

238 "softshrink_out", 

239 "_unique2", 

240 "_upsample_bicubic2d_aa", 

241 "apply_repetition_penalties", 

242 "abs", 

243 "abs_", 

244 "acos", 

245 "add", 

246 "add_", 

247 "addcdiv", 

248 "addcmul", 

249 "addmm", 

250 "addmm_out", 

251 "addmv", 

252 "addmv_out", 

253 "addr", 

254 "all", 

255 "all_dim", 

256 "all_dims", 

257 "allclose", 

258 "amax", 

259 "angle", 

260 "any", 

261 "any_dim", 

262 "any_dims", 

263 "arange", 

264 "arange_start", 

265 "argmax", 

266 "argmin", 

267 "atan", 

268 "atan_", 

269 "avg_pool2d", 

270 "avg_pool2d_backward", 

271 "baddbmm", 

272 "batch_norm", 

273 "batch_norm_backward", 

274 "bitwise_and_scalar", 

275 "bitwise_and_scalar_", 

276 "bitwise_and_scalar_tensor", 

277 "bitwise_and_tensor", 

278 "bitwise_and_tensor_", 

279 "bitwise_left_shift", 

280 "bitwise_not", 

281 "bitwise_not_", 

282 "bitwise_or_scalar", 

283 "bitwise_or_scalar_", 

284 "bitwise_or_scalar_tensor", 

285 "bitwise_or_tensor", 

286 "bitwise_or_tensor_", 

287 "bitwise_right_shift", 

288 "bmm", 

289 "bmm_out", 

290 "cat", 

291 "celu", 

292 "celu_", 

293 "clamp", 

294 "clamp_", 

295 "clamp_tensor", 

296 "clamp_tensor_", 

297 "clamp_min", 

298 "clamp_min_", 

299 "constant_pad_nd", 

300 "contiguous", 

301 "conv1d", 

302 "conv2d", 

303 "conv3d", 

304 "copy", 

305 "copy_", 

306 "cos", 

307 "cos_", 

308 "count_nonzero", 

309 "cummax", 

310 "cummin", 

311 "cumprod_", 

312 "cumsum", 

313 "cumsum_out", 

314 "diag", 

315 "diag_embed", 

316 "diagonal_backward", 

317 "div_mode", 

318 "div_mode_", 

319 "dot", 

320 "dropout", 

321 "dropout_backward", 

322 "elu", 

323 "elu_", 

324 "elu_backward", 

325 "embedding", 

326 "embedding_backward", 

327 "eq", 

328 "eq_scalar", 

329 "erf", 

330 "erf_", 

331 "exp", 

332 "exp_", 

333 "exp_out", 

334 "exp2", 

335 "exp2_", 

336 "exponential_", 

337 "eye", 

338 "eye_m", 

339 "fill_scalar", 

340 "fill_scalar_", 

341 "fill_tensor", 

342 "fill_tensor_", 

343 "flash_attention_forward", 

344 "flash_attn_varlen_func", 

345 "flip", 

346 "floor_divide", 

347 "floor_divide_", 

348 "full", 

349 "full_like", 

350 "gather", 

351 "gather_backward", 

352 "ge", 

353 "ge_scalar", 

354 "gelu", 

355 "gelu_", 

356 "gelu_backward", 

357 "get_scheduler_metadata", 

358 "glu", 

359 "glu_backward", 

360 "group_norm", 

361 "group_norm_backward", 

362 "gt", 

363 "gt_scalar", 

364 "hstack", 

365 "hadamard_transform", 

366 "index", 

367 "index_add", 

368 "index_add_", 

369 "index_put", 

370 "index_put_", 

371 "index_select", 

372 "isclose", 

373 "isfinite", 

374 "isin", 

375 "isinf", 

376 "isnan", 

377 "kron", 

378 "layer_norm", 

379 "layer_norm_backward", 

380 "le", 

381 "le_scalar", 

382 "lerp_scalar", 

383 "lerp_scalar_", 

384 "lerp_tensor", 

385 "lerp_tensor_", 

386 "linspace", 

387 "log", 

388 "log_sigmoid", 

389 "log_softmax", 

390 "log_softmax_backward", 

391 "logical_and", 

392 "logical_not", 

393 "logical_or", 

394 "logical_xor", 

395 "logspace", 

396 "lt", 

397 "lt_scalar", 

398 "matmul_bf16", 

399 "matmul_int8", 

400 "masked_fill", 

401 "masked_fill_", 

402 "masked_scatter", 

403 "masked_scatter_", 

404 "masked_select", 

405 "max", 

406 "max_dim", 

407 "maximum", 

408 "max_pool2d_with_indices", 

409 "max_pool2d_backward", 

410 "mean", 

411 "mean_dim", 

412 "min", 

413 "min_dim", 

414 "minimum", 

415 "mm", 

416 "mm_out", 

417 "mse_loss", 

418 "mul", 

419 "mul_", 

420 "multinomial", 

421 "mv", 

422 "mv_cluster", 

423 "nan_to_num", 

424 "ne", 

425 "ne_scalar", 

426 "neg", 

427 "neg_", 

428 "nll_loss_backward", 

429 "nll_loss_forward", 

430 "nll_loss2d_backward", 

431 "nll_loss2d_forward", 

432 "nonzero", 

433 "normal_float_tensor", 

434 "normal_tensor_float", 

435 "normal_tensor_tensor", 

436 "normal_", 

437 "normed_cumsum", 

438 "ones", 

439 "ones_like", 

440 "pad", 

441 "per_token_group_quant_fp8", 

442 "polar", 

443 "pow_scalar", 

444 "pow_tensor_scalar", 

445 "pow_tensor_scalar_", 

446 "pow_tensor_tensor", 

447 "pow_tensor_tensor_", 

448 "prod", 

449 "prod_dim", 

450 "quantile", 

451 "rand", 

452 "rand_like", 

453 "randn", 

454 "randn_like", 

455 "randperm", 

456 "reciprocal", 

457 "reciprocal_", 

458 "reflection_pad1d", 

459 "reflection_pad1d_out", 

460 "reflection_pad2d", 

461 "reflection_pad2d_out", 

462 "relu", 

463 "relu_", 

464 "remainder", 

465 "remainder_", 

466 "repeat", 

467 "repeat_interleave_self_int", 

468 "repeat_interleave_self_tensor", 

469 "repeat_interleave_tensor", 

470 "resolve_conj", 

471 "resolve_neg", 

472 "round", 

473 "round_", 

474 "round_out", 

475 "rms_norm", 

476 "rms_norm_backward", 

477 "rms_norm_forward", 

478 "rsqrt", 

479 "rsqrt_", 

480 "rsub", 

481 "scaled_dot_product_attention", 

482 "scaled_dot_product_attention_backward", 

483 "scaled_dot_product_attention_forward", 

484 "scaled_softmax_backward", 

485 "scaled_softmax_forward", 

486 "scatter", 

487 "scatter_", 

488 "scatter_add_", 

489 "select_scatter", 

490 "sigmoid", 

491 "sigmoid_", 

492 "sigmoid_backward", 

493 "silu", 

494 "silu_", 

495 "silu_backward", 

496 "sin", 

497 "sin_", 

498 "slice_scatter", 

499 "softmax", 

500 "softmax_backward", 

501 "softplus", 

502 "sort", 

503 "sort_stable", 

504 "sqrt", 

505 "sqrt_", 

506 "stack", 

507 "std", 

508 "sub", 

509 "sub_", 

510 "sum", 

511 "sum_dim", 

512 "sum_dim_out", 

513 "sum_out", 

514 "ScaleDotProductAttention", 

515 "SUPPORTED_FP8_DTYPE", 

516 "tan", 

517 "tan_", 

518 "tanh", 

519 "tanh_", 

520 "tanh_backward", 

521 "threshold", 

522 "threshold_backward", 

523 "tile", 

524 "to_copy", 

525 "topk", 

526 "trace", 

527 "triu", 

528 "true_divide", 

529 "true_divide_out", 

530 "true_divide_", 

531 "uniform_", 

532 "upsample_linear1d", 

533 "upsample_nearest1d", 

534 "upsample_nearest2d", 

535 "var_mean", 

536 "vdot", 

537 "vector_norm", 

538 "vstack", 

539 "weight_norm_interface", 

540 "weight_norm_interface_backward", 

541 "where_scalar_other", 

542 "where_scalar_self", 

543 "where_self", 

544 "where_self_out", 

545 "zero", 

546 "zero_out", 

547 "zeros", 

548 "zeros_like", 

549]