Coverage for src/flag_gems/ops/t_copy.py: 53%
70 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
4import torch
5import triton
6import triton.language as tl
8import flag_gems
10logger = logging.getLogger(__name__)
13@triton.jit
14def t_copy_2d_kernel(
15 in_ptr,
16 out_ptr,
17 in_stride_0,
18 in_stride_1,
19 out_stride_0,
20 out_stride_1,
21 M, # input dim0
22 N, # input dim1
23 BLOCK_M: tl.constexpr,
24 BLOCK_N: tl.constexpr,
25):
26 pid_m = tl.program_id(0)
27 pid_n = tl.program_id(1)
29 i = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # corresponds to out rows [0..N)
30 j = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # corresponds to out cols [0..M)
32 i64 = i.to(tl.int64)[None, :] # shape [1, BM]
33 j64 = j.to(tl.int64)[:, None] # shape [BN, 1]
35 # out shape = (N, M)
36 mask = (i64 < N) & (j64 < M)
38 # in index = (j, i) -> in_offset = j*in_stride_0 + i*in_stride_1
39 in_offsets = j64 * in_stride_0 + i64 * in_stride_1
40 # out index = (i, j) -> out_offset = i*out_stride_0 + j*out_stride_1
41 out_offsets = i64 * out_stride_0 + j64 * out_stride_1
43 x = tl.load(in_ptr + in_offsets, mask=mask)
44 tl.store(out_ptr + out_offsets, x, mask=mask)
47@triton.jit
48def copy_1d_strided_kernel(
49 in_ptr,
50 out_ptr,
51 in_stride,
52 out_stride,
53 N,
54 BLOCK_SIZE: tl.constexpr,
55):
56 pid = tl.program_id(0)
57 offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
58 mask = offs < N
59 offs64 = offs.to(tl.int64)
60 in_idx = offs64 * in_stride
61 out_idx = offs64 * out_stride
62 x = tl.load(in_ptr + in_idx, mask=mask)
63 tl.store(out_ptr + out_idx, x, mask=mask)
66def _launch_t_copy_kernel(inp: torch.Tensor, out: torch.Tensor):
67 if inp.device.type != flag_gems.device or out.device.type != flag_gems.device:
68 raise ValueError(f"t_copy kernels require {flag_gems.device} tensors")
69 assert inp.dtype == out.dtype, "dtype mismatch between input and output"
71 dim = inp.dim()
72 if dim == 0:
73 # Scalar copy
74 n = 1
75 grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),)
76 copy_1d_strided_kernel[grid](
77 inp,
78 out,
79 0,
80 0,
81 n,
82 BLOCK_SIZE=1,
83 )
84 elif dim == 1:
85 n = inp.numel()
86 in_stride = inp.stride(0)
87 out_stride = out.stride(0)
88 assert out.numel() == n, "Output size mismatch for 1D t_copy"
89 grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),)
90 copy_1d_strided_kernel[grid](
91 inp,
92 out,
93 in_stride,
94 out_stride,
95 n,
96 BLOCK_SIZE=1024,
97 )
98 elif dim == 2:
99 M, N = inp.shape # input dims
100 # out should be (N, M)
101 assert (
102 out.dim() == 2 and out.shape[0] == N and out.shape[1] == M
103 ), "Output shape must be (input.size(1), input.size(0)) for t_copy"
104 in_s0, in_s1 = inp.stride()
105 out_s0, out_s1 = out.stride()
106 grid = lambda meta: (
107 triton.cdiv(N, meta["BLOCK_M"]),
108 triton.cdiv(M, meta["BLOCK_N"]),
109 )
110 t_copy_2d_kernel[grid](
111 inp,
112 out,
113 in_s0,
114 in_s1,
115 out_s0,
116 out_s1,
117 M,
118 N,
119 BLOCK_M=32,
120 BLOCK_N=32,
121 )
122 else:
123 raise RuntimeError("t_copy expects a tensor with <= 2 dims")
126def t_copy_out(
127 input: torch.Tensor,
128 out: torch.Tensor,
129 memory_format: torch.memory_format | None = None,
130):
131 logger.debug("GEMS T_COPY_OUT")
132 _launch_t_copy_kernel(input, out)
133 return out
136def t_copy(input: torch.Tensor, memory_format: torch.memory_format | None = None):
137 logger.debug("GEMS T_COPY")
138 dim = input.dim()
139 if dim == 0:
140 out = torch.empty((), dtype=input.dtype, device=input.device)
141 elif dim == 1:
142 out = torch.empty_like(input, memory_format=torch.contiguous_format)
143 elif dim == 2:
144 M, N = input.shape
145 out = torch.empty(
146 (N, M),
147 dtype=input.dtype,
148 device=input.device,
149 memory_format=torch.contiguous_format,
150 )
151 else:
152 raise RuntimeError("t_copy expects a tensor with <= 2 dims")
153 _launch_t_copy_kernel(input, out)
154 return out