Coverage for src/flag_gems/runtime/backend/_sunrise/__init__.py: 0%
400 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import os
3import torch_ptpu # noqa: F401
4from backend_utils import VendorInfoBase
6from .monkey_patch import apply_sunrise_monkey_patches
8vendor_info = VendorInfoBase(
9 vendor_name="sunrise",
10 device_name="ptpu",
11 device_query_cmd="pt_smi",
12 triton_extra_name="tang",
13 dispatch_key="PrivateUse1",
14)
16CUSTOMIZED_UNUSED_OPS = ()
19def _sunrise_rebuild_ptpu_tensor_from_cpu(
20 tensor_cls, cpu_tensor, device, requires_grad
21):
22 import torch
23 from torch.nn.parameter import Parameter
25 tensor = cpu_tensor.to(device=device)
26 if tensor_cls == Parameter:
27 return Parameter(tensor, requires_grad=requires_grad)
28 if tensor_cls not in (torch.Tensor, type(tensor)):
29 try:
30 tensor = tensor.as_subclass(tensor_cls)
31 except Exception:
32 pass
33 tensor.requires_grad = requires_grad
34 return tensor
37def _should_stage_ptpu_tensor_for_multiprocessing(tensor):
38 import torch
40 return (
41 isinstance(tensor, torch.Tensor)
42 and tensor.device.type == "ptpu"
43 and tensor.layout == torch.strided
44 and not tensor.is_nested
45 )
48def _sunrise_monkey_patch_enabled():
49 value = os.getenv("FLAG_GEMS_SUNRISE_ENABLE_MONKEY_PATCH", "1").strip().lower()
50 return value not in {"0", "false", "no", "off"}
53if _sunrise_monkey_patch_enabled():
54 apply_sunrise_monkey_patches()
56# [Sunrise fix] Aten lower needed.
57CUSTOMIZED_AUTOGRAD_OPS = (
58 "absolute",
59 "arcsinh_",
60 "arcsinh",
61 "arcsinh.out",
62 "arctanh_",
63 "clip",
64 "clip_",
65 "concatenate",
66 "conj_physical",
67 "diag",
68 "diff",
69 "__ior__.Tensor",
70 "__ior__.Scalar",
71 "__or__.Tensor",
72 "__or__.Scalar",
73 "embedding_backward",
74 "feature_dropout",
75 "feature_dropout_",
76 "gather_backward",
77 "greater.Tensor",
78 "greater.Scalar",
79 "greater.Scalar_out",
80 "hstack",
81 "isclose",
82 "isfinite",
83 "kron",
84 "log_sigmoid",
85 "margin_ranking_loss",
86 "nonzero_numpy",
87 "pad",
88 "prelu",
89 "quantile",
90 "relu6",
91 "repeat_interleave.self_int",
92 "repeat_interleave.self_Tensor",
93 "resolve_conj",
94 "resolve_neg",
95 "selu",
96 "selu_",
97 "square",
98 "square_",
99 "square.out",
100 "svd",
101 "tile",
102 "vstack",
103)
106def _sunrise_extra_config_entries(): # 有些公共库也没有注册的op,只能先放在这里了。使得tests能过
107 from .ops import (
108 amax_out,
109 amin,
110 amin_out,
111 aminmax_out,
112 clamp_min,
113 clamp_min_,
114 clamp_min_out,
115 hypot_out,
116 )
118 return (
119 ("amax.out", amax_out),
120 ("amin", amin),
121 ("amin.out", amin_out),
122 ("aminmax.out", aminmax_out),
123 ("clamp_min.Tensor", clamp_min),
124 ("clamp_min.Tensor_out", clamp_min_out),
125 ("clamp_min_.Tensor", clamp_min_),
126 ("hypot.out", hypot_out),
127 )
130def _install_autograd_dispatch_patch():
131 import torch
133 from flag_gems.runtime.op_registrar import GeneralOpRegistrar
135 register_cls = GeneralOpRegistrar
137 if getattr(register_cls, "_sunrise_autograd_dispatch_patched", False):
138 return
140 original_register_impl = register_cls.register_impl
141 autograd_key = torch._C.DispatchKey.Autograd.name
142 autograd_ops = frozenset(CUSTOMIZED_AUTOGRAD_OPS)
144 def register_impl(self, key, fn, extra_dispatch_keys=()):
145 if self.device.vendor_name == vendor_info.vendor_name and key in autograd_ops:
146 all_dispatch_keys = list(extra_dispatch_keys)
147 if autograd_key not in all_dispatch_keys:
148 all_dispatch_keys.append(autograd_key)
149 extra_dispatch_keys = tuple(all_dispatch_keys)
151 return original_register_impl(self, key, fn, extra_dispatch_keys)
153 register_cls.register_impl = register_impl
154 register_cls._sunrise_autograd_dispatch_patched = True
155 register_cls._sunrise_original_register_impl = original_register_impl
158def _install_register_config_patch():
159 from flag_gems.runtime.op_registrar import GeneralOpRegistrar
161 register_cls = GeneralOpRegistrar
163 if getattr(register_cls, "_sunrise_config_patched", False):
164 return
166 original_init = register_cls.__init__
168 def _extend_config(config, full_config_by_func):
169 extra_entries = _sunrise_extra_config_entries()
170 existing_keys = {item[0] for item in config}
171 merged_config = tuple(config) + tuple(
172 item for item in extra_entries if item[0] not in existing_keys
173 )
175 if full_config_by_func is None:
176 return merged_config, None
178 merged_map = {key: list(value) for key, value in full_config_by_func.items()}
179 for item in extra_entries:
180 fn = item[1]
181 func_name = fn.__name__ if hasattr(fn, "__name__") else str(fn)
182 merged_map.setdefault(func_name, [])
183 if item not in merged_map[func_name]:
184 merged_map[func_name].append(item)
185 return merged_config, merged_map
187 def __init__(
188 self,
189 config,
190 user_include_ops=None,
191 user_exclude_ops=None,
192 cpp_patched_ops=None,
193 lib=None,
194 full_config_by_func=None,
195 ):
196 config, full_config_by_func = _extend_config(config, full_config_by_func)
197 return original_init(
198 self,
199 config,
200 user_include_ops=user_include_ops,
201 user_exclude_ops=user_exclude_ops,
202 cpp_patched_ops=cpp_patched_ops,
203 lib=lib,
204 full_config_by_func=full_config_by_func,
205 )
207 register_cls.__init__ = __init__
208 register_cls._sunrise_config_patched = True
209 register_cls._sunrise_original_init = original_init
212def _install_typed_ptr_device_patch():
213 from flag_gems.utils.tensor_wrapper import TypedPtr
215 if getattr(TypedPtr, "_sunrise_device_patched", False):
216 return
218 def __init__(self, ptr, dtype, device=None):
219 self.ptr = ptr
220 self.dtype = dtype
221 self.device = device
223 @classmethod
224 def from_tensor(cls, tensor, offset=0):
225 return cls(
226 tensor.data_ptr() + tensor.element_size() * offset,
227 tensor.dtype,
228 tensor.device,
229 )
231 @classmethod
232 def reinterpret_tensor(cls, tensor, dtype, offset=0):
233 return cls(tensor.data_ptr() + dtype.itemsize * offset, dtype, tensor.device)
235 TypedPtr.__init__ = __init__
236 TypedPtr.from_tensor = from_tensor
237 TypedPtr.reinterpret_tensor = reinterpret_tensor
238 TypedPtr._sunrise_device_patched = True
241def _install_pointwise_dynamic_complex_patch():
242 import torch
244 from flag_gems.utils.pointwise_dynamic import ComplexMode, PointwiseDynamicFunction
245 from flag_gems.utils.shape_utils import all_the_same_shape, all_the_same_stride
246 from flag_gems.utils.tensor_wrapper import StridedBuffer
248 if getattr(PointwiseDynamicFunction, "_sunrise_complex_patched", False):
249 return
251 def _tensor_is_contiguous(tensor):
252 if isinstance(tensor, torch.Tensor):
253 return tensor.is_contiguous()
254 expected_stride = 1
255 for size, stride in zip(reversed(tensor.shape), reversed(tensor.stride())):
256 if size == 1:
257 continue
258 if stride != expected_stride:
259 return False
260 expected_stride *= size
261 return True
263 if not hasattr(StridedBuffer, "is_contiguous"):
264 StridedBuffer.is_contiguous = _tensor_is_contiguous
266 def _call_real_impl(self, *args, _skip_tensor_check=False, **kwargs):
267 from flag_gems import runtime
269 if not runtime.device.support_fp64:
270 ptpu_tensor = next(
271 (
272 arg
273 for arg in args
274 if isinstance(arg, torch.Tensor)
275 and arg.device.type == "ptpu"
276 and arg.dtype == torch.float64
277 ),
278 None,
279 )
280 if ptpu_tensor is None:
281 ptpu_tensor = next(
282 (
283 value
284 for value in kwargs.values()
285 if isinstance(value, torch.Tensor)
286 and value.device.type == "ptpu"
287 and value.dtype == torch.float64
288 ),
289 None,
290 )
291 if ptpu_tensor is not None:
292 cpu_args = tuple(
293 arg.cpu() if isinstance(arg, torch.Tensor) else arg for arg in args
294 )
295 cpu_kwargs = {
296 key: (value.cpu() if isinstance(value, torch.Tensor) else value)
297 for key, value in kwargs.items()
298 if not key.startswith("out")
299 }
300 py_fn = getattr(self._scalar_fn, "fn", self._scalar_fn)
301 result = py_fn(*cpu_args, **cpu_kwargs)
302 out = kwargs.get("out0")
303 if out is not None:
304 out.copy_(result.to(out.device))
305 return out
306 if isinstance(result, tuple):
307 return tuple(
308 item.to(ptpu_tensor.device)
309 if isinstance(item, torch.Tensor)
310 else item
311 for item in result
312 )
313 if isinstance(result, torch.Tensor):
314 return result.to(ptpu_tensor.device)
315 return result
317 ndim, args, kwargs = self.prepare_args(
318 *args, _skip_tensor_check=_skip_tensor_check, **kwargs
319 )
320 overload = self.instantiate(ndim)
321 out = overload(*args, **kwargs)
322 return self._unwrap(out)
324 def _is_missing_backend_view_op(self, exc, aten_op):
325 message = str(exc)
326 return (
327 isinstance(exc, NotImplementedError)
328 and aten_op in message
329 and "ptpu" in message
330 )
332 def _complex_real_view_buffer(self, tensor):
333 real_dtype = tensor.dtype.to_real()
334 shape = tuple(tensor.shape) + (2,)
335 strides = tuple(stride * 2 for stride in tensor.stride()) + (1,)
336 return StridedBuffer(tensor, shape=shape, strides=strides, dtype=real_dtype)
338 def _complex_component_buffers(self, tensor):
339 real_dtype = tensor.dtype.to_real()
340 strides = tuple(stride * 2 for stride in tensor.stride())
341 real = StridedBuffer(
342 tensor, shape=tensor.shape, strides=strides, dtype=real_dtype
343 )
344 imag = StridedBuffer(
345 tensor,
346 shape=tensor.shape,
347 strides=strides,
348 dtype=real_dtype,
349 offset=1,
350 )
351 return real, imag
353 def _view_as_real_for_kernel(self, tensor):
354 try:
355 return torch.view_as_real(tensor)
356 except Exception as exc:
357 if self._is_missing_backend_view_op(exc, "aten::view_as_real"):
358 return self._complex_real_view_buffer(tensor)
359 raise
361 def _split_complex_components(self, tensor):
362 try:
363 real_view = torch.view_as_real(tensor)
364 return real_view[..., 0], real_view[..., 1]
365 except Exception as exc:
366 if self._is_missing_backend_view_op(exc, "aten::view_as_real"):
367 return self._complex_component_buffers(tensor)
368 raise
370 def _view_as_complex_result(self, tensor):
371 try:
372 return torch.view_as_complex(tensor.contiguous())
373 except Exception as exc:
374 if self._is_missing_backend_view_op(exc, "aten::view_as_complex"):
375 return torch.view_as_complex(tensor.cpu().contiguous()).to(
376 tensor.device
377 )
378 raise
380 def _cross_components_for_kernel(self, tensor):
381 try:
382 real_view = torch.view_as_real(tensor)
383 except Exception as exc:
384 if not self._is_missing_backend_view_op(exc, "aten::view_as_real"):
385 raise
386 real_view = torch.view_as_real(tensor.cpu()).to(tensor.device)
387 return real_view[..., 0], real_view[..., 1]
389 def _cpu_fallback_value(self, value):
390 if isinstance(value, torch.Tensor):
391 return value.cpu()
392 return value
394 def _should_cpu_fallback_complex(self, result_dtype, device):
395 if device is None or device.type == "cpu":
396 return False
397 if result_dtype != torch.complex128:
398 return False
399 from flag_gems import runtime
401 return not runtime.device.support_fp64
403 def _cpu_fallback_complex_dispatch(self, args, kwargs, device):
404 cpu_args = tuple(self._cpu_fallback_value(arg) for arg in args)
405 out = kwargs.get("out0")
406 py_fn = getattr(self._scalar_fn, "fn", self._scalar_fn)
407 result = py_fn(*cpu_args)
408 if out is not None:
409 out.copy_(result.to(out.device))
410 return out
411 if isinstance(result, torch.Tensor):
412 return result.to(device)
413 if isinstance(result, tuple):
414 return tuple(
415 item.to(device) if isinstance(item, torch.Tensor) else item
416 for item in result
417 )
418 return result
420 def _call_complex_dispatch(self, *args, **kwargs):
421 strategy = self.complex_strategy
422 operands, others = self._split_args(args)
424 device = self._infer_device(operands)
425 result_dtype = self._infer_complex_dtype(operands)
427 if self._should_cpu_fallback_complex(result_dtype, device):
428 return self._cpu_fallback_complex_dispatch(args, kwargs, device)
430 if strategy.tensorize_scalars and strategy.fallback_target is not None:
431 operands = self._tensorize_scalar_operands(operands, result_dtype, device)
432 new_args = self._merge_args(operands, others)
433 return strategy.fallback_target(*new_args, **kwargs)
435 for i in list(operands.keys()):
436 operands[i] = self._to_complex_tensor(operands[i], result_dtype, device)
438 complex_tensors = [operands[i] for i in sorted(operands.keys())]
439 complex_tensors = torch.broadcast_tensors(*complex_tensors)
440 for idx, key in enumerate(sorted(operands.keys())):
441 operands[key] = complex_tensors[idx]
443 classification = self._classify_complex_inputs(operands)
445 if strategy.mode == ComplexMode.CROSS and classification == "all_complex":
446 return self._call_complex_cross(operands, result_dtype)
447 if classification in ("all_complex", "mixed"):
448 return self._call_complex_elementwise(
449 operands, others, result_dtype, kwargs
450 )
451 new_args = self._merge_args(operands, others)
452 return self._call_real_impl(*new_args, **kwargs)
454 def _call_complex_elementwise(self, operands, others, result_dtype, kwargs):
455 real_tensors = {
456 i: self._view_as_real_for_kernel(t) for i, t in operands.items()
457 }
458 out_kwargs = dict(kwargs)
459 out_complex = out_kwargs.get("out0")
460 if out_complex is None:
461 first_operand = operands[sorted(operands.keys())[0]]
462 out_complex = torch.empty(
463 first_operand.shape,
464 dtype=result_dtype,
465 device=first_operand.device,
466 )
467 out_kwargs["out0"] = out_complex
469 out_kwargs["out0"] = self._view_as_real_for_kernel(out_complex)
470 new_args = self._merge_args(real_tensors, others)
471 self._call_real_impl(*new_args, _skip_tensor_check=True, **out_kwargs)
472 return out_complex
474 def _call_complex_cross(self, operands, result_dtype):
475 sorted_keys = sorted(operands.keys())
476 a_tensor, b_tensor = operands[sorted_keys[0]], operands[sorted_keys[1]]
477 ar, ai = self._cross_components_for_kernel(a_tensor)
478 br, bi = self._cross_components_for_kernel(b_tensor)
480 common_dtype = torch.promote_types(ar.dtype, br.dtype)
481 if ar.dtype != common_dtype:
482 ar, ai = ar.to(common_dtype), ai.to(common_dtype)
483 if br.dtype != common_dtype:
484 br, bi = br.to(common_dtype), bi.to(common_dtype)
486 cross_kernel = self.complex_strategy.cross_kernel
487 real, imag = cross_kernel._call_real_impl(
488 ar, ai, br, bi, _skip_tensor_check=True
489 )
490 out = torch.stack((real, imag), dim=-1)
491 return self._view_as_complex_result(out).to(result_dtype)
493 def use_fast_path(tensors):
494 if not all_the_same_shape(tensors):
495 return False
496 if all(_tensor_is_contiguous(tensor) for tensor in tensors):
497 return True
498 return (
499 all(isinstance(tensor, torch.Tensor) for tensor in tensors)
500 and all_the_same_stride(tensors)
501 and torch.ops.aten.is_non_overlapping_and_dense(tensors[0])
502 )
504 PointwiseDynamicFunction._call_real_impl = _call_real_impl
505 PointwiseDynamicFunction._is_missing_backend_view_op = _is_missing_backend_view_op
506 PointwiseDynamicFunction._complex_real_view_buffer = _complex_real_view_buffer
507 PointwiseDynamicFunction._complex_component_buffers = _complex_component_buffers
508 PointwiseDynamicFunction._view_as_real_for_kernel = _view_as_real_for_kernel
509 PointwiseDynamicFunction._split_complex_components = _split_complex_components
510 PointwiseDynamicFunction._view_as_complex_result = _view_as_complex_result
511 PointwiseDynamicFunction._cross_components_for_kernel = _cross_components_for_kernel
512 PointwiseDynamicFunction._cpu_fallback_value = _cpu_fallback_value
513 PointwiseDynamicFunction._should_cpu_fallback_complex = _should_cpu_fallback_complex
514 PointwiseDynamicFunction._cpu_fallback_complex_dispatch = (
515 _cpu_fallback_complex_dispatch
516 )
517 PointwiseDynamicFunction._call_complex_dispatch = _call_complex_dispatch
518 PointwiseDynamicFunction._call_complex_elementwise = _call_complex_elementwise
519 PointwiseDynamicFunction._call_complex_cross = _call_complex_cross
520 PointwiseDynamicFunction.use_fast_path = staticmethod(use_fast_path)
521 PointwiseDynamicFunction._sunrise_complex_patched = True
524def _install_pointwise_dynamic_post_import_hook():
525 import builtins
526 import sys
528 if getattr(builtins, "_sunrise_pointwise_import_hook_installed", False):
529 return
531 original_import = builtins.__import__
533 def maybe_patch():
534 module = sys.modules.get("flag_gems.utils.pointwise_dynamic")
535 if module is None or getattr(module, "_sunrise_complex_patch_attempted", False):
536 return
537 module._sunrise_complex_patch_attempted = True
538 builtins.__import__ = original_import
539 builtins._sunrise_pointwise_import_hook_installed = False
540 _install_pointwise_dynamic_complex_patch()
542 def import_with_sunrise_pointwise_patch(
543 name, globals=None, locals=None, fromlist=(), level=0
544 ):
545 module = original_import(name, globals, locals, fromlist, level)
546 if name == "flag_gems.utils.pointwise_dynamic" or (
547 name == "flag_gems.utils" and "pointwise_dynamic" in fromlist
548 ):
549 maybe_patch()
550 return module
552 builtins.__import__ = import_with_sunrise_pointwise_patch
553 builtins._sunrise_pointwise_import_hook_installed = True
556def _install_ptpu_manual_seed_patch():
557 import torch
559 ptpu_mod = getattr(torch, "ptpu", None)
560 if ptpu_mod is None or getattr(ptpu_mod, "_sunrise_manual_seed_patched", False):
561 return
563 def _is_in_bad_fork():
564 return False
566 def manual_seed_all(seed):
567 seed = int(seed) & 0xFFFFFFFFFFFFFFFF
569 # PTPU exposes RNG state get/set but not a Python seed API. The runtime
570 # state is a 16-byte blob where the low 8 bytes act as the seed and the
571 # high 8 bytes can be reset to zero for a fresh sequence start.
572 # `torch.manual_seed()` can be called under `with torch.device("ptpu")`,
573 # so build the state explicitly on CPU for `torch.ptpu.set_rng_state()`.
574 state = torch.zeros(16, dtype=torch.uint8, device="cpu")
575 for i in range(8):
576 state[i] = (seed >> (8 * i)) & 0xFF
578 for device_idx in range(ptpu_mod.device_count()):
579 ptpu_mod.set_rng_state(state, device_idx)
581 ptpu_mod._is_in_bad_fork = _is_in_bad_fork
582 ptpu_mod.manual_seed_all = manual_seed_all
583 ptpu_mod._sunrise_manual_seed_patched = True
586def _install_ptpu_default_generators_patch():
587 import torch
589 ptpu_mod = getattr(torch, "ptpu", None)
590 if (
591 ptpu_mod is None
592 or hasattr(ptpu_mod, "default_generators")
593 or getattr(ptpu_mod, "_sunrise_default_generators_patched", False)
594 ):
595 return
597 class _SunrisePtpuGenerator:
598 def __init__(self, device_idx):
599 self.device_idx = int(device_idx)
600 self.device = torch.device("ptpu", self.device_idx)
602 def get_state(self):
603 return ptpu_mod.get_rng_state(self.device_idx).detach().cpu().clone()
605 def set_state(self, state):
606 if not isinstance(state, torch.Tensor):
607 raise TypeError("PTPU RNG state must be a torch.Tensor")
608 if state.dtype != torch.uint8:
609 raise TypeError("PTPU RNG state must be a torch.uint8 tensor")
610 ptpu_mod.set_rng_state(state.detach().cpu().contiguous(), self.device_idx)
612 def manual_seed(self, seed):
613 seed = int(seed) & 0xFFFFFFFFFFFFFFFF
614 state = torch.zeros(16, dtype=torch.uint8, device="cpu")
615 for i in range(8):
616 state[i] = (seed >> (8 * i)) & 0xFF
617 self.set_state(state)
618 return self
620 class _SunrisePtpuDefaultGenerators:
621 def __init__(self):
622 self._generators = {}
624 def _normalize_device(self, device):
625 if device is None:
626 return int(ptpu_mod.current_device())
627 if isinstance(device, torch.device):
628 if device.type != "ptpu":
629 raise RuntimeError(f"Expected a ptpu device, got {device}")
630 return (
631 int(ptpu_mod.current_device())
632 if device.index is None
633 else int(device.index)
634 )
635 if isinstance(device, str):
636 return self._normalize_device(torch.device(device))
637 return int(device)
639 def __getitem__(self, device):
640 device_idx = self._normalize_device(device)
641 device_count = int(ptpu_mod.device_count())
642 if device_idx < 0:
643 device_idx += device_count
644 if device_idx < 0 or device_idx >= device_count:
645 raise IndexError(
646 f"PTPU device index {device_idx} is out of range "
647 f"for {device_count} devices"
648 )
649 if device_idx not in self._generators:
650 self._generators[device_idx] = _SunrisePtpuGenerator(device_idx)
651 return self._generators[device_idx]
653 def __iter__(self):
654 for device_idx in range(len(self)):
655 yield self[device_idx]
657 def __len__(self):
658 return int(ptpu_mod.device_count())
660 ptpu_mod.default_generators = _SunrisePtpuDefaultGenerators()
661 ptpu_mod._sunrise_default_generators_patched = True
664def _install_ptpu_multiprocessing_reduction_patch():
665 import multiprocessing.reduction as mp_reduction
667 import torch
668 import torch.multiprocessing.reductions as reductions
669 from torch.nn.parameter import Parameter
671 if getattr(reductions, "_sunrise_ptpu_reduce_tensor_patched", False):
672 return
674 original_reduce_tensor = reductions.reduce_tensor
676 # Keep this in `_sunrise/__init__.py` instead of `monkey_patch.py` because
677 # multiprocessing reducers are registered eagerly in Python's global
678 # pickling table. This is import-time runtime wiring, not a call-site-level
679 # torch API fallback that can be caught and retried after a NotImplemented.
680 def reduce_tensor_with_ptpu_cpu_staging(tensor):
681 if not _should_stage_ptpu_tensor_for_multiprocessing(tensor):
682 return original_reduce_tensor(tensor)
684 if tensor.requires_grad and not tensor.is_leaf:
685 raise RuntimeError(
686 "Cowardly refusing to serialize non-leaf tensor which requires_grad, "
687 "since autograd does not support crossing process boundaries. "
688 "If you just want to transfer the data, call detach() on the tensor "
689 "before serializing (e.g., putting it on the queue)."
690 )
692 reductions.check_serializing_named_tensor(tensor)
693 torch.utils.hooks.warn_if_has_hooks(tensor)
695 return (
696 reductions._sunrise_rebuild_ptpu_tensor_from_cpu,
697 (
698 type(tensor),
699 tensor.detach().cpu(),
700 tensor.device,
701 tensor.requires_grad,
702 ),
703 )
705 reductions._sunrise_rebuild_ptpu_tensor_from_cpu = (
706 _sunrise_rebuild_ptpu_tensor_from_cpu
707 )
708 reductions._sunrise_original_reduce_tensor = original_reduce_tensor
709 reductions.reduce_tensor = reduce_tensor_with_ptpu_cpu_staging
710 for tensor_cls in torch._tensor_classes:
711 mp_reduction.register(tensor_cls, reduce_tensor_with_ptpu_cpu_staging)
712 mp_reduction.register(torch.Tensor, reduce_tensor_with_ptpu_cpu_staging)
713 mp_reduction.register(Parameter, reduce_tensor_with_ptpu_cpu_staging)
714 reductions._sunrise_ptpu_reduce_tensor_patched = True
717_install_ptpu_default_generators_patch()
718_install_ptpu_manual_seed_patch()
719_install_autograd_dispatch_patch()
720_install_register_config_patch() # 有些公共库也没有注册的op,只能先放在这里了。使得tests能过
721_install_pointwise_dynamic_post_import_hook()
724__all__ = ["*"]