Coverage for src/flag_gems/runtime/backend/_sunrise/fused/bincount.py: 0%
228 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
10logger = logging.getLogger(__name__)
13def _select_params(n):
14 if n <= 256:
15 return 256, 2
16 if n <= 1024:
17 return 256, 4
18 if n <= 4096:
19 return 512, 4
20 return 1024, 4
23def _estimate_output_size(n, minlength):
24 estimate = max(8192, n * 4, minlength)
25 estimate = min(estimate, 65536)
26 return max(estimate, minlength)
29def _select_max_block_size(n):
30 return triton.next_power_of_2(max(1, math.ceil(math.sqrt(n))))
33def _select_bins_block(output_size):
34 return min(128, triton.next_power_of_2(max(1, output_size)))
37@triton.jit
38def fused_max_bincount_kernel(
39 input_ptr,
40 max_ptr,
41 output_ptr,
42 n_elements,
43 output_size,
44 BLOCK_SIZE: tl.constexpr,
45):
46 pid = tl.program_id(0)
47 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
48 mask = offsets < n_elements
49 vals = tl.load(input_ptr + offsets, mask=mask, other=0)
51 local_max = tl.max(vals, axis=0)
52 tl.atomic_max(max_ptr, local_max)
54 safe_mask = mask & (vals < output_size)
55 tl.atomic_add(output_ptr + vals, 1, mask=safe_mask)
58@triton.jit
59def bincount_kernel(
60 input_ptr,
61 output_ptr,
62 n_elements,
63 BLOCK_SIZE: tl.constexpr,
64):
65 pid = tl.program_id(0)
66 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
67 mask = offsets < n_elements
68 vals = tl.load(input_ptr + offsets, mask=mask, other=0)
69 tl.atomic_add(output_ptr + vals, 1, mask=mask)
72@triton.jit
73def fused_max_bincount_weights_fp32_kernel(
74 input_ptr,
75 weights_ptr,
76 max_ptr,
77 output_ptr,
78 n_elements,
79 output_size,
80 BLOCK_SIZE: tl.constexpr,
81):
82 pid = tl.program_id(0)
83 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
84 mask = offsets < n_elements
85 vals = tl.load(input_ptr + offsets, mask=mask, other=0)
86 w = tl.load(weights_ptr + offsets, mask=mask, other=0.0)
87 w_fp32 = w.to(tl.float32)
89 local_max = tl.max(vals, axis=0)
90 tl.atomic_max(max_ptr, local_max)
92 safe_mask = mask & (vals < output_size)
93 tl.atomic_add(output_ptr + vals, w_fp32, mask=safe_mask)
96@triton.jit
97def fused_max_bincount_weights_fp64_kernel(
98 input_ptr,
99 weights_ptr,
100 max_ptr,
101 output_ptr,
102 n_elements,
103 output_size,
104 BLOCK_SIZE: tl.constexpr,
105):
106 pid = tl.program_id(0)
107 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
108 mask = offsets < n_elements
109 vals = tl.load(input_ptr + offsets, mask=mask, other=0)
110 w = tl.load(weights_ptr + offsets, mask=mask, other=0.0)
111 w_fp64 = w.to(tl.float64)
113 local_max = tl.max(vals, axis=0)
114 tl.atomic_max(max_ptr, local_max)
116 safe_mask = mask & (vals < output_size)
117 tl.atomic_add(output_ptr + vals, w_fp64, mask=safe_mask)
120@triton.jit
121def bincount_max_kernel_1(
122 input_ptr,
123 mid_ptr,
124 n_elements,
125 BLOCK_SIZE: tl.constexpr,
126):
127 pid = tl.program_id(0)
128 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
129 mask = offsets < n_elements
130 vals = tl.load(input_ptr + offsets, mask=mask, other=0)
131 local_max = tl.max(vals, axis=0)
132 tl.store(mid_ptr + pid, local_max)
135@triton.jit
136def bincount_max_kernel_2(
137 mid_ptr,
138 max_ptr,
139 mid_size,
140 BLOCK_MID: tl.constexpr,
141):
142 offsets = tl.arange(0, BLOCK_MID)
143 mask = offsets < mid_size
144 mid_vals = tl.load(mid_ptr + offsets, mask=mask, other=0)
145 max_val = tl.max(mid_vals, axis=0)
146 tl.store(max_ptr, max_val)
149@triton.jit
150def bincount_weights_fp32_kernel(
151 input_ptr,
152 weights_ptr,
153 output_ptr,
154 n_elements,
155 BLOCK_SIZE: tl.constexpr,
156):
157 pid = tl.program_id(0)
158 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
159 mask = offsets < n_elements
160 vals = tl.load(input_ptr + offsets, mask=mask, other=0)
161 w = tl.load(weights_ptr + offsets, mask=mask, other=0.0)
162 w_fp32 = w.to(tl.float32)
163 tl.atomic_add(output_ptr + vals, w_fp32, mask=mask)
166@triton.jit
167def bincount_weights_fp64_kernel(
168 input_ptr,
169 weights_ptr,
170 output_ptr,
171 n_elements,
172 BLOCK_SIZE: tl.constexpr,
173):
174 pid = tl.program_id(0)
175 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
176 mask = offsets < n_elements
177 vals = tl.load(input_ptr + offsets, mask=mask, other=0)
178 w = tl.load(weights_ptr + offsets, mask=mask, other=0.0)
179 w_fp64 = w.to(tl.float64)
180 tl.atomic_add(output_ptr + vals, w_fp64, mask=mask)
183@triton.jit
184def bincount_partial_int64_kernel(
185 input_ptr,
186 partial_ptr,
187 n_elements,
188 output_size,
189 BLOCK_SIZE: tl.constexpr,
190 BLOCK_BINS: tl.constexpr,
191 TILE_INPUT: tl.constexpr,
192):
193 pid_block = tl.program_id(0)
194 pid_bin = tl.program_id(1)
196 block_start = pid_block * BLOCK_SIZE
197 bin_offsets = pid_bin * BLOCK_BINS + tl.arange(0, BLOCK_BINS)
198 bin_mask = bin_offsets < output_size
199 acc = tl.zeros([BLOCK_BINS], dtype=tl.int32)
201 for tile_start in range(0, BLOCK_SIZE, TILE_INPUT):
202 input_offsets = block_start + tile_start + tl.arange(0, TILE_INPUT)
203 input_mask = input_offsets < n_elements
204 vals = tl.load(input_ptr + input_offsets, mask=input_mask, other=0)
205 bins = bin_offsets.to(vals.dtype)
206 matches = (
207 bin_mask[:, None] & input_mask[None, :] & (bins[:, None] == vals[None, :])
208 )
209 acc += tl.sum(matches.to(tl.int32), axis=1)
211 partial_offsets = pid_block * output_size + bin_offsets
212 tl.store(partial_ptr + partial_offsets, acc.to(tl.int64), mask=bin_mask)
215@triton.jit
216def bincount_partial_weights_fp64_kernel(
217 input_ptr,
218 weights_ptr,
219 partial_ptr,
220 n_elements,
221 output_size,
222 BLOCK_SIZE: tl.constexpr,
223 BLOCK_BINS: tl.constexpr,
224 TILE_INPUT: tl.constexpr,
225):
226 pid_block = tl.program_id(0)
227 pid_bin = tl.program_id(1)
229 block_start = pid_block * BLOCK_SIZE
230 bin_offsets = pid_bin * BLOCK_BINS + tl.arange(0, BLOCK_BINS)
231 bin_mask = bin_offsets < output_size
232 acc = tl.zeros([BLOCK_BINS], dtype=tl.float64)
234 for tile_start in range(0, BLOCK_SIZE, TILE_INPUT):
235 input_offsets = block_start + tile_start + tl.arange(0, TILE_INPUT)
236 input_mask = input_offsets < n_elements
237 vals = tl.load(input_ptr + input_offsets, mask=input_mask, other=0)
238 w = tl.load(weights_ptr + input_offsets, mask=input_mask, other=0.0).to(
239 tl.float64
240 )
241 bins = bin_offsets.to(vals.dtype)
242 matches = (
243 bin_mask[:, None] & input_mask[None, :] & (bins[:, None] == vals[None, :])
244 )
245 acc += tl.sum(tl.where(matches, w[None, :], 0.0), axis=1)
247 partial_offsets = pid_block * output_size + bin_offsets
248 tl.store(partial_ptr + partial_offsets, acc, mask=bin_mask)
251@triton.jit
252def bincount_reduce_partial_kernel(
253 partial_ptr,
254 output_ptr,
255 num_partials,
256 output_size,
257 BLOCK_PARTIAL: tl.constexpr,
258 BLOCK_BINS: tl.constexpr,
259):
260 pid_bin = tl.program_id(0)
261 bin_offsets = pid_bin * BLOCK_BINS + tl.arange(0, BLOCK_BINS)
262 bin_mask = bin_offsets < output_size
263 acc = tl.zeros([BLOCK_BINS], dtype=output_ptr.dtype.element_ty)
265 for partial_start in range(0, num_partials, BLOCK_PARTIAL):
266 partial_rows = partial_start + tl.arange(0, BLOCK_PARTIAL)
267 partial_ptrs = (
268 partial_ptr + partial_rows[:, None] * output_size + bin_offsets[None, :]
269 )
270 partial_mask = (partial_rows[:, None] < num_partials) & bin_mask[None, :]
271 partial_vals = tl.load(partial_ptrs, mask=partial_mask, other=0)
272 acc += tl.sum(partial_vals, axis=0)
274 tl.store(output_ptr + bin_offsets, acc, mask=bin_mask)
277def _compute_output_size(input_contig, n, minlength):
278 max_block_size = _select_max_block_size(n)
279 mid_size = triton.cdiv(n, max_block_size)
280 block_mid = triton.next_power_of_2(mid_size)
282 mid = torch.empty((mid_size,), dtype=input_contig.dtype, device=input_contig.device)
283 max_tensor = torch.empty([], dtype=input_contig.dtype, device=input_contig.device)
285 with torch_device_fn.device(input_contig.device):
286 bincount_max_kernel_1[(mid_size, 1, 1)](
287 input_contig,
288 mid,
289 n,
290 BLOCK_SIZE=max_block_size,
291 )
292 bincount_max_kernel_2[(1, 1, 1)](
293 mid,
294 max_tensor,
295 mid_size,
296 BLOCK_MID=block_mid,
297 )
299 return max(int(max_tensor.item()) + 1, minlength)
302def _bincount_atomic_launch(
303 input_contig,
304 weights_contig,
305 n,
306 output_size,
307 BLOCK_SIZE,
308 num_warps,
309):
310 output = torch.zeros(output_size, dtype=torch.float32, device=input_contig.device)
311 grid = (triton.cdiv(n, BLOCK_SIZE),)
313 with torch_device_fn.device(input_contig.device):
314 bincount_weights_fp32_kernel[grid](
315 input_contig,
316 weights_contig,
317 output,
318 n,
319 BLOCK_SIZE=BLOCK_SIZE,
320 num_warps=num_warps,
321 )
323 return output
326def _fused_bincount_atomic_launch(
327 input_contig,
328 weights_contig,
329 n,
330 pre_size,
331 minlength,
332 out_dtype,
333 grid,
334 BLOCK_SIZE,
335 num_warps,
336):
337 max_tensor = torch.zeros(1, dtype=input_contig.dtype, device=input_contig.device)
338 is_fp64 = out_dtype == torch.float64
339 compute_dtype = torch.float64 if is_fp64 else torch.float32
340 output = torch.zeros(pre_size, dtype=compute_dtype, device=input_contig.device)
342 with torch_device_fn.device(input_contig.device):
343 if is_fp64:
344 fused_max_bincount_weights_fp64_kernel[grid](
345 input_contig,
346 weights_contig,
347 max_tensor,
348 output,
349 n,
350 pre_size,
351 BLOCK_SIZE=BLOCK_SIZE,
352 num_warps=num_warps,
353 )
354 else:
355 fused_max_bincount_weights_fp32_kernel[grid](
356 input_contig,
357 weights_contig,
358 max_tensor,
359 output,
360 n,
361 pre_size,
362 BLOCK_SIZE=BLOCK_SIZE,
363 num_warps=num_warps,
364 )
366 max_val = int(max_tensor.item())
367 needed_size = max(max_val + 1, minlength)
369 if needed_size <= pre_size:
370 return output[:needed_size]
372 if is_fp64:
373 output = torch.zeros(
374 needed_size, dtype=torch.float64, device=input_contig.device
375 )
376 else:
377 output = torch.zeros(
378 needed_size, dtype=torch.float32, device=input_contig.device
379 )
381 with torch_device_fn.device(input_contig.device):
382 if is_fp64:
383 bincount_weights_fp64_kernel[grid](
384 input_contig,
385 weights_contig,
386 output,
387 n,
388 BLOCK_SIZE=BLOCK_SIZE,
389 num_warps=num_warps,
390 )
391 else:
392 bincount_weights_fp32_kernel[grid](
393 input_contig,
394 weights_contig,
395 output,
396 n,
397 BLOCK_SIZE=BLOCK_SIZE,
398 num_warps=num_warps,
399 )
401 return output
404def _bincount_no_atomic_launch(
405 input_contig,
406 weights_contig,
407 n,
408 output_size,
409 out_dtype,
410 BLOCK_SIZE,
411 num_warps,
412):
413 block_bins = _select_bins_block(output_size)
414 tile_input = min(64, BLOCK_SIZE)
415 num_partials = triton.cdiv(n, BLOCK_SIZE)
416 grid = (num_partials, triton.cdiv(output_size, block_bins))
418 partial = torch.empty(
419 (num_partials, output_size), dtype=out_dtype, device=input_contig.device
420 )
421 output = torch.empty(output_size, dtype=out_dtype, device=input_contig.device)
423 with torch_device_fn.device(input_contig.device):
424 if weights_contig is None:
425 bincount_partial_int64_kernel[grid](
426 input_contig,
427 partial,
428 n,
429 output_size,
430 BLOCK_SIZE=BLOCK_SIZE,
431 BLOCK_BINS=block_bins,
432 TILE_INPUT=tile_input,
433 num_warps=num_warps,
434 )
435 else:
436 bincount_partial_weights_fp64_kernel[grid](
437 input_contig,
438 weights_contig,
439 partial,
440 n,
441 output_size,
442 BLOCK_SIZE=BLOCK_SIZE,
443 BLOCK_BINS=block_bins,
444 TILE_INPUT=tile_input,
445 num_warps=num_warps,
446 )
448 bincount_reduce_partial_kernel[(triton.cdiv(output_size, block_bins), 1, 1)](
449 partial,
450 output,
451 num_partials,
452 output_size,
453 BLOCK_PARTIAL=8,
454 BLOCK_BINS=block_bins,
455 num_warps=4,
456 )
458 return output
461def _supports_atomic_accumulate(out_dtype):
462 return out_dtype not in (torch.int64, torch.float64)
465def _supports_fused_atomic(input_dtype, out_dtype):
466 return _supports_atomic_accumulate(out_dtype) and input_dtype == torch.int32
469def bincount(input, weights=None, minlength=0):
470 logger.debug("GEMS BINCOUNT")
472 assert input.dim() == 1, "input must be a 1-D tensor"
473 assert minlength >= 0, "minlength must be non-negative"
475 if weights is not None:
476 assert weights.shape == input.shape, "weights must have the same shape as input"
478 n = input.numel()
480 if n == 0:
481 if weights is not None:
482 return torch.zeros(minlength, dtype=weights.dtype, device=input.device)
483 return torch.zeros(minlength, dtype=torch.int64, device=input.device)
485 input_contig = input.contiguous()
486 weights_contig = weights.contiguous() if weights is not None else None
488 if weights is not None and weights.dtype == torch.float64:
489 return torch.bincount(
490 input_contig.cpu(),
491 weights=weights_contig.cpu(),
492 minlength=minlength,
493 ).to(input.device)
495 BLOCK_SIZE, num_warps = _select_params(n)
496 grid = (triton.cdiv(n, BLOCK_SIZE),)
498 out_dtype = weights.dtype if weights is not None else torch.int64
500 if _supports_fused_atomic(input_contig.dtype, out_dtype):
501 pre_size = _estimate_output_size(n, minlength)
502 output = _fused_bincount_atomic_launch(
503 input_contig,
504 weights_contig,
505 n,
506 pre_size,
507 minlength,
508 out_dtype,
509 grid,
510 BLOCK_SIZE,
511 num_warps,
512 )
513 elif _supports_atomic_accumulate(out_dtype):
514 output_size = _compute_output_size(input_contig, n, minlength)
515 output = _bincount_atomic_launch(
516 input_contig,
517 weights_contig,
518 n,
519 output_size,
520 BLOCK_SIZE,
521 num_warps,
522 )
523 else:
524 output_size = _compute_output_size(input_contig, n, minlength)
525 output = _bincount_no_atomic_launch(
526 input_contig,
527 weights_contig,
528 n,
529 output_size,
530 out_dtype,
531 BLOCK_SIZE,
532 num_warps,
533 )
535 if (
536 weights is not None
537 and weights.dtype != torch.float64
538 and weights.dtype != torch.float32
539 ):
540 output = output.to(dtype=weights.dtype)
542 return output