Coverage for src/flag_gems/ops/bincount.py: 76%

55 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from ..utils import libentry 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12@libentry() 

13@triton.jit 

14def bincount_kernel( 

15 inp_ptr, 

16 out_ptr, 

17 N, 

18 BLOCK_SIZE: tl.constexpr, 

19): 

20 """Kernel for bincount without weights.""" 

21 pid = tl.program_id(0) 

22 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

23 mask = offsets < N 

24 

25 # Load input values (indices) 

26 indices = tl.load(inp_ptr + offsets, mask=mask, other=0) 

27 

28 # Atomic add 1 to the output at each index 

29 # Use int64 for the atomic add 

30 ones = tl.full((BLOCK_SIZE,), 1, dtype=tl.int64) 

31 tl.atomic_add(out_ptr + indices, ones, mask=mask, sem="relaxed") 

32 

33 

34@libentry() 

35@triton.jit 

36def bincount_weights_kernel( 

37 inp_ptr, 

38 weights_ptr, 

39 out_ptr, 

40 N, 

41 BLOCK_SIZE: tl.constexpr, 

42): 

43 """Kernel for bincount with weights.""" 

44 pid = tl.program_id(0) 

45 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

46 mask = offsets < N 

47 

48 # Load input values (indices) and weights 

49 indices = tl.load(inp_ptr + offsets, mask=mask, other=0) 

50 weights = tl.load(weights_ptr + offsets, mask=mask, other=0.0) 

51 

52 # Atomic add weights to the output at each index 

53 tl.atomic_add(out_ptr + indices, weights, mask=mask, sem="relaxed") 

54 

55 

56def bincount(inp, weights=None, minlength=0): 

57 """ 

58 Count the frequency of each value in an array of non-negative ints. 

59 

60 Args: 

61 inp: 1-d int tensor of non-negative integers 

62 weights: optional weights tensor of same size as inp 

63 minlength: optional minimum number of bins 

64 

65 Returns: 

66 Tensor of shape (max(inp) + 1,) or (minlength,) if minlength is larger 

67 """ 

68 logger.debug("GEMS BINCOUNT") 

69 

70 # Input validation 

71 assert inp.ndim == 1, "bincount only supports 1-d tensors" 

72 assert inp.dtype in ( 

73 torch.int8, 

74 torch.int16, 

75 torch.int32, 

76 torch.int64, 

77 torch.uint8, 

78 ), "bincount only supports integer tensors" 

79 

80 N = inp.numel() 

81 

82 # Handle empty input 

83 if N == 0: 

84 if weights is not None: 

85 return torch.zeros(minlength, dtype=weights.dtype, device=inp.device) 

86 return torch.zeros(minlength, dtype=torch.int64, device=inp.device) 

87 

88 # Compute output size 

89 max_val = int(inp.max().item()) 

90 output_size = max(max_val + 1, minlength) 

91 

92 # Ensure input is contiguous 

93 inp = inp.contiguous() 

94 

95 if weights is not None: 

96 assert weights.shape == inp.shape, "weights must have same shape as input" 

97 weights = weights.contiguous() 

98 

99 # Output dtype matches weights dtype 

100 # For atomic_add compatibility, convert to float32 if float16/bfloat16 

101 weights_dtype = weights.dtype 

102 if weights_dtype in (torch.float16, torch.bfloat16): 

103 weights = weights.to(torch.float32) 

104 out = torch.zeros(output_size, dtype=torch.float32, device=inp.device) 

105 else: 

106 out = torch.zeros(output_size, dtype=weights.dtype, device=inp.device) 

107 

108 BLOCK_SIZE = 1024 

109 grid = (triton.cdiv(N, BLOCK_SIZE),) 

110 

111 bincount_weights_kernel[grid]( 

112 inp, 

113 weights, 

114 out, 

115 N, 

116 BLOCK_SIZE=BLOCK_SIZE, 

117 ) 

118 

119 # Convert back if needed 

120 if weights_dtype in (torch.float16, torch.bfloat16): 

121 out = out.to(weights_dtype) 

122 

123 return out 

124 else: 

125 # No weights: count occurrences 

126 out = torch.zeros(output_size, dtype=torch.int64, device=inp.device) 

127 

128 BLOCK_SIZE = 1024 

129 grid = (triton.cdiv(N, BLOCK_SIZE),) 

130 

131 bincount_kernel[grid]( 

132 inp, 

133 out, 

134 N, 

135 BLOCK_SIZE=BLOCK_SIZE, 

136 ) 

137 

138 return out