Coverage for src/flag_gems/runtime/backend/_mthreads/ops/repeat_interleave.py: 0%

232 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry 

9from flag_gems.utils import triton_lang_extension as tle 

10from flag_gems.utils.pointwise_dynamic import pointwise_dynamic 

11from flag_gems.utils.shape_utils import c_contiguous_stride 

12from flag_gems.utils.tensor_wrapper import StridedBuffer 

13 

14logger = logging.getLogger( 

15 f"flag_gems.runtime.backend._mthreads.ops.{__name__.split('.')[-1]}" 

16) 

17 

18# repeat_interleave.self_{int,Tensor} are CompositeImplicitAutograd; 

19# Direct coverage will cause the gradient to break; 

20# Redispatch to this keyset to run the decomposed forward (and backward) 

21# when gradients may be needed. 

22_FALLBACK_KEYSET = torch._C.DispatchKeySet( 

23 torch._C.DispatchKey.CompositeImplicitAutograd 

24) 

25 

26 

27@pointwise_dynamic(num_inputs=1, promotion_methods=[(0, "DEFAULT")]) 

28@triton.jit 

29def copy_func(x): 

30 return x 

31 

32 

33def repeat_interleave_self_int(inp, repeats, dim=None, *, output_size=None): 

34 logger.debug("GEMS_MTHREADS REPEAT_INTERLEAVE_SELF_INT") 

35 if torch.is_grad_enabled(): 

36 return torch.ops.aten.repeat_interleave.self_int.redispatch( 

37 _FALLBACK_KEYSET, inp, repeats, dim, output_size=output_size 

38 ) 

39 if dim is None: 

40 inp = inp.flatten() 

41 dim = 0 

42 else: 

43 if (dim < -inp.ndim) or (dim >= inp.ndim): 

44 raise IndexError( 

45 "Dimension out of range (expected to be in range of [{}, {}], but got {})".format( 

46 -inp.ndim, inp.ndim - 1, dim 

47 ) 

48 ) 

49 inp_shape = list(inp.shape) 

50 inp_stride = list(inp.stride()) 

51 output_shape = list(inp.shape) 

52 

53 if dim < 0: 

54 dim = dim + len(inp_shape) 

55 

56 output_shape[dim] *= repeats 

57 

58 if output_size is not None and output_size != output_shape[dim]: 

59 raise RuntimeError( 

60 "repeat_interleave: Invalid output_size, expected {} but got {}".format( 

61 output_shape[dim], output_size 

62 ) 

63 ) 

64 

65 output = torch.empty(output_shape, dtype=inp.dtype, device=inp.device) 

66 

67 if repeats == 0: 

68 return output 

69 

70 in_view_stride = inp_stride[: dim + 1] + [0] + inp_stride[dim + 1 :] 

71 out_view_shape = inp_shape[: dim + 1] + [repeats] + inp_shape[dim + 1 :] 

72 out_view_stride = c_contiguous_stride(out_view_shape) 

73 

74 in_view = StridedBuffer(inp, out_view_shape, in_view_stride) 

75 out_view = StridedBuffer(output, out_view_shape, out_view_stride) 

76 ndim = len(out_view_shape) 

77 copy_func.instantiate(ndim)(in_view, out0=out_view) 

78 return output 

79 

80 

81@triton.jit 

82def repeat_interleave_tensor_kernel( 

83 repeats_ptr, cumsum_ptr, out_ptr, size, BLOCK_SIZE: tl.constexpr 

84): 

85 pid = tle.program_id(0) 

86 mask = pid < size 

87 cumsum = tl.load(cumsum_ptr + pid, mask, other=0) 

88 repeats = tl.load(repeats_ptr + pid, mask, other=0) 

89 out_offset = cumsum - repeats 

90 

91 tl.device_assert(repeats >= 0, "repeats can not be negative") 

92 

93 out_ptr += out_offset 

94 for start_k in range(0, repeats, BLOCK_SIZE): 

95 offsets_k = start_k + tl.arange(0, BLOCK_SIZE) 

96 mask_k = offsets_k < repeats 

97 tl.store(out_ptr + offsets_k, pid, mask=mask_k) 

98 

99 

100def repeat_interleave_tensor(repeats, *, output_size=None): 

101 logger.debug("GEMS_MTHREADS REPEAT_INTERLEAVE_TENSOR") 

102 

103 assert repeats.ndim == 1, "repeat_interleave only accept 1D vector as repeat" 

104 

105 cumsum = repeats.cumsum(axis=0) 

106 result_size = cumsum[-1].item() 

107 

108 assert result_size >= 0, "repeats can not be negative" 

109 

110 out = torch.empty((result_size,), dtype=repeats.dtype, device=repeats.device) 

111 size = repeats.size(0) 

112 

113 grid = (size,) 

114 BLOCK_SIZE = 32 

115 with torch_device_fn.device(repeats.device): 

116 repeat_interleave_tensor_kernel[grid]( 

117 repeats, 

118 cumsum, 

119 out, 

120 size, 

121 BLOCK_SIZE=BLOCK_SIZE, 

122 num_warps=1, 

123 ) 

124 return out 

125 

126 

127@libentry() 

128@triton.jit 

129def fused_repeat_interleave_dim0_kernel( 

130 inp_ptr, 

131 out_ptr, 

132 cumsum_ptr, 

133 num_input_rows, 

134 row_size, 

135 BLOCK_SIZE: tl.constexpr, 

136): 

137 """Fused kernel for repeat_interleave with dim=0. 

138 Each program handles one input row and copies to all its repeated output positions. 

139 """ 

140 pid = tle.program_id(0) 

141 

142 if pid >= num_input_rows: 

143 return 

144 

145 # Get output row range for this input row 

146 row_idx_mask = pid > 0 

147 start_row_idx = tl.load(cumsum_ptr + pid - 1, mask=row_idx_mask, other=0) 

148 end_row_idx = tl.load(cumsum_ptr + pid) 

149 

150 num_of_rows = end_row_idx - start_row_idx 

151 if num_of_rows == 0: 

152 return 

153 

154 # Calculate input row offset 

155 inp_row_offset = pid * row_size 

156 

157 # Process columns in blocks 

158 for col_block in range(0, tl.cdiv(row_size, BLOCK_SIZE)): 

159 col_offsets = col_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

160 col_mask = col_offsets < row_size 

161 

162 # Load from input 

163 cur_inp = tl.load( 

164 inp_ptr + inp_row_offset + col_offsets, mask=col_mask, other=0.0 

165 ) 

166 

167 # Store to each output row 

168 for cur_row in range(0, num_of_rows): 

169 output_row_index = start_row_idx + cur_row 

170 output_row_offsets = output_row_index * row_size + col_offsets 

171 tl.store(out_ptr + output_row_offsets, cur_inp, mask=col_mask) 

172 

173 

174@libentry() 

175@triton.jit 

176def fused_repeat_interleave_output_centric_kernel( 

177 inp_ptr, 

178 out_ptr, 

179 cumsum_ptr, 

180 num_input_rows, 

181 num_output_rows, 

182 row_size, 

183 BLOCK_SIZE: tl.constexpr, 

184): 

185 """Output-centric kernel for repeat_interleave with dim=0. 

186 Uses 2D grid: (num_output_rows, num_col_chunks). 

187 Uses binary search to find input row. 

188 """ 

189 out_row_idx = tle.program_id(0) 

190 col_chunk_idx = tle.program_id(1) 

191 

192 if out_row_idx >= num_output_rows: 

193 return 

194 

195 # Binary search to find input row index 

196 # Find the smallest i such that cumsum[i] > out_row_idx 

197 low = 0 

198 high = num_input_rows 

199 while low < high: 

200 mid = (low + high) // 2 

201 cumsum_mid = tl.load(cumsum_ptr + mid) 

202 if cumsum_mid <= out_row_idx: 

203 low = mid + 1 

204 else: 

205 high = mid 

206 

207 inp_row_idx = low 

208 

209 # Calculate column offsets for this chunk 

210 col_offset = col_chunk_idx * BLOCK_SIZE 

211 col_offsets = col_offset + tl.arange(0, BLOCK_SIZE) 

212 col_mask = col_offsets < row_size 

213 

214 # Load from input 

215 inp_row_offset = inp_row_idx * row_size 

216 cur_inp = tl.load(inp_ptr + inp_row_offset + col_offsets, mask=col_mask, other=0.0) 

217 

218 # Store to output 

219 out_row_offset = out_row_idx * row_size 

220 tl.store(out_ptr + out_row_offset + col_offsets, cur_inp, mask=col_mask) 

221 

222 

223@libentry() 

224@triton.jit 

225def fused_repeat_interleave_1d_bsearch_kernel( 

226 inp_ptr, 

227 out_ptr, 

228 cumsum_ptr, 

229 num_input_rows, 

230 num_output_rows, 

231 row_size, 

232 BLOCK_SIZE: tl.constexpr, 

233): 

234 """1D output-centric kernel with binary search. 

235 Each program handles one complete output row. 

236 Better for large row sizes. 

237 """ 

238 out_row_idx = tle.program_id(0) 

239 

240 if out_row_idx >= num_output_rows: 

241 return 

242 

243 # Binary search to find input row index 

244 low = 0 

245 high = num_input_rows 

246 while low < high: 

247 mid = (low + high) // 2 

248 cumsum_mid = tl.load(cumsum_ptr + mid) 

249 if cumsum_mid <= out_row_idx: 

250 low = mid + 1 

251 else: 

252 high = mid 

253 

254 inp_row_idx = low 

255 

256 # Calculate row offsets 

257 inp_row_offset = inp_row_idx * row_size 

258 out_row_offset = out_row_idx * row_size 

259 

260 # Process all columns in blocks 

261 for col_offset in range(0, row_size, BLOCK_SIZE): 

262 col_offsets = col_offset + tl.arange(0, BLOCK_SIZE) 

263 col_mask = col_offsets < row_size 

264 

265 cur_inp = tl.load( 

266 inp_ptr + inp_row_offset + col_offsets, mask=col_mask, other=0.0 

267 ) 

268 tl.store(out_ptr + out_row_offset + col_offsets, cur_inp, mask=col_mask) 

269 

270 

271@libentry() 

272@triton.jit 

273def fused_repeat_interleave_with_indices_kernel( 

274 inp_ptr, 

275 out_ptr, 

276 index_ptr, 

277 num_output_rows, 

278 row_size, 

279 BLOCK_SIZE: tl.constexpr, 

280): 

281 """Output-centric kernel using precomputed index mapping. 

282 Uses 2D grid: (num_output_rows, num_col_chunks). 

283 """ 

284 out_row_idx = tle.program_id(0) 

285 col_chunk_idx = tle.program_id(1) 

286 

287 if out_row_idx >= num_output_rows: 

288 return 

289 

290 # Load precomputed input row index 

291 inp_row_idx = tl.load(index_ptr + out_row_idx) 

292 

293 # Calculate column offsets for this chunk 

294 col_offset = col_chunk_idx * BLOCK_SIZE 

295 col_offsets = col_offset + tl.arange(0, BLOCK_SIZE) 

296 col_mask = col_offsets < row_size 

297 

298 # Load from input 

299 inp_row_offset = inp_row_idx * row_size 

300 cur_inp = tl.load(inp_ptr + inp_row_offset + col_offsets, mask=col_mask, other=0.0) 

301 

302 # Store to output 

303 out_row_offset = out_row_idx * row_size 

304 tl.store(out_ptr + out_row_offset + col_offsets, cur_inp, mask=col_mask) 

305 

306 

307@libentry() 

308@triton.jit 

309def fused_repeat_interleave_large_row_kernel( 

310 inp_ptr, 

311 out_ptr, 

312 index_ptr, 

313 num_output_rows, 

314 row_size, 

315 BLOCK_SIZE: tl.constexpr, 

316): 

317 """Optimized kernel for large row sizes. 

318 Each program handles one output row and processes all columns. 

319 """ 

320 out_row_idx = tle.program_id(0) 

321 

322 if out_row_idx >= num_output_rows: 

323 return 

324 

325 # Load precomputed input row index 

326 inp_row_idx = tl.load(index_ptr + out_row_idx) 

327 

328 # Calculate row offsets 

329 inp_row_offset = inp_row_idx * row_size 

330 out_row_offset = out_row_idx * row_size 

331 

332 # Process all columns in blocks 

333 for col_offset in range(0, row_size, BLOCK_SIZE): 

334 col_offsets = col_offset + tl.arange(0, BLOCK_SIZE) 

335 col_mask = col_offsets < row_size 

336 

337 # Load from input and store to output 

338 cur_inp = tl.load( 

339 inp_ptr + inp_row_offset + col_offsets, mask=col_mask, other=0.0 

340 ) 

341 tl.store(out_ptr + out_row_offset + col_offsets, cur_inp, mask=col_mask) 

342 

343 

344def fused_repeat_interleave_dim0(inp, repeats, dim): 

345 """Fused repeat_interleave for dim=0 case. 

346 Works with any tensor dimension, handles dim=0 efficiently. 

347 """ 

348 logger.debug("GEMS_MTHREADS FUSED_REPEAT_INTERLEAVE_DIM0") 

349 

350 assert repeats.ndim == 1, "repeat_interleave only accept 1D vector as repeat" 

351 

352 # Compute cumsum of repeats 

353 cumsum = repeats.cumsum(axis=0) 

354 total_output_rows = cumsum[-1].item() 

355 

356 if total_output_rows == 0: 

357 out_shape = list(inp.shape) 

358 out_shape[dim] = 0 

359 return torch.empty(out_shape, dtype=inp.dtype, device=inp.device) 

360 

361 # Setup output tensor 

362 out_shape = list(inp.shape) 

363 out_shape[dim] = total_output_rows 

364 out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device) 

365 

366 # Flatten non-dim dimensions for easier indexing 

367 num_input_rows = inp.shape[dim] 

368 row_size = inp.numel() // num_input_rows 

369 

370 # Make input contiguous for efficient access 

371 inp_contig = inp.contiguous() 

372 

373 # Strategy selection: 

374 # 1. Small tensors: input-centric kernel 

375 # 2. Medium row sizes: output-centric 2D grid with binary search 

376 # 3. Large row sizes: output-centric 1D grid with binary search 

377 

378 if row_size < 512 and total_output_rows < 512: 

379 # Small tensor: use input-centric kernel 

380 BLOCK_SIZE = min(triton.next_power_of_2(row_size), 4096) 

381 

382 if BLOCK_SIZE <= 256: 

383 num_warps = 2 

384 elif BLOCK_SIZE <= 512: 

385 num_warps = 4 

386 else: 

387 num_warps = 8 

388 

389 grid = (num_input_rows,) 

390 

391 with torch_device_fn.device(inp.device): 

392 fused_repeat_interleave_dim0_kernel[grid]( 

393 inp_contig, 

394 out, 

395 cumsum, 

396 num_input_rows, 

397 row_size, 

398 BLOCK_SIZE=BLOCK_SIZE, 

399 num_warps=num_warps, 

400 ) 

401 elif row_size >= 16384: 

402 # Large row size: use 1D grid with binary search 

403 # This reduces total number of programs and amortizes binary search cost 

404 BLOCK_SIZE = 2048 

405 num_warps = 16 

406 

407 grid = (total_output_rows,) 

408 

409 with torch_device_fn.device(inp.device): 

410 fused_repeat_interleave_1d_bsearch_kernel[grid]( 

411 inp_contig, 

412 out, 

413 cumsum, 

414 num_input_rows, 

415 total_output_rows, 

416 row_size, 

417 BLOCK_SIZE=BLOCK_SIZE, 

418 num_warps=num_warps, 

419 ) 

420 else: 

421 # Medium row size: use 2D grid with binary search 

422 BLOCK_SIZE = min(triton.next_power_of_2(row_size), 1024) 

423 num_col_chunks = triton.cdiv(row_size, BLOCK_SIZE) 

424 

425 if BLOCK_SIZE <= 256: 

426 num_warps = 2 

427 elif BLOCK_SIZE <= 512: 

428 num_warps = 4 

429 else: 

430 num_warps = 8 

431 

432 grid = (total_output_rows, num_col_chunks) 

433 

434 with torch_device_fn.device(inp.device): 

435 fused_repeat_interleave_output_centric_kernel[grid]( 

436 inp_contig, 

437 out, 

438 cumsum, 

439 num_input_rows, 

440 total_output_rows, 

441 row_size, 

442 BLOCK_SIZE=BLOCK_SIZE, 

443 num_warps=num_warps, 

444 ) 

445 

446 return out 

447 

448 

449def repeat_interleave_self_tensor(inp, repeats, dim=None, *, output_size=None): 

450 logger.debug("GEMS_MTHREADS REPEAT_INTERLEAVE_SELF_TENSOR") 

451 if torch.is_grad_enabled(): 

452 return torch.ops.aten.repeat_interleave.self_Tensor.redispatch( 

453 _FALLBACK_KEYSET, inp, repeats, dim, output_size=output_size 

454 ) 

455 

456 if repeats.numel() == 0: 

457 return inp.clone() 

458 

459 if dim is None: 

460 inp = inp.flatten() 

461 dim = 0 

462 else: 

463 if (dim < -inp.ndim) or (dim >= inp.ndim): 

464 raise IndexError( 

465 "Dimension out of range (expected to be in range of [{}, {}], but got {})".format( 

466 -inp.ndim, inp.ndim - 1, dim 

467 ) 

468 ) 

469 

470 if repeats.ndim == 0 or (repeats.ndim == 1 and repeats.size(0) == 1): 

471 return repeat_interleave_self_int( 

472 inp, repeats.item(), dim=dim, output_size=output_size 

473 ) 

474 elif repeats.ndim > 1: 

475 raise RuntimeError("repeats must be 0-dim or 1-dim tensor") 

476 

477 inp_shape = list(inp.shape) 

478 if dim < 0: 

479 dim = dim + len(inp_shape) 

480 

481 if repeats.size(0) != inp_shape[dim]: 

482 raise RuntimeError( 

483 "repeats must have the same size as input along dim, but got \ 

484 repeats.size(0) = {} and input.size({}) = {}".format( 

485 repeats.size(0), dim, inp_shape[dim] 

486 ) 

487 ) 

488 

489 # Use fused kernel for dim=0 

490 if dim == 0: 

491 return fused_repeat_interleave_dim0(inp, repeats, dim) 

492 

493 # For other dimensions, use the fallback implementation 

494 indices = repeat_interleave_tensor(repeats) 

495 res = torch.index_select(inp, dim, indices) 

496 

497 return res