Coverage for src/flag_gems/ops/cat.py: 66%
140 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import logging
2from typing import List, Tuple, Union
4import torch
5import triton
6import triton.language as tl
8logger = logging.getLogger(__name__)
11@triton.jit
12def cat_copy_func_kernel_4(
13 out_ptr,
14 in_ptr_a,
15 in_ptr_b,
16 in_ptr_c,
17 in_ptr_d,
18 dim_size_in_a,
19 dim_size_in_b,
20 dim_size_in_c,
21 dim_size_in_d,
22 dim_size_out,
23 dim_prod_post,
24 dim_offset_a,
25 dim_offset_b,
26 dim_offset_c,
27 dim_offset_d,
28 total_elements_a,
29 total_elements_b,
30 total_elements_c,
31 total_elements_d,
32 BLOCK_X: tl.constexpr,
33):
34 pid_x = tl.program_id(0)
35 pid_y = tl.program_id(1)
37 if pid_y == 0:
38 in_ptr = in_ptr_a
39 dim_size_in = dim_size_in_a
40 dim_offset = dim_offset_a
41 total_elements = total_elements_a
42 elif pid_y == 1:
43 in_ptr = in_ptr_b
44 dim_size_in = dim_size_in_b
45 dim_offset = dim_offset_b
46 total_elements = total_elements_b
47 elif pid_y == 2:
48 in_ptr = in_ptr_c
49 dim_size_in = dim_size_in_c
50 dim_offset = dim_offset_c
51 total_elements = total_elements_c
52 else:
53 in_ptr = in_ptr_d
54 dim_size_in = dim_size_in_d
55 dim_offset = dim_offset_d
56 total_elements = total_elements_d
58 block_start = pid_x * BLOCK_X
59 offsets = tl.arange(0, BLOCK_X)
60 mask = block_start + offsets < total_elements
62 idx = block_start + offsets
64 pre_idx = idx // (dim_size_in * dim_prod_post)
65 dim_idx = (idx // dim_prod_post) % dim_size_in
66 post_idx = idx % dim_prod_post
68 out_idx = (
69 pre_idx * dim_size_out * dim_prod_post
70 + (dim_idx + dim_offset) * dim_prod_post
71 + post_idx
72 )
74 data = tl.load(in_ptr + idx, mask=mask)
75 tl.store(out_ptr + out_idx, data, mask=mask)
78def _cat_run_kernel(
79 A: List[torch.Tensor],
80 dim: int,
81 out_shape: List[int],
82 out: torch.Tensor,
83):
84 BLOCK = 1024
85 dim_offset = 0
86 i = 0
87 while i < len(A):
88 tensors_in_batch = A[i : i + 4]
89 num_tensors_in_batch = len(tensors_in_batch)
91 args = []
92 total_elements_list = []
93 current_dim_offset = dim_offset
95 for j in range(4):
96 if j < num_tensors_in_batch:
97 tensor = tensors_in_batch[j].contiguous()
98 shape = tensor.shape
99 total_elements = tensor.numel()
100 dim_size_in = shape[dim]
102 args.extend([tensor, dim_size_in, current_dim_offset, total_elements])
103 total_elements_list.append(total_elements)
104 current_dim_offset += dim_size_in
105 else:
106 args.extend([tensors_in_batch[0], 0, 0, 0])
107 total_elements_list.append(0)
109 dim_size_out = out_shape[dim]
110 dim_prod_post = 1
111 for d in range(dim + 1, A[0].ndim):
112 dim_prod_post *= A[0].shape[d]
114 grid_y = num_tensors_in_batch
115 max_elements_in_batch = max(total_elements_list) if total_elements_list else 0
116 grid = (triton.cdiv(max_elements_in_batch, BLOCK), grid_y)
118 (
119 tensor_a,
120 dim_size_in_a,
121 dim_offset_a,
122 total_elements_a,
123 tensor_b,
124 dim_size_in_b,
125 dim_offset_b,
126 total_elements_b,
127 tensor_c,
128 dim_size_in_c,
129 dim_offset_c,
130 total_elements_c,
131 tensor_d,
132 dim_size_in_d,
133 dim_offset_d,
134 total_elements_d,
135 ) = args
137 cat_copy_func_kernel_4[grid](
138 out,
139 tensor_a,
140 tensor_b,
141 tensor_c,
142 tensor_d,
143 dim_size_in_a,
144 dim_size_in_b,
145 dim_size_in_c,
146 dim_size_in_d,
147 dim_size_out,
148 dim_prod_post,
149 dim_offset_a,
150 dim_offset_b,
151 dim_offset_c,
152 dim_offset_d,
153 total_elements_a,
154 total_elements_b,
155 total_elements_c,
156 total_elements_d,
157 BLOCK_X=BLOCK,
158 )
160 dim_offset = current_dim_offset
161 i += num_tensors_in_batch
164def _cat_build_working_list(
165 A: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int
166):
167 """Returns (mode, payload) where mode is 'single'|'empty'|'multi'."""
168 if len(A) == 0:
169 raise RuntimeError("torch.cat(): expected a non-empty list of Tensors")
170 if len(A) == 1:
171 return "single", A[0]
173 device = A[0].device
174 dtype = A[0].dtype
175 A = list(A)
176 for i in range(len(A) - 1, -1, -1):
177 if A[i].shape == torch.Size([0]):
178 A.pop(i)
179 if len(A) == 0:
180 return "empty", torch.tensor([], device=device, dtype=dtype)
181 if len(A) == 1:
182 return "single", A[0]
184 assert dim >= -A[0].ndim and dim < A[0].ndim, f"Invalid dim: {dim}"
185 dim %= A[0].ndim
187 inp_shapes = [list(_.shape) for _ in A]
188 inp0_shape = inp_shapes[0]
189 for s in inp_shapes[1:]:
190 if len(s) != len(inp0_shape):
191 raise RuntimeError(
192 f"Tensors must have same number of dimensions: got {len(inp0_shape)} and {len(s)}"
193 )
194 for tensor_idx, inp_shape in enumerate(inp_shapes):
195 for idx, (common_length, length) in enumerate(zip(inp0_shape, inp_shape)):
196 if idx != dim and length != common_length:
197 raise RuntimeError(
198 f"Sizes of tensors must match except in dimension {dim}. "
199 f"Expected size {common_length} but got size {length} for tensor number "
200 f"{tensor_idx} in the list"
201 )
203 dtypes = [t.dtype for t in A]
204 dtype = dtypes[0]
205 for dt in dtypes[1:]:
206 dtype = torch.promote_types(dtype, dt)
207 A = [t.to(dtype) if t.dtype != dtype else t for t in A]
209 shapes = [t.shape for t in A]
210 cat_dim_sizes = [s[dim] for s in shapes]
211 out_shape = list(shapes[0])
212 out_shape[dim] = sum(cat_dim_sizes)
213 return "multi", (A, dim, out_shape, dtype, A[0].device)
216def cat_out(
217 A: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]],
218 dim: int = 0,
219 *,
220 out: torch.Tensor,
221) -> torch.Tensor:
222 logger.debug("GEMS CAT_OUT")
223 mode, payload = _cat_build_working_list(A, dim)
224 if mode == "single":
225 t = payload
226 out.resize_(t.shape)
227 if out.dtype != t.dtype:
228 out.copy_(t.to(out.dtype))
229 else:
230 out.copy_(t)
231 return out
232 if mode == "empty":
233 t = payload
234 out.resize_(t.shape)
235 out.copy_(t)
236 return out
238 A, dim, out_shape, dtype, device = payload
239 if out.dtype != dtype:
240 raise RuntimeError(f"cat.out: expected out dtype {dtype}, got {out.dtype}")
241 if list(out.shape) != out_shape:
242 out.resize_(out_shape)
243 _cat_run_kernel(A, dim, out_shape, out)
244 return out
247def cat(
248 A: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int = 0
249) -> torch.Tensor:
250 logger.debug("GEMS CAT")
251 mode, payload = _cat_build_working_list(A, dim)
252 if mode == "single":
253 return payload
254 if mode == "empty":
255 return payload
257 A, dim, out_shape, dtype, device = payload
258 out = torch.empty(out_shape, dtype=dtype, device=device)
259 _cat_run_kernel(A, dim, out_shape, out)
260 return out