Coverage for src/flag_gems/ops/tensor_split.py: 9%

95 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3from typing import List, Union 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12@triton.jit 

13def split_copy_kernel( 

14 input_ptr, 

15 output_ptr, 

16 num_elements: tl.constexpr, 

17 dim_size_input: tl.constexpr, 

18 dim_size_output: tl.constexpr, 

19 split_dim: tl.constexpr, 

20 dim_prod_pre: tl.constexpr, 

21 dim_prod_post: tl.constexpr, 

22 BLOCK_SIZE: tl.constexpr, 

23): 

24 """Kernel to copy a split from input to output tensor.""" 

25 pid = tl.program_id(0) 

26 offset = pid * BLOCK_SIZE 

27 offsets = tl.arange(0, BLOCK_SIZE) 

28 mask = offset + offsets < num_elements 

29 

30 # Compute the linear index 

31 idx = offset + offsets 

32 

33 # Convert linear index to multi-dimensional index 

34 # For the split dimension, we use the output dimension size 

35 # For other dimensions, we use the input dimension sizes 

36 

37 # Linear index = pre_idx * dim_size_input * post_dim + split_idx * post_dim + post_idx 

38 # But we need to map to the correct position in the input tensor 

39 

40 # Compute indices for each dimension 

41 pre_idx = idx // (dim_size_output * dim_prod_post) 

42 split_idx = (idx // dim_prod_post) % dim_size_output 

43 post_idx = idx % dim_prod_post 

44 

45 # Compute input index 

46 # The split dimension offset is already accounted for in the output tensor layout 

47 input_idx = ( 

48 pre_idx * dim_size_input * dim_prod_post + split_idx * dim_prod_post + post_idx 

49 ) 

50 

51 # Load from input and store to output 

52 data = tl.load(input_ptr + input_idx, mask=mask) 

53 tl.store(output_ptr + idx, data, mask=mask) 

54 

55 

56def tensor_split( 

57 input: torch.Tensor, 

58 indices_or_sections: Union[int, List[int], torch.Tensor], 

59 dim: int = 0, 

60) -> List[torch.Tensor]: 

61 """ 

62 Split a tensor into multiple sub-tensors along a given dimension. 

63 

64 This implementation uses Triton kernels to copy data to the output tensors. 

65 Note: Unlike torch.tensor_split which returns views, this returns copies. 

66 """ 

67 logger.debug("GEMS tensor_split") 

68 

69 # Validate input 

70 if not isinstance(input, torch.Tensor): 

71 raise TypeError(f"Expected tensor, got {type(input)}") 

72 

73 # Handle empty tensor 

74 if input.numel() == 0: 

75 # Return list of empty tensors with appropriate shapes 

76 if isinstance(indices_or_sections, int): 

77 n = indices_or_sections 

78 else: 

79 n = ( 

80 len(indices_or_sections) + 1 

81 if isinstance(indices_or_sections, (list, tuple)) 

82 else 1 

83 ) 

84 output_tensors = [] 

85 for i in range(n): 

86 output_shape = list(input.shape) 

87 output_shape[dim] = 0 

88 output_tensors.append( 

89 torch.empty(output_shape, dtype=input.dtype, device=input.device) 

90 ) 

91 return output_tensors 

92 

93 # Normalize dim 

94 dim = dim % input.ndim 

95 dim_size = input.shape[dim] 

96 

97 # Handle indices_or_sections 

98 if isinstance(indices_or_sections, int): 

99 # Split into n sections 

100 n = indices_or_sections 

101 if n <= 0: 

102 raise ValueError(f"indices_or_sections must be positive, got {n}") 

103 

104 # Calculate split sizes 

105 base_size = dim_size // n 

106 remainder = dim_size % n 

107 

108 split_sizes = [] 

109 for i in range(n): 

110 if i < remainder: 

111 split_sizes.append(base_size + 1) 

112 else: 

113 split_sizes.append(base_size) 

114 

115 elif isinstance(indices_or_sections, (list, tuple)): 

116 # Split at specific indices 

117 indices = list(indices_or_sections) 

118 if not all(isinstance(i, int) for i in indices): 

119 raise TypeError("All elements in indices_or_sections must be integers") 

120 

121 # Validate indices are sorted and in range 

122 if len(indices) > 0: 

123 if indices[0] < 0 or indices[-1] > dim_size: 

124 raise ValueError( 

125 f"indices_or_sections must be in range [0, {dim_size}]" 

126 ) 

127 if any(indices[i] >= indices[i + 1] for i in range(len(indices) - 1)): 

128 raise ValueError("indices_or_sections must be strictly increasing") 

129 

130 # Calculate split sizes 

131 prev_idx = 0 

132 split_sizes = [] 

133 for idx in indices: 

134 split_sizes.append(idx - prev_idx) 

135 prev_idx = idx 

136 split_sizes.append(dim_size - prev_idx) 

137 

138 elif isinstance(indices_or_sections, torch.Tensor): 

139 # Handle tensor indices (must be on CPU, 0-d or 1-d) 

140 indices_or_sections = indices_or_sections.cpu() 

141 if indices_or_sections.ndim > 1: 

142 raise ValueError("indices_or_sections tensor must be 0-d or 1-d") 

143 

144 if indices_or_sections.ndim == 0: 

145 # Single integer value 

146 return tensor_split(input, indices_or_sections.item(), dim) 

147 else: 

148 # List of indices 

149 indices = indices_or_sections.tolist() 

150 return tensor_split(input, indices, dim) 

151 

152 else: 

153 raise TypeError( 

154 f"indices_or_sections must be int, list, tuple, or tensor, got {type(indices_or_sections)}" 

155 ) 

156 

157 # Create output tensors using Triton kernel 

158 output_tensors = [] 

159 

160 # Pre-compute dimension products 

161 dim_prod_post = 1 

162 for d in range(dim + 1, input.ndim): 

163 dim_prod_post *= input.shape[d] 

164 

165 dim_prod_pre = 1 

166 for d in range(dim): 

167 dim_prod_pre *= input.shape[d] 

168 

169 # Process each split 

170 current_dim_idx = 0 

171 # Standard block size for element-wise copy kernel 

172 BLOCK_SIZE = 1024 

173 

174 for split_size in split_sizes: 

175 # Create output tensor 

176 output_shape = list(input.shape) 

177 output_shape[dim] = split_size 

178 output_tensor = torch.empty( 

179 output_shape, dtype=input.dtype, device=input.device 

180 ) 

181 

182 if split_size == 0: 

183 # Empty tensor 

184 output_tensors.append(output_tensor) 

185 continue 

186 

187 # Calculate total elements in this split 

188 total_elements = 1 

189 for i, s in enumerate(output_shape): 

190 total_elements *= s 

191 

192 # Ensure input is contiguous for correct indexing 

193 input_contiguous = input.contiguous() 

194 

195 # Launch Triton kernel 

196 grid = (triton.cdiv(total_elements, BLOCK_SIZE),) 

197 split_copy_kernel[grid]( 

198 input_contiguous, 

199 output_tensor, 

200 total_elements, 

201 dim_size, # dim_size_input 

202 split_size, # dim_size_output 

203 dim, # split_dim 

204 dim_prod_pre, 

205 dim_prod_post, 

206 BLOCK_SIZE=BLOCK_SIZE, 

207 ) 

208 

209 output_tensors.append(output_tensor) 

210 current_dim_idx += split_size 

211 

212 return output_tensors