Coverage for src/flag_gems/runtime/backend/_sunrise/ops/to.py: 0%
58 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
2from typing import Optional
4import torch
5import triton
7from flag_gems.utils import pointwise_dynamic
9logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
11_FALLBACK_KEYSET = torch._C.DispatchKeySet(
12 torch._C.DispatchKey.CompositeExplicitAutograd
13)
16@pointwise_dynamic(
17 is_tensor=[
18 True,
19 ],
20 promotion_methods=[(0, "DEFAULT")],
21)
22@triton.jit
23def _to_copy_func(x):
24 return x
27def _resolve_dtype(x: torch.Tensor, dtype: Optional[torch.dtype]) -> torch.dtype:
28 if dtype is None:
29 return x.dtype
30 if isinstance(dtype, torch.dtype):
31 return dtype
32 raise TypeError(f"Unsupported dtype argument type: {type(dtype)!r}")
35def _resolve_device(x: torch.Tensor, device: Optional[torch.device]) -> torch.device:
36 if device is None:
37 return x.device
38 return torch.device(device)
41def _normalize_memory_format(
42 memory_format: Optional[torch.memory_format],
43) -> torch.memory_format:
44 if memory_format is None:
45 return torch.preserve_format
46 return memory_format
49def _allocate_preserve_format(x: torch.Tensor, empty_kwargs: dict) -> torch.Tensor:
50 if torch.ops.aten.is_non_overlapping_and_dense(x):
51 return torch.empty_strided(x.size(), x.stride(), **empty_kwargs)
52 return torch.empty_like(x, memory_format=torch.preserve_format, **empty_kwargs)
55def _fallback_to_copy(
56 x: torch.Tensor,
57 *,
58 dtype: torch.dtype,
59 layout,
60 device: torch.device,
61 pin_memory,
62 non_blocking: bool,
63 memory_format: torch.memory_format,
64):
65 return torch.ops.aten._to_copy.default.redispatch(
66 _FALLBACK_KEYSET,
67 x,
68 dtype=dtype,
69 layout=layout,
70 device=device,
71 pin_memory=pin_memory,
72 non_blocking=non_blocking,
73 memory_format=memory_format,
74 )
77# func: _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None,
78# bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor
79def to_copy(
80 x,
81 *,
82 dtype=None,
83 layout=None,
84 device=None,
85 pin_memory=None,
86 non_blocking=False,
87 memory_format=None,
88):
89 if (layout is not None and layout != torch.strided) or x.layout != torch.strided:
90 raise NotImplementedError(
91 "FlagGems to_copy currently supports strided tensors only."
92 )
93 if pin_memory is not None:
94 raise NotImplementedError(
95 "FlagGems to_copy does not yet support pin_memory=True."
96 )
97 if x.is_quantized:
98 raise NotImplementedError(
99 "Quantized tensors are not supported in FlagGems to_copy yet."
100 )
102 target_dtype = _resolve_dtype(x, dtype)
103 target_device = _resolve_device(x, device)
104 target_memory_format = _normalize_memory_format(memory_format)
106 # PTPU can hold complex tensors, but the same-device ptpu cast path drops into
107 # a backend copy_ implementation that does not handle complex. Stage through CPU
108 # to avoid ptpu complex copy_/view_as_real gaps.
109 if x.dtype.is_complex or target_dtype.is_complex:
110 logger.debug("GEMS_SUNRISE _TO_COPY COMPLEX VIA CPU")
111 cpu_x = x
112 if x.device.type != "cpu":
113 cpu_x = _fallback_to_copy(
114 x,
115 dtype=x.dtype,
116 layout=layout,
117 device=torch.device("cpu"),
118 pin_memory=pin_memory,
119 non_blocking=non_blocking,
120 memory_format=target_memory_format,
121 )
122 cpu_res = _fallback_to_copy(
123 cpu_x,
124 dtype=target_dtype,
125 layout=layout,
126 device=torch.device("cpu"),
127 pin_memory=pin_memory,
128 non_blocking=non_blocking,
129 memory_format=target_memory_format,
130 )
131 if target_device.type == "cpu":
132 return cpu_res
133 return _fallback_to_copy(
134 cpu_res,
135 dtype=target_dtype,
136 layout=layout,
137 device=target_device,
138 pin_memory=pin_memory,
139 non_blocking=non_blocking,
140 memory_format=target_memory_format,
141 )
143 if target_device != x.device or (
144 x.device.type == "cpu" and target_device.type == "cpu"
145 ):
146 return _fallback_to_copy(
147 x,
148 dtype=target_dtype,
149 layout=layout,
150 device=target_device,
151 pin_memory=pin_memory,
152 non_blocking=non_blocking,
153 memory_format=target_memory_format,
154 )
156 logger.debug("GEMS_SUNRISE _TO_COPY")
157 empty_kwargs = {"dtype": target_dtype, "device": target_device}
159 if target_memory_format is torch.preserve_format:
160 out = _allocate_preserve_format(x, empty_kwargs)
161 else:
162 out = torch.empty_like(x, memory_format=target_memory_format, **empty_kwargs)
164 return _to_copy_func(x, out0=out)