Coverage for src/flag_gems/runtime/backend/_sunrise/ops/to.py: 0%
58 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
2from typing import Optional
4import torch
5import triton
7from flag_gems.utils import pointwise_dynamic
9logger = logging.getLogger(__name__)
11_FALLBACK_KEYSET = torch._C.DispatchKeySet(
12 torch._C.DispatchKey.CompositeExplicitAutograd
13)
16@pointwise_dynamic(
17 is_tensor=[
18 True,
19 ],
20 promotion_methods=[(0, "DEFAULT")],
21)
22@triton.jit
23def _to_copy_func(x):
24 return x
27def _resolve_dtype(x: torch.Tensor, dtype: Optional[torch.dtype]) -> torch.dtype:
28 if dtype is None:
29 return x.dtype
30 if isinstance(dtype, torch.dtype):
31 return dtype
32 raise TypeError(f"Unsupported dtype argument type: {type(dtype)!r}")
35def _resolve_device(x: torch.Tensor, device: Optional[torch.device]) -> torch.device:
36 if device is None:
37 return x.device
38 return torch.device(device)
41def _normalize_memory_format(
42 memory_format: Optional[torch.memory_format],
43) -> torch.memory_format:
44 if memory_format is None:
45 return torch.preserve_format
46 return memory_format
49def _allocate_preserve_format(x: torch.Tensor, empty_kwargs: dict) -> torch.Tensor:
50 """Recreate tensor storage while honoring preserve_format semantics."""
51 if torch.ops.aten.is_non_overlapping_and_dense(x):
52 return torch.empty_strided(x.size(), x.stride(), **empty_kwargs)
53 # Fall back to PyTorch's best-effort layout suggestion when stride replication is unsafe.
54 return torch.empty_like(x, memory_format=torch.preserve_format, **empty_kwargs)
57def _fallback_to_copy(
58 x: torch.Tensor,
59 *,
60 dtype: torch.dtype,
61 layout,
62 device: torch.device,
63 pin_memory,
64 non_blocking: bool,
65 memory_format: torch.memory_format,
66):
67 return torch.ops.aten._to_copy.default.redispatch(
68 _FALLBACK_KEYSET,
69 x,
70 dtype=dtype,
71 layout=layout,
72 device=device,
73 pin_memory=pin_memory,
74 non_blocking=non_blocking,
75 memory_format=memory_format,
76 )
79# func: _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None,
80# bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor
81def to_copy(
82 x,
83 *,
84 dtype=None,
85 layout=None,
86 device=None,
87 pin_memory=None,
88 non_blocking=False,
89 memory_format=None,
90):
91 # We only implement the dense strided kernel today; all other layouts fall back to PyTorch.
92 if (layout is not None and layout != torch.strided) or x.layout != torch.strided:
93 raise NotImplementedError(
94 "FlagGems to_copy currently supports strided tensors only."
95 )
96 if pin_memory is not None:
97 raise NotImplementedError(
98 "FlagGems to_copy does not yet support pin_memory=True."
99 )
100 if x.is_quantized:
101 raise NotImplementedError(
102 "Quantized tensors are not supported in FlagGems to_copy yet."
103 )
105 target_dtype = _resolve_dtype(x, dtype)
106 target_device = _resolve_device(x, device)
107 target_memory_format = _normalize_memory_format(memory_format)
109 # PTPU can hold complex tensors, but the same-device ptpu cast path drops into
110 # a backend copy_ implementation that does not handle complex. Stage through CPU
111 # to avoid ptpu complex copy_/view_as_real gaps.
112 if x.dtype.is_complex or target_dtype.is_complex:
113 logger.debug("GEMS_SUNRISE _TO_COPY COMPLEX VIA CPU")
114 cpu_x = x
115 if x.device.type != "cpu":
116 cpu_x = _fallback_to_copy(
117 x,
118 dtype=x.dtype,
119 layout=layout,
120 device=torch.device("cpu"),
121 pin_memory=pin_memory,
122 non_blocking=non_blocking,
123 memory_format=target_memory_format,
124 )
125 cpu_res = _fallback_to_copy(
126 cpu_x,
127 dtype=target_dtype,
128 layout=layout,
129 device=torch.device("cpu"),
130 pin_memory=pin_memory,
131 non_blocking=non_blocking,
132 memory_format=target_memory_format,
133 )
134 if target_device.type == "cpu":
135 return cpu_res
136 return _fallback_to_copy(
137 cpu_res,
138 dtype=target_dtype,
139 layout=layout,
140 device=target_device,
141 pin_memory=pin_memory,
142 non_blocking=non_blocking,
143 memory_format=target_memory_format,
144 )
146 if target_device != x.device or (
147 x.device.type == "cpu" and target_device.type == "cpu"
148 ):
149 # Device transfer (d2h/h2d etc.) relies on PyTorch's implementation.
150 return _fallback_to_copy(
151 x,
152 dtype=target_dtype,
153 layout=layout,
154 device=target_device,
155 pin_memory=pin_memory,
156 non_blocking=non_blocking,
157 memory_format=target_memory_format,
158 )
160 logger.debug("GEMS_SUNRISE _TO_COPY")
161 empty_kwargs = {"dtype": target_dtype, "device": target_device}
163 if target_memory_format is torch.preserve_format:
164 out = _allocate_preserve_format(x, empty_kwargs)
165 else:
166 out = torch.empty_like(x, memory_format=target_memory_format, **empty_kwargs)
168 return _to_copy_func(x, out0=out)