Coverage for src/flag_gems/ops/concatenate.py: 100%
8 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
2from typing import List, Tuple, Union
4import torch
6from flag_gems.ops.cat import cat
8logger = logging.getLogger(__name__)
11def concatenate(
12 A: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int = 0
13) -> torch.Tensor:
14 """
15 Concatenate tensors along a given dimension.
17 This is an alias for torch.cat. The function signature matches
18 aten::concatenate(Tensor[] tensors, int dim=0) -> Tensor
19 """
20 logger.debug("GEMS CONCATENATE")
21 return cat(A, dim=dim)