Coverage for src/flag_gems/ops/unique_consecutive.py: 50%
168 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import triton_lang_extension as ext
9from flag_gems.utils.libentry import libentry
11logger = logging.getLogger(__name__)
14@libentry()
15@triton.jit
16def simple_unique_consecutive_flat_kernel(
17 data_ptr: tl.tensor, # in
18 data_out_ptr: tl.tensor,
19 inverse_indices_ptr: tl.tensor,
20 idx_ptr: tl.tensor,
21 unique_size_ptr: tl.tensor, # out
22 return_inverse: tl.constexpr,
23 return_counts: tl.constexpr,
24 num_tasks: int,
25 tile_size: tl.constexpr,
26):
27 """Simple kernel for small inputs that fits in a single tile."""
28 i0 = tl.arange(0, tile_size)
29 mask = i0 < num_tasks
31 # load current and previous elements
32 a = tl.load(data_ptr + i0, mask=mask)
33 i0_prev = tl.where(i0 > 0, i0 - 1, 0)
34 b = tl.load(data_ptr + i0_prev, mask=mask)
36 # Check if element differs from previous (first element always starts a new group)
37 ne_result = tl.where(i0 > 0, a != b, 1)
38 cumsum = tl.cumsum(ne_result)
40 # cumsum gives us 1-indexed positions, we want 0-indexed
41 out_idx = cumsum - 1
43 # unique_size is the last cumsum value
44 unique_size_mask = i0 == num_tasks - 1
45 tl.store(unique_size_ptr + tl.zeros_like(i0), cumsum, mask=unique_size_mask)
47 # data_out: scatter unique values to their output positions
48 # Only write when this is the first element of a consecutive group
49 write_mask = ne_result.to(tl.int1) & mask
50 tl.store(data_out_ptr + out_idx, a, mask=write_mask)
52 # inverse_indices: each input position maps to its output position
53 if return_inverse:
54 tl.store(inverse_indices_ptr + i0, out_idx, mask=mask)
56 # idx: store the starting position of each unique group
57 if return_counts:
58 tl.store(idx_ptr + out_idx, i0, mask=write_mask)
61@triton.jit
62def output_counts_impl(
63 global_pid,
64 idx_ptr: tl.tensor,
65 origin_num_tasks: int, # in
66 counts_ptr: tl.tensor, # out
67 num_tasks: int,
68 tile_size: tl.constexpr,
69):
70 """Compute counts from idx positions."""
71 r = tl.arange(0, tile_size)
72 i0 = global_pid * tile_size + r
73 mask = i0 < num_tasks
75 # load idx
76 idx = tl.load(idx_ptr + i0, mask=mask)
78 # load idx_next
79 i0_next = i0 + 1
80 next_mask = i0_next < num_tasks
81 idx_next = tl.load(idx_ptr + i0_next, mask=next_mask)
83 # counts = next_idx - current_idx (or total - current_idx for last element)
84 counts = tl.where(i0_next < num_tasks, idx_next - idx, origin_num_tasks - idx)
86 # store counts
87 tl.store(counts_ptr + i0, counts, mask=mask)
90@libentry()
91@triton.jit
92def output_counts_kernel(
93 idx_ptr: tl.tensor,
94 origin_num_tasks: int, # in
95 counts_ptr: tl.tensor, # out
96 num_tasks: int,
97 tiles_per_cta: int,
98 tile_size: tl.constexpr,
99):
100 pid = ext.program_id(0)
101 ctas_num = ext.num_programs(0)
102 for j in range(0, tiles_per_cta):
103 global_pid = pid + j * ctas_num
104 output_counts_impl(
105 global_pid,
106 idx_ptr,
107 origin_num_tasks,
108 counts_ptr,
109 num_tasks,
110 tile_size,
111 )
114@triton.jit
115def local_ne_consecutive_impl(
116 global_pid,
117 data_ptr: tl.tensor, # in
118 ne_result_ptr: tl.tensor,
119 tile_sum_ptr: tl.tensor, # out
120 global_ctas_num: int,
121 num_tasks: int,
122 tile_size: tl.constexpr,
123):
124 """Compute ne_result (whether each element differs from previous) for a tile."""
125 r = tl.arange(0, tile_size)
126 i0 = global_pid * tile_size + r
127 mask = i0 < num_tasks
128 i0_prev = tl.where(i0 > 0, i0 - 1, 0)
130 # load current and previous
131 a = tl.load(data_ptr + i0, mask=mask)
132 b = tl.load(data_ptr + i0_prev, mask=mask)
134 # compute ne_result
135 ne_result = tl.where(i0 > 0, a != b, 1)
137 # store ne_result
138 tl.store(ne_result_ptr + i0, ne_result, mask=mask)
140 # store tile_sum
141 tile_sum = tl.sum(ne_result)
142 tile_sum_mask = global_pid < global_ctas_num
143 tl.store(tile_sum_ptr + global_pid, tile_sum, mask=tile_sum_mask)
146@libentry()
147@triton.jit
148def local_ne_consecutive_kernel(
149 data_ptr: tl.tensor, # in
150 ne_result_ptr: tl.tensor,
151 tile_sum_ptr: tl.tensor, # out
152 global_ctas_num: int,
153 num_tasks: int,
154 tiles_per_cta: int,
155 tile_size: tl.constexpr,
156):
157 pid = ext.program_id(0)
158 ctas_num = ext.num_programs(0)
159 for j in range(0, tiles_per_cta):
160 global_pid = pid + j * ctas_num
161 local_ne_consecutive_impl(
162 global_pid,
163 data_ptr,
164 ne_result_ptr,
165 tile_sum_ptr,
166 global_ctas_num,
167 num_tasks,
168 tile_size,
169 )
172@triton.jit
173def global_cumsum_consecutive_impl(
174 global_pid,
175 total,
176 ne_result_ptr: tl.tensor,
177 tile_sum_ptr: tl.tensor, # in
178 data_ptr: tl.tensor, # in
179 data_out_ptr: tl.tensor,
180 inverse_indices_ptr: tl.tensor,
181 idx_ptr: tl.tensor, # out
182 ctas_num: tl.constexpr,
183 global_ctas_num: int,
184 next_power_global_ctas_num: tl.constexpr,
185 num_tasks: int,
186 tile_size: tl.constexpr,
187 return_inverse: tl.constexpr,
188 return_counts: tl.constexpr,
189):
190 """Compute global cumsum and scatter outputs."""
191 offset = global_pid * tile_size
192 r = tl.arange(0, tile_size)
193 i0 = offset + r
194 mask = i0 < num_tasks
196 # load data
197 data = tl.load(data_ptr + i0, mask=mask)
199 # load tile_sum for previous tiles
200 p = tl.arange(0, next_power_global_ctas_num)
201 pre_tile_sum_mask = (
202 (p >= global_pid - ctas_num)
203 & (p < global_pid)
204 & (p >= 0)
205 & (p < global_ctas_num)
206 )
207 pre_tile_sum = tl.load(tile_sum_ptr + p, mask=pre_tile_sum_mask, other=0)
209 # cumsum within tile
210 total += tl.sum(pre_tile_sum)
211 ne_result = tl.load(ne_result_ptr + i0, mask=mask)
212 ne_result_i1 = ne_result.to(tl.int1)
213 ne_result_i32 = ne_result.to(tl.int32)
214 cumsum = tl.cumsum(ne_result_i32)
216 # Store final tile sum for the last tile
217 if global_pid == global_ctas_num - 1:
218 last_tile_sum_mask = i0 == num_tasks - 1
219 final_tile_sum = tl.where(last_tile_sum_mask, total + cumsum, cumsum)
220 tl.store(
221 tile_sum_ptr + global_pid + tl.zeros_like(r),
222 final_tile_sum,
223 mask=last_tile_sum_mask,
224 )
225 cumsum += total
227 # output index (0-indexed)
228 out_idx = cumsum - 1
230 # data_out: scatter unique values (only first element of each consecutive group)
231 tl.store(data_out_ptr + out_idx, data, mask=ne_result_i1 & mask)
233 # inverse_indices: each input position maps to its output index
234 if return_inverse:
235 tl.store(inverse_indices_ptr + i0, out_idx, mask=mask)
237 # idx: store starting position of each unique group
238 if return_counts:
239 tl.store(idx_ptr + out_idx, i0, mask=ne_result_i1 & mask)
241 return total
244@libentry()
245@triton.jit
246def global_cumsum_consecutive_kernel(
247 ne_result_ptr: tl.tensor,
248 tile_sum_ptr: tl.tensor, # in
249 data_ptr: tl.tensor, # in
250 data_out_ptr: tl.tensor,
251 inverse_indices_ptr: tl.tensor,
252 idx_ptr: tl.tensor, # out
253 ctas_num: int,
254 global_ctas_num: int,
255 next_power_global_ctas_num: tl.constexpr,
256 num_tasks: int,
257 tiles_per_cta: int,
258 tile_size: tl.constexpr,
259 one_tile_per_cta: tl.constexpr,
260 return_inverse: tl.constexpr,
261 return_counts: tl.constexpr,
262):
263 pid = ext.program_id(0)
264 ctas_num = ext.num_programs(0)
265 if one_tile_per_cta:
266 global_cumsum_consecutive_impl(
267 pid,
268 0,
269 ne_result_ptr,
270 tile_sum_ptr,
271 data_ptr,
272 data_out_ptr,
273 inverse_indices_ptr,
274 idx_ptr,
275 ctas_num,
276 global_ctas_num,
277 next_power_global_ctas_num,
278 num_tasks,
279 tile_size,
280 return_inverse,
281 return_counts,
282 )
283 else:
284 total = tl.zeros([1], dtype=tl.int64)
285 for j in range(0, tiles_per_cta):
286 global_pid = pid + j * ctas_num
287 total = global_cumsum_consecutive_impl(
288 global_pid,
289 total,
290 ne_result_ptr,
291 tile_sum_ptr,
292 data_ptr,
293 data_out_ptr,
294 inverse_indices_ptr,
295 idx_ptr,
296 ctas_num,
297 global_ctas_num,
298 next_power_global_ctas_num,
299 num_tasks,
300 tile_size,
301 return_inverse,
302 return_counts,
303 )
306def simple_unique_consecutive_flat(
307 data: torch.Tensor,
308 return_inverse: bool,
309 return_counts: bool,
310):
311 """Handle small inputs with a single kernel launch."""
312 num_tasks = data.numel()
313 grid = (1, 1, 1)
315 # allocate tensors
316 data_out = torch.empty_like(data)
317 inverse_indices = (
318 torch.empty(num_tasks, dtype=torch.int64, device=data.device)
319 if return_inverse
320 else None
321 )
322 idx = (
323 torch.empty(num_tasks, dtype=torch.int64, device=data.device)
324 if return_counts
325 else None
326 )
327 unique_size = torch.empty([1], dtype=torch.int64, device=data.device)
329 # launch kernel
330 with torch_device_fn.device(data.device.index):
331 simple_unique_consecutive_flat_kernel[grid](
332 data,
333 data_out,
334 inverse_indices,
335 idx,
336 unique_size,
337 return_inverse,
338 return_counts,
339 num_tasks,
340 tile_size=triton.next_power_of_2(num_tasks),
341 num_warps=8,
342 )
344 out_size = unique_size.item()
345 counts = None
346 if return_counts:
347 idx = idx[:out_size]
348 counts = torch.empty_like(idx)
349 with torch_device_fn.device(data.device.index):
350 output_counts_kernel[grid](
351 idx,
352 num_tasks,
353 counts,
354 num_tasks=out_size,
355 tiles_per_cta=1,
356 tile_size=triton.next_power_of_2(out_size),
357 num_warps=8,
358 )
360 return data_out[:out_size], inverse_indices, counts
363def large_unique_consecutive_flat(
364 data: torch.Tensor,
365 return_inverse: bool,
366 return_counts: bool,
367):
368 """Handle larger inputs with multi-kernel approach."""
369 num_tasks = data.numel()
370 next_power_num_tasks = triton.next_power_of_2(num_tasks)
371 tile_size = min(8192, next_power_num_tasks)
372 global_ctas_num = triton.cdiv(num_tasks, tile_size)
374 if global_ctas_num <= 8192:
375 min_tile_size = 512 if global_ctas_num > 32 else 256
376 tile_size = max(
377 min_tile_size,
378 min(triton.next_power_of_2(global_ctas_num), next_power_num_tasks),
379 )
380 global_ctas_num = triton.cdiv(num_tasks, tile_size)
382 next_power_global_ctas_num = triton.next_power_of_2(global_ctas_num)
383 ctas_num = global_ctas_num if global_ctas_num < 32768 else 8192
384 tiles_per_cta = triton.cdiv(num_tasks, tile_size * ctas_num)
385 num_warps = 8 if tiles_per_cta == 1 else 32
386 grid = (ctas_num, 1, 1)
388 # allocate tensors
389 ne_result = torch.empty(num_tasks, dtype=torch.bool, device=data.device)
390 tile_sum = torch.empty(global_ctas_num, dtype=torch.int64, device=data.device)
391 data_out = torch.empty_like(data)
392 inverse_indices = (
393 torch.empty(num_tasks, dtype=torch.int64, device=data.device)
394 if return_inverse
395 else None
396 )
397 idx = (
398 torch.empty(num_tasks, dtype=torch.int64, device=data.device)
399 if return_counts
400 else None
401 )
403 # launch kernels
404 with torch_device_fn.device(data.device.index):
405 local_ne_consecutive_kernel[grid](
406 data,
407 ne_result,
408 tile_sum,
409 global_ctas_num,
410 num_tasks,
411 tiles_per_cta=tiles_per_cta,
412 tile_size=tile_size,
413 num_warps=num_warps,
414 )
415 global_cumsum_consecutive_kernel[grid](
416 ne_result,
417 tile_sum,
418 data,
419 data_out,
420 inverse_indices,
421 idx,
422 ctas_num,
423 global_ctas_num,
424 next_power_global_ctas_num,
425 num_tasks,
426 tiles_per_cta=tiles_per_cta,
427 tile_size=tile_size,
428 one_tile_per_cta=tiles_per_cta == 1,
429 return_inverse=return_inverse,
430 return_counts=return_counts,
431 num_warps=num_warps,
432 )
433 out_size = tile_sum[-1].item()
435 counts = None
436 if return_counts:
437 idx = idx[:out_size]
438 counts = torch.empty_like(idx)
439 output_counts_kernel[grid](
440 idx,
441 num_tasks,
442 counts,
443 out_size,
444 tiles_per_cta,
445 tile_size,
446 num_warps=num_warps,
447 )
449 return data_out[:out_size], inverse_indices, counts
452def unique_consecutive(
453 input: torch.Tensor,
454 return_inverse: bool = False,
455 return_counts: bool = False,
456 dim: int = None,
457):
458 """
459 Eliminates all but the first element from every consecutive group of equivalent elements.
461 Args:
462 input: the input tensor
463 return_inverse: Whether to return inverse indices
464 return_counts: Whether to return counts for each unique element
465 dim: the dimension to apply unique. If None, the unique of the flattened input is returned.
467 Returns:
468 (Tensor, Tensor (optional), Tensor (optional)): output, inverse_indices, counts
469 """
470 logger.debug("GEMS UNIQUE_CONSECUTIVE")
472 if dim is not None:
473 # For dim-wise unique_consecutive, fall back to PyTorch for now
474 # This could be implemented with a more complex kernel
475 return torch.unique_consecutive(
476 input,
477 return_inverse=return_inverse,
478 return_counts=return_counts,
479 dim=dim,
480 )
482 # Flatten input for the None dim case
483 flat_input = input.ravel()
484 num_tasks = flat_input.numel()
486 if num_tasks == 0:
487 # Handle empty input
488 output = torch.empty(0, dtype=input.dtype, device=input.device)
489 inverse_indices = (
490 torch.empty(0, dtype=torch.int64, device=input.device)
491 if return_inverse
492 else None
493 )
494 counts = (
495 torch.empty(0, dtype=torch.int64, device=input.device)
496 if return_counts
497 else None
498 )
499 return output, inverse_indices, counts
501 # Choose algorithm based on input size
502 if num_tasks <= 8192:
503 output, inverse_indices, counts = simple_unique_consecutive_flat(
504 flat_input, return_inverse, return_counts
505 )
506 else:
507 output, inverse_indices, counts = large_unique_consecutive_flat(
508 flat_input, return_inverse, return_counts
509 )
511 # Reshape inverse_indices to match input shape
512 if inverse_indices is not None:
513 inverse_indices = inverse_indices.view_as(input)
515 return output, inverse_indices, counts