Coverage for src/flag_gems/ops/unique_dim.py: 54%

428 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import triton_lang_extension as ext 

9from flag_gems.utils.libentry import libentry 

10 

11logger = logging.getLogger(__name__) 

12 

13_UNIQUE_DIM_COMPARE_BLOCK_SIZE = 1024 

14_UNIQUE_DIM_GATHER_BLOCK_SIZE = 1024 

15# Largest row count handled by the single-launch group-id scan kernel. 

16_UNIQUE_DIM_GROUP_SCAN_BLOCK_SIZE = 4096 

17# Largest key count sorted by the single-launch rank-sort kernel. Above this we 

18# delegate to ``torch.sort`` which, under FlagGems op interception, dispatches to 

19# the backend's Triton radix sort. Rank-sort is O(N^2) but a single launch, so it 

20# is much cheaper than a 16-pass int64 radix sort for tiny shapes. 

21_UNIQUE_DIM_RANK_SORT_MAX_KEYS = 2048 

22_UNIQUE_DIM_HASH_MIN_ROW_LEN = 1024 

23# Smaller tile for the fused key kernel's float branches: their int64 bit-twiddle 

24# temporaries overflow the Ascend unified buffer at the default tile size. 

25_UNIQUE_DIM_BUILD_KEY_FLOAT_BLOCK_SIZE = 256 

26 

27 

28# Per-column bit budgets and to-int64 conversions that preserve the original 

29# value ordering. The encodings let us pack a per-row ``group_id`` together 

30# with a single column's key into one int64 that, when compared as signed 

31# int64, matches the lex order over ``(group_id, signed_value)``. 

32_INT_DTYPE_BITS = { 

33 torch.bool: 1, 

34 torch.int8: 8, 

35 torch.uint8: 8, 

36 torch.int16: 16, 

37 torch.int32: 32, 

38 torch.float16: 16, 

39 torch.bfloat16: 16, 

40 torch.float32: 32, 

41} 

42 

43 

44@libentry() 

45@triton.jit 

46def _unique_dim_argsort_rank_kernel( 

47 keys_ptr: tl.tensor, 

48 indices_ptr: tl.tensor, 

49 sorted_keys_ptr: tl.tensor, 

50 num_keys: int, 

51 BLOCK_SIZE: tl.constexpr, 

52 STORE_SORTED_KEYS: tl.constexpr, 

53): 

54 row = ext.program_id(0) 

55 candidates = tl.arange(0, BLOCK_SIZE) 

56 mask = candidates < num_keys 

57 

58 cur = tl.load(keys_ptr + row) 

59 vals = tl.load(keys_ptr + candidates, mask=mask, other=cur) 

60 before = ((vals < cur) | ((vals == cur) & (candidates < row))) & mask 

61 rank = tl.sum(before.to(tl.int32), axis=0) 

62 tl.store(indices_ptr + rank, row) 

63 if STORE_SORTED_KEYS: 

64 tl.store(sorted_keys_ptr + rank, cur) 

65 

66 

67@libentry() 

68@triton.jit 

69def _unique_dim_gather_1d_kernel( 

70 input_ptr: tl.tensor, 

71 index_ptr: tl.tensor, 

72 output_ptr: tl.tensor, 

73 num_elements: int, 

74 BLOCK_SIZE: tl.constexpr, 

75): 

76 pid = ext.program_id(0) 

77 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

78 mask = offsets < num_elements 

79 indices = tl.load(index_ptr + offsets, mask=mask, other=0) 

80 values = tl.load(input_ptr + indices, mask=mask) 

81 tl.store(output_ptr + offsets, values, mask=mask) 

82 

83 

84@libentry() 

85@triton.jit 

86def _unique_dim_group_id_kernel( 

87 composite_ptr: tl.tensor, 

88 group_id_ptr: tl.tensor, 

89 last_group_id_ptr: tl.tensor, 

90 num_rows: int, 

91 BLOCK_SIZE: tl.constexpr, 

92): 

93 offsets = tl.arange(0, BLOCK_SIZE) 

94 mask = offsets < num_rows 

95 cur = tl.load(composite_ptr + offsets, mask=mask, other=0) 

96 prev_offsets = tl.where(offsets == 0, 0, offsets - 1) 

97 prev = tl.load(composite_ptr + prev_offsets, mask=offsets > 0, other=cur) 

98 diff = ((cur - prev) != 0) & mask 

99 diff = tl.where(offsets == 0, False, diff) 

100 group_id = tl.cumsum(diff.to(tl.int64), axis=0) 

101 tl.store(group_id_ptr + offsets, group_id, mask=mask) 

102 last = tl.sum(tl.where(offsets == num_rows - 1, group_id, 0), axis=0) 

103 tl.store(last_group_id_ptr, last) 

104 

105 

106@libentry() 

107@triton.jit 

108def _unique_dim_row_hash_chunk_kernel( 

109 flat_ptr: tl.tensor, 

110 chunk_hash_ptr: tl.tensor, 

111 num_rows: int, 

112 row_len: int, 

113 num_chunks: int, 

114 BLOCK_SIZE: tl.constexpr, 

115): 

116 row = ext.program_id(0) 

117 chunk = ext.program_id(1) 

118 offsets = chunk * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

119 mask = offsets < row_len 

120 

121 vals = tl.load(flat_ptr + row * row_len + offsets, mask=mask, other=0) 

122 vals_i64 = vals.to(tl.int64) 

123 offsets_i64 = offsets.to(tl.int64) 

124 mix = (vals_i64 + (offsets_i64 + 1) * 1009 + 9176) * 131071 

125 mix = tl.where(mask, mix, 0) 

126 tl.store(chunk_hash_ptr + row * num_chunks + chunk, tl.sum(mix, axis=0)) 

127 

128 

129@libentry() 

130@triton.jit 

131def _unique_dim_row_hash_reduce_kernel( 

132 chunk_hash_ptr: tl.tensor, 

133 row_hash_ptr: tl.tensor, 

134 num_chunks: int, 

135 BLOCK_CHUNKS: tl.constexpr, 

136): 

137 row = ext.program_id(0) 

138 chunks = tl.arange(0, BLOCK_CHUNKS) 

139 mask = chunks < num_chunks 

140 vals = tl.load(chunk_hash_ptr + row * num_chunks + chunks, mask=mask, other=0) 

141 tl.store(row_hash_ptr + row, tl.sum(vals, axis=0)) 

142 

143 

144@libentry() 

145@triton.jit 

146def _unique_dim_row_chunk_diff_kernel( 

147 flat_ptr: tl.tensor, 

148 sorted_indices_ptr: tl.tensor, 

149 row_chunk_diff_ptr: tl.tensor, 

150 num_rows: int, 

151 row_len: int, 

152 num_chunks: int, 

153 BLOCK_SIZE: tl.constexpr, 

154): 

155 row = ext.program_id(0) 

156 chunk = ext.program_id(1) 

157 offsets = chunk * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

158 mask = offsets < row_len 

159 

160 out = tl.full((), 0, dtype=tl.int32) 

161 if row == 0: 

162 out = tl.where(chunk == 0, 1, 0) 

163 else: 

164 cur_row = tl.load(sorted_indices_ptr + row) 

165 prev_row = tl.load(sorted_indices_ptr + row - 1) 

166 cur = tl.load(flat_ptr + cur_row * row_len + offsets, mask=mask) 

167 prev = tl.load(flat_ptr + prev_row * row_len + offsets, mask=mask) 

168 neq = (cur != prev) & mask 

169 has_diff = tl.sum(neq.to(tl.int32), axis=0) != 0 

170 out = has_diff.to(tl.int32) 

171 tl.store(row_chunk_diff_ptr + row * num_chunks + chunk, out) 

172 

173 

174@libentry() 

175@triton.jit 

176def _unique_dim_row_chunk_diff_hash_kernel( 

177 flat_ptr: tl.tensor, 

178 sorted_indices_ptr: tl.tensor, 

179 row_hash_ptr: tl.tensor, 

180 row_chunk_diff_ptr: tl.tensor, 

181 num_rows: int, 

182 row_len: int, 

183 num_chunks: int, 

184 BLOCK_SIZE: tl.constexpr, 

185): 

186 row = ext.program_id(0) 

187 chunk = ext.program_id(1) 

188 offsets = chunk * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

189 mask = offsets < row_len 

190 

191 out = tl.full((), 0, dtype=tl.int32) 

192 if row == 0: 

193 out = tl.where(chunk == 0, 1, 0) 

194 else: 

195 cur_row = tl.load(sorted_indices_ptr + row) 

196 prev_row = tl.load(sorted_indices_ptr + row - 1) 

197 cur_hash = tl.load(row_hash_ptr + cur_row) 

198 prev_hash = tl.load(row_hash_ptr + prev_row) 

199 if cur_hash != prev_hash: 

200 out = tl.where(chunk == 0, 1, 0) 

201 else: 

202 cur = tl.load(flat_ptr + cur_row * row_len + offsets, mask=mask) 

203 prev = tl.load(flat_ptr + prev_row * row_len + offsets, mask=mask) 

204 neq = (cur != prev) & mask 

205 has_diff = tl.sum(neq.to(tl.int32), axis=0) != 0 

206 out = has_diff.to(tl.int32) 

207 tl.store(row_chunk_diff_ptr + row * num_chunks + chunk, out) 

208 

209 

210@libentry() 

211@triton.jit 

212def _unique_dim_row_diff_reduce_kernel( 

213 row_chunk_diff_ptr: tl.tensor, 

214 is_first_ptr: tl.tensor, 

215 num_chunks: int, 

216 BLOCK_CHUNKS: tl.constexpr, 

217): 

218 row = ext.program_id(0) 

219 chunks = tl.arange(0, BLOCK_CHUNKS) 

220 mask = chunks < num_chunks 

221 vals = tl.load(row_chunk_diff_ptr + row * num_chunks + chunks, mask=mask, other=0) 

222 tl.store(is_first_ptr + row, tl.sum(vals, axis=0) != 0) 

223 

224 

225@libentry() 

226@triton.jit 

227def _unique_dim_row_single_chunk_first_kernel( 

228 flat_ptr: tl.tensor, 

229 sorted_indices_ptr: tl.tensor, 

230 is_first_ptr: tl.tensor, 

231 num_rows: int, 

232 row_len: int, 

233 BLOCK_SIZE: tl.constexpr, 

234): 

235 row = ext.program_id(0) 

236 offsets = tl.arange(0, BLOCK_SIZE) 

237 mask = offsets < row_len 

238 

239 out = tl.full((), True, dtype=tl.int1) 

240 if row != 0: 

241 cur_row = tl.load(sorted_indices_ptr + row) 

242 prev_row = tl.load(sorted_indices_ptr + row - 1) 

243 cur = tl.load(flat_ptr + cur_row * row_len + offsets, mask=mask) 

244 prev = tl.load(flat_ptr + prev_row * row_len + offsets, mask=mask) 

245 neq = (cur != prev) & mask 

246 out = tl.sum(neq.to(tl.int32), axis=0) != 0 

247 tl.store(is_first_ptr + row, out) 

248 

249 

250@libentry() 

251@triton.jit 

252def _unique_dim_gather_moved_kernel( 

253 flat_ptr: tl.tensor, 

254 unique_indices_ptr: tl.tensor, 

255 output_ptr: tl.tensor, 

256 num_unique: int, 

257 row_len: int, 

258 BLOCK_SIZE: tl.constexpr, 

259): 

260 # One program per (output row, column chunk). Copies a contiguous span of 

261 # the source row selected through ``unique_indices`` into the matching span 

262 # of the output row. Loading ``src_row`` once per program (scalar) and using 

263 # contiguous column offsets avoids the per-element integer divide/modulo and 

264 # scattered indexing of a flat-offset gather, which dominate NPU time. 

265 row = ext.program_id(0) 

266 chunk = ext.program_id(1) 

267 col = chunk * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

268 mask = col < row_len 

269 

270 src_row = tl.load(unique_indices_ptr + row) 

271 values = tl.load(flat_ptr + src_row * row_len + col, mask=mask) 

272 tl.store(output_ptr + row * row_len + col, values, mask=mask) 

273 

274 

275@libentry() 

276@triton.jit 

277def _unique_dim_inverse_permutation_kernel( 

278 sorted_indices_ptr: tl.tensor, 

279 inverse_ptr: tl.tensor, 

280 num_rows: int, 

281 BLOCK_SIZE: tl.constexpr, 

282): 

283 pid = ext.program_id(0) 

284 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

285 mask = offsets < num_rows 

286 sorted_indices = tl.load(sorted_indices_ptr + offsets, mask=mask, other=0) 

287 tl.store(inverse_ptr + sorted_indices, offsets.to(tl.int64), mask=mask) 

288 

289 

290@libentry() 

291@triton.jit 

292def _unique_dim_inverse_kernel( 

293 sorted_indices_ptr: tl.tensor, 

294 inverse_sorted_ptr: tl.tensor, 

295 inverse_ptr: tl.tensor, 

296 num_rows: int, 

297 BLOCK_SIZE: tl.constexpr, 

298): 

299 pid = ext.program_id(0) 

300 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

301 mask = offsets < num_rows 

302 sorted_indices = tl.load(sorted_indices_ptr + offsets, mask=mask) 

303 inverse_sorted = tl.load(inverse_sorted_ptr + offsets, mask=mask) 

304 tl.store(inverse_ptr + sorted_indices, inverse_sorted, mask=mask) 

305 

306 

307@libentry() 

308@triton.jit 

309def _unique_dim_counts_kernel( 

310 first_positions_ptr: tl.tensor, 

311 counts_ptr: tl.tensor, 

312 num_rows: int, 

313 num_unique: int, 

314 BLOCK_SIZE: tl.constexpr, 

315): 

316 pid = ext.program_id(0) 

317 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

318 mask = offsets < num_unique 

319 positions = tl.load(first_positions_ptr + offsets, mask=mask) 

320 next_positions = tl.load( 

321 first_positions_ptr + offsets + 1, 

322 mask=(offsets + 1) < num_unique, 

323 other=num_rows, 

324 ) 

325 tl.store(counts_ptr + offsets, next_positions - positions, mask=mask) 

326 

327 

328def _triton_num_warps(block_size: int) -> int: 

329 if block_size >= 8192: 

330 return 8 

331 if block_size >= 2048: 

332 return 4 

333 return 1 

334 

335 

336def _monotonic_key_bits(dtype: torch.dtype): 

337 """Return the per-element key width for ``dtype`` if it can be mapped 

338 into a monotonic int64 view, else ``None``.""" 

339 return _INT_DTYPE_BITS.get(dtype) 

340 

341 

342# Monotonic-remap kinds for the fused key-build kernel. 

343_REMAP_INT = 0 # signed/unsigned int: value + KEY_OFFSET 

344_REMAP_FP16 = 1 # 16-bit float: order-preserving bit twiddle 

345_REMAP_FP32 = 2 # 32-bit float: order-preserving bit twiddle 

346 

347 

348def _remap_info(flat: torch.Tensor): 

349 """Return ``(int_view, remap_kind, key_offset)`` describing how to map this 

350 dtype to an order-preserving non-negative int64 in the fused key kernel. 

351 

352 ``int_view`` reinterprets the buffer as an integer type the kernel can load 

353 directly (floats are bit-cast); the remap itself happens on-device. 

354 """ 

355 dt = flat.dtype 

356 if dt == torch.bool: 

357 return flat.view(torch.uint8), _REMAP_INT, 0 

358 if dt == torch.uint8: 

359 return flat, _REMAP_INT, 0 

360 if dt == torch.int8: 

361 return flat, _REMAP_INT, 1 << 7 

362 if dt == torch.int16: 

363 return flat, _REMAP_INT, 1 << 15 

364 if dt == torch.int32: 

365 return flat, _REMAP_INT, 1 << 31 

366 if dt in (torch.float16, torch.bfloat16): 

367 return flat.view(torch.int16), _REMAP_FP16, 0 

368 if dt == torch.float32: 

369 return flat.view(torch.int32), _REMAP_FP32, 0 

370 raise NotImplementedError(dt) 

371 

372 

373@libentry() 

374@triton.jit 

375def _unique_dim_build_key_kernel( 

376 flat_ptr: tl.tensor, 

377 indices_ptr: tl.tensor, 

378 group_id_ptr: tl.tensor, 

379 out_ptr: tl.tensor, 

380 num_rows: int, 

381 row_stride: int, 

382 col: int, 

383 KEY_OFFSET: tl.constexpr, 

384 KEY_SCALE: tl.constexpr, 

385 REMAP_KIND: tl.constexpr, 

386 FIRST: tl.constexpr, 

387 BLOCK_SIZE: tl.constexpr, 

388): 

389 """Build one cascade pass' composite key in a single launch. 

390 

391 For ``FIRST`` (first column) the key is just the column's monotonic remap. 

392 Otherwise the row is fetched through the current permutation ``indices`` and 

393 the running ``group_id`` prefix is folded in as ``group_id * key_scale + 

394 value`` (multiply/add rather than shift/or, matching the rest of the file). 

395 

396 This fuses what was a ``select -> contiguous -> cast -> add -> gather -> 

397 mul -> add`` chain of separate ops into one kernel, which is the dominant 

398 per-pass host/launch cost on backends with a native sort. 

399 """ 

400 pid = ext.program_id(0) 

401 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

402 mask = offsets < num_rows 

403 

404 if FIRST: 

405 row = offsets.to(tl.int64) 

406 else: 

407 row = tl.load(indices_ptr + offsets, mask=mask, other=0) 

408 base = row * row_stride + col 

409 

410 if REMAP_KIND == 0: # _REMAP_INT 

411 x = tl.load(flat_ptr + base, mask=mask, other=0).to(tl.int64) 

412 val = x + KEY_OFFSET 

413 elif REMAP_KIND == 1: # _REMAP_FP16 

414 bits = tl.load(flat_ptr + base, mask=mask, other=0).to(tl.int64) & 0xFFFF 

415 sign = (bits & 0x8000) != 0 

416 val = tl.where(sign, bits ^ 0xFFFF, bits ^ 0x8000) 

417 else: # _REMAP_FP32 

418 bits = tl.load(flat_ptr + base, mask=mask, other=0).to(tl.int64) & 0xFFFFFFFF 

419 sign = (bits & 0x80000000) != 0 

420 val = tl.where(sign, bits ^ 0xFFFFFFFF, bits ^ 0x80000000) 

421 

422 if FIRST: 

423 out = val 

424 else: 

425 gid = tl.load(group_id_ptr + offsets, mask=mask, other=0) 

426 out = gid * KEY_SCALE + val 

427 tl.store(out_ptr + offsets, out, mask=mask) 

428 

429 

430def _build_composite_key( 

431 flat_view: torch.Tensor, 

432 col: int, 

433 indices: torch.Tensor | None, 

434 group_id: torch.Tensor | None, 

435 num_rows: int, 

436 row_stride: int, 

437 key_offset: int, 

438 key_scale: int, 

439 remap_kind: int, 

440) -> torch.Tensor: 

441 """One-launch composite key for cascade pass ``col``. 

442 

443 ``indices``/``group_id`` are ``None`` on the first pass; otherwise they are 

444 the current permutation and running group ids. 

445 """ 

446 out = torch.empty(num_rows, dtype=torch.int64, device=flat_view.device) 

447 first = indices is None 

448 # Triton needs valid tensor handles even for the unused pointers on the 

449 # first pass; the kernel guards their loads behind ``FIRST``. 

450 indices_arg = flat_view if first else indices 

451 group_id_arg = flat_view if first else group_id 

452 # The float bit-twiddle branches allocate several int64 temporaries per 

453 # element; at the default tile this overflows the Ascend unified buffer, so 

454 # floats use a smaller tile. Integer remap is light and keeps the full tile. 

455 block_size = ( 

456 _UNIQUE_DIM_GATHER_BLOCK_SIZE 

457 if remap_kind == _REMAP_INT 

458 else _UNIQUE_DIM_BUILD_KEY_FLOAT_BLOCK_SIZE 

459 ) 

460 grid = (triton.cdiv(num_rows, block_size), 1, 1) 

461 with torch_device_fn.device(flat_view.device.index): 

462 _unique_dim_build_key_kernel[grid]( 

463 flat_view, 

464 indices_arg, 

465 group_id_arg, 

466 out, 

467 num_rows, 

468 row_stride, 

469 col, 

470 KEY_OFFSET=key_offset, 

471 KEY_SCALE=key_scale, 

472 REMAP_KIND=remap_kind, 

473 FIRST=first, 

474 BLOCK_SIZE=block_size, 

475 num_warps=4, 

476 ) 

477 return out 

478 

479 

480def _triton_gather_1d(values: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: 

481 num_elements = indices.numel() 

482 output = torch.empty(num_elements, dtype=values.dtype, device=values.device) 

483 if num_elements == 0: 

484 return output 

485 grid = (triton.cdiv(num_elements, _UNIQUE_DIM_GATHER_BLOCK_SIZE), 1, 1) 

486 with torch_device_fn.device(values.device.index): 

487 _unique_dim_gather_1d_kernel[grid]( 

488 values, 

489 indices, 

490 output, 

491 num_elements, 

492 BLOCK_SIZE=_UNIQUE_DIM_GATHER_BLOCK_SIZE, 

493 num_warps=4, 

494 ) 

495 return output 

496 

497 

498def _argsort_keys(keys: torch.Tensor): 

499 """Stable ascending argsort of a 1D key tensor. 

500 

501 Returns ``(perm, sorted_keys)`` where ``sorted_keys = keys[perm]``. 

502 

503 Small key counts use the single-launch Triton rank-sort kernel (cheap for 

504 tiny shapes). Larger counts delegate to ``torch.sort``; under FlagGems op 

505 interception this dispatches to the backend's Triton radix sort. 

506 """ 

507 num_keys = keys.numel() 

508 if num_keys == 0: 

509 return torch.empty(0, dtype=torch.int64, device=keys.device), keys 

510 if num_keys <= _UNIQUE_DIM_RANK_SORT_MAX_KEYS: 

511 perm = torch.empty(num_keys, dtype=torch.int64, device=keys.device) 

512 sorted_keys = torch.empty_like(keys) 

513 block_size = triton.next_power_of_2(num_keys) 

514 with torch_device_fn.device(keys.device.index): 

515 _unique_dim_argsort_rank_kernel[(num_keys, 1, 1)]( 

516 keys.contiguous(), 

517 perm, 

518 sorted_keys, 

519 num_keys, 

520 BLOCK_SIZE=block_size, 

521 STORE_SORTED_KEYS=True, 

522 num_warps=_triton_num_warps(block_size), 

523 ) 

524 return perm, sorted_keys 

525 sorted_keys, perm = torch.sort(keys) 

526 return perm, sorted_keys 

527 

528 

529def _group_id_from_sorted(sorted_keys: torch.Tensor): 

530 """Dense lexicographic group ids for an ascending key tensor. 

531 

532 Returns ``(group_id, last_group_id)`` where ``group_id[i]`` is the count of 

533 distinct key values strictly before position ``i`` and ``last_group_id`` is 

534 the (host-side) value of ``group_id[-1]`` (or ``-1`` when empty). 

535 

536 Small row counts use the single-launch scan kernel. Larger counts use a 

537 safe ``int64`` adjacent-difference followed by ``torch.cumsum`` (a FlagGems 

538 multi-block scan under op interception). The difference is computed as 

539 ``int64 - int64`` then ``!= 0`` against a scalar; running through the 

540 registered tensor-vs-tensor comparison op would route int64 through float32 

541 and lose precision around ``2**24``. 

542 """ 

543 num_rows = sorted_keys.numel() 

544 device = sorted_keys.device 

545 if num_rows == 0: 

546 return torch.empty(0, dtype=torch.int64, device=device), -1 

547 if num_rows <= _UNIQUE_DIM_GROUP_SCAN_BLOCK_SIZE: 

548 group_id = torch.empty(num_rows, dtype=torch.int64, device=device) 

549 last_group_id = torch.empty((), dtype=torch.int64, device=device) 

550 block_size = triton.next_power_of_2(num_rows) 

551 with torch_device_fn.device(device.index): 

552 _unique_dim_group_id_kernel[(1, 1, 1)]( 

553 sorted_keys, 

554 group_id, 

555 last_group_id, 

556 num_rows, 

557 BLOCK_SIZE=block_size, 

558 num_warps=_triton_num_warps(block_size), 

559 ) 

560 return group_id, int(last_group_id.item()) 

561 

562 diff = ((sorted_keys[1:] - sorted_keys[:-1]) != 0).to(torch.int64) 

563 group_id = torch.cat( 

564 [ 

565 torch.zeros(1, dtype=torch.int64, device=device), 

566 torch.cumsum(diff, dim=0), 

567 ] 

568 ) 

569 return group_id, int(group_id[-1].item()) 

570 

571 

572def _lex_argsort_rows_composite(flat: torch.Tensor): 

573 """Lex-sort rows by packing ``(group_id, monotonic_key)`` per column. 

574 

575 Mirrors the way ATen's CUDA ``unique_dim`` does a single comparator-driven 

576 sort: each cascade step performs *one* argsort on an int64 key that encodes 

577 the "current lex prefix" in the high bits and "this column's value" in the 

578 low bits. As soon as every row has a unique prefix we terminate; for random 

579 data this happens after one or two columns even when ``M`` is large, 

580 replacing ``M`` argsorts with a small constant. 

581 """ 

582 key_bits = _monotonic_key_bits(flat.dtype) 

583 if key_bits is None: 

584 return None 

585 

586 num_rows, num_cols = flat.shape 

587 device = flat.device 

588 if num_cols == 0: 

589 indices = torch.arange(num_rows, dtype=torch.int64, device=device) 

590 return indices, False 

591 if num_rows <= 1: 

592 indices = torch.arange(num_rows, dtype=torch.int64, device=device) 

593 return indices, True 

594 

595 key_scale = 1 << key_bits 

596 flat_view, remap_kind, key_offset = _remap_info(flat) 

597 indices = None 

598 group_id = None 

599 all_unique = False 

600 for col in range(num_cols): 

601 # One fused launch builds ``group_id * key_scale + monotonic(value)``, 

602 # gathering through the current permutation when ``col > 0``. 

603 keys = _build_composite_key( 

604 flat_view, 

605 col, 

606 indices, 

607 group_id, 

608 num_rows, 

609 num_cols, 

610 key_offset, 

611 key_scale, 

612 remap_kind, 

613 ) 

614 perm, sorted_keys = _argsort_keys(keys) 

615 indices = perm if col == 0 else _triton_gather_1d(indices, perm) 

616 group_id, last_group_id = _group_id_from_sorted(sorted_keys) 

617 # Early termination: every row already has a unique lex prefix. 

618 if last_group_id == num_rows - 1: 

619 all_unique = True 

620 break 

621 return indices, all_unique 

622 

623 

624def _lex_argsort_rows_cascade(flat: torch.Tensor) -> torch.Tensor: 

625 """Generic-dtype fallback: cascade of stable argsorts, least to most 

626 significant column. ``O(M)`` argsorts of length ``D`` with ``O(D)`` memory 

627 traffic per step. Used for dtypes without a monotonic int64 remap.""" 

628 num_rows, num_cols = flat.shape 

629 indices = torch.arange(num_rows, dtype=torch.int64, device=flat.device) 

630 if num_rows <= 1 or num_cols == 0: 

631 return indices 

632 flat_t = flat.t().contiguous() 

633 for col in range(num_cols - 1, -1, -1): 

634 keys = _triton_gather_1d(flat_t[col], indices) 

635 # LSD cascade requires a stable sort to preserve previous-column order. 

636 _, perm = torch.sort(keys, stable=True) 

637 indices = _triton_gather_1d(indices, perm) 

638 return indices 

639 

640 

641def _lex_argsort_rows(flat: torch.Tensor) -> tuple[torch.Tensor, bool]: 

642 """Return indices that sort rows of a 2D tensor lexicographically.""" 

643 composite = _lex_argsort_rows_composite(flat) 

644 if composite is not None: 

645 return composite 

646 return _lex_argsort_rows_cascade(flat), False 

647 

648 

649def _unique_dim_row_hash(flat: torch.Tensor) -> torch.Tensor: 

650 num_rows, row_len = flat.shape 

651 block_size = min(_UNIQUE_DIM_COMPARE_BLOCK_SIZE, triton.next_power_of_2(row_len)) 

652 num_chunks = triton.cdiv(row_len, block_size) 

653 chunk_hash = torch.empty( 

654 (num_rows, num_chunks), dtype=torch.int64, device=flat.device 

655 ) 

656 row_hash = torch.empty(num_rows, dtype=torch.int64, device=flat.device) 

657 with torch_device_fn.device(flat.device.index): 

658 _unique_dim_row_hash_chunk_kernel[(num_rows, num_chunks, 1)]( 

659 flat, 

660 chunk_hash, 

661 num_rows, 

662 row_len, 

663 num_chunks, 

664 BLOCK_SIZE=block_size, 

665 num_warps=_triton_num_warps(block_size), 

666 ) 

667 _unique_dim_row_hash_reduce_kernel[(num_rows, 1, 1)]( 

668 chunk_hash, 

669 row_hash, 

670 num_chunks, 

671 BLOCK_CHUNKS=triton.next_power_of_2(num_chunks), 

672 num_warps=_triton_num_warps(triton.next_power_of_2(num_chunks)), 

673 ) 

674 return row_hash 

675 

676 

677def _unique_dim_first_mask(flat: torch.Tensor, sorted_indices: torch.Tensor): 

678 """Return a bool mask for first rows in sorted lexicographic groups.""" 

679 num_rows, row_len = flat.shape 

680 if num_rows == 1 or row_len == 0: 

681 is_first = torch.zeros(num_rows, dtype=torch.bool, device=flat.device) 

682 is_first[0] = True 

683 return is_first 

684 

685 block_size = min(_UNIQUE_DIM_COMPARE_BLOCK_SIZE, triton.next_power_of_2(row_len)) 

686 num_chunks = triton.cdiv(row_len, block_size) 

687 is_first = torch.empty(num_rows, dtype=torch.bool, device=flat.device) 

688 if num_chunks == 1: 

689 with torch_device_fn.device(flat.device.index): 

690 _unique_dim_row_single_chunk_first_kernel[(num_rows, 1, 1)]( 

691 flat, 

692 sorted_indices, 

693 is_first, 

694 num_rows, 

695 row_len, 

696 BLOCK_SIZE=block_size, 

697 num_warps=_triton_num_warps(block_size), 

698 ) 

699 return is_first 

700 

701 row_chunk_diff = torch.empty( 

702 (num_rows, num_chunks), dtype=torch.int32, device=flat.device 

703 ) 

704 grid = (num_rows, num_chunks, 1) 

705 row_hash = ( 

706 _unique_dim_row_hash(flat) if row_len >= _UNIQUE_DIM_HASH_MIN_ROW_LEN else None 

707 ) 

708 with torch_device_fn.device(flat.device.index): 

709 if row_hash is None: 

710 _unique_dim_row_chunk_diff_kernel[grid]( 

711 flat, 

712 sorted_indices, 

713 row_chunk_diff, 

714 num_rows, 

715 row_len, 

716 num_chunks, 

717 BLOCK_SIZE=block_size, 

718 num_warps=_triton_num_warps(block_size), 

719 ) 

720 else: 

721 _unique_dim_row_chunk_diff_hash_kernel[grid]( 

722 flat, 

723 sorted_indices, 

724 row_hash, 

725 row_chunk_diff, 

726 num_rows, 

727 row_len, 

728 num_chunks, 

729 BLOCK_SIZE=block_size, 

730 num_warps=_triton_num_warps(block_size), 

731 ) 

732 _unique_dim_row_diff_reduce_kernel[(num_rows, 1, 1)]( 

733 row_chunk_diff, 

734 is_first, 

735 num_chunks, 

736 BLOCK_CHUNKS=triton.next_power_of_2(num_chunks), 

737 num_warps=_triton_num_warps(triton.next_power_of_2(num_chunks)), 

738 ) 

739 return is_first 

740 

741 

742def _unique_dim_gather_output( 

743 moved: torch.Tensor, 

744 unique_indices: torch.Tensor, 

745 dim: int, 

746 input_shape: torch.Size, 

747) -> torch.Tensor: 

748 num_unique = unique_indices.numel() 

749 output_shape = ( 

750 tuple(input_shape[:dim]) + (num_unique,) + tuple(input_shape[dim + 1 :]) 

751 ) 

752 if num_unique == 0: 

753 return torch.empty(output_shape, dtype=moved.dtype, device=moved.device) 

754 

755 row_len = moved[0].numel() 

756 flat = moved.reshape(moved.shape[0], row_len) 

757 moved_output = torch.empty( 

758 (num_unique,) + tuple(moved.shape[1:]), 

759 dtype=moved.dtype, 

760 device=moved.device, 

761 ) 

762 num_chunks = triton.cdiv(row_len, _UNIQUE_DIM_GATHER_BLOCK_SIZE) 

763 grid = (num_unique, num_chunks, 1) 

764 with torch_device_fn.device(moved.device.index): 

765 _unique_dim_gather_moved_kernel[grid]( 

766 flat, 

767 unique_indices, 

768 moved_output, 

769 num_unique, 

770 row_len, 

771 BLOCK_SIZE=_UNIQUE_DIM_GATHER_BLOCK_SIZE, 

772 num_warps=4, 

773 ) 

774 return moved_output.movedim(0, dim) 

775 

776 

777def _unique_dim_inverse_from_permutation(sorted_indices: torch.Tensor) -> torch.Tensor: 

778 """Inverse mapping for the all-unique case: ``inverse[sorted_indices[i]] = i``. 

779 

780 A plain 1D scatter (no per-element column predicate), which is correct on 

781 every backend; the fused gather+scatter variant miscompiles its masked 

782 inverse store on some Ascend/NPU backends. 

783 """ 

784 num_rows = sorted_indices.numel() 

785 inverse_indices = torch.empty_like(sorted_indices) 

786 if num_rows == 0: 

787 return inverse_indices 

788 grid = (triton.cdiv(num_rows, _UNIQUE_DIM_GATHER_BLOCK_SIZE), 1, 1) 

789 with torch_device_fn.device(sorted_indices.device.index): 

790 _unique_dim_inverse_permutation_kernel[grid]( 

791 sorted_indices, 

792 inverse_indices, 

793 num_rows, 

794 BLOCK_SIZE=_UNIQUE_DIM_GATHER_BLOCK_SIZE, 

795 num_warps=4, 

796 ) 

797 return inverse_indices 

798 

799 

800def _unique_dim_inverse( 

801 sorted_indices: torch.Tensor, 

802 is_first: torch.Tensor, 

803) -> torch.Tensor: 

804 """Inverse mapping: scatter dense group ids back to original positions.""" 

805 num_rows = sorted_indices.numel() 

806 inverse_indices = torch.empty( 

807 num_rows, dtype=torch.int64, device=sorted_indices.device 

808 ) 

809 if num_rows == 0: 

810 return inverse_indices 

811 

812 inverse_in_sorted = torch.cumsum(is_first.to(torch.int64), dim=0) - 1 

813 grid = (triton.cdiv(num_rows, _UNIQUE_DIM_GATHER_BLOCK_SIZE), 1, 1) 

814 with torch_device_fn.device(sorted_indices.device.index): 

815 _unique_dim_inverse_kernel[grid]( 

816 sorted_indices, 

817 inverse_in_sorted, 

818 inverse_indices, 

819 num_rows, 

820 BLOCK_SIZE=_UNIQUE_DIM_GATHER_BLOCK_SIZE, 

821 num_warps=4, 

822 ) 

823 return inverse_indices 

824 

825 

826def _unique_dim_unique_indices( 

827 sorted_indices: torch.Tensor, 

828 is_first: torch.Tensor, 

829) -> torch.Tensor: 

830 """Original-space indices of the first row in each sorted group.""" 

831 first_positions = torch.nonzero(is_first, as_tuple=False).flatten() 

832 return _triton_gather_1d(sorted_indices, first_positions) 

833 

834 

835def _unique_dim_unique_indices_and_inverse( 

836 sorted_indices: torch.Tensor, 

837 is_first: torch.Tensor, 

838) -> tuple[torch.Tensor, torch.Tensor]: 

839 unique_indices = _unique_dim_unique_indices(sorted_indices, is_first) 

840 inverse_indices = _unique_dim_inverse(sorted_indices, is_first) 

841 return unique_indices, inverse_indices 

842 

843 

844def _unique_dim_counts( 

845 is_first: torch.Tensor, 

846 num_rows: int, 

847) -> torch.Tensor: 

848 first_positions = torch.nonzero(is_first, as_tuple=False).flatten() 

849 num_unique = first_positions.numel() 

850 counts = torch.empty(num_unique, dtype=torch.int64, device=is_first.device) 

851 if num_unique == 0: 

852 return counts 

853 

854 grid = (triton.cdiv(num_unique, _UNIQUE_DIM_GATHER_BLOCK_SIZE), 1, 1) 

855 with torch_device_fn.device(is_first.device.index): 

856 _unique_dim_counts_kernel[grid]( 

857 first_positions, 

858 counts, 

859 num_rows, 

860 num_unique, 

861 BLOCK_SIZE=_UNIQUE_DIM_GATHER_BLOCK_SIZE, 

862 num_warps=4, 

863 ) 

864 return counts 

865 

866 

867def unique_dim( 

868 input: torch.Tensor, 

869 dim: int, 

870 sorted: bool = True, 

871 return_inverse: bool = False, 

872 return_counts: bool = False, 

873): 

874 """Dimension-aware ``torch.unique`` (a.k.a. ``aten::unique_dim``). 

875 

876 Treats each slice along ``dim`` as a single element, returning the unique 

877 slices, an optional inverse mapping of shape ``(input.size(dim),)`` and an 

878 optional per-unique count tensor of shape ``(output.size(dim),)``. 

879 """ 

880 logger.debug("GEMS UNIQUE_DIM") 

881 

882 ndim = input.ndim if input.ndim > 0 else 1 

883 if dim < 0: 

884 dim += ndim 

885 if dim < 0 or dim >= max(input.ndim, 1): 

886 raise IndexError( 

887 f"Dimension out of range (expected to be in range of " 

888 f"[{-input.ndim}, {input.ndim - 1}], but got {dim})" 

889 ) 

890 

891 device = input.device 

892 size_dim = input.size(dim) if input.ndim > 0 else input.numel() 

893 

894 if size_dim == 0: 

895 output = input.clone() 

896 inverse_indices = torch.empty(0, dtype=torch.int64, device=device) 

897 counts = torch.empty(0, dtype=torch.int64, device=device) 

898 return output, inverse_indices, counts 

899 

900 moved = input.movedim(dim, 0).contiguous() 

901 flat = moved.reshape(size_dim, -1) 

902 

903 sorted_indices, all_unique = _lex_argsort_rows(flat) 

904 

905 inverse_indices = torch.empty(0, dtype=torch.int64, device=device) 

906 counts = torch.empty(0, dtype=torch.int64, device=device) 

907 

908 if all_unique: 

909 if return_counts: 

910 counts = torch.ones(size_dim, dtype=torch.int64, device=device) 

911 if return_inverse: 

912 inverse_indices = _unique_dim_inverse_from_permutation(sorted_indices) 

913 output = _unique_dim_gather_output(moved, sorted_indices, dim, input.shape) 

914 return output, inverse_indices, counts 

915 

916 is_first = _unique_dim_first_mask(flat, sorted_indices) 

917 if return_inverse: 

918 unique_in_orig, inverse_indices = _unique_dim_unique_indices_and_inverse( 

919 sorted_indices, 

920 is_first, 

921 ) 

922 else: 

923 unique_in_orig = _unique_dim_unique_indices(sorted_indices, is_first) 

924 

925 if return_counts: 

926 counts = _unique_dim_counts(is_first, size_dim) 

927 

928 output = _unique_dim_gather_output(moved, unique_in_orig, dim, input.shape) 

929 

930 return output, inverse_indices, counts