Coverage for src/flag_gems/ops/to.py: 87%
52 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
2from typing import Optional
4import torch
5import triton
7from flag_gems.utils import pointwise_dynamic
9logger = logging.getLogger(__name__)
11_FALLBACK_KEYSET = torch._C.DispatchKeySet(
12 torch._C.DispatchKey.CompositeExplicitAutograd
13)
15# Check if float8_e8m0fnu dtype is available in current PyTorch version
16_FLOAT8_E8M0FNU = getattr(torch, "float8_e8m0fnu", None)
19@pointwise_dynamic(
20 is_tensor=[
21 True,
22 ],
23 promotion_methods=[(0, "DEFAULT")],
24)
25@triton.jit
26def _to_copy_func(x):
27 return x
30def _resolve_dtype(x: torch.Tensor, dtype: Optional[torch.dtype]) -> torch.dtype:
31 if dtype is None:
32 return x.dtype
33 if isinstance(dtype, torch.dtype):
34 return dtype
35 raise TypeError(f"Unsupported dtype argument type: {type(dtype)!r}")
38def _resolve_device(x: torch.Tensor, device: Optional[torch.device]) -> torch.device:
39 if device is None:
40 return x.device
41 return torch.device(device)
44def _normalize_memory_format(
45 memory_format: Optional[torch.memory_format],
46) -> torch.memory_format:
47 if memory_format is None:
48 return torch.preserve_format
49 return memory_format
52def _allocate_preserve_format(x: torch.Tensor, empty_kwargs: dict) -> torch.Tensor:
53 """Recreate tensor storage while honoring preserve_format semantics."""
54 if torch.ops.aten.is_non_overlapping_and_dense(x):
55 return torch.empty_strided(x.size(), x.stride(), **empty_kwargs)
56 # Fall back to PyTorch's best-effort layout suggestion when stride replication is unsafe.
57 return torch.empty_like(x, memory_format=torch.preserve_format, **empty_kwargs)
60# func: _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None,
61# bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor
62def to_copy(
63 x,
64 *,
65 dtype=None,
66 layout=None,
67 device=None,
68 pin_memory=None,
69 non_blocking=False,
70 memory_format=None,
71):
72 # We only implement the dense strided kernel today; all other layouts fall back to PyTorch.
73 if (layout is not None and layout != torch.strided) or x.layout != torch.strided:
74 raise NotImplementedError(
75 "FlagGems to_copy currently supports strided tensors only."
76 )
77 if pin_memory is not None:
78 raise NotImplementedError(
79 "FlagGems to_copy does not yet support pin_memory=True."
80 )
81 if x.is_quantized:
82 raise NotImplementedError(
83 "Quantized tensors are not supported in FlagGems to_copy yet."
84 )
86 target_dtype = _resolve_dtype(x, dtype)
87 target_device = _resolve_device(x, device)
88 target_memory_format = _normalize_memory_format(memory_format)
90 # Triton does not support complex dtypes; fall back to PyTorch.
91 if x.dtype.is_complex or target_dtype.is_complex:
92 return torch.ops.aten._to_copy.default.redispatch(
93 _FALLBACK_KEYSET,
94 x,
95 dtype=target_dtype,
96 layout=layout,
97 device=target_device,
98 pin_memory=pin_memory,
99 non_blocking=non_blocking,
100 memory_format=target_memory_format,
101 )
103 # Triton does not support float8_e8m0fnu dtypes; fall back to PyTorch.
104 if _FLOAT8_E8M0FNU is not None and (
105 x.dtype == torch.float8_e8m0fnu or target_dtype == torch.float8_e8m0fnu
106 ):
107 return torch.ops.aten._to_copy.default.redispatch(
108 _FALLBACK_KEYSET,
109 x,
110 dtype=target_dtype,
111 layout=layout,
112 device=target_device,
113 pin_memory=pin_memory,
114 non_blocking=non_blocking,
115 memory_format=target_memory_format,
116 )
118 if target_device != x.device or (
119 x.device.type == "cpu" and target_device.type == "cpu"
120 ):
121 # Device transfer (d2h/h2d etc.) relies on PyTorch's implementation.
122 return torch.ops.aten._to_copy.default.redispatch(
123 _FALLBACK_KEYSET,
124 x,
125 dtype=target_dtype,
126 layout=layout,
127 device=target_device,
128 pin_memory=pin_memory,
129 non_blocking=non_blocking,
130 memory_format=target_memory_format,
131 )
133 logger.debug("GEMS TO_COPY")
134 empty_kwargs = {"dtype": target_dtype, "device": target_device}
136 if target_memory_format is torch.preserve_format:
137 out = _allocate_preserve_format(x, empty_kwargs)
138 else:
139 out = torch.empty_like(x, memory_format=target_memory_format, **empty_kwargs)
141 return _to_copy_func(x, out0=out)