Coverage for src/flag_gems/runtime/backend/_sunrise/ops/unique_consecutive.py: 0%

191 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import triton_lang_extension as ext 

9from flag_gems.utils.libentry import libentry 

10 

11logger = logging.getLogger(__name__) 

12 

13_PTPU_SAFE_MAX_TILE_SIZE = 512 

14 

15 

16@libentry() 

17@triton.jit 

18def simple_unique_consecutive_flat_kernel( 

19 data_ptr: tl.tensor, # in 

20 data_out_ptr: tl.tensor, 

21 inverse_indices_ptr: tl.tensor, 

22 idx_ptr: tl.tensor, 

23 unique_size_ptr: tl.tensor, # out 

24 return_inverse: tl.constexpr, 

25 return_counts: tl.constexpr, 

26 num_tasks: int, 

27 tile_size: tl.constexpr, 

28): 

29 """Simple kernel for small inputs that fits in a single tile.""" 

30 i0 = tl.arange(0, tile_size) 

31 mask = i0 < num_tasks 

32 

33 # load current and previous elements 

34 a = tl.load(data_ptr + i0, mask=mask) 

35 i0_prev = tl.where(i0 > 0, i0 - 1, 0) 

36 b = tl.load(data_ptr + i0_prev, mask=mask) 

37 

38 # Check if element differs from previous (first element always starts a new group) 

39 ne_result = tl.where(i0 > 0, a != b, 1) 

40 cumsum = tl.cumsum(ne_result) 

41 

42 # cumsum gives us 1-indexed positions, we want 0-indexed 

43 out_idx = cumsum - 1 

44 

45 # unique_size is the last cumsum value 

46 unique_size_mask = i0 == num_tasks - 1 

47 tl.store(unique_size_ptr + tl.zeros_like(i0), cumsum, mask=unique_size_mask) 

48 

49 # data_out: scatter unique values to their output positions 

50 # Only write when this is the first element of a consecutive group 

51 write_mask = ne_result.to(tl.int1) & mask 

52 tl.store(data_out_ptr + out_idx, a, mask=write_mask) 

53 

54 # inverse_indices: each input position maps to its output position 

55 if return_inverse: 

56 tl.store(inverse_indices_ptr + i0, out_idx, mask=mask) 

57 

58 # idx: store the starting position of each unique group 

59 if return_counts: 

60 tl.store(idx_ptr + out_idx, i0, mask=write_mask) 

61 

62 

63@triton.jit 

64def output_counts_impl( 

65 global_pid, 

66 idx_ptr: tl.tensor, 

67 origin_num_tasks: int, # in 

68 counts_ptr: tl.tensor, # out 

69 num_tasks: int, 

70 tile_size: tl.constexpr, 

71): 

72 """Compute counts from idx positions.""" 

73 r = tl.arange(0, tile_size) 

74 i0 = global_pid * tile_size + r 

75 mask = i0 < num_tasks 

76 

77 # load idx 

78 idx = tl.load(idx_ptr + i0, mask=mask) 

79 

80 # load idx_next 

81 i0_next = i0 + 1 

82 next_mask = i0_next < num_tasks 

83 idx_next = tl.load(idx_ptr + i0_next, mask=next_mask) 

84 

85 # counts = next_idx - current_idx (or total - current_idx for last element) 

86 counts = tl.where(i0_next < num_tasks, idx_next - idx, origin_num_tasks - idx) 

87 

88 # store counts 

89 tl.store(counts_ptr + i0, counts, mask=mask) 

90 

91 

92@libentry() 

93@triton.jit 

94def output_counts_kernel( 

95 idx_ptr: tl.tensor, 

96 origin_num_tasks: int, # in 

97 counts_ptr: tl.tensor, # out 

98 num_tasks: int, 

99 tiles_per_cta: int, 

100 tile_size: tl.constexpr, 

101): 

102 pid = ext.program_id(0) 

103 ctas_num = ext.num_programs(0) 

104 for j in range(0, tiles_per_cta): 

105 global_pid = pid + j * ctas_num 

106 output_counts_impl( 

107 global_pid, 

108 idx_ptr, 

109 origin_num_tasks, 

110 counts_ptr, 

111 num_tasks, 

112 tile_size, 

113 ) 

114 

115 

116@triton.jit 

117def local_ne_consecutive_impl( 

118 global_pid, 

119 data_ptr: tl.tensor, # in 

120 ne_result_ptr: tl.tensor, 

121 tile_sum_ptr: tl.tensor, # out 

122 global_ctas_num: int, 

123 num_tasks: int, 

124 tile_size: tl.constexpr, 

125): 

126 """Compute ne_result (whether each element differs from previous) for a tile.""" 

127 r = tl.arange(0, tile_size) 

128 i0 = global_pid * tile_size + r 

129 mask = i0 < num_tasks 

130 i0_prev = tl.where(i0 > 0, i0 - 1, 0) 

131 

132 # load current and previous 

133 a = tl.load(data_ptr + i0, mask=mask) 

134 b = tl.load(data_ptr + i0_prev, mask=mask) 

135 

136 # compute ne_result 

137 ne_result = tl.where(i0 > 0, a != b, 1) 

138 

139 # store ne_result 

140 tl.store(ne_result_ptr + i0, ne_result, mask=mask) 

141 

142 # store tile_sum 

143 tile_sum = tl.sum(ne_result) 

144 tile_sum_mask = global_pid < global_ctas_num 

145 tl.store(tile_sum_ptr + global_pid, tile_sum, mask=tile_sum_mask) 

146 

147 

148@libentry() 

149@triton.jit 

150def local_ne_consecutive_kernel( 

151 data_ptr: tl.tensor, # in 

152 ne_result_ptr: tl.tensor, 

153 tile_sum_ptr: tl.tensor, # out 

154 global_ctas_num: int, 

155 num_tasks: int, 

156 tiles_per_cta: int, 

157 tile_size: tl.constexpr, 

158): 

159 pid = ext.program_id(0) 

160 ctas_num = ext.num_programs(0) 

161 for j in range(0, tiles_per_cta): 

162 global_pid = pid + j * ctas_num 

163 local_ne_consecutive_impl( 

164 global_pid, 

165 data_ptr, 

166 ne_result_ptr, 

167 tile_sum_ptr, 

168 global_ctas_num, 

169 num_tasks, 

170 tile_size, 

171 ) 

172 

173 

174@triton.jit 

175def global_cumsum_consecutive_impl( 

176 global_pid, 

177 total, 

178 ne_result_ptr: tl.tensor, 

179 tile_sum_ptr: tl.tensor, # in 

180 data_ptr: tl.tensor, # in 

181 data_out_ptr: tl.tensor, 

182 inverse_indices_ptr: tl.tensor, 

183 idx_ptr: tl.tensor, # out 

184 ctas_num: tl.constexpr, 

185 global_ctas_num: int, 

186 next_power_global_ctas_num: tl.constexpr, 

187 num_tasks: int, 

188 tile_size: tl.constexpr, 

189 return_inverse: tl.constexpr, 

190 return_counts: tl.constexpr, 

191): 

192 """Compute global cumsum and scatter outputs.""" 

193 offset = global_pid * tile_size 

194 r = tl.arange(0, tile_size) 

195 i0 = offset + r 

196 mask = i0 < num_tasks 

197 

198 # load data 

199 data = tl.load(data_ptr + i0, mask=mask) 

200 

201 # load tile_sum for previous tiles 

202 p = tl.arange(0, next_power_global_ctas_num) 

203 pre_tile_sum_mask = ( 

204 (p >= global_pid - ctas_num) 

205 & (p < global_pid) 

206 & (p >= 0) 

207 & (p < global_ctas_num) 

208 ) 

209 pre_tile_sum = tl.load(tile_sum_ptr + p, mask=pre_tile_sum_mask, other=0) 

210 

211 # cumsum within tile 

212 total += tl.sum(pre_tile_sum) 

213 ne_result = tl.load(ne_result_ptr + i0, mask=mask) 

214 ne_result_i1 = ne_result.to(tl.int1) 

215 ne_result_i32 = ne_result.to(tl.int32) 

216 cumsum = tl.cumsum(ne_result_i32) 

217 

218 # Store final tile sum for the last tile 

219 if global_pid == global_ctas_num - 1: 

220 last_tile_sum_mask = i0 == num_tasks - 1 

221 final_tile_sum = tl.where(last_tile_sum_mask, total + cumsum, cumsum) 

222 tl.store( 

223 tile_sum_ptr + global_pid + tl.zeros_like(r), 

224 final_tile_sum, 

225 mask=last_tile_sum_mask, 

226 ) 

227 cumsum += total 

228 

229 # output index (0-indexed) 

230 out_idx = cumsum - 1 

231 

232 # data_out: scatter unique values (only first element of each consecutive group) 

233 tl.store(data_out_ptr + out_idx, data, mask=ne_result_i1 & mask) 

234 

235 # inverse_indices: each input position maps to its output index 

236 if return_inverse: 

237 tl.store(inverse_indices_ptr + i0, out_idx, mask=mask) 

238 

239 # idx: store starting position of each unique group 

240 if return_counts: 

241 tl.store(idx_ptr + out_idx, i0, mask=ne_result_i1 & mask) 

242 

243 return total 

244 

245 

246@libentry() 

247@triton.jit 

248def global_cumsum_consecutive_kernel( 

249 ne_result_ptr: tl.tensor, 

250 tile_sum_ptr: tl.tensor, # in 

251 data_ptr: tl.tensor, # in 

252 data_out_ptr: tl.tensor, 

253 inverse_indices_ptr: tl.tensor, 

254 idx_ptr: tl.tensor, # out 

255 ctas_num: int, 

256 global_ctas_num: int, 

257 next_power_global_ctas_num: tl.constexpr, 

258 num_tasks: int, 

259 tiles_per_cta: int, 

260 tile_size: tl.constexpr, 

261 one_tile_per_cta: tl.constexpr, 

262 return_inverse: tl.constexpr, 

263 return_counts: tl.constexpr, 

264): 

265 pid = ext.program_id(0) 

266 ctas_num = ext.num_programs(0) 

267 if one_tile_per_cta: 

268 global_cumsum_consecutive_impl( 

269 pid, 

270 0, 

271 ne_result_ptr, 

272 tile_sum_ptr, 

273 data_ptr, 

274 data_out_ptr, 

275 inverse_indices_ptr, 

276 idx_ptr, 

277 ctas_num, 

278 global_ctas_num, 

279 next_power_global_ctas_num, 

280 num_tasks, 

281 tile_size, 

282 return_inverse, 

283 return_counts, 

284 ) 

285 else: 

286 total = tl.zeros([1], dtype=tl.int64) 

287 for j in range(0, tiles_per_cta): 

288 global_pid = pid + j * ctas_num 

289 total = global_cumsum_consecutive_impl( 

290 global_pid, 

291 total, 

292 ne_result_ptr, 

293 tile_sum_ptr, 

294 data_ptr, 

295 data_out_ptr, 

296 inverse_indices_ptr, 

297 idx_ptr, 

298 ctas_num, 

299 global_ctas_num, 

300 next_power_global_ctas_num, 

301 num_tasks, 

302 tile_size, 

303 return_inverse, 

304 return_counts, 

305 ) 

306 

307 

308def simple_unique_consecutive_flat( 

309 data: torch.Tensor, 

310 return_inverse: bool, 

311 return_counts: bool, 

312): 

313 """Handle small inputs with a single kernel launch.""" 

314 num_tasks = data.numel() 

315 grid = (1, 1, 1) 

316 

317 # allocate tensors 

318 data_out = torch.empty_like(data) 

319 inverse_indices = ( 

320 torch.empty(num_tasks, dtype=torch.int64, device=data.device) 

321 if return_inverse 

322 else None 

323 ) 

324 idx = ( 

325 torch.empty(num_tasks, dtype=torch.int64, device=data.device) 

326 if return_counts 

327 else None 

328 ) 

329 unique_size = torch.empty([1], dtype=torch.int64, device=data.device) 

330 

331 # launch kernel 

332 with torch_device_fn.device(data.device.index): 

333 simple_unique_consecutive_flat_kernel[grid]( 

334 data, 

335 data_out, 

336 inverse_indices, 

337 idx, 

338 unique_size, 

339 return_inverse, 

340 return_counts, 

341 num_tasks, 

342 tile_size=triton.next_power_of_2(num_tasks), 

343 num_warps=8, 

344 ) 

345 

346 out_size = unique_size.item() 

347 counts = None 

348 if return_counts: 

349 idx = idx[:out_size] 

350 counts = torch.empty_like(idx) 

351 with torch_device_fn.device(data.device.index): 

352 output_counts_kernel[grid]( 

353 idx, 

354 num_tasks, 

355 counts, 

356 num_tasks=out_size, 

357 tiles_per_cta=1, 

358 tile_size=triton.next_power_of_2(out_size), 

359 num_warps=8, 

360 ) 

361 

362 return data_out[:out_size], inverse_indices, counts 

363 

364 

365def large_unique_consecutive_flat( 

366 data: torch.Tensor, 

367 return_inverse: bool, 

368 return_counts: bool, 

369): 

370 """Handle larger inputs with multi-kernel approach.""" 

371 num_tasks = data.numel() 

372 

373 if data.device.type == "ptpu": 

374 # Sunrise/PTPU only changes the unstable large-input organization path. 

375 next_power_num_tasks = triton.next_power_of_2(num_tasks) 

376 tile_size = min(_PTPU_SAFE_MAX_TILE_SIZE, next_power_num_tasks) 

377 global_ctas_num = triton.cdiv(num_tasks, tile_size) 

378 ctas_num = global_ctas_num if global_ctas_num < 32768 else 8192 

379 tiles_per_cta = triton.cdiv(num_tasks, tile_size * ctas_num) 

380 num_warps = 8 if tiles_per_cta == 1 else 32 

381 grid = (ctas_num, 1, 1) 

382 

383 ne_result = torch.empty(num_tasks, dtype=torch.bool, device=data.device) 

384 tile_sum = torch.empty(global_ctas_num, dtype=torch.int64, device=data.device) 

385 

386 with torch_device_fn.device(data.device.index): 

387 local_ne_consecutive_kernel[grid]( 

388 data, 

389 ne_result, 

390 tile_sum, 

391 global_ctas_num, 

392 num_tasks, 

393 tiles_per_cta=tiles_per_cta, 

394 tile_size=tile_size, 

395 num_warps=num_warps, 

396 ) 

397 

398 starts = torch.nonzero(ne_result, as_tuple=False).flatten() 

399 output = torch.index_select(data, 0, starts) 

400 

401 inverse_indices = None 

402 if return_inverse: 

403 inverse_indices = torch.cumsum(ne_result.to(torch.int64), dim=0) - 1 

404 

405 counts = None 

406 if return_counts: 

407 tail = starts.new_tensor([num_tasks]) - starts[-1:] 

408 counts = torch.cat((starts[1:] - starts[:-1], tail)) 

409 

410 return output, inverse_indices, counts 

411 

412 next_power_num_tasks = triton.next_power_of_2(num_tasks) 

413 tile_size = min(8192, next_power_num_tasks) 

414 global_ctas_num = triton.cdiv(num_tasks, tile_size) 

415 

416 if global_ctas_num <= 8192: 

417 min_tile_size = 512 if global_ctas_num > 32 else 256 

418 tile_size = max( 

419 min_tile_size, 

420 min(triton.next_power_of_2(global_ctas_num), next_power_num_tasks), 

421 ) 

422 global_ctas_num = triton.cdiv(num_tasks, tile_size) 

423 

424 next_power_global_ctas_num = triton.next_power_of_2(global_ctas_num) 

425 ctas_num = global_ctas_num if global_ctas_num < 32768 else 8192 

426 tiles_per_cta = triton.cdiv(num_tasks, tile_size * ctas_num) 

427 num_warps = 8 if tiles_per_cta == 1 else 32 

428 grid = (ctas_num, 1, 1) 

429 

430 # allocate tensors 

431 ne_result = torch.empty(num_tasks, dtype=torch.bool, device=data.device) 

432 tile_sum = torch.empty(global_ctas_num, dtype=torch.int64, device=data.device) 

433 data_out = torch.empty_like(data) 

434 inverse_indices = ( 

435 torch.empty(num_tasks, dtype=torch.int64, device=data.device) 

436 if return_inverse 

437 else None 

438 ) 

439 idx = ( 

440 torch.empty(num_tasks, dtype=torch.int64, device=data.device) 

441 if return_counts 

442 else None 

443 ) 

444 

445 # launch kernels 

446 with torch_device_fn.device(data.device.index): 

447 local_ne_consecutive_kernel[grid]( 

448 data, 

449 ne_result, 

450 tile_sum, 

451 global_ctas_num, 

452 num_tasks, 

453 tiles_per_cta=tiles_per_cta, 

454 tile_size=tile_size, 

455 num_warps=num_warps, 

456 ) 

457 global_cumsum_consecutive_kernel[grid]( 

458 ne_result, 

459 tile_sum, 

460 data, 

461 data_out, 

462 inverse_indices, 

463 idx, 

464 ctas_num, 

465 global_ctas_num, 

466 next_power_global_ctas_num, 

467 num_tasks, 

468 tiles_per_cta=tiles_per_cta, 

469 tile_size=tile_size, 

470 one_tile_per_cta=tiles_per_cta == 1, 

471 return_inverse=return_inverse, 

472 return_counts=return_counts, 

473 num_warps=num_warps, 

474 ) 

475 out_size = tile_sum[-1].item() 

476 

477 counts = None 

478 if return_counts: 

479 idx = idx[:out_size] 

480 counts = torch.empty_like(idx) 

481 output_counts_kernel[grid]( 

482 idx, 

483 num_tasks, 

484 counts, 

485 out_size, 

486 tiles_per_cta, 

487 tile_size, 

488 num_warps=num_warps, 

489 ) 

490 

491 return data_out[:out_size], inverse_indices, counts 

492 

493 

494def unique_consecutive( 

495 input: torch.Tensor, 

496 return_inverse: bool = False, 

497 return_counts: bool = False, 

498 dim: int = None, 

499): 

500 """ 

501 Eliminates all but the first element from every consecutive group of equivalent elements. 

502 

503 Args: 

504 input: the input tensor 

505 return_inverse: Whether to return inverse indices 

506 return_counts: Whether to return counts for each unique element 

507 dim: the dimension to apply unique. If None, the unique of the flattened input is returned. 

508 

509 Returns: 

510 (Tensor, Tensor (optional), Tensor (optional)): output, inverse_indices, counts 

511 """ 

512 logger.debug("GEMS UNIQUE_CONSECUTIVE") 

513 

514 if dim is not None: 

515 # For dim-wise unique_consecutive, fall back to PyTorch for now 

516 # This could be implemented with a more complex kernel 

517 return torch.unique_consecutive( 

518 input, 

519 return_inverse=return_inverse, 

520 return_counts=return_counts, 

521 dim=dim, 

522 ) 

523 

524 # Flatten input for the None dim case 

525 flat_input = input.ravel() 

526 num_tasks = flat_input.numel() 

527 

528 if num_tasks == 0: 

529 # Handle empty input 

530 output = torch.empty(0, dtype=input.dtype, device=input.device) 

531 inverse_indices = ( 

532 torch.empty(0, dtype=torch.int64, device=input.device) 

533 if return_inverse 

534 else None 

535 ) 

536 counts = ( 

537 torch.empty(0, dtype=torch.int64, device=input.device) 

538 if return_counts 

539 else None 

540 ) 

541 return output, inverse_indices, counts 

542 

543 # Choose algorithm based on input size 

544 if num_tasks <= 8192: 

545 output, inverse_indices, counts = simple_unique_consecutive_flat( 

546 flat_input, return_inverse, return_counts 

547 ) 

548 else: 

549 output, inverse_indices, counts = large_unique_consecutive_flat( 

550 flat_input, return_inverse, return_counts 

551 ) 

552 

553 # Reshape inverse_indices to match input shape 

554 if inverse_indices is not None: 

555 inverse_indices = inverse_indices.view_as(input) 

556 

557 return output, inverse_indices, counts