Coverage for src/flag_gems/ops/copy.py: 71%
65 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
2from typing import Optional
4import torch
5import triton
7from flag_gems.utils import pointwise_dynamic
9logger = logging.getLogger(__name__)
11_FALLBACK_KEYSET = torch._C.DispatchKeySet(
12 torch._C.DispatchKey.CompositeExplicitAutograd
13)
15_FLOAT8_E8M0FNU = getattr(torch, "float8_e8m0fnu", None)
18@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "DEFAULT")])
19@triton.jit
20def _copy_kernel(src):
21 return src
24def _can_use_triton(dst: torch.Tensor, src: torch.Tensor) -> bool:
25 if dst.layout != torch.strided or src.layout != torch.strided:
26 return False
27 if dst.device != src.device:
28 return False
29 if dst.is_quantized or src.is_quantized:
30 return False
31 if src.is_complex() or dst.is_complex():
32 # Preserve PyTorch's behaviour of warning when casting complex to real
33 # by forcing the redispatch path, which issues the warning internally.
34 return False
35 if _FLOAT8_E8M0FNU is not None and (
36 src.dtype == _FLOAT8_E8M0FNU or dst.dtype == _FLOAT8_E8M0FNU
37 ):
38 # Triton does not support float8 yet, so defer to PyTorch which has a reference implementation.
39 return False
40 return True
43def _expand_like(src: torch.Tensor, target_shape: torch.Size) -> torch.Tensor:
44 if src.shape == target_shape:
45 return src
46 return src.expand(target_shape)
49def copy(
50 template: torch.Tensor, src: torch.Tensor, *, non_blocking: Optional[bool] = False
51):
52 logger.debug("GEMS COPY (functional)")
53 out = torch.empty_strided(
54 template.size(), template.stride(), dtype=template.dtype, device=template.device
55 )
56 copy_(out, src, non_blocking=bool(non_blocking))
57 return out
60def copy_(dst: torch.Tensor, src: torch.Tensor, non_blocking: bool = False):
61 if isinstance(src, (int, float, bool)):
62 src = torch.tensor(src, device=dst.device)
63 elif not isinstance(src, torch.Tensor):
64 raise TypeError("unsupport src type for copy_: ", type(src))
66 # this is the same as PyTorch's check
67 if dst._is_zerotensor():
68 raise RuntimeError("ZeroTensors are immutable. Call clone() before copy_.")
69 if src._is_zerotensor():
70 return dst.zero_()
72 if torch._C._is_alias_of(dst, src):
73 # Align with PyTorch: if metadata fully matches, this is a no-op.
74 if (
75 dst.storage_offset() == src.storage_offset()
76 and dst.stride() == src.stride()
77 and dst.size() == src.size()
78 and dst.dtype == src.dtype
79 and dst.device == src.device
80 and dst.is_conj() == src.is_conj()
81 and dst.is_neg() == src.is_neg()
82 ):
83 return dst
84 # Otherwise defer to PyTorch for well-defined semantics on overlapping writes.
85 return torch.ops.aten.copy_.default.redispatch(
86 _FALLBACK_KEYSET, dst, src, non_blocking
87 )
89 if _FLOAT8_E8M0FNU is not None and (
90 src.dtype == _FLOAT8_E8M0FNU or dst.dtype == _FLOAT8_E8M0FNU
91 ):
92 return torch.ops.aten.copy_.default.redispatch(
93 _FALLBACK_KEYSET, dst, src, non_blocking
94 )
96 if src.numel() > 2**31 - 1 or dst.numel() > 2**31 - 1:
97 return torch.ops.aten.copy_.default.redispatch(
98 _FALLBACK_KEYSET, dst, src, non_blocking
99 )
101 if not _can_use_triton(dst, src):
102 return torch.ops.aten.copy_.default.redispatch(
103 _FALLBACK_KEYSET, dst, src, non_blocking
104 )
106 if dst.numel() == 0:
107 # Respect PyTorch behaviour: empty tensors should still validate broadcast.
108 return torch.ops.aten.copy_.default.redispatch(
109 _FALLBACK_KEYSET, dst, src, non_blocking
110 )
112 logger.debug("GEMS COPY_")
114 try:
115 broadcast_shape = torch.broadcast_shapes(dst.shape, src.shape)
116 except RuntimeError as exc:
117 raise RuntimeError(str(exc)) from exc
119 if torch.Size(broadcast_shape) != dst.shape:
120 raise RuntimeError(
121 f"The broadcast shape {broadcast_shape} does not match destination shape {tuple(dst.shape)}"
122 )
124 expanded_src = _expand_like(src, dst.shape)
126 overload = _copy_kernel.instantiate(expanded_src.ndim)
127 overload(expanded_src, out0=dst)
128 return dst