Coverage for src/flag_gems/ops/view_copy.py: 61%
46 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
10logger = logging.getLogger(__name__)
13@triton.jit
14def _view_copy_kernel(src_ptr, dst_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
15 pid = tl.program_id(axis=0)
16 block_start = pid * BLOCK_SIZE
17 offsets = block_start + tl.arange(0, BLOCK_SIZE)
18 mask = offsets < n_elements
19 vals = tl.load(src_ptr + offsets, mask=mask)
20 tl.store(dst_ptr + offsets, vals, mask=mask)
23def view_copy(x: torch.Tensor, size) -> torch.Tensor:
24 logger.debug("GEMS VIEW_COPY")
25 """
26 Wrapper for aten::view_copy
27 Creates and returns a copy of `x` with the specified shape.
28 This is like view() but always returns a copy instead of an alias.
29 """
30 # Handle SymInt[] - convert to tuple of ints
31 if isinstance(size, torch.SymInt):
32 size = (int(size),)
33 elif isinstance(size, (list, tuple)):
34 size = tuple(int(s) if isinstance(s, torch.SymInt) else s for s in size)
36 n_elements = x.numel()
38 # Handle -1 (infer this dimension)
39 if -1 in size:
40 if size.count(-1) > 1:
41 raise RuntimeError(f"view_copy: only one dimension can be -1, got {size}")
42 target_numel_except_minus1 = 1
43 for s in size:
44 if s != -1:
45 target_numel_except_minus1 *= s
46 inferred_dim = n_elements // target_numel_except_minus1
47 size = tuple(inferred_dim if s == -1 else s for s in size)
49 # Validate total number of elements matches
50 target_numel = 1
51 for s in size:
52 target_numel *= s
53 if n_elements != target_numel:
54 raise RuntimeError(
55 f"view_copy: cannot reshape tensor of size {n_elements} into shape {size}"
56 )
58 if n_elements == 0:
59 return torch.empty(size, dtype=x.dtype, device=x.device)
61 # Create output tensor with target shape
62 out = torch.empty(size, dtype=x.dtype, device=x.device)
64 # Ensure source is contiguous for efficient linear copy
65 src = x.contiguous() if not x.is_contiguous() else x
66 if not out.is_contiguous():
67 out = out.contiguous()
69 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
70 with torch_device_fn.device(x.device):
71 _view_copy_kernel[grid](src, out, n_elements, BLOCK_SIZE=1024)
72 return out