Coverage for src/flag_gems/ops/cat.py: 54%
193 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
2from typing import List, Tuple, Union
4import torch
5import triton
6import triton.language as tl
8logger = logging.getLogger(__name__)
11def _is_float8_e8m0fnu(dtype: torch.dtype) -> bool:
12 return str(dtype) == "torch.float8_e8m0fnu"
15def _should_use_uint8_view_path(
16 A: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]],
17) -> bool:
18 if len(A) == 0:
19 return False
20 first_dtype = A[0].dtype
21 if not _is_float8_e8m0fnu(first_dtype):
22 return False
23 if A[0].element_size() != 1:
24 return False
25 for tensor in A[1:]:
26 if tensor.dtype != first_dtype or tensor.element_size() != 1:
27 return False
28 return True
31def _cat_build_working_list_uint8_view(
32 A: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]],
33 dim: int,
34):
35 original_dtype = A[0].dtype
36 A_u8 = [tensor.view(torch.uint8) for tensor in A]
37 mode, payload = _cat_build_working_list(A_u8, dim)
38 return mode, payload, original_dtype
41@triton.jit
42def cat_copy_func_kernel_4(
43 out_ptr,
44 in_ptr_a,
45 in_ptr_b,
46 in_ptr_c,
47 in_ptr_d,
48 dim_size_in_a,
49 dim_size_in_b,
50 dim_size_in_c,
51 dim_size_in_d,
52 dim_size_out,
53 dim_prod_post,
54 dim_offset_a: tl.int64,
55 dim_offset_b: tl.int64,
56 dim_offset_c: tl.int64,
57 dim_offset_d: tl.int64,
58 total_elements_a,
59 total_elements_b,
60 total_elements_c,
61 total_elements_d,
62 BLOCK_X: tl.constexpr,
63):
64 pid_x = tl.program_id(0)
65 pid_y = tl.program_id(1)
67 if pid_y == 0:
68 in_ptr = in_ptr_a
69 dim_size_in = dim_size_in_a
70 dim_offset = tl.cast(dim_offset_a, tl.int64)
71 total_elements = total_elements_a
72 elif pid_y == 1:
73 in_ptr = in_ptr_b
74 dim_size_in = dim_size_in_b
75 dim_offset = tl.cast(dim_offset_b, tl.int64)
76 total_elements = total_elements_b
77 elif pid_y == 2:
78 in_ptr = in_ptr_c
79 dim_size_in = dim_size_in_c
80 dim_offset = tl.cast(dim_offset_c, tl.int64)
81 total_elements = total_elements_c
82 else:
83 in_ptr = in_ptr_d
84 dim_size_in = dim_size_in_d
85 dim_offset = tl.cast(dim_offset_d, tl.int64)
86 total_elements = total_elements_d
88 block_start = pid_x * BLOCK_X
89 offsets = tl.arange(0, BLOCK_X)
90 mask = block_start + offsets < total_elements
92 idx = block_start + offsets
94 pre_idx = idx // (dim_size_in * dim_prod_post)
95 dim_idx = (idx // dim_prod_post) % dim_size_in
96 post_idx = idx % dim_prod_post
98 out_idx = (
99 pre_idx * dim_size_out * dim_prod_post
100 + (dim_idx + dim_offset) * dim_prod_post
101 + post_idx
102 )
104 data = tl.load(in_ptr + idx, mask=mask)
105 tl.store(out_ptr + out_idx, data, mask=mask)
108def _cat_run_kernel(
109 A: List[torch.Tensor],
110 dim: int,
111 out_shape: List[int],
112 out: torch.Tensor,
113):
114 BLOCK = 1024
115 dim_offset = 0
116 i = 0
117 while i < len(A):
118 tensors_in_batch = A[i : i + 4]
119 num_tensors_in_batch = len(tensors_in_batch)
121 args = []
122 total_elements_list = []
123 current_dim_offset = dim_offset
125 for j in range(4):
126 if j < num_tensors_in_batch:
127 tensor = tensors_in_batch[j].contiguous()
128 shape = tensor.shape
129 total_elements = tensor.numel()
130 dim_size_in = shape[dim]
132 args.extend([tensor, dim_size_in, current_dim_offset, total_elements])
133 total_elements_list.append(total_elements)
134 current_dim_offset += dim_size_in
135 else:
136 args.extend([tensors_in_batch[0], 0, 0, 0])
137 total_elements_list.append(0)
139 dim_size_out = out_shape[dim]
140 dim_prod_post = 1
141 for d in range(dim + 1, A[0].ndim):
142 dim_prod_post *= A[0].shape[d]
144 grid_y = num_tensors_in_batch
145 max_elements_in_batch = max(total_elements_list) if total_elements_list else 0
146 grid = (triton.cdiv(max_elements_in_batch, BLOCK), grid_y)
148 (
149 tensor_a,
150 dim_size_in_a,
151 dim_offset_a,
152 total_elements_a,
153 tensor_b,
154 dim_size_in_b,
155 dim_offset_b,
156 total_elements_b,
157 tensor_c,
158 dim_size_in_c,
159 dim_offset_c,
160 total_elements_c,
161 tensor_d,
162 dim_size_in_d,
163 dim_offset_d,
164 total_elements_d,
165 ) = args
167 cat_copy_func_kernel_4[grid](
168 out,
169 tensor_a,
170 tensor_b,
171 tensor_c,
172 tensor_d,
173 dim_size_in_a,
174 dim_size_in_b,
175 dim_size_in_c,
176 dim_size_in_d,
177 dim_size_out,
178 dim_prod_post,
179 dim_offset_a,
180 dim_offset_b,
181 dim_offset_c,
182 dim_offset_d,
183 total_elements_a,
184 total_elements_b,
185 total_elements_c,
186 total_elements_d,
187 BLOCK_X=BLOCK,
188 )
190 dim_offset = current_dim_offset
191 i += num_tensors_in_batch
194def _cat_build_working_list(
195 A: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int
196):
197 """Returns (mode, payload) where mode is 'single'|'empty'|'multi'."""
198 if len(A) == 0:
199 raise RuntimeError("torch.cat(): expected a non-empty list of Tensors")
200 if len(A) == 1:
201 return "single", A[0]
203 device = A[0].device
204 dtype = A[0].dtype
205 A = list(A)
206 for i in range(len(A) - 1, -1, -1):
207 if A[i].shape == torch.Size([0]):
208 A.pop(i)
209 if len(A) == 0:
210 return "empty", torch.tensor([], device=device, dtype=dtype)
211 if len(A) == 1:
212 return "single", A[0]
214 assert dim >= -A[0].ndim and dim < A[0].ndim, f"Invalid dim: {dim}"
215 dim %= A[0].ndim
217 inp_shapes = [list(_.shape) for _ in A]
218 inp0_shape = inp_shapes[0]
219 for s in inp_shapes[1:]:
220 if len(s) != len(inp0_shape):
221 raise RuntimeError(
222 f"Tensors must have same number of dimensions: got {len(inp0_shape)} and {len(s)}"
223 )
224 for tensor_idx, inp_shape in enumerate(inp_shapes):
225 for idx, (common_length, length) in enumerate(zip(inp0_shape, inp_shape)):
226 if idx != dim and length != common_length:
227 raise RuntimeError(
228 f"Sizes of tensors must match except in dimension {dim}. "
229 f"Expected size {common_length} but got size {length} for tensor number "
230 f"{tensor_idx} in the list"
231 )
233 dtypes = [t.dtype for t in A]
234 dtype = dtypes[0]
235 for dt in dtypes[1:]:
236 dtype = torch.promote_types(dtype, dt)
237 A = [t.to(dtype) if t.dtype != dtype else t for t in A]
239 shapes = [t.shape for t in A]
240 cat_dim_sizes = [s[dim] for s in shapes]
241 out_shape = list(shapes[0])
242 out_shape[dim] = sum(cat_dim_sizes)
243 return "multi", (A, dim, out_shape, dtype, A[0].device)
246def cat_out(
247 A: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]],
248 dim: int = 0,
249 *,
250 out: torch.Tensor,
251) -> torch.Tensor:
252 logger.debug("GEMS CAT_OUT")
253 A = list(A)
254 if _should_use_uint8_view_path(A):
255 mode, payload, original_dtype = _cat_build_working_list_uint8_view(A, dim)
256 if mode == "single":
257 t = payload.view(original_dtype)
258 out.resize_(t.shape)
259 if out.dtype != t.dtype:
260 out.copy_(t.to(out.dtype))
261 else:
262 out.copy_(t)
263 return out
264 if mode == "empty":
265 t = payload.view(original_dtype)
266 out.resize_(t.shape)
267 out.copy_(t)
268 return out
270 A_u8, dim, out_shape, _, _ = payload
271 if out.dtype != original_dtype:
272 raise RuntimeError(
273 f"cat.out: expected out dtype {original_dtype}, got {out.dtype}"
274 )
275 if list(out.shape) != out_shape:
276 out.resize_(out_shape)
277 out_u8 = out.view(torch.uint8)
278 _cat_run_kernel(A_u8, dim, out_shape, out_u8)
279 return out
281 mode, payload = _cat_build_working_list(A, dim)
282 if mode == "single":
283 t = payload
284 out.resize_(t.shape)
285 if out.dtype != t.dtype:
286 out.copy_(t.to(out.dtype))
287 else:
288 out.copy_(t)
289 return out
290 if mode == "empty":
291 t = payload
292 out.resize_(t.shape)
293 out.copy_(t)
294 return out
296 A, dim, out_shape, dtype, device = payload
297 if out.dtype != dtype:
298 raise RuntimeError(f"cat.out: expected out dtype {dtype}, got {out.dtype}")
299 if list(out.shape) != out_shape:
300 out.resize_(out_shape)
301 _cat_run_kernel(A, dim, out_shape, out)
302 return out
305def cat(
306 A: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int = 0
307) -> torch.Tensor:
308 logger.debug("GEMS CAT")
309 A = list(A)
310 if _should_use_uint8_view_path(A):
311 mode, payload, original_dtype = _cat_build_working_list_uint8_view(A, dim)
312 if mode == "single":
313 return payload.view(original_dtype)
314 if mode == "empty":
315 return payload.view(original_dtype)
317 A_u8, dim, out_shape, _, device = payload
318 out_u8 = torch.empty(out_shape, dtype=torch.uint8, device=device)
319 _cat_run_kernel(A_u8, dim, out_shape, out_u8)
320 return out_u8.view(original_dtype)
322 mode, payload = _cat_build_working_list(A, dim)
323 if mode == "single":
324 return payload
325 if mode == "empty":
326 return payload
328 A, dim, out_shape, dtype, device = payload
329 out = torch.empty(out_shape, dtype=dtype, device=device)
330 _cat_run_kernel(A, dim, out_shape, out)
331 return out