Coverage for src/flag_gems/ops/cat.py: 54%

193 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2from typing import List, Tuple, Union 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8logger = logging.getLogger(__name__) 

9 

10 

11def _is_float8_e8m0fnu(dtype: torch.dtype) -> bool: 

12 return str(dtype) == "torch.float8_e8m0fnu" 

13 

14 

15def _should_use_uint8_view_path( 

16 A: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], 

17) -> bool: 

18 if len(A) == 0: 

19 return False 

20 first_dtype = A[0].dtype 

21 if not _is_float8_e8m0fnu(first_dtype): 

22 return False 

23 if A[0].element_size() != 1: 

24 return False 

25 for tensor in A[1:]: 

26 if tensor.dtype != first_dtype or tensor.element_size() != 1: 

27 return False 

28 return True 

29 

30 

31def _cat_build_working_list_uint8_view( 

32 A: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], 

33 dim: int, 

34): 

35 original_dtype = A[0].dtype 

36 A_u8 = [tensor.view(torch.uint8) for tensor in A] 

37 mode, payload = _cat_build_working_list(A_u8, dim) 

38 return mode, payload, original_dtype 

39 

40 

41@triton.jit 

42def cat_copy_func_kernel_4( 

43 out_ptr, 

44 in_ptr_a, 

45 in_ptr_b, 

46 in_ptr_c, 

47 in_ptr_d, 

48 dim_size_in_a, 

49 dim_size_in_b, 

50 dim_size_in_c, 

51 dim_size_in_d, 

52 dim_size_out, 

53 dim_prod_post, 

54 dim_offset_a: tl.int64, 

55 dim_offset_b: tl.int64, 

56 dim_offset_c: tl.int64, 

57 dim_offset_d: tl.int64, 

58 total_elements_a, 

59 total_elements_b, 

60 total_elements_c, 

61 total_elements_d, 

62 BLOCK_X: tl.constexpr, 

63): 

64 pid_x = tl.program_id(0) 

65 pid_y = tl.program_id(1) 

66 

67 if pid_y == 0: 

68 in_ptr = in_ptr_a 

69 dim_size_in = dim_size_in_a 

70 dim_offset = tl.cast(dim_offset_a, tl.int64) 

71 total_elements = total_elements_a 

72 elif pid_y == 1: 

73 in_ptr = in_ptr_b 

74 dim_size_in = dim_size_in_b 

75 dim_offset = tl.cast(dim_offset_b, tl.int64) 

76 total_elements = total_elements_b 

77 elif pid_y == 2: 

78 in_ptr = in_ptr_c 

79 dim_size_in = dim_size_in_c 

80 dim_offset = tl.cast(dim_offset_c, tl.int64) 

81 total_elements = total_elements_c 

82 else: 

83 in_ptr = in_ptr_d 

84 dim_size_in = dim_size_in_d 

85 dim_offset = tl.cast(dim_offset_d, tl.int64) 

86 total_elements = total_elements_d 

87 

88 block_start = pid_x * BLOCK_X 

89 offsets = tl.arange(0, BLOCK_X) 

90 mask = block_start + offsets < total_elements 

91 

92 idx = block_start + offsets 

93 

94 pre_idx = idx // (dim_size_in * dim_prod_post) 

95 dim_idx = (idx // dim_prod_post) % dim_size_in 

96 post_idx = idx % dim_prod_post 

97 

98 out_idx = ( 

99 pre_idx * dim_size_out * dim_prod_post 

100 + (dim_idx + dim_offset) * dim_prod_post 

101 + post_idx 

102 ) 

103 

104 data = tl.load(in_ptr + idx, mask=mask) 

105 tl.store(out_ptr + out_idx, data, mask=mask) 

106 

107 

108def _cat_run_kernel( 

109 A: List[torch.Tensor], 

110 dim: int, 

111 out_shape: List[int], 

112 out: torch.Tensor, 

113): 

114 BLOCK = 1024 

115 dim_offset = 0 

116 i = 0 

117 while i < len(A): 

118 tensors_in_batch = A[i : i + 4] 

119 num_tensors_in_batch = len(tensors_in_batch) 

120 

121 args = [] 

122 total_elements_list = [] 

123 current_dim_offset = dim_offset 

124 

125 for j in range(4): 

126 if j < num_tensors_in_batch: 

127 tensor = tensors_in_batch[j].contiguous() 

128 shape = tensor.shape 

129 total_elements = tensor.numel() 

130 dim_size_in = shape[dim] 

131 

132 args.extend([tensor, dim_size_in, current_dim_offset, total_elements]) 

133 total_elements_list.append(total_elements) 

134 current_dim_offset += dim_size_in 

135 else: 

136 args.extend([tensors_in_batch[0], 0, 0, 0]) 

137 total_elements_list.append(0) 

138 

139 dim_size_out = out_shape[dim] 

140 dim_prod_post = 1 

141 for d in range(dim + 1, A[0].ndim): 

142 dim_prod_post *= A[0].shape[d] 

143 

144 grid_y = num_tensors_in_batch 

145 max_elements_in_batch = max(total_elements_list) if total_elements_list else 0 

146 grid = (triton.cdiv(max_elements_in_batch, BLOCK), grid_y) 

147 

148 ( 

149 tensor_a, 

150 dim_size_in_a, 

151 dim_offset_a, 

152 total_elements_a, 

153 tensor_b, 

154 dim_size_in_b, 

155 dim_offset_b, 

156 total_elements_b, 

157 tensor_c, 

158 dim_size_in_c, 

159 dim_offset_c, 

160 total_elements_c, 

161 tensor_d, 

162 dim_size_in_d, 

163 dim_offset_d, 

164 total_elements_d, 

165 ) = args 

166 

167 cat_copy_func_kernel_4[grid]( 

168 out, 

169 tensor_a, 

170 tensor_b, 

171 tensor_c, 

172 tensor_d, 

173 dim_size_in_a, 

174 dim_size_in_b, 

175 dim_size_in_c, 

176 dim_size_in_d, 

177 dim_size_out, 

178 dim_prod_post, 

179 dim_offset_a, 

180 dim_offset_b, 

181 dim_offset_c, 

182 dim_offset_d, 

183 total_elements_a, 

184 total_elements_b, 

185 total_elements_c, 

186 total_elements_d, 

187 BLOCK_X=BLOCK, 

188 ) 

189 

190 dim_offset = current_dim_offset 

191 i += num_tensors_in_batch 

192 

193 

194def _cat_build_working_list( 

195 A: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int 

196): 

197 """Returns (mode, payload) where mode is 'single'|'empty'|'multi'.""" 

198 if len(A) == 0: 

199 raise RuntimeError("torch.cat(): expected a non-empty list of Tensors") 

200 if len(A) == 1: 

201 return "single", A[0] 

202 

203 device = A[0].device 

204 dtype = A[0].dtype 

205 A = list(A) 

206 for i in range(len(A) - 1, -1, -1): 

207 if A[i].shape == torch.Size([0]): 

208 A.pop(i) 

209 if len(A) == 0: 

210 return "empty", torch.tensor([], device=device, dtype=dtype) 

211 if len(A) == 1: 

212 return "single", A[0] 

213 

214 assert dim >= -A[0].ndim and dim < A[0].ndim, f"Invalid dim: {dim}" 

215 dim %= A[0].ndim 

216 

217 inp_shapes = [list(_.shape) for _ in A] 

218 inp0_shape = inp_shapes[0] 

219 for s in inp_shapes[1:]: 

220 if len(s) != len(inp0_shape): 

221 raise RuntimeError( 

222 f"Tensors must have same number of dimensions: got {len(inp0_shape)} and {len(s)}" 

223 ) 

224 for tensor_idx, inp_shape in enumerate(inp_shapes): 

225 for idx, (common_length, length) in enumerate(zip(inp0_shape, inp_shape)): 

226 if idx != dim and length != common_length: 

227 raise RuntimeError( 

228 f"Sizes of tensors must match except in dimension {dim}. " 

229 f"Expected size {common_length} but got size {length} for tensor number " 

230 f"{tensor_idx} in the list" 

231 ) 

232 

233 dtypes = [t.dtype for t in A] 

234 dtype = dtypes[0] 

235 for dt in dtypes[1:]: 

236 dtype = torch.promote_types(dtype, dt) 

237 A = [t.to(dtype) if t.dtype != dtype else t for t in A] 

238 

239 shapes = [t.shape for t in A] 

240 cat_dim_sizes = [s[dim] for s in shapes] 

241 out_shape = list(shapes[0]) 

242 out_shape[dim] = sum(cat_dim_sizes) 

243 return "multi", (A, dim, out_shape, dtype, A[0].device) 

244 

245 

246def cat_out( 

247 A: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], 

248 dim: int = 0, 

249 *, 

250 out: torch.Tensor, 

251) -> torch.Tensor: 

252 logger.debug("GEMS CAT_OUT") 

253 A = list(A) 

254 if _should_use_uint8_view_path(A): 

255 mode, payload, original_dtype = _cat_build_working_list_uint8_view(A, dim) 

256 if mode == "single": 

257 t = payload.view(original_dtype) 

258 out.resize_(t.shape) 

259 if out.dtype != t.dtype: 

260 out.copy_(t.to(out.dtype)) 

261 else: 

262 out.copy_(t) 

263 return out 

264 if mode == "empty": 

265 t = payload.view(original_dtype) 

266 out.resize_(t.shape) 

267 out.copy_(t) 

268 return out 

269 

270 A_u8, dim, out_shape, _, _ = payload 

271 if out.dtype != original_dtype: 

272 raise RuntimeError( 

273 f"cat.out: expected out dtype {original_dtype}, got {out.dtype}" 

274 ) 

275 if list(out.shape) != out_shape: 

276 out.resize_(out_shape) 

277 out_u8 = out.view(torch.uint8) 

278 _cat_run_kernel(A_u8, dim, out_shape, out_u8) 

279 return out 

280 

281 mode, payload = _cat_build_working_list(A, dim) 

282 if mode == "single": 

283 t = payload 

284 out.resize_(t.shape) 

285 if out.dtype != t.dtype: 

286 out.copy_(t.to(out.dtype)) 

287 else: 

288 out.copy_(t) 

289 return out 

290 if mode == "empty": 

291 t = payload 

292 out.resize_(t.shape) 

293 out.copy_(t) 

294 return out 

295 

296 A, dim, out_shape, dtype, device = payload 

297 if out.dtype != dtype: 

298 raise RuntimeError(f"cat.out: expected out dtype {dtype}, got {out.dtype}") 

299 if list(out.shape) != out_shape: 

300 out.resize_(out_shape) 

301 _cat_run_kernel(A, dim, out_shape, out) 

302 return out 

303 

304 

305def cat( 

306 A: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], dim: int = 0 

307) -> torch.Tensor: 

308 logger.debug("GEMS CAT") 

309 A = list(A) 

310 if _should_use_uint8_view_path(A): 

311 mode, payload, original_dtype = _cat_build_working_list_uint8_view(A, dim) 

312 if mode == "single": 

313 return payload.view(original_dtype) 

314 if mode == "empty": 

315 return payload.view(original_dtype) 

316 

317 A_u8, dim, out_shape, _, device = payload 

318 out_u8 = torch.empty(out_shape, dtype=torch.uint8, device=device) 

319 _cat_run_kernel(A_u8, dim, out_shape, out_u8) 

320 return out_u8.view(original_dtype) 

321 

322 mode, payload = _cat_build_working_list(A, dim) 

323 if mode == "single": 

324 return payload 

325 if mode == "empty": 

326 return payload 

327 

328 A, dim, out_shape, dtype, device = payload 

329 out = torch.empty(out_shape, dtype=dtype, device=device) 

330 _cat_run_kernel(A, dim, out_shape, out) 

331 return out