Coverage for src/flag_gems/runtime/backend/_mthreads/ops/__init__.py: 0%
44 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1from torch_musa import current_device, get_device_capability
3from .all import all, all_dim, all_dims
4from .amax import amax
5from .any import any, any_dim, any_dims
6from .arange import arange, arange_start
7from .argmin import argmin
8from .batch_norm import batch_norm, batch_norm_backward
9from .celu import celu
10from .conv2d import conv2d
11from .dropout import dropout, dropout_backward
12from .gather import gather, gather_backward
13from .index_add import index_add, index_add_
14from .index_put import index_put, index_put_
15from .index_select import index_select
16from .log import log
17from .log_softmax import log_softmax, log_softmax_backward
18from .max import max, max_dim
19from .min import min, min_dim
20from .normal import normal_
21from .one_hot import one_hot
22from .ones import ones
23from .ones_like import ones_like
24from .prod import prod, prod_dim
25from .rand import rand
26from .rand_like import rand_like
27from .randn import randn
28from .randn_like import randn_like
29from .randperm import randperm
30from .repeat import repeat
31from .repeat_interleave import (
32 repeat_interleave_self_int,
33 repeat_interleave_self_tensor,
34 repeat_interleave_tensor,
35)
36from .resolve_conj import resolve_conj
37from .sort import sort, sort_stable
38from .tile import tile
39from .w8a8_block_fp8_matmul import w8a8_block_fp8_matmul
40from .zeros import zero_, zeros
41from .zeros_like import zeros_like
43__all__ = [
44 "amax",
45 "all",
46 "all_dim",
47 "all_dims",
48 "any",
49 "any_dim",
50 "any_dims",
51 "arange",
52 "arange_start",
53 "argmin",
54 "batch_norm",
55 "batch_norm_backward",
56 "celu",
57 # "celu_",
58 "conv2d",
59 "dropout",
60 "dropout_backward",
61 "gather",
62 "gather_backward",
63 "index_add",
64 "index_add_",
65 "index_put",
66 "index_put_",
67 "index_select",
68 "log",
69 "log_softmax",
70 "log_softmax_backward",
71 "max",
72 "max_dim",
73 "min",
74 "min_dim",
75 "normal_",
76 "one_hot",
77 "ones",
78 "ones_like",
79 "prod",
80 "prod_dim",
81 "rand",
82 "rand_like",
83 "randn",
84 "randn_like",
85 "randperm",
86 "repeat",
87 "repeat_interleave_self_int",
88 "repeat_interleave_self_tensor",
89 "repeat_interleave_tensor",
90 "resolve_conj",
91 "sort",
92 "sort_stable",
93 "tile",
94 "w8a8_block_fp8_matmul",
95 "zero_",
96 "zeros",
97 "zeros_like",
98]
100if get_device_capability(current_device())[0] >= 3:
101 from .addmm import addmm # noqa: F401
102 from .bmm import bmm # noqa: F401
103 from .gelu import gelu # noqa: F401
104 from .mm import mm # noqa: F401
105 from .tanh import tanh # noqa: F401
107 __all__.extend(
108 [
109 "addmm",
110 "bmm",
111 "gelu",
112 "mm",
113 "tanh",
114 ]
115 )