Coverage for src/flag_gems/ops/bincount.py: 76%
55 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from ..utils import libentry
9logger = logging.getLogger(__name__)
12@libentry()
13@triton.jit
14def bincount_kernel(
15 inp_ptr,
16 out_ptr,
17 N,
18 BLOCK_SIZE: tl.constexpr,
19):
20 """Kernel for bincount without weights."""
21 pid = tl.program_id(0)
22 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
23 mask = offsets < N
25 # Load input values (indices)
26 indices = tl.load(inp_ptr + offsets, mask=mask, other=0)
28 # Atomic add 1 to the output at each index
29 # Use int64 for the atomic add
30 ones = tl.full((BLOCK_SIZE,), 1, dtype=tl.int64)
31 tl.atomic_add(out_ptr + indices, ones, mask=mask, sem="relaxed")
34@libentry()
35@triton.jit
36def bincount_weights_kernel(
37 inp_ptr,
38 weights_ptr,
39 out_ptr,
40 N,
41 BLOCK_SIZE: tl.constexpr,
42):
43 """Kernel for bincount with weights."""
44 pid = tl.program_id(0)
45 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
46 mask = offsets < N
48 # Load input values (indices) and weights
49 indices = tl.load(inp_ptr + offsets, mask=mask, other=0)
50 weights = tl.load(weights_ptr + offsets, mask=mask, other=0.0)
52 # Atomic add weights to the output at each index
53 tl.atomic_add(out_ptr + indices, weights, mask=mask, sem="relaxed")
56def bincount(inp, weights=None, minlength=0):
57 """
58 Count the frequency of each value in an array of non-negative ints.
60 Args:
61 inp: 1-d int tensor of non-negative integers
62 weights: optional weights tensor of same size as inp
63 minlength: optional minimum number of bins
65 Returns:
66 Tensor of shape (max(inp) + 1,) or (minlength,) if minlength is larger
67 """
68 logger.debug("GEMS BINCOUNT")
70 # Input validation
71 assert inp.ndim == 1, "bincount only supports 1-d tensors"
72 assert inp.dtype in (
73 torch.int8,
74 torch.int16,
75 torch.int32,
76 torch.int64,
77 torch.uint8,
78 ), "bincount only supports integer tensors"
80 N = inp.numel()
82 # Handle empty input
83 if N == 0:
84 if weights is not None:
85 return torch.zeros(minlength, dtype=weights.dtype, device=inp.device)
86 return torch.zeros(minlength, dtype=torch.int64, device=inp.device)
88 # Compute output size
89 max_val = int(inp.max().item())
90 output_size = max(max_val + 1, minlength)
92 # Ensure input is contiguous
93 inp = inp.contiguous()
95 if weights is not None:
96 assert weights.shape == inp.shape, "weights must have same shape as input"
97 weights = weights.contiguous()
99 # Output dtype matches weights dtype
100 # For atomic_add compatibility, convert to float32 if float16/bfloat16
101 weights_dtype = weights.dtype
102 if weights_dtype in (torch.float16, torch.bfloat16):
103 weights = weights.to(torch.float32)
104 out = torch.zeros(output_size, dtype=torch.float32, device=inp.device)
105 else:
106 out = torch.zeros(output_size, dtype=weights.dtype, device=inp.device)
108 BLOCK_SIZE = 1024
109 grid = (triton.cdiv(N, BLOCK_SIZE),)
111 bincount_weights_kernel[grid](
112 inp,
113 weights,
114 out,
115 N,
116 BLOCK_SIZE=BLOCK_SIZE,
117 )
119 # Convert back if needed
120 if weights_dtype in (torch.float16, torch.bfloat16):
121 out = out.to(weights_dtype)
123 return out
124 else:
125 # No weights: count occurrences
126 out = torch.zeros(output_size, dtype=torch.int64, device=inp.device)
128 BLOCK_SIZE = 1024
129 grid = (triton.cdiv(N, BLOCK_SIZE),)
131 bincount_kernel[grid](
132 inp,
133 out,
134 N,
135 BLOCK_SIZE=BLOCK_SIZE,
136 )
138 return out