Coverage for src/flag_gems/ops/tensor_split.py: 9%
95 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen
2import logging
3from typing import List, Union
5import torch
6import triton
7import triton.language as tl
9logger = logging.getLogger(__name__)
12@triton.jit
13def split_copy_kernel(
14 input_ptr,
15 output_ptr,
16 num_elements: tl.constexpr,
17 dim_size_input: tl.constexpr,
18 dim_size_output: tl.constexpr,
19 split_dim: tl.constexpr,
20 dim_prod_pre: tl.constexpr,
21 dim_prod_post: tl.constexpr,
22 BLOCK_SIZE: tl.constexpr,
23):
24 """Kernel to copy a split from input to output tensor."""
25 pid = tl.program_id(0)
26 offset = pid * BLOCK_SIZE
27 offsets = tl.arange(0, BLOCK_SIZE)
28 mask = offset + offsets < num_elements
30 # Compute the linear index
31 idx = offset + offsets
33 # Convert linear index to multi-dimensional index
34 # For the split dimension, we use the output dimension size
35 # For other dimensions, we use the input dimension sizes
37 # Linear index = pre_idx * dim_size_input * post_dim + split_idx * post_dim + post_idx
38 # But we need to map to the correct position in the input tensor
40 # Compute indices for each dimension
41 pre_idx = idx // (dim_size_output * dim_prod_post)
42 split_idx = (idx // dim_prod_post) % dim_size_output
43 post_idx = idx % dim_prod_post
45 # Compute input index
46 # The split dimension offset is already accounted for in the output tensor layout
47 input_idx = (
48 pre_idx * dim_size_input * dim_prod_post + split_idx * dim_prod_post + post_idx
49 )
51 # Load from input and store to output
52 data = tl.load(input_ptr + input_idx, mask=mask)
53 tl.store(output_ptr + idx, data, mask=mask)
56def tensor_split(
57 input: torch.Tensor,
58 indices_or_sections: Union[int, List[int], torch.Tensor],
59 dim: int = 0,
60) -> List[torch.Tensor]:
61 """
62 Split a tensor into multiple sub-tensors along a given dimension.
64 This implementation uses Triton kernels to copy data to the output tensors.
65 Note: Unlike torch.tensor_split which returns views, this returns copies.
66 """
67 logger.debug("GEMS tensor_split")
69 # Validate input
70 if not isinstance(input, torch.Tensor):
71 raise TypeError(f"Expected tensor, got {type(input)}")
73 # Handle empty tensor
74 if input.numel() == 0:
75 # Return list of empty tensors with appropriate shapes
76 if isinstance(indices_or_sections, int):
77 n = indices_or_sections
78 else:
79 n = (
80 len(indices_or_sections) + 1
81 if isinstance(indices_or_sections, (list, tuple))
82 else 1
83 )
84 output_tensors = []
85 for i in range(n):
86 output_shape = list(input.shape)
87 output_shape[dim] = 0
88 output_tensors.append(
89 torch.empty(output_shape, dtype=input.dtype, device=input.device)
90 )
91 return output_tensors
93 # Normalize dim
94 dim = dim % input.ndim
95 dim_size = input.shape[dim]
97 # Handle indices_or_sections
98 if isinstance(indices_or_sections, int):
99 # Split into n sections
100 n = indices_or_sections
101 if n <= 0:
102 raise ValueError(f"indices_or_sections must be positive, got {n}")
104 # Calculate split sizes
105 base_size = dim_size // n
106 remainder = dim_size % n
108 split_sizes = []
109 for i in range(n):
110 if i < remainder:
111 split_sizes.append(base_size + 1)
112 else:
113 split_sizes.append(base_size)
115 elif isinstance(indices_or_sections, (list, tuple)):
116 # Split at specific indices
117 indices = list(indices_or_sections)
118 if not all(isinstance(i, int) for i in indices):
119 raise TypeError("All elements in indices_or_sections must be integers")
121 # Validate indices are sorted and in range
122 if len(indices) > 0:
123 if indices[0] < 0 or indices[-1] > dim_size:
124 raise ValueError(
125 f"indices_or_sections must be in range [0, {dim_size}]"
126 )
127 if any(indices[i] >= indices[i + 1] for i in range(len(indices) - 1)):
128 raise ValueError("indices_or_sections must be strictly increasing")
130 # Calculate split sizes
131 prev_idx = 0
132 split_sizes = []
133 for idx in indices:
134 split_sizes.append(idx - prev_idx)
135 prev_idx = idx
136 split_sizes.append(dim_size - prev_idx)
138 elif isinstance(indices_or_sections, torch.Tensor):
139 # Handle tensor indices (must be on CPU, 0-d or 1-d)
140 indices_or_sections = indices_or_sections.cpu()
141 if indices_or_sections.ndim > 1:
142 raise ValueError("indices_or_sections tensor must be 0-d or 1-d")
144 if indices_or_sections.ndim == 0:
145 # Single integer value
146 return tensor_split(input, indices_or_sections.item(), dim)
147 else:
148 # List of indices
149 indices = indices_or_sections.tolist()
150 return tensor_split(input, indices, dim)
152 else:
153 raise TypeError(
154 f"indices_or_sections must be int, list, tuple, or tensor, got {type(indices_or_sections)}"
155 )
157 # Create output tensors using Triton kernel
158 output_tensors = []
160 # Pre-compute dimension products
161 dim_prod_post = 1
162 for d in range(dim + 1, input.ndim):
163 dim_prod_post *= input.shape[d]
165 dim_prod_pre = 1
166 for d in range(dim):
167 dim_prod_pre *= input.shape[d]
169 # Process each split
170 current_dim_idx = 0
171 # Standard block size for element-wise copy kernel
172 BLOCK_SIZE = 1024
174 for split_size in split_sizes:
175 # Create output tensor
176 output_shape = list(input.shape)
177 output_shape[dim] = split_size
178 output_tensor = torch.empty(
179 output_shape, dtype=input.dtype, device=input.device
180 )
182 if split_size == 0:
183 # Empty tensor
184 output_tensors.append(output_tensor)
185 continue
187 # Calculate total elements in this split
188 total_elements = 1
189 for i, s in enumerate(output_shape):
190 total_elements *= s
192 # Ensure input is contiguous for correct indexing
193 input_contiguous = input.contiguous()
195 # Launch Triton kernel
196 grid = (triton.cdiv(total_elements, BLOCK_SIZE),)
197 split_copy_kernel[grid](
198 input_contiguous,
199 output_tensor,
200 total_elements,
201 dim_size, # dim_size_input
202 split_size, # dim_size_output
203 dim, # split_dim
204 dim_prod_pre,
205 dim_prod_post,
206 BLOCK_SIZE=BLOCK_SIZE,
207 )
209 output_tensors.append(output_tensor)
210 current_dim_idx += split_size
212 return output_tensors