Coverage for src/flag_gems/runtime/backend/_cambricon/ops/unique.py: 0%
94 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils.libentry import libentry
10from ..utils import TOTAL_CORE_NUM
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15@libentry()
16@triton.autotune(
17 configs=[
18 triton.Config({"BLOCK_SIZE": 2**k}, num_stages=s, num_warps=1)
19 for k in range(11, 17, 1)
20 for s in [1, 3]
21 ],
22 key=[
23 "tile_size",
24 ],
25)
26@triton.jit
27def get_ne_kernel(
28 sorted_data_ptr: tl.tensor,
29 sorted_data_2: tl.tensor,
30 ne_out_ptr: tl.tensor,
31 tile_size: tl.constexpr,
32 BLOCK_SIZE: tl.constexpr,
33):
34 pid = tl.program_id(axis=0)
35 num_jobs = tl.num_programs(axis=0)
36 split_n = (tile_size + num_jobs - 1) // num_jobs
37 start_offset = pid * split_n
38 i0 = tl.arange(0, BLOCK_SIZE)
40 for i in range(0, split_n, BLOCK_SIZE):
41 offset = start_offset + i + i0
42 mask = offset < tile_size
43 a = tl.load(sorted_data_ptr + offset, mask=mask)
44 b = tl.load(sorted_data_2 + offset, mask=mask)
45 # ne
46 ne_result = (offset > 0) * (a != b)
47 tl.store(ne_out_ptr + offset, ne_result, mask=mask)
50@libentry()
51@triton.autotune(
52 configs=[
53 triton.Config({"BLOCK_SIZE": k}, num_stages=s, num_warps=1)
54 for k in [32, 256, 1024, 2048, 4096]
55 for s in [1, 3]
56 ],
57 key=[
58 "tile_size",
59 ],
60)
61@triton.jit
62def get_unique_out_kernel(
63 sorted_data_ptr: tl.tensor,
64 sorted_indices_ptr: tl.tensor, # in
65 ne_result_ptr: tl.tensor,
66 pre_sum_ptr: tl.tensor,
67 idx_ptr: tl.tensor,
68 data_out_ptr: tl.tensor,
69 inverse_indices_ptr: tl.tensor,
70 return_inverse: tl.constexpr,
71 return_counts: tl.constexpr,
72 tile_size: tl.constexpr,
73 BLOCK_SIZE: tl.constexpr,
74):
75 pid = tl.program_id(axis=0)
76 num_jobs = tl.num_programs(axis=0)
78 split_n = (tile_size + num_jobs - 1) // num_jobs
79 start_offset = pid * split_n
80 i0 = tl.arange(0, BLOCK_SIZE)
82 for i in range(0, split_n, BLOCK_SIZE):
83 offset = start_offset + i + i0
84 mask = offset < tile_size
85 sorted_data = tl.load(sorted_data_ptr + offset, mask=mask)
86 pre_sum_data = tl.load(pre_sum_ptr + offset, mask=mask)
88 # data_out: scatter_(to=pre_sum_data, sorted_data)
89 tl.store(data_out_ptr + pre_sum_data, sorted_data, mask=mask)
91 # inverse_indices: scatter_(to=sorted_indices, pre_sum_data)
92 if return_inverse:
93 sorted_indices = tl.load(sorted_indices_ptr + offset, mask=mask)
94 tl.store(inverse_indices_ptr + sorted_indices, pre_sum_data, mask=mask)
96 # idx: mark positions of unique values in idx_ptr
97 if return_counts:
98 ne_result = tl.load(ne_result_ptr + offset, mask=mask)
99 idx_mask = ((offset == 0) | ne_result.to(tl.int1)) & mask
100 tl.store(idx_ptr + pre_sum_data, offset, mask=idx_mask)
103@triton.autotune(
104 configs=[
105 triton.Config({"BLOCK_SIZE": 2**k}, num_stages=s, num_warps=1)
106 for k in range(7, 14, 1)
107 for s in [1, 3]
108 ],
109 key=[
110 "tile_size",
111 ],
112)
113@triton.jit
114def get_output_counts_kernel(
115 idx_ptr: tl.tensor,
116 idx_next_ptr: tl.tensor,
117 counts_ptr: tl.tensor, # out
118 tile_size: tl.constexpr,
119 BLOCK_SIZE: tl.constexpr,
120):
121 pid = tl.program_id(axis=0)
122 num_jobs = tl.num_programs(axis=0)
123 split_n = (tile_size + num_jobs - 1) // num_jobs
124 start_offset = pid * split_n
126 i0 = tl.arange(0, BLOCK_SIZE)
128 for i in range(0, split_n, BLOCK_SIZE):
129 offset = start_offset + i + i0
130 mask = offset < tile_size
131 # load idx
132 idx = tl.load(idx_ptr + offset, mask=mask)
133 # load idx_next
134 idx_next = tl.load(idx_next_ptr + offset, mask=mask)
135 # diff
136 counts = idx_next - idx
137 # store counts
138 tl.store(counts_ptr + offset, counts, mask=mask)
141def sorted_unique_flat(
142 sorted_data: torch.Tensor,
143 sorted_indices: torch.Tensor,
144 return_inverse: bool,
145 return_counts: bool,
146):
147 num_tasks = sorted_data.numel()
148 grid = lambda meta: (
149 min(triton.cdiv(num_tasks, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),
150 )
152 # allocate tensor
153 ne_out = torch.empty_like(sorted_data, dtype=torch.bool)
154 data_out = torch.empty_like(sorted_data)
155 if return_inverse:
156 inverse_indices = torch.empty_like(sorted_data, dtype=torch.int64)
157 else:
158 inverse_indices = None
159 if return_counts:
160 idx = torch.empty_like(sorted_data, dtype=torch.int64)
161 else:
162 idx = None
163 sorted_data_2 = torch.empty_like(sorted_data)
164 sorted_data_2[1:] = sorted_data[:-1]
166 # launch kernel
167 with torch_device_fn.device(sorted_data.device.index):
168 get_ne_kernel[grid](
169 sorted_data,
170 sorted_data_2,
171 ne_out,
172 tile_size=num_tasks,
173 )
174 pre_sum = ne_out.cumsum(axis=0)
175 get_unique_out_kernel[grid](
176 sorted_data,
177 sorted_indices,
178 ne_out,
179 pre_sum,
180 idx,
181 data_out,
182 inverse_indices,
183 return_inverse,
184 return_counts,
185 tile_size=num_tasks,
186 )
188 out_size = pre_sum[-1].item() + 1
189 counts = None
190 if return_counts:
191 idx = idx[:out_size]
192 sorted_data_size = len(sorted_data)
193 idx_next = torch.roll(idx, -1)
194 idx_next[-1] = sorted_data_size
195 counts = torch.zeros_like(idx)
196 with torch_device_fn.device(sorted_data.device.index):
197 get_output_counts_kernel[grid](
198 idx,
199 idx_next,
200 counts, # out
201 tile_size=out_size,
202 )
203 return data_out[:out_size], inverse_indices, counts
206def _unique2(
207 in0: torch.Tensor,
208 sorted: bool = True,
209 return_inverse: bool = False,
210 return_counts: bool = False,
211):
212 logger.debug("GEMS_CAMBRICON _UNIQUE2")
213 sorted_data, sorted_indices = torch.sort(in0.ravel(), stable=False)
214 data_out, inverse_indices, counts = sorted_unique_flat(
215 sorted_data, sorted_indices, return_inverse, return_counts
216 )
217 return (
218 data_out,
219 inverse_indices if inverse_indices is None else inverse_indices.view_as(in0),
220 counts,
221 )