Coverage for src/flag_gems/ops/as_strided_copy.py: 65%
161 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import pointwise_dynamic
8from flag_gems.utils.shape_utils import MemOverlap, has_internal_overlapping
10logger = logging.getLogger(__name__)
12_FALLBACK_KEYSET = torch._C.DispatchKeySet(
13 torch._C.DispatchKey.CompositeExplicitAutograd
14)
15_MAX_TRITON_ELEMENTS = torch.iinfo(torch.int32).max
16_BLOCK_SIZE = 512
17_BLOCK_M = 16
18_BLOCK_N = 16
21@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "DEFAULT")])
22@triton.jit
23def _as_strided_copy_kernel(x):
24 return x
27@triton.jit
28def _as_strided_copy_1d_kernel(
29 input,
30 out,
31 input_stride_0,
32 out_stride_0,
33 n_elements,
34 BLOCK_SIZE: tl.constexpr,
35):
36 offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
37 mask = offsets < n_elements
38 offsets = offsets.to(tl.int64)
39 values = tl.load(input + offsets * input_stride_0, mask=mask)
40 tl.store(out + offsets * out_stride_0, values, mask=mask)
43@triton.jit
44def _as_strided_copy_2d_kernel(
45 input,
46 out,
47 input_stride_0,
48 input_stride_1,
49 out_stride_0,
50 out_stride_1,
51 dim_0,
52 dim_1,
53 BLOCK_M: tl.constexpr,
54 BLOCK_N: tl.constexpr,
55):
56 offsets_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
57 offsets_n = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)
58 offsets_m = offsets_m.to(tl.int64)[:, None]
59 offsets_n = offsets_n.to(tl.int64)[None, :]
60 mask = (offsets_m < dim_0) & (offsets_n < dim_1)
61 input_offsets = offsets_m * input_stride_0 + offsets_n * input_stride_1
62 out_offsets = offsets_m * out_stride_0 + offsets_n * out_stride_1
63 values = tl.load(input + input_offsets, mask=mask)
64 tl.store(out + out_offsets, values, mask=mask)
67@triton.jit
68def _as_strided_copy_3d_kernel(
69 input,
70 out,
71 input_stride_0,
72 input_stride_1,
73 input_stride_2,
74 out_stride_0,
75 out_stride_1,
76 out_stride_2,
77 dim_1,
78 dim_2,
79 n_elements,
80 BLOCK_SIZE: tl.constexpr,
81):
82 offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
83 mask = offsets < n_elements
84 offsets = offsets.to(tl.int64)
85 index_2 = offsets % dim_2
86 tmp = offsets // dim_2
87 index_1 = tmp % dim_1
88 index_0 = tmp // dim_1
89 input_offsets = (
90 index_0 * input_stride_0 + index_1 * input_stride_1 + index_2 * input_stride_2
91 )
92 out_offsets = (
93 index_0 * out_stride_0 + index_1 * out_stride_1 + index_2 * out_stride_2
94 )
95 values = tl.load(input + input_offsets, mask=mask)
96 tl.store(out + out_offsets, values, mask=mask)
99def _is_float8(dtype: torch.dtype) -> bool:
100 return str(dtype).startswith("torch.float8_")
103def _has_lazy_metadata(tensor: torch.Tensor) -> bool:
104 is_neg = getattr(tensor, "is_neg", lambda: False)
105 return tensor.is_conj() or is_neg()
108def _make_as_strided_view(
109 input: torch.Tensor,
110 size,
111 stride,
112 storage_offset,
113) -> torch.Tensor:
114 # Reuse PyTorch's view construction to match its validation and None-offset semantics.
115 if storage_offset is None:
116 return torch.as_strided(input, size, stride)
117 return torch.as_strided(input, size, stride, storage_offset)
120def _native_copy_(out: torch.Tensor, src: torch.Tensor):
121 return torch.ops.aten.copy_.default.redispatch(_FALLBACK_KEYSET, out, src, False)
124def _fallback_as_strided_copy(input, size, stride, storage_offset=None):
125 view = _make_as_strided_view(input, size, stride, storage_offset)
126 out = torch.empty(tuple(size), dtype=input.dtype, device=input.device)
127 if out.numel() != 0:
128 # Call native copy_ directly so unsupported CUDA dtypes do not re-enter
129 # FlagGems copy kernels through the composite as_strided_copy fallback.
130 _native_copy_(out, view)
131 return out
134def _fallback_as_strided_copy_out(input, size, stride, storage_offset=None, *, out):
135 view = _make_as_strided_view(input, size, stride, storage_offset)
136 if (
137 torch._C._is_alias_of(input, out)
138 or has_internal_overlapping(out) != MemOverlap.No
139 ):
140 temp = torch.empty(tuple(size), dtype=input.dtype, device=input.device)
141 if temp.numel() != 0:
142 _native_copy_(temp, view)
143 view = temp
144 _native_copy_(out, view)
145 return out
148def _can_use_triton(input: torch.Tensor, out: torch.Tensor) -> bool:
149 if input.layout != torch.strided or out.layout != torch.strided:
150 return False
151 if input.device != out.device or input.dtype != out.dtype:
152 return False
153 if input.is_quantized or out.is_quantized:
154 return False
155 if input.is_complex() or _is_float8(input.dtype):
156 return False
157 if out.numel() > _MAX_TRITON_ELEMENTS:
158 return False
159 return True
162def _can_use_byte_triton(input: torch.Tensor, out: torch.Tensor) -> bool:
163 if input.layout != torch.strided or out.layout != torch.strided:
164 return False
165 if input.device != out.device or input.dtype != out.dtype:
166 return False
167 if not _is_float8(input.dtype):
168 return False
169 if input.element_size() != 1 or out.element_size() != 1:
170 return False
171 if _has_lazy_metadata(input) or _has_lazy_metadata(out):
172 return False
173 if out.numel() > _MAX_TRITON_ELEMENTS:
174 return False
175 return True
178def _launch_as_strided_copy(view: torch.Tensor, out: torch.Tensor):
179 dim = view.dim()
180 if dim == 0:
181 _as_strided_copy_1d_kernel[(1,)](
182 view,
183 out,
184 0,
185 0,
186 1,
187 BLOCK_SIZE=1,
188 )
189 elif dim == 1:
190 n_elements = view.numel()
191 grid = (triton.cdiv(n_elements, _BLOCK_SIZE),)
192 _as_strided_copy_1d_kernel[grid](
193 view,
194 out,
195 view.stride(0),
196 out.stride(0),
197 n_elements,
198 BLOCK_SIZE=_BLOCK_SIZE,
199 )
200 elif dim == 2:
201 dim_0, dim_1 = view.shape
202 grid = (triton.cdiv(dim_0, _BLOCK_M), triton.cdiv(dim_1, _BLOCK_N))
203 _as_strided_copy_2d_kernel[grid](
204 view,
205 out,
206 view.stride(0),
207 view.stride(1),
208 out.stride(0),
209 out.stride(1),
210 dim_0,
211 dim_1,
212 BLOCK_M=_BLOCK_M,
213 BLOCK_N=_BLOCK_N,
214 )
215 elif dim == 3:
216 n_elements = view.numel()
217 grid = (triton.cdiv(n_elements, _BLOCK_SIZE),)
218 _as_strided_copy_3d_kernel[grid](
219 view,
220 out,
221 view.stride(0),
222 view.stride(1),
223 view.stride(2),
224 out.stride(0),
225 out.stride(1),
226 out.stride(2),
227 view.shape[1],
228 view.shape[2],
229 n_elements,
230 BLOCK_SIZE=_BLOCK_SIZE,
231 )
232 else:
233 return _as_strided_copy_kernel(view, out0=out)
234 return out
237def _launch_byte_as_strided_copy(view: torch.Tensor, out: torch.Tensor):
238 # Copy one-byte dtypes through uint8 views to avoid Triton fp8 scalar codegen.
239 # The dtype-view API requires at least one logical dimension on some builds.
240 byte_view = (
241 view.reshape(1).view(torch.uint8) if view.dim() == 0 else view.view(torch.uint8)
242 )
243 byte_out = (
244 out.reshape(1).view(torch.uint8) if out.dim() == 0 else out.view(torch.uint8)
245 )
246 _launch_as_strided_copy(byte_view, byte_out)
247 return out
250def as_strided_copy(input, size, stride, storage_offset=None):
251 logger.debug("GEMS AS_STRIDED_COPY")
252 if input.device.type != "cuda":
253 view = _make_as_strided_view(input, size, stride, storage_offset)
254 return view.clone(memory_format=torch.contiguous_format)
256 out = torch.empty(size, dtype=input.dtype, device=input.device)
257 if out.numel() == 0:
258 _make_as_strided_view(input, size, stride, storage_offset)
259 return out
261 view = _make_as_strided_view(input, size, stride, storage_offset)
262 if _can_use_triton(view, out):
263 return _launch_as_strided_copy(view, out)
264 if _can_use_byte_triton(view, out):
265 return _launch_byte_as_strided_copy(view, out)
266 return _fallback_as_strided_copy(input, size, stride, storage_offset)
269def as_strided_copy_out(input, size, stride, storage_offset=None, *, out):
270 logger.debug("GEMS AS_STRIDED_COPY_OUT")
271 if out.dtype != input.dtype:
272 # Match PyTorch's strict out-dtype contract without measuring native fallback.
273 raise RuntimeError(
274 f"Expected out tensor to have dtype {input.dtype}, but got {out.dtype} instead"
275 )
277 target_size = tuple(size)
278 if tuple(out.shape) != target_size:
279 out.resize_(target_size)
281 if out.numel() == 0:
282 _make_as_strided_view(input, size, stride, storage_offset)
283 return out
285 if input.device.type != "cuda":
286 view = _make_as_strided_view(input, size, stride, storage_offset)
287 if (
288 torch._C._is_alias_of(input, out)
289 or has_internal_overlapping(out) != MemOverlap.No
290 ):
291 view = view.clone(memory_format=torch.contiguous_format)
292 out.copy_(view)
293 return out
295 if (
296 torch._C._is_alias_of(input, out)
297 or has_internal_overlapping(out) != MemOverlap.No
298 ):
299 return _fallback_as_strided_copy_out(
300 input, size, stride, storage_offset, out=out
301 )
303 view = _make_as_strided_view(input, size, stride, storage_offset)
304 if _can_use_triton(view, out):
305 return _launch_as_strided_copy(view, out)
306 if _can_use_byte_triton(view, out):
307 return _launch_byte_as_strided_copy(view, out)
308 return _fallback_as_strided_copy_out(input, size, stride, storage_offset, out=out)