Coverage for src/flag_gems/ops/histc.py: 45%
64 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry
9from flag_gems.utils import triton_lang_extension as ext
11logger = logging.getLogger(__name__)
14@libentry()
15@triton.jit
16def histc_kernel(
17 inp_ptr,
18 out_ptr,
19 n_elements,
20 bins: tl.constexpr,
21 min_val,
22 max_val,
23 BLOCK_SIZE: tl.constexpr,
24):
25 """
26 Compute histogram of input tensor.
27 Each thread processes BLOCK_SIZE elements, computing which bin they belong to
28 and atomically incrementing the corresponding bin counter.
29 """
30 pid = ext.program_id(0)
31 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
32 mask = offset < n_elements
34 # Load input values
35 inp_val = tl.load(inp_ptr + offset, mask=mask, other=0.0)
37 # Convert to float32 for computation
38 inp_val = inp_val.to(tl.float32)
40 # Compute bin range
41 bin_width = (max_val - min_val) / bins
43 # Compute bin indices
44 # Elements equal to max_val go to the last bin (bins - 1)
45 # Elements outside [min_val, max_val] or NaN are ignored
46 bin_idx = ((inp_val - min_val) / bin_width).to(tl.int32)
48 # Clamp to valid range [0, bins-1] for elements in range
49 # Elements outside range or NaN should be excluded
50 in_range = (inp_val >= min_val) & (inp_val <= max_val)
52 # Handle edge case: elements exactly equal to max go to last bin
53 bin_idx = tl.where(inp_val == max_val, bins - 1, bin_idx)
54 bin_idx = tl.where(bin_idx < 0, 0, bin_idx)
55 bin_idx = tl.where(bin_idx >= bins, bins - 1, bin_idx)
57 # Only count elements in range (excludes NaN via the comparison)
58 valid_mask = mask & in_range
60 # Atomic add to histogram bins
61 # We need to iterate through each element and add to the appropriate bin
62 for i in range(BLOCK_SIZE):
63 if tl.load(valid_mask.to(tl.int8).reshape(BLOCK_SIZE) + i) != 0:
64 idx = tl.load(bin_idx.reshape(BLOCK_SIZE) + i)
65 tl.atomic_add(out_ptr + idx, 1.0, sem="relaxed")
68@libentry()
69@triton.jit
70def histc_kernel_simple(
71 inp_ptr,
72 out_ptr,
73 n_elements,
74 bins,
75 min_val,
76 max_val,
77 BLOCK_SIZE: tl.constexpr,
78):
79 """
80 Simple histogram kernel - each program handles one element at a time.
81 """
82 pid = ext.program_id(0)
83 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
84 mask = offset < n_elements
86 # Load input values
87 inp_val = tl.load(inp_ptr + offset, mask=mask, other=float("nan"))
89 # Convert to float32 for computation
90 inp_val = inp_val.to(tl.float32)
92 # Compute bin indices using multiplication to avoid float division precision loss
93 bin_idx = tl.floor((inp_val - min_val) * bins / (max_val - min_val)).to(tl.int64)
95 # Handle edge case: elements exactly equal to max go to last bin
96 bin_idx = tl.where(inp_val == max_val, bins - 1, bin_idx)
98 # Check if elements are in valid range (excludes NaN)
99 in_range = (inp_val >= min_val) & (inp_val <= max_val)
101 # Clamp bin indices to valid range
102 bin_idx = tl.where(bin_idx < 0, 0, bin_idx)
103 bin_idx = tl.where(bin_idx >= bins, bins - 1, bin_idx)
105 valid_mask = mask & in_range
107 # Atomically add to histogram
108 tl.atomic_add(out_ptr + bin_idx, 1.0, mask=valid_mask, sem="relaxed")
111def histc(inp, bins=100, min=0, max=0):
112 """
113 Compute the histogram of a tensor.
115 Args:
116 inp: Input tensor
117 bins: Number of histogram bins (default: 100)
118 min: Lower end of the range (inclusive). If min == max == 0, uses data min.
119 max: Upper end of the range (inclusive). If min == max == 0, uses data max.
121 Returns:
122 Tensor: Histogram represented as a tensor of shape (bins,)
123 """
124 logger.debug("GEMS HISTC")
126 # Ensure input is contiguous
127 inp = inp.contiguous()
129 # Get min and max values
130 min_val = float(min)
131 max_val = float(max)
133 if min_val == 0 and max_val == 0:
134 # Use actual min/max of the data
135 min_val = float(inp.min().item())
136 max_val = float(inp.max().item())
138 # Handle edge case where min == max
139 if min_val == max_val:
140 # All elements go to the first bin if they equal min_val
141 out = torch.zeros(bins, dtype=inp.dtype, device=inp.device)
142 # Count how many elements equal min_val (excluding NaN)
143 count = ((inp == min_val) & ~torch.isnan(inp)).sum().item()
144 out[0] = count
145 return out
147 # Create output histogram tensor
148 out = torch.zeros(bins, dtype=inp.dtype, device=inp.device)
150 n_elements = inp.numel()
152 if n_elements == 0:
153 return out
155 # Choose block size
156 BLOCK_SIZE = 1024
158 # Calculate grid size
159 grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
161 with torch_device_fn.device(inp.device):
162 histc_kernel_simple[grid](
163 inp,
164 out,
165 n_elements,
166 bins,
167 min_val,
168 max_val,
169 BLOCK_SIZE=BLOCK_SIZE,
170 )
172 return out