Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/bincount.py: 0%
63 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
9# ---------------------------------------------------------------------------
10# Per-bin scalar-sequential kernels
11#
12# Design rationale (mirrors moe_align_block_size_stage1):
13# - grid = (output_size,), each program owns exactly ONE output bin
14# - Inner loop iterates over ALL input elements one by one (scalar loads)
15# - Conditional `if val == bin_id` accumulates only matching elements
16# - No atomics, no vectorised scatter, no tl.sum tree-reduction
17#
18# Why scalar-sequential matters for float:
19# Any parallel split of the input changes the fp32 intermediate totals and
20# makes the result diverge from torch.bincount's sequential scan by ~0.02
21# for n=100_000, which exceeds the test tolerance. Scalar-sequential order
22# exactly reproduces torch.bincount's per-bin accumulation order, giving
23# bit-identical float results.
24#
25# XPU compatibility:
26# - `for i in range(n_elements)` with do_not_specialize is the same pattern
27# used by moe_align_block_size_stage1 on this backend.
28# - `if scalar_triton_bool:` inside a loop is likewise supported.
29# - isCloseUnrollControl=True prevents the compiler from trying to unroll
30# the dynamic-bound loop (which would blow up code size for large n).
31# ---------------------------------------------------------------------------
34@triton.jit(do_not_specialize=["n_elements"])
35def _bincount_kernel(
36 input_ptr,
37 output_ptr,
38 n_elements,
39):
40 """Integer bincount: count occurrences of each value (no weights)."""
41 bin_id = tl.program_id(0).to(tl.int64)
42 count = 0 # int32; consistent type throughout (no int64 tl.sum used)
43 for i in range(n_elements):
44 val = tl.load(input_ptr + i).to(tl.int64)
45 if val == bin_id:
46 count = count + 1
47 tl.store(output_ptr + bin_id, count.to(tl.int64))
50@triton.jit(do_not_specialize=["n_elements"])
51def _bincount_weights_fp32_kernel(
52 input_ptr,
53 weights_ptr,
54 output_ptr,
55 n_elements,
56):
57 """Weighted bincount with fp32 accumulation.
59 Scalar-sequential order exactly matches torch.bincount's internal order,
60 yielding bit-identical float32 results.
61 """
62 bin_id = tl.program_id(0).to(tl.int64)
63 acc = 0.0 # float32 in Triton JIT (consistent with fp32 weight loads)
64 for i in range(n_elements):
65 val = tl.load(input_ptr + i).to(tl.int64)
66 if val == bin_id:
67 w = tl.load(weights_ptr + i).to(tl.float32)
68 acc = acc + w
69 tl.store(output_ptr + bin_id, acc)
72@triton.jit(do_not_specialize=["n_elements"])
73def _bincount_weights_fp64_kernel(
74 input_ptr,
75 weights_ptr,
76 output_ptr,
77 n_elements,
78):
79 """Weighted bincount with fp64 accumulation."""
80 bin_id = tl.program_id(0).to(tl.int64)
81 # Explicit fp64 init to keep the loop-carried type consistent.
82 acc = tl.zeros([1], dtype=tl.float64)[0]
83 for i in range(n_elements):
84 val = tl.load(input_ptr + i).to(tl.int64)
85 if val == bin_id:
86 w = tl.load(weights_ptr + i).to(tl.float64)
87 acc = acc + w
88 tl.store(output_ptr + bin_id, acc)
91def bincount(input, weights=None, minlength=0):
92 logger.debug("GEMS_KUNLUNXIN BINCOUNT")
94 assert input.dim() == 1, "input must be a 1-D tensor"
95 assert minlength >= 0, "minlength must be non-negative"
97 if weights is not None:
98 assert weights.shape == input.shape, "weights must have the same shape as input"
100 n = input.numel()
102 if n == 0:
103 if weights is not None:
104 return torch.zeros(minlength, dtype=weights.dtype, device=input.device)
105 return torch.zeros(minlength, dtype=torch.int64, device=input.device)
107 input_contig = input.contiguous()
109 # Determine output size; use PyTorch max to avoid tl.atomic_max with int64
110 # (incomplete support on XPU).
111 max_val = int(input_contig.max().item())
112 output_size = max(max_val + 1, minlength)
114 grid = (output_size,)
116 if weights is None:
117 output = torch.zeros(output_size, dtype=torch.int64, device=input.device)
118 _bincount_kernel[grid](
119 input_contig,
120 output,
121 n,
122 isCloseUnrollControl=True,
123 )
124 return output
126 weights_contig = weights.contiguous()
127 out_dtype = weights.dtype
129 if out_dtype == torch.float64:
130 output = torch.zeros(output_size, dtype=torch.float64, device=input.device)
131 _bincount_weights_fp64_kernel[grid](
132 input_contig,
133 weights_contig,
134 output,
135 n,
136 isCloseUnrollControl=True,
137 )
138 else:
139 # Accumulate in fp32 for fp16 / bf16 / fp32 weights
140 output = torch.zeros(output_size, dtype=torch.float32, device=input.device)
141 _bincount_weights_fp32_kernel[grid](
142 input_contig,
143 weights_contig,
144 output,
145 n,
146 isCloseUnrollControl=True,
147 )
148 if out_dtype != torch.float32:
149 output = output.to(out_dtype)
151 return output