Coverage for src/flag_gems/ops/lift_fresh_copy.py: 39%
56 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8import flag_gems
10logger = logging.getLogger(__name__)
13@triton.jit
14def _copy_kernel(in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
15 pid = tl.program_id(axis=0)
16 block_start = pid * BLOCK_SIZE
17 offsets = block_start + tl.arange(0, BLOCK_SIZE)
18 mask = offsets < n_elements
19 x = tl.load(in_ptr + offsets, mask=mask)
20 tl.store(out_ptr + offsets, x, mask=mask)
23def lift_fresh_copy(*args, **kwargs):
24 logger.debug("GEMS LIFT_FRESH_COPY")
25 # Attempt to find the input tensor from args/kwargs
26 x = None
27 if len(args) > 0 and isinstance(args[0], torch.Tensor):
28 x = args[0]
29 elif "self" in kwargs and isinstance(kwargs["self"], torch.Tensor):
30 x = kwargs["self"]
31 else:
32 for v in list(args) + list(kwargs.values()):
33 if isinstance(v, torch.Tensor):
34 x = v
35 break
36 if x is None:
37 raise ValueError("lift_fresh_copy expects a Tensor argument")
39 if x.device.type != flag_gems.device:
40 raise ValueError(
41 f"lift_fresh_copy Triton kernel requires a {flag_gems.device} tensor"
42 )
44 x_contig = x.contiguous()
45 out = torch.empty_like(x_contig, memory_format=torch.contiguous_format)
47 n_elements = x_contig.numel()
48 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
49 _copy_kernel[grid](x_contig, out, n_elements, BLOCK_SIZE=1024)
51 return out.view_as(x_contig)
54def lift_fresh_copy_out(x: torch.Tensor, out: torch.Tensor = None):
55 logger.debug("GEMS LIFT_FRESH_COPY_OUT")
56 if x is None or not isinstance(x, torch.Tensor):
57 raise ValueError("lift_fresh_copy_out expects 'x' to be a Tensor")
58 if x.device.type != flag_gems.device:
59 raise ValueError(
60 f"lift_fresh_copy_out Triton kernel requires {flag_gems.device} tensors"
61 )
63 x_contig = x.contiguous()
65 if out is None:
66 out = torch.empty_like(x_contig, memory_format=torch.contiguous_format)
67 else:
68 if out.device.type != flag_gems.device:
69 raise ValueError(f"Output tensor 'out' must be on {flag_gems.device}")
70 if out.dtype != x_contig.dtype:
71 raise ValueError("Output tensor 'out' must have the same dtype as input")
72 # Resize to match input shape and ensure contiguous layout
73 if out.numel() != x_contig.numel() or not out.is_contiguous():
74 out.resize_(x_contig.shape)
75 if not out.is_contiguous():
76 out = out.contiguous()
78 n_elements = x_contig.numel()
79 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
80 _copy_kernel[grid](x_contig, out, n_elements, BLOCK_SIZE=1024)
82 return out.view_as(x_contig)