Coverage for src/flag_gems/runtime/backend/_ascend/ops/cat.py: 0%
76 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
2from typing import List, Tuple, Union
4import torch
6logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
9def cat(
10 A: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int = 0
11) -> torch.Tensor:
12 logger.debug("GEMS_ASCEND CAT")
14 device = A[0].device
15 dtype = A[0].dtype
16 A = list(A)
17 for i in range(len(A) - 1, -1, -1):
18 if A[i].shape == torch.Size([0]):
19 A.pop(i)
20 if len(A) == 0:
21 return torch.tensor([], device=device, dtype=dtype)
22 if len(A) == 1:
23 return A[0]
25 assert dim >= -A[0].ndim and dim < A[0].ndim, f"Invalid dim: {dim}"
26 dim = dim % A[0].ndim
28 inp_shapes = [list(_.shape) for _ in A]
29 inp0_shape = inp_shapes[0]
30 for s in inp_shapes[1:]:
31 if len(s) != len(inp0_shape):
32 raise RuntimeError(
33 f"Tensors must have same number of dimensions: got {len(inp0_shape)} and {len(s)}"
34 )
35 for tensor_idx, inp_shape in enumerate(inp_shapes):
36 for idx, (common_length, length) in enumerate(zip(inp0_shape, inp_shape)):
37 if idx == dim:
38 continue
39 elif length != common_length:
40 raise RuntimeError(
41 f"Sizes of tensors must match except in dimension {dim}. "
42 f"Expected size {common_length} but got size {length} for tensor number "
43 f"{tensor_idx} in the list"
44 )
46 out_shape = list(inp0_shape)
47 out_shape[dim] = sum(s[dim] for s in inp_shapes)
48 out = torch.empty(out_shape, dtype=A[0].dtype, device=A[0].device)
49 _cat_fill(out, A, dim)
50 return out
53def _cat_fill(out, A, dim):
54 idx = [slice(None)] * out.ndim
55 offset = 0
56 for a in A:
57 a = a.contiguous()
58 idx[dim] = slice(offset, offset + a.shape[dim])
59 out[tuple(idx)] = a
60 offset += a.shape[dim]
63def cat_out(
64 A: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]],
65 dim: int = 0,
66 *,
67 out: torch.Tensor,
68) -> torch.Tensor:
69 logger.debug("GEMS_ASCEND CAT_OUT")
71 if len(A) == 0:
72 raise RuntimeError("torch.cat(): expected a non-empty list of Tensors")
74 A = list(A)
75 for i in range(len(A) - 1, -1, -1):
76 if A[i].shape == torch.Size([0]):
77 A.pop(i)
78 if len(A) == 0:
79 out.resize_(0)
80 return out
81 if len(A) == 1:
82 t = A[0]
83 out.resize_(t.shape)
84 out.copy_(t)
85 return out
87 assert dim >= -A[0].ndim and dim < A[0].ndim, f"Invalid dim: {dim}"
88 dim = dim % A[0].ndim
90 inp_shapes = [list(_.shape) for _ in A]
91 inp0_shape = inp_shapes[0]
92 for s in inp_shapes[1:]:
93 if len(s) != len(inp0_shape):
94 raise RuntimeError(
95 f"Tensors must have same number of dimensions: got {len(inp0_shape)} and {len(s)}"
96 )
97 for tensor_idx, inp_shape in enumerate(inp_shapes):
98 for idx, (common_length, length) in enumerate(zip(inp0_shape, inp_shape)):
99 if idx == dim:
100 continue
101 elif length != common_length:
102 raise RuntimeError(
103 f"Sizes of tensors must match except in dimension {dim}. "
104 f"Expected size {common_length} but got size {length} for tensor number "
105 f"{tensor_idx} in the list"
106 )
108 out_shape = list(inp0_shape)
109 out_shape[dim] = sum(s[dim] for s in inp_shapes)
110 out.resize_(out_shape)
111 _cat_fill(out, A, dim)
112 return out