Coverage for src/flag_gems/ops/histc.py: 48%

64 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry 

9from flag_gems.utils import triton_lang_extension as ext 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@libentry() 

15@triton.jit 

16def histc_kernel( 

17 inp_ptr, 

18 out_ptr, 

19 n_elements, 

20 bins: tl.constexpr, 

21 min_val, 

22 max_val, 

23 BLOCK_SIZE: tl.constexpr, 

24): 

25 """ 

26 Compute histogram of input tensor. 

27 Each thread processes BLOCK_SIZE elements, computing which bin they belong to 

28 and atomically incrementing the corresponding bin counter. 

29 """ 

30 pid = ext.program_id(0) 

31 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

32 mask = offset < n_elements 

33 

34 # Load input values 

35 inp_val = tl.load(inp_ptr + offset, mask=mask, other=0.0) 

36 

37 # Convert to float32 for computation 

38 inp_val = inp_val.to(tl.float32) 

39 

40 # Compute bin range 

41 bin_width = (max_val - min_val) / bins 

42 

43 # Compute bin indices 

44 # Elements equal to max_val go to the last bin (bins - 1) 

45 # Elements outside [min_val, max_val] or NaN are ignored 

46 bin_idx = ((inp_val - min_val) / bin_width).to(tl.int32) 

47 

48 # Clamp to valid range [0, bins-1] for elements in range 

49 # Elements outside range or NaN should be excluded 

50 in_range = (inp_val >= min_val) & (inp_val <= max_val) 

51 

52 # Handle edge case: elements exactly equal to max go to last bin 

53 bin_idx = tl.where(inp_val == max_val, bins - 1, bin_idx) 

54 bin_idx = tl.where(bin_idx < 0, 0, bin_idx) 

55 bin_idx = tl.where(bin_idx >= bins, bins - 1, bin_idx) 

56 

57 # Only count elements in range (excludes NaN via the comparison) 

58 valid_mask = mask & in_range 

59 

60 # Atomic add to histogram bins 

61 # We need to iterate through each element and add to the appropriate bin 

62 for i in range(BLOCK_SIZE): 

63 if tl.load(valid_mask.to(tl.int8).reshape(BLOCK_SIZE) + i) != 0: 

64 idx = tl.load(bin_idx.reshape(BLOCK_SIZE) + i) 

65 tl.atomic_add(out_ptr + idx, 1.0, sem="relaxed") 

66 

67 

68@libentry() 

69@triton.jit 

70def histc_kernel_simple( 

71 inp_ptr, 

72 out_ptr, 

73 n_elements, 

74 bins, 

75 min_val, 

76 max_val, 

77 BLOCK_SIZE: tl.constexpr, 

78): 

79 """ 

80 Simple histogram kernel - each program handles one element at a time. 

81 """ 

82 pid = ext.program_id(0) 

83 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

84 mask = offset < n_elements 

85 

86 # Load input values 

87 inp_val = tl.load(inp_ptr + offset, mask=mask, other=float("nan")) 

88 

89 # Convert to float32 for computation 

90 inp_val = inp_val.to(tl.float32) 

91 

92 # Compute bin indices using multiplication to avoid float division precision loss 

93 bin_idx = tl.floor((inp_val - min_val) * bins / (max_val - min_val)).to(tl.int64) 

94 

95 # Handle edge case: elements exactly equal to max go to last bin 

96 bin_idx = tl.where(inp_val == max_val, bins - 1, bin_idx) 

97 

98 # Check if elements are in valid range (excludes NaN) 

99 in_range = (inp_val >= min_val) & (inp_val <= max_val) 

100 

101 # Clamp bin indices to valid range 

102 bin_idx = tl.where(bin_idx < 0, 0, bin_idx) 

103 bin_idx = tl.where(bin_idx >= bins, bins - 1, bin_idx) 

104 

105 valid_mask = mask & in_range 

106 

107 # Atomically add to histogram 

108 tl.atomic_add(out_ptr + bin_idx, 1.0, mask=valid_mask, sem="relaxed") 

109 

110 

111def histc(inp, bins=100, min=0, max=0): 

112 """ 

113 Compute the histogram of a tensor. 

114 

115 Args: 

116 inp: Input tensor 

117 bins: Number of histogram bins (default: 100) 

118 min: Lower end of the range (inclusive). If min == max == 0, uses data min. 

119 max: Upper end of the range (inclusive). If min == max == 0, uses data max. 

120 

121 Returns: 

122 Tensor: Histogram represented as a tensor of shape (bins,) 

123 """ 

124 logger.debug("GEMS HISTC") 

125 

126 # Ensure input is contiguous 

127 inp = inp.contiguous() 

128 

129 # Get min and max values 

130 min_val = float(min) 

131 max_val = float(max) 

132 

133 if min_val == 0 and max_val == 0: 

134 # Use actual min/max of the data 

135 min_val = float(inp.min().item()) 

136 max_val = float(inp.max().item()) 

137 

138 # Handle edge case where min == max 

139 if min_val == max_val: 

140 # All elements go to the first bin if they equal min_val 

141 out = torch.zeros(bins, dtype=inp.dtype, device=inp.device) 

142 # Count how many elements equal min_val (excluding NaN) 

143 count = ((inp == min_val) & ~torch.isnan(inp)).sum().item() 

144 out[0] = count 

145 return out 

146 

147 # Create output histogram tensor 

148 out = torch.zeros(bins, dtype=inp.dtype, device=inp.device) 

149 

150 n_elements = inp.numel() 

151 

152 if n_elements == 0: 

153 return out 

154 

155 # Choose block size 

156 BLOCK_SIZE = 1024 

157 

158 # Calculate grid size 

159 grid = (triton.cdiv(n_elements, BLOCK_SIZE),) 

160 

161 with torch_device_fn.device(inp.device): 

162 histc_kernel_simple[grid]( 

163 inp, 

164 out, 

165 n_elements, 

166 bins, 

167 min_val, 

168 max_val, 

169 BLOCK_SIZE=BLOCK_SIZE, 

170 ) 

171 

172 return out