Coverage for src/flag_gems/runtime/backend/_sunrise/ops/unique_consecutive.py: 0%
191 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import triton_lang_extension as ext
9from flag_gems.utils.libentry import libentry
11logger = logging.getLogger(__name__)
13_PTPU_SAFE_MAX_TILE_SIZE = 512
16@libentry()
17@triton.jit
18def simple_unique_consecutive_flat_kernel(
19 data_ptr: tl.tensor, # in
20 data_out_ptr: tl.tensor,
21 inverse_indices_ptr: tl.tensor,
22 idx_ptr: tl.tensor,
23 unique_size_ptr: tl.tensor, # out
24 return_inverse: tl.constexpr,
25 return_counts: tl.constexpr,
26 num_tasks: int,
27 tile_size: tl.constexpr,
28):
29 """Simple kernel for small inputs that fits in a single tile."""
30 i0 = tl.arange(0, tile_size)
31 mask = i0 < num_tasks
33 # load current and previous elements
34 a = tl.load(data_ptr + i0, mask=mask)
35 i0_prev = tl.where(i0 > 0, i0 - 1, 0)
36 b = tl.load(data_ptr + i0_prev, mask=mask)
38 # Check if element differs from previous (first element always starts a new group)
39 ne_result = tl.where(i0 > 0, a != b, 1)
40 cumsum = tl.cumsum(ne_result)
42 # cumsum gives us 1-indexed positions, we want 0-indexed
43 out_idx = cumsum - 1
45 # unique_size is the last cumsum value
46 unique_size_mask = i0 == num_tasks - 1
47 tl.store(unique_size_ptr + tl.zeros_like(i0), cumsum, mask=unique_size_mask)
49 # data_out: scatter unique values to their output positions
50 # Only write when this is the first element of a consecutive group
51 write_mask = ne_result.to(tl.int1) & mask
52 tl.store(data_out_ptr + out_idx, a, mask=write_mask)
54 # inverse_indices: each input position maps to its output position
55 if return_inverse:
56 tl.store(inverse_indices_ptr + i0, out_idx, mask=mask)
58 # idx: store the starting position of each unique group
59 if return_counts:
60 tl.store(idx_ptr + out_idx, i0, mask=write_mask)
63@triton.jit
64def output_counts_impl(
65 global_pid,
66 idx_ptr: tl.tensor,
67 origin_num_tasks: int, # in
68 counts_ptr: tl.tensor, # out
69 num_tasks: int,
70 tile_size: tl.constexpr,
71):
72 """Compute counts from idx positions."""
73 r = tl.arange(0, tile_size)
74 i0 = global_pid * tile_size + r
75 mask = i0 < num_tasks
77 # load idx
78 idx = tl.load(idx_ptr + i0, mask=mask)
80 # load idx_next
81 i0_next = i0 + 1
82 next_mask = i0_next < num_tasks
83 idx_next = tl.load(idx_ptr + i0_next, mask=next_mask)
85 # counts = next_idx - current_idx (or total - current_idx for last element)
86 counts = tl.where(i0_next < num_tasks, idx_next - idx, origin_num_tasks - idx)
88 # store counts
89 tl.store(counts_ptr + i0, counts, mask=mask)
92@libentry()
93@triton.jit
94def output_counts_kernel(
95 idx_ptr: tl.tensor,
96 origin_num_tasks: int, # in
97 counts_ptr: tl.tensor, # out
98 num_tasks: int,
99 tiles_per_cta: int,
100 tile_size: tl.constexpr,
101):
102 pid = ext.program_id(0)
103 ctas_num = ext.num_programs(0)
104 for j in range(0, tiles_per_cta):
105 global_pid = pid + j * ctas_num
106 output_counts_impl(
107 global_pid,
108 idx_ptr,
109 origin_num_tasks,
110 counts_ptr,
111 num_tasks,
112 tile_size,
113 )
116@triton.jit
117def local_ne_consecutive_impl(
118 global_pid,
119 data_ptr: tl.tensor, # in
120 ne_result_ptr: tl.tensor,
121 tile_sum_ptr: tl.tensor, # out
122 global_ctas_num: int,
123 num_tasks: int,
124 tile_size: tl.constexpr,
125):
126 """Compute ne_result (whether each element differs from previous) for a tile."""
127 r = tl.arange(0, tile_size)
128 i0 = global_pid * tile_size + r
129 mask = i0 < num_tasks
130 i0_prev = tl.where(i0 > 0, i0 - 1, 0)
132 # load current and previous
133 a = tl.load(data_ptr + i0, mask=mask)
134 b = tl.load(data_ptr + i0_prev, mask=mask)
136 # compute ne_result
137 ne_result = tl.where(i0 > 0, a != b, 1)
139 # store ne_result
140 tl.store(ne_result_ptr + i0, ne_result, mask=mask)
142 # store tile_sum
143 tile_sum = tl.sum(ne_result)
144 tile_sum_mask = global_pid < global_ctas_num
145 tl.store(tile_sum_ptr + global_pid, tile_sum, mask=tile_sum_mask)
148@libentry()
149@triton.jit
150def local_ne_consecutive_kernel(
151 data_ptr: tl.tensor, # in
152 ne_result_ptr: tl.tensor,
153 tile_sum_ptr: tl.tensor, # out
154 global_ctas_num: int,
155 num_tasks: int,
156 tiles_per_cta: int,
157 tile_size: tl.constexpr,
158):
159 pid = ext.program_id(0)
160 ctas_num = ext.num_programs(0)
161 for j in range(0, tiles_per_cta):
162 global_pid = pid + j * ctas_num
163 local_ne_consecutive_impl(
164 global_pid,
165 data_ptr,
166 ne_result_ptr,
167 tile_sum_ptr,
168 global_ctas_num,
169 num_tasks,
170 tile_size,
171 )
174@triton.jit
175def global_cumsum_consecutive_impl(
176 global_pid,
177 total,
178 ne_result_ptr: tl.tensor,
179 tile_sum_ptr: tl.tensor, # in
180 data_ptr: tl.tensor, # in
181 data_out_ptr: tl.tensor,
182 inverse_indices_ptr: tl.tensor,
183 idx_ptr: tl.tensor, # out
184 ctas_num: tl.constexpr,
185 global_ctas_num: int,
186 next_power_global_ctas_num: tl.constexpr,
187 num_tasks: int,
188 tile_size: tl.constexpr,
189 return_inverse: tl.constexpr,
190 return_counts: tl.constexpr,
191):
192 """Compute global cumsum and scatter outputs."""
193 offset = global_pid * tile_size
194 r = tl.arange(0, tile_size)
195 i0 = offset + r
196 mask = i0 < num_tasks
198 # load data
199 data = tl.load(data_ptr + i0, mask=mask)
201 # load tile_sum for previous tiles
202 p = tl.arange(0, next_power_global_ctas_num)
203 pre_tile_sum_mask = (
204 (p >= global_pid - ctas_num)
205 & (p < global_pid)
206 & (p >= 0)
207 & (p < global_ctas_num)
208 )
209 pre_tile_sum = tl.load(tile_sum_ptr + p, mask=pre_tile_sum_mask, other=0)
211 # cumsum within tile
212 total += tl.sum(pre_tile_sum)
213 ne_result = tl.load(ne_result_ptr + i0, mask=mask)
214 ne_result_i1 = ne_result.to(tl.int1)
215 ne_result_i32 = ne_result.to(tl.int32)
216 cumsum = tl.cumsum(ne_result_i32)
218 # Store final tile sum for the last tile
219 if global_pid == global_ctas_num - 1:
220 last_tile_sum_mask = i0 == num_tasks - 1
221 final_tile_sum = tl.where(last_tile_sum_mask, total + cumsum, cumsum)
222 tl.store(
223 tile_sum_ptr + global_pid + tl.zeros_like(r),
224 final_tile_sum,
225 mask=last_tile_sum_mask,
226 )
227 cumsum += total
229 # output index (0-indexed)
230 out_idx = cumsum - 1
232 # data_out: scatter unique values (only first element of each consecutive group)
233 tl.store(data_out_ptr + out_idx, data, mask=ne_result_i1 & mask)
235 # inverse_indices: each input position maps to its output index
236 if return_inverse:
237 tl.store(inverse_indices_ptr + i0, out_idx, mask=mask)
239 # idx: store starting position of each unique group
240 if return_counts:
241 tl.store(idx_ptr + out_idx, i0, mask=ne_result_i1 & mask)
243 return total
246@libentry()
247@triton.jit
248def global_cumsum_consecutive_kernel(
249 ne_result_ptr: tl.tensor,
250 tile_sum_ptr: tl.tensor, # in
251 data_ptr: tl.tensor, # in
252 data_out_ptr: tl.tensor,
253 inverse_indices_ptr: tl.tensor,
254 idx_ptr: tl.tensor, # out
255 ctas_num: int,
256 global_ctas_num: int,
257 next_power_global_ctas_num: tl.constexpr,
258 num_tasks: int,
259 tiles_per_cta: int,
260 tile_size: tl.constexpr,
261 one_tile_per_cta: tl.constexpr,
262 return_inverse: tl.constexpr,
263 return_counts: tl.constexpr,
264):
265 pid = ext.program_id(0)
266 ctas_num = ext.num_programs(0)
267 if one_tile_per_cta:
268 global_cumsum_consecutive_impl(
269 pid,
270 0,
271 ne_result_ptr,
272 tile_sum_ptr,
273 data_ptr,
274 data_out_ptr,
275 inverse_indices_ptr,
276 idx_ptr,
277 ctas_num,
278 global_ctas_num,
279 next_power_global_ctas_num,
280 num_tasks,
281 tile_size,
282 return_inverse,
283 return_counts,
284 )
285 else:
286 total = tl.zeros([1], dtype=tl.int64)
287 for j in range(0, tiles_per_cta):
288 global_pid = pid + j * ctas_num
289 total = global_cumsum_consecutive_impl(
290 global_pid,
291 total,
292 ne_result_ptr,
293 tile_sum_ptr,
294 data_ptr,
295 data_out_ptr,
296 inverse_indices_ptr,
297 idx_ptr,
298 ctas_num,
299 global_ctas_num,
300 next_power_global_ctas_num,
301 num_tasks,
302 tile_size,
303 return_inverse,
304 return_counts,
305 )
308def simple_unique_consecutive_flat(
309 data: torch.Tensor,
310 return_inverse: bool,
311 return_counts: bool,
312):
313 """Handle small inputs with a single kernel launch."""
314 num_tasks = data.numel()
315 grid = (1, 1, 1)
317 # allocate tensors
318 data_out = torch.empty_like(data)
319 inverse_indices = (
320 torch.empty(num_tasks, dtype=torch.int64, device=data.device)
321 if return_inverse
322 else None
323 )
324 idx = (
325 torch.empty(num_tasks, dtype=torch.int64, device=data.device)
326 if return_counts
327 else None
328 )
329 unique_size = torch.empty([1], dtype=torch.int64, device=data.device)
331 # launch kernel
332 with torch_device_fn.device(data.device.index):
333 simple_unique_consecutive_flat_kernel[grid](
334 data,
335 data_out,
336 inverse_indices,
337 idx,
338 unique_size,
339 return_inverse,
340 return_counts,
341 num_tasks,
342 tile_size=triton.next_power_of_2(num_tasks),
343 num_warps=8,
344 )
346 out_size = unique_size.item()
347 counts = None
348 if return_counts:
349 idx = idx[:out_size]
350 counts = torch.empty_like(idx)
351 with torch_device_fn.device(data.device.index):
352 output_counts_kernel[grid](
353 idx,
354 num_tasks,
355 counts,
356 num_tasks=out_size,
357 tiles_per_cta=1,
358 tile_size=triton.next_power_of_2(out_size),
359 num_warps=8,
360 )
362 return data_out[:out_size], inverse_indices, counts
365def large_unique_consecutive_flat(
366 data: torch.Tensor,
367 return_inverse: bool,
368 return_counts: bool,
369):
370 """Handle larger inputs with multi-kernel approach."""
371 num_tasks = data.numel()
373 if data.device.type == "ptpu":
374 # Sunrise/PTPU only changes the unstable large-input organization path.
375 next_power_num_tasks = triton.next_power_of_2(num_tasks)
376 tile_size = min(_PTPU_SAFE_MAX_TILE_SIZE, next_power_num_tasks)
377 global_ctas_num = triton.cdiv(num_tasks, tile_size)
378 ctas_num = global_ctas_num if global_ctas_num < 32768 else 8192
379 tiles_per_cta = triton.cdiv(num_tasks, tile_size * ctas_num)
380 num_warps = 8 if tiles_per_cta == 1 else 32
381 grid = (ctas_num, 1, 1)
383 ne_result = torch.empty(num_tasks, dtype=torch.bool, device=data.device)
384 tile_sum = torch.empty(global_ctas_num, dtype=torch.int64, device=data.device)
386 with torch_device_fn.device(data.device.index):
387 local_ne_consecutive_kernel[grid](
388 data,
389 ne_result,
390 tile_sum,
391 global_ctas_num,
392 num_tasks,
393 tiles_per_cta=tiles_per_cta,
394 tile_size=tile_size,
395 num_warps=num_warps,
396 )
398 starts = torch.nonzero(ne_result, as_tuple=False).flatten()
399 output = torch.index_select(data, 0, starts)
401 inverse_indices = None
402 if return_inverse:
403 inverse_indices = torch.cumsum(ne_result.to(torch.int64), dim=0) - 1
405 counts = None
406 if return_counts:
407 tail = starts.new_tensor([num_tasks]) - starts[-1:]
408 counts = torch.cat((starts[1:] - starts[:-1], tail))
410 return output, inverse_indices, counts
412 next_power_num_tasks = triton.next_power_of_2(num_tasks)
413 tile_size = min(8192, next_power_num_tasks)
414 global_ctas_num = triton.cdiv(num_tasks, tile_size)
416 if global_ctas_num <= 8192:
417 min_tile_size = 512 if global_ctas_num > 32 else 256
418 tile_size = max(
419 min_tile_size,
420 min(triton.next_power_of_2(global_ctas_num), next_power_num_tasks),
421 )
422 global_ctas_num = triton.cdiv(num_tasks, tile_size)
424 next_power_global_ctas_num = triton.next_power_of_2(global_ctas_num)
425 ctas_num = global_ctas_num if global_ctas_num < 32768 else 8192
426 tiles_per_cta = triton.cdiv(num_tasks, tile_size * ctas_num)
427 num_warps = 8 if tiles_per_cta == 1 else 32
428 grid = (ctas_num, 1, 1)
430 # allocate tensors
431 ne_result = torch.empty(num_tasks, dtype=torch.bool, device=data.device)
432 tile_sum = torch.empty(global_ctas_num, dtype=torch.int64, device=data.device)
433 data_out = torch.empty_like(data)
434 inverse_indices = (
435 torch.empty(num_tasks, dtype=torch.int64, device=data.device)
436 if return_inverse
437 else None
438 )
439 idx = (
440 torch.empty(num_tasks, dtype=torch.int64, device=data.device)
441 if return_counts
442 else None
443 )
445 # launch kernels
446 with torch_device_fn.device(data.device.index):
447 local_ne_consecutive_kernel[grid](
448 data,
449 ne_result,
450 tile_sum,
451 global_ctas_num,
452 num_tasks,
453 tiles_per_cta=tiles_per_cta,
454 tile_size=tile_size,
455 num_warps=num_warps,
456 )
457 global_cumsum_consecutive_kernel[grid](
458 ne_result,
459 tile_sum,
460 data,
461 data_out,
462 inverse_indices,
463 idx,
464 ctas_num,
465 global_ctas_num,
466 next_power_global_ctas_num,
467 num_tasks,
468 tiles_per_cta=tiles_per_cta,
469 tile_size=tile_size,
470 one_tile_per_cta=tiles_per_cta == 1,
471 return_inverse=return_inverse,
472 return_counts=return_counts,
473 num_warps=num_warps,
474 )
475 out_size = tile_sum[-1].item()
477 counts = None
478 if return_counts:
479 idx = idx[:out_size]
480 counts = torch.empty_like(idx)
481 output_counts_kernel[grid](
482 idx,
483 num_tasks,
484 counts,
485 out_size,
486 tiles_per_cta,
487 tile_size,
488 num_warps=num_warps,
489 )
491 return data_out[:out_size], inverse_indices, counts
494def unique_consecutive(
495 input: torch.Tensor,
496 return_inverse: bool = False,
497 return_counts: bool = False,
498 dim: int = None,
499):
500 """
501 Eliminates all but the first element from every consecutive group of equivalent elements.
503 Args:
504 input: the input tensor
505 return_inverse: Whether to return inverse indices
506 return_counts: Whether to return counts for each unique element
507 dim: the dimension to apply unique. If None, the unique of the flattened input is returned.
509 Returns:
510 (Tensor, Tensor (optional), Tensor (optional)): output, inverse_indices, counts
511 """
512 logger.debug("GEMS UNIQUE_CONSECUTIVE")
514 if dim is not None:
515 # For dim-wise unique_consecutive, fall back to PyTorch for now
516 # This could be implemented with a more complex kernel
517 return torch.unique_consecutive(
518 input,
519 return_inverse=return_inverse,
520 return_counts=return_counts,
521 dim=dim,
522 )
524 # Flatten input for the None dim case
525 flat_input = input.ravel()
526 num_tasks = flat_input.numel()
528 if num_tasks == 0:
529 # Handle empty input
530 output = torch.empty(0, dtype=input.dtype, device=input.device)
531 inverse_indices = (
532 torch.empty(0, dtype=torch.int64, device=input.device)
533 if return_inverse
534 else None
535 )
536 counts = (
537 torch.empty(0, dtype=torch.int64, device=input.device)
538 if return_counts
539 else None
540 )
541 return output, inverse_indices, counts
543 # Choose algorithm based on input size
544 if num_tasks <= 8192:
545 output, inverse_indices, counts = simple_unique_consecutive_flat(
546 flat_input, return_inverse, return_counts
547 )
548 else:
549 output, inverse_indices, counts = large_unique_consecutive_flat(
550 flat_input, return_inverse, return_counts
551 )
553 # Reshape inverse_indices to match input shape
554 if inverse_indices is not None:
555 inverse_indices = inverse_indices.view_as(input)
557 return output, inverse_indices, counts