Coverage for src/flag_gems/ops/unique_consecutive.py: 50%

168 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import triton_lang_extension as ext 

9from flag_gems.utils.libentry import libentry 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@libentry() 

15@triton.jit 

16def simple_unique_consecutive_flat_kernel( 

17 data_ptr: tl.tensor, # in 

18 data_out_ptr: tl.tensor, 

19 inverse_indices_ptr: tl.tensor, 

20 idx_ptr: tl.tensor, 

21 unique_size_ptr: tl.tensor, # out 

22 return_inverse: tl.constexpr, 

23 return_counts: tl.constexpr, 

24 num_tasks: int, 

25 tile_size: tl.constexpr, 

26): 

27 """Simple kernel for small inputs that fits in a single tile.""" 

28 i0 = tl.arange(0, tile_size) 

29 mask = i0 < num_tasks 

30 

31 # load current and previous elements 

32 a = tl.load(data_ptr + i0, mask=mask) 

33 i0_prev = tl.where(i0 > 0, i0 - 1, 0) 

34 b = tl.load(data_ptr + i0_prev, mask=mask) 

35 

36 # Check if element differs from previous (first element always starts a new group) 

37 ne_result = tl.where(i0 > 0, a != b, 1) 

38 cumsum = tl.cumsum(ne_result) 

39 

40 # cumsum gives us 1-indexed positions, we want 0-indexed 

41 out_idx = cumsum - 1 

42 

43 # unique_size is the last cumsum value 

44 unique_size_mask = i0 == num_tasks - 1 

45 tl.store(unique_size_ptr + tl.zeros_like(i0), cumsum, mask=unique_size_mask) 

46 

47 # data_out: scatter unique values to their output positions 

48 # Only write when this is the first element of a consecutive group 

49 write_mask = ne_result.to(tl.int1) & mask 

50 tl.store(data_out_ptr + out_idx, a, mask=write_mask) 

51 

52 # inverse_indices: each input position maps to its output position 

53 if return_inverse: 

54 tl.store(inverse_indices_ptr + i0, out_idx, mask=mask) 

55 

56 # idx: store the starting position of each unique group 

57 if return_counts: 

58 tl.store(idx_ptr + out_idx, i0, mask=write_mask) 

59 

60 

61@triton.jit 

62def output_counts_impl( 

63 global_pid, 

64 idx_ptr: tl.tensor, 

65 origin_num_tasks: int, # in 

66 counts_ptr: tl.tensor, # out 

67 num_tasks: int, 

68 tile_size: tl.constexpr, 

69): 

70 """Compute counts from idx positions.""" 

71 r = tl.arange(0, tile_size) 

72 i0 = global_pid * tile_size + r 

73 mask = i0 < num_tasks 

74 

75 # load idx 

76 idx = tl.load(idx_ptr + i0, mask=mask) 

77 

78 # load idx_next 

79 i0_next = i0 + 1 

80 next_mask = i0_next < num_tasks 

81 idx_next = tl.load(idx_ptr + i0_next, mask=next_mask) 

82 

83 # counts = next_idx - current_idx (or total - current_idx for last element) 

84 counts = tl.where(i0_next < num_tasks, idx_next - idx, origin_num_tasks - idx) 

85 

86 # store counts 

87 tl.store(counts_ptr + i0, counts, mask=mask) 

88 

89 

90@libentry() 

91@triton.jit 

92def output_counts_kernel( 

93 idx_ptr: tl.tensor, 

94 origin_num_tasks: int, # in 

95 counts_ptr: tl.tensor, # out 

96 num_tasks: int, 

97 tiles_per_cta: int, 

98 tile_size: tl.constexpr, 

99): 

100 pid = ext.program_id(0) 

101 ctas_num = ext.num_programs(0) 

102 for j in range(0, tiles_per_cta): 

103 global_pid = pid + j * ctas_num 

104 output_counts_impl( 

105 global_pid, 

106 idx_ptr, 

107 origin_num_tasks, 

108 counts_ptr, 

109 num_tasks, 

110 tile_size, 

111 ) 

112 

113 

114@triton.jit 

115def local_ne_consecutive_impl( 

116 global_pid, 

117 data_ptr: tl.tensor, # in 

118 ne_result_ptr: tl.tensor, 

119 tile_sum_ptr: tl.tensor, # out 

120 global_ctas_num: int, 

121 num_tasks: int, 

122 tile_size: tl.constexpr, 

123): 

124 """Compute ne_result (whether each element differs from previous) for a tile.""" 

125 r = tl.arange(0, tile_size) 

126 i0 = global_pid * tile_size + r 

127 mask = i0 < num_tasks 

128 i0_prev = tl.where(i0 > 0, i0 - 1, 0) 

129 

130 # load current and previous 

131 a = tl.load(data_ptr + i0, mask=mask) 

132 b = tl.load(data_ptr + i0_prev, mask=mask) 

133 

134 # compute ne_result 

135 ne_result = tl.where(i0 > 0, a != b, 1) 

136 

137 # store ne_result 

138 tl.store(ne_result_ptr + i0, ne_result, mask=mask) 

139 

140 # store tile_sum 

141 tile_sum = tl.sum(ne_result) 

142 tile_sum_mask = global_pid < global_ctas_num 

143 tl.store(tile_sum_ptr + global_pid, tile_sum, mask=tile_sum_mask) 

144 

145 

146@libentry() 

147@triton.jit 

148def local_ne_consecutive_kernel( 

149 data_ptr: tl.tensor, # in 

150 ne_result_ptr: tl.tensor, 

151 tile_sum_ptr: tl.tensor, # out 

152 global_ctas_num: int, 

153 num_tasks: int, 

154 tiles_per_cta: int, 

155 tile_size: tl.constexpr, 

156): 

157 pid = ext.program_id(0) 

158 ctas_num = ext.num_programs(0) 

159 for j in range(0, tiles_per_cta): 

160 global_pid = pid + j * ctas_num 

161 local_ne_consecutive_impl( 

162 global_pid, 

163 data_ptr, 

164 ne_result_ptr, 

165 tile_sum_ptr, 

166 global_ctas_num, 

167 num_tasks, 

168 tile_size, 

169 ) 

170 

171 

172@triton.jit 

173def global_cumsum_consecutive_impl( 

174 global_pid, 

175 total, 

176 ne_result_ptr: tl.tensor, 

177 tile_sum_ptr: tl.tensor, # in 

178 data_ptr: tl.tensor, # in 

179 data_out_ptr: tl.tensor, 

180 inverse_indices_ptr: tl.tensor, 

181 idx_ptr: tl.tensor, # out 

182 ctas_num: tl.constexpr, 

183 global_ctas_num: int, 

184 next_power_global_ctas_num: tl.constexpr, 

185 num_tasks: int, 

186 tile_size: tl.constexpr, 

187 return_inverse: tl.constexpr, 

188 return_counts: tl.constexpr, 

189): 

190 """Compute global cumsum and scatter outputs.""" 

191 offset = global_pid * tile_size 

192 r = tl.arange(0, tile_size) 

193 i0 = offset + r 

194 mask = i0 < num_tasks 

195 

196 # load data 

197 data = tl.load(data_ptr + i0, mask=mask) 

198 

199 # load tile_sum for previous tiles 

200 p = tl.arange(0, next_power_global_ctas_num) 

201 pre_tile_sum_mask = ( 

202 (p >= global_pid - ctas_num) 

203 & (p < global_pid) 

204 & (p >= 0) 

205 & (p < global_ctas_num) 

206 ) 

207 pre_tile_sum = tl.load(tile_sum_ptr + p, mask=pre_tile_sum_mask, other=0) 

208 

209 # cumsum within tile 

210 total += tl.sum(pre_tile_sum) 

211 ne_result = tl.load(ne_result_ptr + i0, mask=mask) 

212 ne_result_i1 = ne_result.to(tl.int1) 

213 ne_result_i32 = ne_result.to(tl.int32) 

214 cumsum = tl.cumsum(ne_result_i32) 

215 

216 # Store final tile sum for the last tile 

217 if global_pid == global_ctas_num - 1: 

218 last_tile_sum_mask = i0 == num_tasks - 1 

219 final_tile_sum = tl.where(last_tile_sum_mask, total + cumsum, cumsum) 

220 tl.store( 

221 tile_sum_ptr + global_pid + tl.zeros_like(r), 

222 final_tile_sum, 

223 mask=last_tile_sum_mask, 

224 ) 

225 cumsum += total 

226 

227 # output index (0-indexed) 

228 out_idx = cumsum - 1 

229 

230 # data_out: scatter unique values (only first element of each consecutive group) 

231 tl.store(data_out_ptr + out_idx, data, mask=ne_result_i1 & mask) 

232 

233 # inverse_indices: each input position maps to its output index 

234 if return_inverse: 

235 tl.store(inverse_indices_ptr + i0, out_idx, mask=mask) 

236 

237 # idx: store starting position of each unique group 

238 if return_counts: 

239 tl.store(idx_ptr + out_idx, i0, mask=ne_result_i1 & mask) 

240 

241 return total 

242 

243 

244@libentry() 

245@triton.jit 

246def global_cumsum_consecutive_kernel( 

247 ne_result_ptr: tl.tensor, 

248 tile_sum_ptr: tl.tensor, # in 

249 data_ptr: tl.tensor, # in 

250 data_out_ptr: tl.tensor, 

251 inverse_indices_ptr: tl.tensor, 

252 idx_ptr: tl.tensor, # out 

253 ctas_num: int, 

254 global_ctas_num: int, 

255 next_power_global_ctas_num: tl.constexpr, 

256 num_tasks: int, 

257 tiles_per_cta: int, 

258 tile_size: tl.constexpr, 

259 one_tile_per_cta: tl.constexpr, 

260 return_inverse: tl.constexpr, 

261 return_counts: tl.constexpr, 

262): 

263 pid = ext.program_id(0) 

264 ctas_num = ext.num_programs(0) 

265 if one_tile_per_cta: 

266 global_cumsum_consecutive_impl( 

267 pid, 

268 0, 

269 ne_result_ptr, 

270 tile_sum_ptr, 

271 data_ptr, 

272 data_out_ptr, 

273 inverse_indices_ptr, 

274 idx_ptr, 

275 ctas_num, 

276 global_ctas_num, 

277 next_power_global_ctas_num, 

278 num_tasks, 

279 tile_size, 

280 return_inverse, 

281 return_counts, 

282 ) 

283 else: 

284 total = tl.zeros([1], dtype=tl.int64) 

285 for j in range(0, tiles_per_cta): 

286 global_pid = pid + j * ctas_num 

287 total = global_cumsum_consecutive_impl( 

288 global_pid, 

289 total, 

290 ne_result_ptr, 

291 tile_sum_ptr, 

292 data_ptr, 

293 data_out_ptr, 

294 inverse_indices_ptr, 

295 idx_ptr, 

296 ctas_num, 

297 global_ctas_num, 

298 next_power_global_ctas_num, 

299 num_tasks, 

300 tile_size, 

301 return_inverse, 

302 return_counts, 

303 ) 

304 

305 

306def simple_unique_consecutive_flat( 

307 data: torch.Tensor, 

308 return_inverse: bool, 

309 return_counts: bool, 

310): 

311 """Handle small inputs with a single kernel launch.""" 

312 num_tasks = data.numel() 

313 grid = (1, 1, 1) 

314 

315 # allocate tensors 

316 data_out = torch.empty_like(data) 

317 inverse_indices = ( 

318 torch.empty(num_tasks, dtype=torch.int64, device=data.device) 

319 if return_inverse 

320 else None 

321 ) 

322 idx = ( 

323 torch.empty(num_tasks, dtype=torch.int64, device=data.device) 

324 if return_counts 

325 else None 

326 ) 

327 unique_size = torch.empty([1], dtype=torch.int64, device=data.device) 

328 

329 # launch kernel 

330 with torch_device_fn.device(data.device.index): 

331 simple_unique_consecutive_flat_kernel[grid]( 

332 data, 

333 data_out, 

334 inverse_indices, 

335 idx, 

336 unique_size, 

337 return_inverse, 

338 return_counts, 

339 num_tasks, 

340 tile_size=triton.next_power_of_2(num_tasks), 

341 num_warps=8, 

342 ) 

343 

344 out_size = unique_size.item() 

345 counts = None 

346 if return_counts: 

347 idx = idx[:out_size] 

348 counts = torch.empty_like(idx) 

349 with torch_device_fn.device(data.device.index): 

350 output_counts_kernel[grid]( 

351 idx, 

352 num_tasks, 

353 counts, 

354 num_tasks=out_size, 

355 tiles_per_cta=1, 

356 tile_size=triton.next_power_of_2(out_size), 

357 num_warps=8, 

358 ) 

359 

360 return data_out[:out_size], inverse_indices, counts 

361 

362 

363def large_unique_consecutive_flat( 

364 data: torch.Tensor, 

365 return_inverse: bool, 

366 return_counts: bool, 

367): 

368 """Handle larger inputs with multi-kernel approach.""" 

369 num_tasks = data.numel() 

370 next_power_num_tasks = triton.next_power_of_2(num_tasks) 

371 tile_size = min(8192, next_power_num_tasks) 

372 global_ctas_num = triton.cdiv(num_tasks, tile_size) 

373 

374 if global_ctas_num <= 8192: 

375 min_tile_size = 512 if global_ctas_num > 32 else 256 

376 tile_size = max( 

377 min_tile_size, 

378 min(triton.next_power_of_2(global_ctas_num), next_power_num_tasks), 

379 ) 

380 global_ctas_num = triton.cdiv(num_tasks, tile_size) 

381 

382 next_power_global_ctas_num = triton.next_power_of_2(global_ctas_num) 

383 ctas_num = global_ctas_num if global_ctas_num < 32768 else 8192 

384 tiles_per_cta = triton.cdiv(num_tasks, tile_size * ctas_num) 

385 num_warps = 8 if tiles_per_cta == 1 else 32 

386 grid = (ctas_num, 1, 1) 

387 

388 # allocate tensors 

389 ne_result = torch.empty(num_tasks, dtype=torch.bool, device=data.device) 

390 tile_sum = torch.empty(global_ctas_num, dtype=torch.int64, device=data.device) 

391 data_out = torch.empty_like(data) 

392 inverse_indices = ( 

393 torch.empty(num_tasks, dtype=torch.int64, device=data.device) 

394 if return_inverse 

395 else None 

396 ) 

397 idx = ( 

398 torch.empty(num_tasks, dtype=torch.int64, device=data.device) 

399 if return_counts 

400 else None 

401 ) 

402 

403 # launch kernels 

404 with torch_device_fn.device(data.device.index): 

405 local_ne_consecutive_kernel[grid]( 

406 data, 

407 ne_result, 

408 tile_sum, 

409 global_ctas_num, 

410 num_tasks, 

411 tiles_per_cta=tiles_per_cta, 

412 tile_size=tile_size, 

413 num_warps=num_warps, 

414 ) 

415 global_cumsum_consecutive_kernel[grid]( 

416 ne_result, 

417 tile_sum, 

418 data, 

419 data_out, 

420 inverse_indices, 

421 idx, 

422 ctas_num, 

423 global_ctas_num, 

424 next_power_global_ctas_num, 

425 num_tasks, 

426 tiles_per_cta=tiles_per_cta, 

427 tile_size=tile_size, 

428 one_tile_per_cta=tiles_per_cta == 1, 

429 return_inverse=return_inverse, 

430 return_counts=return_counts, 

431 num_warps=num_warps, 

432 ) 

433 out_size = tile_sum[-1].item() 

434 

435 counts = None 

436 if return_counts: 

437 idx = idx[:out_size] 

438 counts = torch.empty_like(idx) 

439 output_counts_kernel[grid]( 

440 idx, 

441 num_tasks, 

442 counts, 

443 out_size, 

444 tiles_per_cta, 

445 tile_size, 

446 num_warps=num_warps, 

447 ) 

448 

449 return data_out[:out_size], inverse_indices, counts 

450 

451 

452def unique_consecutive( 

453 input: torch.Tensor, 

454 return_inverse: bool = False, 

455 return_counts: bool = False, 

456 dim: int = None, 

457): 

458 """ 

459 Eliminates all but the first element from every consecutive group of equivalent elements. 

460 

461 Args: 

462 input: the input tensor 

463 return_inverse: Whether to return inverse indices 

464 return_counts: Whether to return counts for each unique element 

465 dim: the dimension to apply unique. If None, the unique of the flattened input is returned. 

466 

467 Returns: 

468 (Tensor, Tensor (optional), Tensor (optional)): output, inverse_indices, counts 

469 """ 

470 logger.debug("GEMS UNIQUE_CONSECUTIVE") 

471 

472 if dim is not None: 

473 # For dim-wise unique_consecutive, fall back to PyTorch for now 

474 # This could be implemented with a more complex kernel 

475 return torch.unique_consecutive( 

476 input, 

477 return_inverse=return_inverse, 

478 return_counts=return_counts, 

479 dim=dim, 

480 ) 

481 

482 # Flatten input for the None dim case 

483 flat_input = input.ravel() 

484 num_tasks = flat_input.numel() 

485 

486 if num_tasks == 0: 

487 # Handle empty input 

488 output = torch.empty(0, dtype=input.dtype, device=input.device) 

489 inverse_indices = ( 

490 torch.empty(0, dtype=torch.int64, device=input.device) 

491 if return_inverse 

492 else None 

493 ) 

494 counts = ( 

495 torch.empty(0, dtype=torch.int64, device=input.device) 

496 if return_counts 

497 else None 

498 ) 

499 return output, inverse_indices, counts 

500 

501 # Choose algorithm based on input size 

502 if num_tasks <= 8192: 

503 output, inverse_indices, counts = simple_unique_consecutive_flat( 

504 flat_input, return_inverse, return_counts 

505 ) 

506 else: 

507 output, inverse_indices, counts = large_unique_consecutive_flat( 

508 flat_input, return_inverse, return_counts 

509 ) 

510 

511 # Reshape inverse_indices to match input shape 

512 if inverse_indices is not None: 

513 inverse_indices = inverse_indices.view_as(input) 

514 

515 return output, inverse_indices, counts