Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/to.py: 0%
69 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
2import os
3from typing import Optional
5import torch
6import triton
7from _kunlunxin.utils.codegen_config_utils import CodeGenConfig
9from ..utils.pointwise_dynamic import pointwise_dynamic
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13_FALLBACK_KEYSET = torch._C.DispatchKeySet(
14 torch._C.DispatchKey.CompositeExplicitAutograd
15)
18@pointwise_dynamic(
19 is_tensor=[
20 True,
21 ],
22 promotion_methods=[(0, "DEFAULT")],
23)
24@triton.jit
25def _to_copy_func(x):
26 return x
29close_interleave_config = CodeGenConfig(
30 512,
31 (65536, 65536, 65536),
32 32,
33 True,
34 prefer_1d_tile=True,
35 isCloseInterleave=True,
36)
39@pointwise_dynamic(
40 is_tensor=[
41 True,
42 ],
43 promotion_methods=[(0, "DEFAULT")],
44 config=close_interleave_config,
45)
46@triton.jit
47def _to_copy_func_close_interleave(x):
48 return x
51def _resolve_dtype(x: torch.Tensor, dtype: Optional[torch.dtype]) -> torch.dtype:
52 if dtype is None:
53 return x.dtype
54 if isinstance(dtype, torch.dtype):
55 return dtype
56 raise TypeError(f"Unsupported dtype argument type: {type(dtype)!r}")
59def _resolve_device(x: torch.Tensor, device: Optional[torch.device]) -> torch.device:
60 if device is None:
61 return x.device
62 return torch.device(device)
65def _normalize_memory_format(
66 memory_format: Optional[torch.memory_format],
67) -> torch.memory_format:
68 if memory_format is None:
69 return torch.preserve_format
70 return memory_format
73def _allocate_preserve_format(x: torch.Tensor, empty_kwargs: dict) -> torch.Tensor:
74 """Recreate tensor storage while honoring preserve_format semantics."""
75 if torch.ops.aten.is_non_overlapping_and_dense(x):
76 return torch.empty_strided(x.size(), x.stride(), **empty_kwargs)
77 # Fall back to PyTorch's best-effort layout suggestion when stride replication is unsafe.
78 return torch.empty_like(x, memory_format=torch.preserve_format, **empty_kwargs)
81# func: _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None,
82# bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor
83def to_copy(
84 x,
85 *,
86 dtype=None,
87 layout=None,
88 device=None,
89 pin_memory=None,
90 non_blocking=False,
91 memory_format=None,
92):
93 if x.dtype == torch.bfloat16:
94 to_dtype_fn = _to_copy_func_close_interleave
95 else:
96 to_dtype_fn = _to_copy_func
98 # We only implement the dense strided kernel today; all other layouts fall back to PyTorch.
99 if (layout is not None and layout != torch.strided) or x.layout != torch.strided:
100 raise NotImplementedError(
101 "FlagGems to_copy currently supports strided tensors only."
102 )
103 if pin_memory is not None:
104 raise NotImplementedError(
105 "FlagGems to_copy does not yet support pin_memory=True."
106 )
107 if x.is_quantized:
108 raise NotImplementedError(
109 "Quantized tensors are not supported in FlagGems to_copy yet."
110 )
112 target_dtype = _resolve_dtype(x, dtype)
113 target_device = _resolve_device(x, device)
114 target_memory_format = _normalize_memory_format(memory_format)
116 # Triton on kunlunxin does not support complex dtypes; fall back to PyTorch.
117 if x.dtype.is_complex or target_dtype.is_complex:
118 return torch.ops.aten._to_copy.default.redispatch(
119 _FALLBACK_KEYSET,
120 x,
121 dtype=target_dtype,
122 layout=layout,
123 device=target_device,
124 pin_memory=pin_memory,
125 non_blocking=non_blocking,
126 memory_format=target_memory_format,
127 )
129 if target_device != x.device or (
130 x.device.type == "cpu" and target_device.type == "cpu"
131 ):
132 # Device transfer (d2h/h2d etc.) relies on PyTorch's implementation.
133 return torch.ops.aten._to_copy.default.redispatch(
134 _FALLBACK_KEYSET,
135 x,
136 dtype=target_dtype,
137 layout=layout,
138 device=target_device,
139 pin_memory=pin_memory,
140 non_blocking=non_blocking,
141 memory_format=target_memory_format,
142 )
144 logger.debug("GEMS _TO_COPY")
145 empty_kwargs = {"dtype": target_dtype, "device": target_device}
147 if target_memory_format is torch.preserve_format:
148 out = _allocate_preserve_format(x, empty_kwargs)
149 else:
150 out = torch.empty_like(x, memory_format=target_memory_format, **empty_kwargs)
152 out = torch.empty_like(x, dtype=dtype, memory_format=memory_format)
153 if out.element_size() == 8:
154 os.environ["TRITONXPU_ELEMBYTES"] = "8"
155 os.environ["TRITONXPU_BF16_FAST"] = "1"
156 res = to_dtype_fn(x, out0=out)
157 del os.environ["TRITONXPU_ELEMBYTES"]
158 del os.environ["TRITONXPU_BF16_FAST"]
159 else:
160 os.environ["TRITONXPU_BF16_FAST"] = "1"
161 res = to_dtype_fn(x, out0=out)
162 del os.environ["TRITONXPU_BF16_FAST"]
163 return res