Coverage for src/flag_gems/ops/unique_dim.py: 54%
428 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import triton_lang_extension as ext
9from flag_gems.utils.libentry import libentry
11logger = logging.getLogger(__name__)
13_UNIQUE_DIM_COMPARE_BLOCK_SIZE = 1024
14_UNIQUE_DIM_GATHER_BLOCK_SIZE = 1024
15# Largest row count handled by the single-launch group-id scan kernel.
16_UNIQUE_DIM_GROUP_SCAN_BLOCK_SIZE = 4096
17# Largest key count sorted by the single-launch rank-sort kernel. Above this we
18# delegate to ``torch.sort`` which, under FlagGems op interception, dispatches to
19# the backend's Triton radix sort. Rank-sort is O(N^2) but a single launch, so it
20# is much cheaper than a 16-pass int64 radix sort for tiny shapes.
21_UNIQUE_DIM_RANK_SORT_MAX_KEYS = 2048
22_UNIQUE_DIM_HASH_MIN_ROW_LEN = 1024
23# Smaller tile for the fused key kernel's float branches: their int64 bit-twiddle
24# temporaries overflow the Ascend unified buffer at the default tile size.
25_UNIQUE_DIM_BUILD_KEY_FLOAT_BLOCK_SIZE = 256
28# Per-column bit budgets and to-int64 conversions that preserve the original
29# value ordering. The encodings let us pack a per-row ``group_id`` together
30# with a single column's key into one int64 that, when compared as signed
31# int64, matches the lex order over ``(group_id, signed_value)``.
32_INT_DTYPE_BITS = {
33 torch.bool: 1,
34 torch.int8: 8,
35 torch.uint8: 8,
36 torch.int16: 16,
37 torch.int32: 32,
38 torch.float16: 16,
39 torch.bfloat16: 16,
40 torch.float32: 32,
41}
44@libentry()
45@triton.jit
46def _unique_dim_argsort_rank_kernel(
47 keys_ptr: tl.tensor,
48 indices_ptr: tl.tensor,
49 sorted_keys_ptr: tl.tensor,
50 num_keys: int,
51 BLOCK_SIZE: tl.constexpr,
52 STORE_SORTED_KEYS: tl.constexpr,
53):
54 row = ext.program_id(0)
55 candidates = tl.arange(0, BLOCK_SIZE)
56 mask = candidates < num_keys
58 cur = tl.load(keys_ptr + row)
59 vals = tl.load(keys_ptr + candidates, mask=mask, other=cur)
60 before = ((vals < cur) | ((vals == cur) & (candidates < row))) & mask
61 rank = tl.sum(before.to(tl.int32), axis=0)
62 tl.store(indices_ptr + rank, row)
63 if STORE_SORTED_KEYS:
64 tl.store(sorted_keys_ptr + rank, cur)
67@libentry()
68@triton.jit
69def _unique_dim_gather_1d_kernel(
70 input_ptr: tl.tensor,
71 index_ptr: tl.tensor,
72 output_ptr: tl.tensor,
73 num_elements: int,
74 BLOCK_SIZE: tl.constexpr,
75):
76 pid = ext.program_id(0)
77 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
78 mask = offsets < num_elements
79 indices = tl.load(index_ptr + offsets, mask=mask, other=0)
80 values = tl.load(input_ptr + indices, mask=mask)
81 tl.store(output_ptr + offsets, values, mask=mask)
84@libentry()
85@triton.jit
86def _unique_dim_group_id_kernel(
87 composite_ptr: tl.tensor,
88 group_id_ptr: tl.tensor,
89 last_group_id_ptr: tl.tensor,
90 num_rows: int,
91 BLOCK_SIZE: tl.constexpr,
92):
93 offsets = tl.arange(0, BLOCK_SIZE)
94 mask = offsets < num_rows
95 cur = tl.load(composite_ptr + offsets, mask=mask, other=0)
96 prev_offsets = tl.where(offsets == 0, 0, offsets - 1)
97 prev = tl.load(composite_ptr + prev_offsets, mask=offsets > 0, other=cur)
98 diff = ((cur - prev) != 0) & mask
99 diff = tl.where(offsets == 0, False, diff)
100 group_id = tl.cumsum(diff.to(tl.int64), axis=0)
101 tl.store(group_id_ptr + offsets, group_id, mask=mask)
102 last = tl.sum(tl.where(offsets == num_rows - 1, group_id, 0), axis=0)
103 tl.store(last_group_id_ptr, last)
106@libentry()
107@triton.jit
108def _unique_dim_row_hash_chunk_kernel(
109 flat_ptr: tl.tensor,
110 chunk_hash_ptr: tl.tensor,
111 num_rows: int,
112 row_len: int,
113 num_chunks: int,
114 BLOCK_SIZE: tl.constexpr,
115):
116 row = ext.program_id(0)
117 chunk = ext.program_id(1)
118 offsets = chunk * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
119 mask = offsets < row_len
121 vals = tl.load(flat_ptr + row * row_len + offsets, mask=mask, other=0)
122 vals_i64 = vals.to(tl.int64)
123 offsets_i64 = offsets.to(tl.int64)
124 mix = (vals_i64 + (offsets_i64 + 1) * 1009 + 9176) * 131071
125 mix = tl.where(mask, mix, 0)
126 tl.store(chunk_hash_ptr + row * num_chunks + chunk, tl.sum(mix, axis=0))
129@libentry()
130@triton.jit
131def _unique_dim_row_hash_reduce_kernel(
132 chunk_hash_ptr: tl.tensor,
133 row_hash_ptr: tl.tensor,
134 num_chunks: int,
135 BLOCK_CHUNKS: tl.constexpr,
136):
137 row = ext.program_id(0)
138 chunks = tl.arange(0, BLOCK_CHUNKS)
139 mask = chunks < num_chunks
140 vals = tl.load(chunk_hash_ptr + row * num_chunks + chunks, mask=mask, other=0)
141 tl.store(row_hash_ptr + row, tl.sum(vals, axis=0))
144@libentry()
145@triton.jit
146def _unique_dim_row_chunk_diff_kernel(
147 flat_ptr: tl.tensor,
148 sorted_indices_ptr: tl.tensor,
149 row_chunk_diff_ptr: tl.tensor,
150 num_rows: int,
151 row_len: int,
152 num_chunks: int,
153 BLOCK_SIZE: tl.constexpr,
154):
155 row = ext.program_id(0)
156 chunk = ext.program_id(1)
157 offsets = chunk * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
158 mask = offsets < row_len
160 out = tl.full((), 0, dtype=tl.int32)
161 if row == 0:
162 out = tl.where(chunk == 0, 1, 0)
163 else:
164 cur_row = tl.load(sorted_indices_ptr + row)
165 prev_row = tl.load(sorted_indices_ptr + row - 1)
166 cur = tl.load(flat_ptr + cur_row * row_len + offsets, mask=mask)
167 prev = tl.load(flat_ptr + prev_row * row_len + offsets, mask=mask)
168 neq = (cur != prev) & mask
169 has_diff = tl.sum(neq.to(tl.int32), axis=0) != 0
170 out = has_diff.to(tl.int32)
171 tl.store(row_chunk_diff_ptr + row * num_chunks + chunk, out)
174@libentry()
175@triton.jit
176def _unique_dim_row_chunk_diff_hash_kernel(
177 flat_ptr: tl.tensor,
178 sorted_indices_ptr: tl.tensor,
179 row_hash_ptr: tl.tensor,
180 row_chunk_diff_ptr: tl.tensor,
181 num_rows: int,
182 row_len: int,
183 num_chunks: int,
184 BLOCK_SIZE: tl.constexpr,
185):
186 row = ext.program_id(0)
187 chunk = ext.program_id(1)
188 offsets = chunk * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
189 mask = offsets < row_len
191 out = tl.full((), 0, dtype=tl.int32)
192 if row == 0:
193 out = tl.where(chunk == 0, 1, 0)
194 else:
195 cur_row = tl.load(sorted_indices_ptr + row)
196 prev_row = tl.load(sorted_indices_ptr + row - 1)
197 cur_hash = tl.load(row_hash_ptr + cur_row)
198 prev_hash = tl.load(row_hash_ptr + prev_row)
199 if cur_hash != prev_hash:
200 out = tl.where(chunk == 0, 1, 0)
201 else:
202 cur = tl.load(flat_ptr + cur_row * row_len + offsets, mask=mask)
203 prev = tl.load(flat_ptr + prev_row * row_len + offsets, mask=mask)
204 neq = (cur != prev) & mask
205 has_diff = tl.sum(neq.to(tl.int32), axis=0) != 0
206 out = has_diff.to(tl.int32)
207 tl.store(row_chunk_diff_ptr + row * num_chunks + chunk, out)
210@libentry()
211@triton.jit
212def _unique_dim_row_diff_reduce_kernel(
213 row_chunk_diff_ptr: tl.tensor,
214 is_first_ptr: tl.tensor,
215 num_chunks: int,
216 BLOCK_CHUNKS: tl.constexpr,
217):
218 row = ext.program_id(0)
219 chunks = tl.arange(0, BLOCK_CHUNKS)
220 mask = chunks < num_chunks
221 vals = tl.load(row_chunk_diff_ptr + row * num_chunks + chunks, mask=mask, other=0)
222 tl.store(is_first_ptr + row, tl.sum(vals, axis=0) != 0)
225@libentry()
226@triton.jit
227def _unique_dim_row_single_chunk_first_kernel(
228 flat_ptr: tl.tensor,
229 sorted_indices_ptr: tl.tensor,
230 is_first_ptr: tl.tensor,
231 num_rows: int,
232 row_len: int,
233 BLOCK_SIZE: tl.constexpr,
234):
235 row = ext.program_id(0)
236 offsets = tl.arange(0, BLOCK_SIZE)
237 mask = offsets < row_len
239 out = tl.full((), True, dtype=tl.int1)
240 if row != 0:
241 cur_row = tl.load(sorted_indices_ptr + row)
242 prev_row = tl.load(sorted_indices_ptr + row - 1)
243 cur = tl.load(flat_ptr + cur_row * row_len + offsets, mask=mask)
244 prev = tl.load(flat_ptr + prev_row * row_len + offsets, mask=mask)
245 neq = (cur != prev) & mask
246 out = tl.sum(neq.to(tl.int32), axis=0) != 0
247 tl.store(is_first_ptr + row, out)
250@libentry()
251@triton.jit
252def _unique_dim_gather_moved_kernel(
253 flat_ptr: tl.tensor,
254 unique_indices_ptr: tl.tensor,
255 output_ptr: tl.tensor,
256 num_unique: int,
257 row_len: int,
258 BLOCK_SIZE: tl.constexpr,
259):
260 # One program per (output row, column chunk). Copies a contiguous span of
261 # the source row selected through ``unique_indices`` into the matching span
262 # of the output row. Loading ``src_row`` once per program (scalar) and using
263 # contiguous column offsets avoids the per-element integer divide/modulo and
264 # scattered indexing of a flat-offset gather, which dominate NPU time.
265 row = ext.program_id(0)
266 chunk = ext.program_id(1)
267 col = chunk * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
268 mask = col < row_len
270 src_row = tl.load(unique_indices_ptr + row)
271 values = tl.load(flat_ptr + src_row * row_len + col, mask=mask)
272 tl.store(output_ptr + row * row_len + col, values, mask=mask)
275@libentry()
276@triton.jit
277def _unique_dim_inverse_permutation_kernel(
278 sorted_indices_ptr: tl.tensor,
279 inverse_ptr: tl.tensor,
280 num_rows: int,
281 BLOCK_SIZE: tl.constexpr,
282):
283 pid = ext.program_id(0)
284 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
285 mask = offsets < num_rows
286 sorted_indices = tl.load(sorted_indices_ptr + offsets, mask=mask, other=0)
287 tl.store(inverse_ptr + sorted_indices, offsets.to(tl.int64), mask=mask)
290@libentry()
291@triton.jit
292def _unique_dim_inverse_kernel(
293 sorted_indices_ptr: tl.tensor,
294 inverse_sorted_ptr: tl.tensor,
295 inverse_ptr: tl.tensor,
296 num_rows: int,
297 BLOCK_SIZE: tl.constexpr,
298):
299 pid = ext.program_id(0)
300 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
301 mask = offsets < num_rows
302 sorted_indices = tl.load(sorted_indices_ptr + offsets, mask=mask)
303 inverse_sorted = tl.load(inverse_sorted_ptr + offsets, mask=mask)
304 tl.store(inverse_ptr + sorted_indices, inverse_sorted, mask=mask)
307@libentry()
308@triton.jit
309def _unique_dim_counts_kernel(
310 first_positions_ptr: tl.tensor,
311 counts_ptr: tl.tensor,
312 num_rows: int,
313 num_unique: int,
314 BLOCK_SIZE: tl.constexpr,
315):
316 pid = ext.program_id(0)
317 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
318 mask = offsets < num_unique
319 positions = tl.load(first_positions_ptr + offsets, mask=mask)
320 next_positions = tl.load(
321 first_positions_ptr + offsets + 1,
322 mask=(offsets + 1) < num_unique,
323 other=num_rows,
324 )
325 tl.store(counts_ptr + offsets, next_positions - positions, mask=mask)
328def _triton_num_warps(block_size: int) -> int:
329 if block_size >= 8192:
330 return 8
331 if block_size >= 2048:
332 return 4
333 return 1
336def _monotonic_key_bits(dtype: torch.dtype):
337 """Return the per-element key width for ``dtype`` if it can be mapped
338 into a monotonic int64 view, else ``None``."""
339 return _INT_DTYPE_BITS.get(dtype)
342# Monotonic-remap kinds for the fused key-build kernel.
343_REMAP_INT = 0 # signed/unsigned int: value + KEY_OFFSET
344_REMAP_FP16 = 1 # 16-bit float: order-preserving bit twiddle
345_REMAP_FP32 = 2 # 32-bit float: order-preserving bit twiddle
348def _remap_info(flat: torch.Tensor):
349 """Return ``(int_view, remap_kind, key_offset)`` describing how to map this
350 dtype to an order-preserving non-negative int64 in the fused key kernel.
352 ``int_view`` reinterprets the buffer as an integer type the kernel can load
353 directly (floats are bit-cast); the remap itself happens on-device.
354 """
355 dt = flat.dtype
356 if dt == torch.bool:
357 return flat.view(torch.uint8), _REMAP_INT, 0
358 if dt == torch.uint8:
359 return flat, _REMAP_INT, 0
360 if dt == torch.int8:
361 return flat, _REMAP_INT, 1 << 7
362 if dt == torch.int16:
363 return flat, _REMAP_INT, 1 << 15
364 if dt == torch.int32:
365 return flat, _REMAP_INT, 1 << 31
366 if dt in (torch.float16, torch.bfloat16):
367 return flat.view(torch.int16), _REMAP_FP16, 0
368 if dt == torch.float32:
369 return flat.view(torch.int32), _REMAP_FP32, 0
370 raise NotImplementedError(dt)
373@libentry()
374@triton.jit
375def _unique_dim_build_key_kernel(
376 flat_ptr: tl.tensor,
377 indices_ptr: tl.tensor,
378 group_id_ptr: tl.tensor,
379 out_ptr: tl.tensor,
380 num_rows: int,
381 row_stride: int,
382 col: int,
383 KEY_OFFSET: tl.constexpr,
384 KEY_SCALE: tl.constexpr,
385 REMAP_KIND: tl.constexpr,
386 FIRST: tl.constexpr,
387 BLOCK_SIZE: tl.constexpr,
388):
389 """Build one cascade pass' composite key in a single launch.
391 For ``FIRST`` (first column) the key is just the column's monotonic remap.
392 Otherwise the row is fetched through the current permutation ``indices`` and
393 the running ``group_id`` prefix is folded in as ``group_id * key_scale +
394 value`` (multiply/add rather than shift/or, matching the rest of the file).
396 This fuses what was a ``select -> contiguous -> cast -> add -> gather ->
397 mul -> add`` chain of separate ops into one kernel, which is the dominant
398 per-pass host/launch cost on backends with a native sort.
399 """
400 pid = ext.program_id(0)
401 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
402 mask = offsets < num_rows
404 if FIRST:
405 row = offsets.to(tl.int64)
406 else:
407 row = tl.load(indices_ptr + offsets, mask=mask, other=0)
408 base = row * row_stride + col
410 if REMAP_KIND == 0: # _REMAP_INT
411 x = tl.load(flat_ptr + base, mask=mask, other=0).to(tl.int64)
412 val = x + KEY_OFFSET
413 elif REMAP_KIND == 1: # _REMAP_FP16
414 bits = tl.load(flat_ptr + base, mask=mask, other=0).to(tl.int64) & 0xFFFF
415 sign = (bits & 0x8000) != 0
416 val = tl.where(sign, bits ^ 0xFFFF, bits ^ 0x8000)
417 else: # _REMAP_FP32
418 bits = tl.load(flat_ptr + base, mask=mask, other=0).to(tl.int64) & 0xFFFFFFFF
419 sign = (bits & 0x80000000) != 0
420 val = tl.where(sign, bits ^ 0xFFFFFFFF, bits ^ 0x80000000)
422 if FIRST:
423 out = val
424 else:
425 gid = tl.load(group_id_ptr + offsets, mask=mask, other=0)
426 out = gid * KEY_SCALE + val
427 tl.store(out_ptr + offsets, out, mask=mask)
430def _build_composite_key(
431 flat_view: torch.Tensor,
432 col: int,
433 indices: torch.Tensor | None,
434 group_id: torch.Tensor | None,
435 num_rows: int,
436 row_stride: int,
437 key_offset: int,
438 key_scale: int,
439 remap_kind: int,
440) -> torch.Tensor:
441 """One-launch composite key for cascade pass ``col``.
443 ``indices``/``group_id`` are ``None`` on the first pass; otherwise they are
444 the current permutation and running group ids.
445 """
446 out = torch.empty(num_rows, dtype=torch.int64, device=flat_view.device)
447 first = indices is None
448 # Triton needs valid tensor handles even for the unused pointers on the
449 # first pass; the kernel guards their loads behind ``FIRST``.
450 indices_arg = flat_view if first else indices
451 group_id_arg = flat_view if first else group_id
452 # The float bit-twiddle branches allocate several int64 temporaries per
453 # element; at the default tile this overflows the Ascend unified buffer, so
454 # floats use a smaller tile. Integer remap is light and keeps the full tile.
455 block_size = (
456 _UNIQUE_DIM_GATHER_BLOCK_SIZE
457 if remap_kind == _REMAP_INT
458 else _UNIQUE_DIM_BUILD_KEY_FLOAT_BLOCK_SIZE
459 )
460 grid = (triton.cdiv(num_rows, block_size), 1, 1)
461 with torch_device_fn.device(flat_view.device.index):
462 _unique_dim_build_key_kernel[grid](
463 flat_view,
464 indices_arg,
465 group_id_arg,
466 out,
467 num_rows,
468 row_stride,
469 col,
470 KEY_OFFSET=key_offset,
471 KEY_SCALE=key_scale,
472 REMAP_KIND=remap_kind,
473 FIRST=first,
474 BLOCK_SIZE=block_size,
475 num_warps=4,
476 )
477 return out
480def _triton_gather_1d(values: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
481 num_elements = indices.numel()
482 output = torch.empty(num_elements, dtype=values.dtype, device=values.device)
483 if num_elements == 0:
484 return output
485 grid = (triton.cdiv(num_elements, _UNIQUE_DIM_GATHER_BLOCK_SIZE), 1, 1)
486 with torch_device_fn.device(values.device.index):
487 _unique_dim_gather_1d_kernel[grid](
488 values,
489 indices,
490 output,
491 num_elements,
492 BLOCK_SIZE=_UNIQUE_DIM_GATHER_BLOCK_SIZE,
493 num_warps=4,
494 )
495 return output
498def _argsort_keys(keys: torch.Tensor):
499 """Stable ascending argsort of a 1D key tensor.
501 Returns ``(perm, sorted_keys)`` where ``sorted_keys = keys[perm]``.
503 Small key counts use the single-launch Triton rank-sort kernel (cheap for
504 tiny shapes). Larger counts delegate to ``torch.sort``; under FlagGems op
505 interception this dispatches to the backend's Triton radix sort.
506 """
507 num_keys = keys.numel()
508 if num_keys == 0:
509 return torch.empty(0, dtype=torch.int64, device=keys.device), keys
510 if num_keys <= _UNIQUE_DIM_RANK_SORT_MAX_KEYS:
511 perm = torch.empty(num_keys, dtype=torch.int64, device=keys.device)
512 sorted_keys = torch.empty_like(keys)
513 block_size = triton.next_power_of_2(num_keys)
514 with torch_device_fn.device(keys.device.index):
515 _unique_dim_argsort_rank_kernel[(num_keys, 1, 1)](
516 keys.contiguous(),
517 perm,
518 sorted_keys,
519 num_keys,
520 BLOCK_SIZE=block_size,
521 STORE_SORTED_KEYS=True,
522 num_warps=_triton_num_warps(block_size),
523 )
524 return perm, sorted_keys
525 sorted_keys, perm = torch.sort(keys)
526 return perm, sorted_keys
529def _group_id_from_sorted(sorted_keys: torch.Tensor):
530 """Dense lexicographic group ids for an ascending key tensor.
532 Returns ``(group_id, last_group_id)`` where ``group_id[i]`` is the count of
533 distinct key values strictly before position ``i`` and ``last_group_id`` is
534 the (host-side) value of ``group_id[-1]`` (or ``-1`` when empty).
536 Small row counts use the single-launch scan kernel. Larger counts use a
537 safe ``int64`` adjacent-difference followed by ``torch.cumsum`` (a FlagGems
538 multi-block scan under op interception). The difference is computed as
539 ``int64 - int64`` then ``!= 0`` against a scalar; running through the
540 registered tensor-vs-tensor comparison op would route int64 through float32
541 and lose precision around ``2**24``.
542 """
543 num_rows = sorted_keys.numel()
544 device = sorted_keys.device
545 if num_rows == 0:
546 return torch.empty(0, dtype=torch.int64, device=device), -1
547 if num_rows <= _UNIQUE_DIM_GROUP_SCAN_BLOCK_SIZE:
548 group_id = torch.empty(num_rows, dtype=torch.int64, device=device)
549 last_group_id = torch.empty((), dtype=torch.int64, device=device)
550 block_size = triton.next_power_of_2(num_rows)
551 with torch_device_fn.device(device.index):
552 _unique_dim_group_id_kernel[(1, 1, 1)](
553 sorted_keys,
554 group_id,
555 last_group_id,
556 num_rows,
557 BLOCK_SIZE=block_size,
558 num_warps=_triton_num_warps(block_size),
559 )
560 return group_id, int(last_group_id.item())
562 diff = ((sorted_keys[1:] - sorted_keys[:-1]) != 0).to(torch.int64)
563 group_id = torch.cat(
564 [
565 torch.zeros(1, dtype=torch.int64, device=device),
566 torch.cumsum(diff, dim=0),
567 ]
568 )
569 return group_id, int(group_id[-1].item())
572def _lex_argsort_rows_composite(flat: torch.Tensor):
573 """Lex-sort rows by packing ``(group_id, monotonic_key)`` per column.
575 Mirrors the way ATen's CUDA ``unique_dim`` does a single comparator-driven
576 sort: each cascade step performs *one* argsort on an int64 key that encodes
577 the "current lex prefix" in the high bits and "this column's value" in the
578 low bits. As soon as every row has a unique prefix we terminate; for random
579 data this happens after one or two columns even when ``M`` is large,
580 replacing ``M`` argsorts with a small constant.
581 """
582 key_bits = _monotonic_key_bits(flat.dtype)
583 if key_bits is None:
584 return None
586 num_rows, num_cols = flat.shape
587 device = flat.device
588 if num_cols == 0:
589 indices = torch.arange(num_rows, dtype=torch.int64, device=device)
590 return indices, False
591 if num_rows <= 1:
592 indices = torch.arange(num_rows, dtype=torch.int64, device=device)
593 return indices, True
595 key_scale = 1 << key_bits
596 flat_view, remap_kind, key_offset = _remap_info(flat)
597 indices = None
598 group_id = None
599 all_unique = False
600 for col in range(num_cols):
601 # One fused launch builds ``group_id * key_scale + monotonic(value)``,
602 # gathering through the current permutation when ``col > 0``.
603 keys = _build_composite_key(
604 flat_view,
605 col,
606 indices,
607 group_id,
608 num_rows,
609 num_cols,
610 key_offset,
611 key_scale,
612 remap_kind,
613 )
614 perm, sorted_keys = _argsort_keys(keys)
615 indices = perm if col == 0 else _triton_gather_1d(indices, perm)
616 group_id, last_group_id = _group_id_from_sorted(sorted_keys)
617 # Early termination: every row already has a unique lex prefix.
618 if last_group_id == num_rows - 1:
619 all_unique = True
620 break
621 return indices, all_unique
624def _lex_argsort_rows_cascade(flat: torch.Tensor) -> torch.Tensor:
625 """Generic-dtype fallback: cascade of stable argsorts, least to most
626 significant column. ``O(M)`` argsorts of length ``D`` with ``O(D)`` memory
627 traffic per step. Used for dtypes without a monotonic int64 remap."""
628 num_rows, num_cols = flat.shape
629 indices = torch.arange(num_rows, dtype=torch.int64, device=flat.device)
630 if num_rows <= 1 or num_cols == 0:
631 return indices
632 flat_t = flat.t().contiguous()
633 for col in range(num_cols - 1, -1, -1):
634 keys = _triton_gather_1d(flat_t[col], indices)
635 # LSD cascade requires a stable sort to preserve previous-column order.
636 _, perm = torch.sort(keys, stable=True)
637 indices = _triton_gather_1d(indices, perm)
638 return indices
641def _lex_argsort_rows(flat: torch.Tensor) -> tuple[torch.Tensor, bool]:
642 """Return indices that sort rows of a 2D tensor lexicographically."""
643 composite = _lex_argsort_rows_composite(flat)
644 if composite is not None:
645 return composite
646 return _lex_argsort_rows_cascade(flat), False
649def _unique_dim_row_hash(flat: torch.Tensor) -> torch.Tensor:
650 num_rows, row_len = flat.shape
651 block_size = min(_UNIQUE_DIM_COMPARE_BLOCK_SIZE, triton.next_power_of_2(row_len))
652 num_chunks = triton.cdiv(row_len, block_size)
653 chunk_hash = torch.empty(
654 (num_rows, num_chunks), dtype=torch.int64, device=flat.device
655 )
656 row_hash = torch.empty(num_rows, dtype=torch.int64, device=flat.device)
657 with torch_device_fn.device(flat.device.index):
658 _unique_dim_row_hash_chunk_kernel[(num_rows, num_chunks, 1)](
659 flat,
660 chunk_hash,
661 num_rows,
662 row_len,
663 num_chunks,
664 BLOCK_SIZE=block_size,
665 num_warps=_triton_num_warps(block_size),
666 )
667 _unique_dim_row_hash_reduce_kernel[(num_rows, 1, 1)](
668 chunk_hash,
669 row_hash,
670 num_chunks,
671 BLOCK_CHUNKS=triton.next_power_of_2(num_chunks),
672 num_warps=_triton_num_warps(triton.next_power_of_2(num_chunks)),
673 )
674 return row_hash
677def _unique_dim_first_mask(flat: torch.Tensor, sorted_indices: torch.Tensor):
678 """Return a bool mask for first rows in sorted lexicographic groups."""
679 num_rows, row_len = flat.shape
680 if num_rows == 1 or row_len == 0:
681 is_first = torch.zeros(num_rows, dtype=torch.bool, device=flat.device)
682 is_first[0] = True
683 return is_first
685 block_size = min(_UNIQUE_DIM_COMPARE_BLOCK_SIZE, triton.next_power_of_2(row_len))
686 num_chunks = triton.cdiv(row_len, block_size)
687 is_first = torch.empty(num_rows, dtype=torch.bool, device=flat.device)
688 if num_chunks == 1:
689 with torch_device_fn.device(flat.device.index):
690 _unique_dim_row_single_chunk_first_kernel[(num_rows, 1, 1)](
691 flat,
692 sorted_indices,
693 is_first,
694 num_rows,
695 row_len,
696 BLOCK_SIZE=block_size,
697 num_warps=_triton_num_warps(block_size),
698 )
699 return is_first
701 row_chunk_diff = torch.empty(
702 (num_rows, num_chunks), dtype=torch.int32, device=flat.device
703 )
704 grid = (num_rows, num_chunks, 1)
705 row_hash = (
706 _unique_dim_row_hash(flat) if row_len >= _UNIQUE_DIM_HASH_MIN_ROW_LEN else None
707 )
708 with torch_device_fn.device(flat.device.index):
709 if row_hash is None:
710 _unique_dim_row_chunk_diff_kernel[grid](
711 flat,
712 sorted_indices,
713 row_chunk_diff,
714 num_rows,
715 row_len,
716 num_chunks,
717 BLOCK_SIZE=block_size,
718 num_warps=_triton_num_warps(block_size),
719 )
720 else:
721 _unique_dim_row_chunk_diff_hash_kernel[grid](
722 flat,
723 sorted_indices,
724 row_hash,
725 row_chunk_diff,
726 num_rows,
727 row_len,
728 num_chunks,
729 BLOCK_SIZE=block_size,
730 num_warps=_triton_num_warps(block_size),
731 )
732 _unique_dim_row_diff_reduce_kernel[(num_rows, 1, 1)](
733 row_chunk_diff,
734 is_first,
735 num_chunks,
736 BLOCK_CHUNKS=triton.next_power_of_2(num_chunks),
737 num_warps=_triton_num_warps(triton.next_power_of_2(num_chunks)),
738 )
739 return is_first
742def _unique_dim_gather_output(
743 moved: torch.Tensor,
744 unique_indices: torch.Tensor,
745 dim: int,
746 input_shape: torch.Size,
747) -> torch.Tensor:
748 num_unique = unique_indices.numel()
749 output_shape = (
750 tuple(input_shape[:dim]) + (num_unique,) + tuple(input_shape[dim + 1 :])
751 )
752 if num_unique == 0:
753 return torch.empty(output_shape, dtype=moved.dtype, device=moved.device)
755 row_len = moved[0].numel()
756 flat = moved.reshape(moved.shape[0], row_len)
757 moved_output = torch.empty(
758 (num_unique,) + tuple(moved.shape[1:]),
759 dtype=moved.dtype,
760 device=moved.device,
761 )
762 num_chunks = triton.cdiv(row_len, _UNIQUE_DIM_GATHER_BLOCK_SIZE)
763 grid = (num_unique, num_chunks, 1)
764 with torch_device_fn.device(moved.device.index):
765 _unique_dim_gather_moved_kernel[grid](
766 flat,
767 unique_indices,
768 moved_output,
769 num_unique,
770 row_len,
771 BLOCK_SIZE=_UNIQUE_DIM_GATHER_BLOCK_SIZE,
772 num_warps=4,
773 )
774 return moved_output.movedim(0, dim)
777def _unique_dim_inverse_from_permutation(sorted_indices: torch.Tensor) -> torch.Tensor:
778 """Inverse mapping for the all-unique case: ``inverse[sorted_indices[i]] = i``.
780 A plain 1D scatter (no per-element column predicate), which is correct on
781 every backend; the fused gather+scatter variant miscompiles its masked
782 inverse store on some Ascend/NPU backends.
783 """
784 num_rows = sorted_indices.numel()
785 inverse_indices = torch.empty_like(sorted_indices)
786 if num_rows == 0:
787 return inverse_indices
788 grid = (triton.cdiv(num_rows, _UNIQUE_DIM_GATHER_BLOCK_SIZE), 1, 1)
789 with torch_device_fn.device(sorted_indices.device.index):
790 _unique_dim_inverse_permutation_kernel[grid](
791 sorted_indices,
792 inverse_indices,
793 num_rows,
794 BLOCK_SIZE=_UNIQUE_DIM_GATHER_BLOCK_SIZE,
795 num_warps=4,
796 )
797 return inverse_indices
800def _unique_dim_inverse(
801 sorted_indices: torch.Tensor,
802 is_first: torch.Tensor,
803) -> torch.Tensor:
804 """Inverse mapping: scatter dense group ids back to original positions."""
805 num_rows = sorted_indices.numel()
806 inverse_indices = torch.empty(
807 num_rows, dtype=torch.int64, device=sorted_indices.device
808 )
809 if num_rows == 0:
810 return inverse_indices
812 inverse_in_sorted = torch.cumsum(is_first.to(torch.int64), dim=0) - 1
813 grid = (triton.cdiv(num_rows, _UNIQUE_DIM_GATHER_BLOCK_SIZE), 1, 1)
814 with torch_device_fn.device(sorted_indices.device.index):
815 _unique_dim_inverse_kernel[grid](
816 sorted_indices,
817 inverse_in_sorted,
818 inverse_indices,
819 num_rows,
820 BLOCK_SIZE=_UNIQUE_DIM_GATHER_BLOCK_SIZE,
821 num_warps=4,
822 )
823 return inverse_indices
826def _unique_dim_unique_indices(
827 sorted_indices: torch.Tensor,
828 is_first: torch.Tensor,
829) -> torch.Tensor:
830 """Original-space indices of the first row in each sorted group."""
831 first_positions = torch.nonzero(is_first, as_tuple=False).flatten()
832 return _triton_gather_1d(sorted_indices, first_positions)
835def _unique_dim_unique_indices_and_inverse(
836 sorted_indices: torch.Tensor,
837 is_first: torch.Tensor,
838) -> tuple[torch.Tensor, torch.Tensor]:
839 unique_indices = _unique_dim_unique_indices(sorted_indices, is_first)
840 inverse_indices = _unique_dim_inverse(sorted_indices, is_first)
841 return unique_indices, inverse_indices
844def _unique_dim_counts(
845 is_first: torch.Tensor,
846 num_rows: int,
847) -> torch.Tensor:
848 first_positions = torch.nonzero(is_first, as_tuple=False).flatten()
849 num_unique = first_positions.numel()
850 counts = torch.empty(num_unique, dtype=torch.int64, device=is_first.device)
851 if num_unique == 0:
852 return counts
854 grid = (triton.cdiv(num_unique, _UNIQUE_DIM_GATHER_BLOCK_SIZE), 1, 1)
855 with torch_device_fn.device(is_first.device.index):
856 _unique_dim_counts_kernel[grid](
857 first_positions,
858 counts,
859 num_rows,
860 num_unique,
861 BLOCK_SIZE=_UNIQUE_DIM_GATHER_BLOCK_SIZE,
862 num_warps=4,
863 )
864 return counts
867def unique_dim(
868 input: torch.Tensor,
869 dim: int,
870 sorted: bool = True,
871 return_inverse: bool = False,
872 return_counts: bool = False,
873):
874 """Dimension-aware ``torch.unique`` (a.k.a. ``aten::unique_dim``).
876 Treats each slice along ``dim`` as a single element, returning the unique
877 slices, an optional inverse mapping of shape ``(input.size(dim),)`` and an
878 optional per-unique count tensor of shape ``(output.size(dim),)``.
879 """
880 logger.debug("GEMS UNIQUE_DIM")
882 ndim = input.ndim if input.ndim > 0 else 1
883 if dim < 0:
884 dim += ndim
885 if dim < 0 or dim >= max(input.ndim, 1):
886 raise IndexError(
887 f"Dimension out of range (expected to be in range of "
888 f"[{-input.ndim}, {input.ndim - 1}], but got {dim})"
889 )
891 device = input.device
892 size_dim = input.size(dim) if input.ndim > 0 else input.numel()
894 if size_dim == 0:
895 output = input.clone()
896 inverse_indices = torch.empty(0, dtype=torch.int64, device=device)
897 counts = torch.empty(0, dtype=torch.int64, device=device)
898 return output, inverse_indices, counts
900 moved = input.movedim(dim, 0).contiguous()
901 flat = moved.reshape(size_dim, -1)
903 sorted_indices, all_unique = _lex_argsort_rows(flat)
905 inverse_indices = torch.empty(0, dtype=torch.int64, device=device)
906 counts = torch.empty(0, dtype=torch.int64, device=device)
908 if all_unique:
909 if return_counts:
910 counts = torch.ones(size_dim, dtype=torch.int64, device=device)
911 if return_inverse:
912 inverse_indices = _unique_dim_inverse_from_permutation(sorted_indices)
913 output = _unique_dim_gather_output(moved, sorted_indices, dim, input.shape)
914 return output, inverse_indices, counts
916 is_first = _unique_dim_first_mask(flat, sorted_indices)
917 if return_inverse:
918 unique_in_orig, inverse_indices = _unique_dim_unique_indices_and_inverse(
919 sorted_indices,
920 is_first,
921 )
922 else:
923 unique_in_orig = _unique_dim_unique_indices(sorted_indices, is_first)
925 if return_counts:
926 counts = _unique_dim_counts(is_first, size_dim)
928 output = _unique_dim_gather_output(moved, unique_in_orig, dim, input.shape)
930 return output, inverse_indices, counts