Coverage for src/flag_gems/runtime/backend/_sunrise/ops/__init__.py: 0%

92 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1from .. import _install_typed_ptr_device_patch 

2from ._safe_softmax import _safe_softmax 

3from ._upsample_nearest_exact1d import _upsample_nearest_exact1d 

4from .abs import abs, abs_ 

5from .add import add, add_ 

6from .addmm import addmm, addmm_out 

7from .aminmax import amax, amax_out, amin, amin_out, aminmax, aminmax_out 

8from .angle import angle 

9from .arcsinh import arcsinh, arcsinh_out 

10from .attention import ( 

11 ScaleDotProductAttention, 

12 flash_attention_forward, 

13 flash_attn_varlen_func, 

14 scaled_dot_product_attention, 

15 scaled_dot_product_attention_backward, 

16 scaled_dot_product_attention_forward, 

17) 

18from .bitwise_and import ( 

19 bitwise_and_scalar, 

20 bitwise_and_scalar_, 

21 bitwise_and_scalar_tensor, 

22 bitwise_and_tensor, 

23 bitwise_and_tensor_, 

24) 

25from .bitwise_left_shift import ( 

26 bitwise_left_shift, 

27 bitwise_left_shift_, 

28 bitwise_left_shift_out, 

29) 

30from .bitwise_right_shift import ( 

31 bitwise_right_shift, 

32 bitwise_right_shift_, 

33 bitwise_right_shift_out, 

34) 

35from .clamp import ( 

36 clamp, 

37 clamp_, 

38 clamp_min, 

39 clamp_min_, 

40 clamp_min_out, 

41 clamp_tensor, 

42 clamp_tensor_, 

43) 

44from .conj_physical import conj_physical 

45from .conv2d import conv2d 

46from .cos import cos, cos_ 

47from .count_nonzero import count_nonzero 

48from .ctc_loss import ctc_loss 

49from .cumsum import cumsum, cumsum_out, normed_cumsum 

50from .div import ( 

51 div_mode, 

52 div_mode_, 

53 floor_divide, 

54 floor_divide_, 

55 remainder, 

56 remainder_, 

57 true_divide, 

58 true_divide_, 

59 true_divide_out, 

60) 

61from .dropout import dropout, dropout_backward 

62from .embedding import embedding, embedding_backward 

63from .eq import eq, eq_scalar, equal 

64from .exponential_ import exponential_ 

65from .fft import fft 

66from .fill import ( 

67 fill_scalar, 

68 fill_scalar_, 

69 fill_scalar_out, 

70 fill_tensor, 

71 fill_tensor_, 

72 fill_tensor_out, 

73) 

74from .gather import gather, gather_backward 

75from .ge import ge, ge_scalar 

76from .gelu import gelu, gelu_, gelu_backward 

77from .hypot import hypot, hypot_out 

78from .i0 import i0, i0_out 

79from .i0_ import i0_ 

80from .index_add import index_add, index_add_ 

81from .index_put import index_put, index_put_ 

82from .index_select import index_select 

83from .isin import isin 

84from .isnan import isnan 

85from .layernorm import layer_norm, layer_norm_backward 

86from .lift_fresh_copy import lift_fresh_copy, lift_fresh_copy_out 

87from .linspace import linspace 

88from .log_softmax import log_softmax, log_softmax_backward 

89from .logaddexp import logaddexp, logaddexp_out 

90from .logical_and import logical_and 

91from .logical_or import logical_or, logical_or_ 

92from .margin_ranking_loss import margin_ranking_loss 

93from .masked_select import masked_select 

94from .mean import mean, mean_dim 

95from .mul import mul, mul_ 

96from .multinomial import multinomial 

97from .mv import mv 

98from .neg import neg, neg_ 

99from .nonzero import nonzero 

100from .one_hot import one_hot 

101from .pad import constant_pad_nd, pad 

102from .polar import polar 

103from .pow import ( 

104 pow_scalar, 

105 pow_tensor_scalar, 

106 pow_tensor_scalar_, 

107 pow_tensor_tensor, 

108 pow_tensor_tensor_, 

109) 

110from .prelu import prelu 

111from .quantile import quantile 

112from .randperm import randperm 

113from .reflection_pad2d import reflection_pad2d 

114from .repeat import repeat 

115from .repeat_interleave import ( 

116 repeat_interleave_self_int, 

117 repeat_interleave_self_tensor, 

118 repeat_interleave_tensor, 

119) 

120from .resolve_neg import resolve_neg 

121from .rms_norm import rms_norm, rms_norm_backward, rms_norm_forward 

122from .scatter import scatter, scatter_ 

123from .scatter_reduce import scatter_reduce, scatter_reduce_, scatter_reduce_out 

124from .select_backward import select_backward 

125from .sigmoid import sigmoid, sigmoid_, sigmoid_backward 

126from .soft_margin_loss import soft_margin_loss, soft_margin_loss_out 

127from .softmax import softmax, softmax_backward 

128from .sort import sort, sort_stable 

129from .special_i0e import special_i0e, special_i0e_out 

130from .special_i1 import special_i1, special_i1_out 

131from .sub import sub, sub_ 

132from .sum import sum, sum_dim, sum_dim_out, sum_out 

133from .svd import svd 

134from .t_copy import t_copy, t_copy_out 

135from .tile import tile 

136from .to import to_copy 

137from .topk import topk 

138from .triu import triu 

139from .unique import _unique2 

140from .unique_consecutive import unique_consecutive 

141from .upsample_bicubic2d import upsample_bicubic2d 

142from .upsample_linear1d import upsample_linear1d 

143from .upsample_nearest2d import upsample_nearest2d 

144from .vdot import vdot 

145from .where import where_scalar_other, where_scalar_self, where_self, where_self_out 

146from .zero import zero, zero_out 

147 

148# Run after runtime initialization; importing tensor_wrapper in _sunrise/__init__.py 

149# would hit a circular import through flag_gems.utils. 

150_install_typed_ptr_device_patch() 

151 

152 

153__all__ = [ 

154 "_safe_softmax", 

155 "_upsample_nearest_exact1d", 

156 "abs", 

157 "abs_", 

158 "add", 

159 "add_", 

160 "addmm", 

161 "addmm_out", 

162 "amin", 

163 "amin_out", 

164 "amax", 

165 "amax_out", 

166 "aminmax", 

167 "aminmax_out", 

168 "angle", 

169 "arcsinh", 

170 "arcsinh_out", 

171 "bitwise_and_scalar", 

172 "bitwise_and_scalar_", 

173 "bitwise_and_scalar_tensor", 

174 "bitwise_and_tensor", 

175 "bitwise_and_tensor_", 

176 "bitwise_left_shift", 

177 "bitwise_left_shift_", 

178 "bitwise_left_shift_out", 

179 "bitwise_right_shift", 

180 "bitwise_right_shift_", 

181 "bitwise_right_shift_out", 

182 "clamp", 

183 "clamp_", 

184 "clamp_tensor", 

185 "clamp_tensor_", 

186 "clamp_min", 

187 "clamp_min_", 

188 "clamp_min_out", 

189 "conv2d", 

190 "cos", 

191 "cos_", 

192 "count_nonzero", 

193 "conj_physical", 

194 "ctc_loss", 

195 "cumsum", 

196 "cumsum_out", 

197 "normed_cumsum", 

198 "div_mode", 

199 "div_mode_", 

200 "embedding", 

201 "embedding_backward", 

202 "floor_divide", 

203 "floor_divide_", 

204 "remainder", 

205 "remainder_", 

206 "true_divide", 

207 "true_divide_", 

208 "true_divide_out", 

209 "dropout", 

210 "dropout_backward", 

211 "eq", 

212 "eq_scalar", 

213 "equal", 

214 "exponential_", 

215 "fill_scalar", 

216 "fill_scalar_", 

217 "fill_scalar_out", 

218 "fill_tensor", 

219 "fill_tensor_", 

220 "fill_tensor_out", 

221 "flash_attention_forward", 

222 "flash_attn_varlen_func", 

223 "fft", 

224 "gather", 

225 "gather_backward", 

226 "ge", 

227 "ge_scalar", 

228 "gelu", 

229 "gelu_", 

230 "gelu_backward", 

231 "hypot", 

232 "hypot_out", 

233 "i0", 

234 "i0_out", 

235 "i0_", 

236 "index_add", 

237 "index_add_", 

238 "index_put", 

239 "index_put_", 

240 "index_select", 

241 "isin", 

242 "isnan", 

243 "layer_norm", 

244 "layer_norm_backward", 

245 "lift_fresh_copy", 

246 "lift_fresh_copy_out", 

247 "linspace", 

248 "log_softmax", 

249 "log_softmax_backward", 

250 "logaddexp", 

251 "logaddexp_out", 

252 "logical_and", 

253 "logical_or", 

254 "logical_or_", 

255 "margin_ranking_loss", 

256 "masked_select", 

257 "mean", 

258 "mean_dim", 

259 "mul", 

260 "mul_", 

261 "multinomial", 

262 "mv", 

263 "neg", 

264 "neg_", 

265 "nonzero", 

266 "one_hot", 

267 "pad", 

268 "polar", 

269 "constant_pad_nd", 

270 "pow_scalar", 

271 "pow_tensor_scalar", 

272 "pow_tensor_scalar_", 

273 "pow_tensor_tensor", 

274 "pow_tensor_tensor_", 

275 "prelu", 

276 "quantile", 

277 "randperm", 

278 "reflection_pad2d", 

279 "repeat", 

280 "repeat_interleave_self_int", 

281 "repeat_interleave_self_tensor", 

282 "repeat_interleave_tensor", 

283 "resolve_neg", 

284 "rms_norm", 

285 "rms_norm_forward", 

286 "rms_norm_backward", 

287 "scaled_dot_product_attention", 

288 "scaled_dot_product_attention_backward", 

289 "scaled_dot_product_attention_forward", 

290 "scatter", 

291 "scatter_", 

292 "scatter_reduce", 

293 "scatter_reduce_", 

294 "scatter_reduce_out", 

295 "select_backward", 

296 "sigmoid", 

297 "sigmoid_", 

298 "sigmoid_backward", 

299 "soft_margin_loss", 

300 "soft_margin_loss_out", 

301 "softmax", 

302 "softmax_backward", 

303 "sort", 

304 "sort_stable", 

305 "special_i0e", 

306 "special_i0e_out", 

307 "special_i1", 

308 "special_i1_out", 

309 "sub", 

310 "sub_", 

311 "svd", 

312 "sum", 

313 "sum_dim", 

314 "sum_dim_out", 

315 "sum_out", 

316 "t_copy", 

317 "t_copy_out", 

318 "ScaleDotProductAttention", 

319 "tile", 

320 "to_copy", 

321 "topk", 

322 "triu", 

323 "_unique2", 

324 "unique_consecutive", 

325 "upsample_bicubic2d", 

326 "upsample_linear1d", 

327 "upsample_nearest2d", 

328 "vdot", 

329 "where_scalar_other", 

330 "where_scalar_self", 

331 "where_self", 

332 "where_self_out", 

333 "zero", 

334 "zero_out", 

335]