Coverage for src/flag_gems/ops/concatenate.py: 100%

8 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import logging 

2from typing import List, Tuple, Union 

3 

4import torch 

5 

6from flag_gems.ops.cat import cat 

7 

8logger = logging.getLogger(__name__) 

9 

10 

11def concatenate( 

12 A: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int = 0 

13) -> torch.Tensor: 

14 """ 

15 Concatenate tensors along a given dimension. 

16 

17 This is an alias for torch.cat. The function signature matches 

18 aten::concatenate(Tensor[] tensors, int dim=0) -> Tensor 

19 """ 

20 logger.debug("GEMS CONCATENATE") 

21 return cat(A, dim=dim)