Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/bincount.py: 0%

63 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

8 

9# --------------------------------------------------------------------------- 

10# Per-bin scalar-sequential kernels 

11# 

12# Design rationale (mirrors moe_align_block_size_stage1): 

13# - grid = (output_size,), each program owns exactly ONE output bin 

14# - Inner loop iterates over ALL input elements one by one (scalar loads) 

15# - Conditional `if val == bin_id` accumulates only matching elements 

16# - No atomics, no vectorised scatter, no tl.sum tree-reduction 

17# 

18# Why scalar-sequential matters for float: 

19# Any parallel split of the input changes the fp32 intermediate totals and 

20# makes the result diverge from torch.bincount's sequential scan by ~0.02 

21# for n=100_000, which exceeds the test tolerance. Scalar-sequential order 

22# exactly reproduces torch.bincount's per-bin accumulation order, giving 

23# bit-identical float results. 

24# 

25# XPU compatibility: 

26# - `for i in range(n_elements)` with do_not_specialize is the same pattern 

27# used by moe_align_block_size_stage1 on this backend. 

28# - `if scalar_triton_bool:` inside a loop is likewise supported. 

29# - isCloseUnrollControl=True prevents the compiler from trying to unroll 

30# the dynamic-bound loop (which would blow up code size for large n). 

31# --------------------------------------------------------------------------- 

32 

33 

34@triton.jit(do_not_specialize=["n_elements"]) 

35def _bincount_kernel( 

36 input_ptr, 

37 output_ptr, 

38 n_elements, 

39): 

40 """Integer bincount: count occurrences of each value (no weights).""" 

41 bin_id = tl.program_id(0).to(tl.int64) 

42 count = 0 # int32; consistent type throughout (no int64 tl.sum used) 

43 for i in range(n_elements): 

44 val = tl.load(input_ptr + i).to(tl.int64) 

45 if val == bin_id: 

46 count = count + 1 

47 tl.store(output_ptr + bin_id, count.to(tl.int64)) 

48 

49 

50@triton.jit(do_not_specialize=["n_elements"]) 

51def _bincount_weights_fp32_kernel( 

52 input_ptr, 

53 weights_ptr, 

54 output_ptr, 

55 n_elements, 

56): 

57 """Weighted bincount with fp32 accumulation. 

58 

59 Scalar-sequential order exactly matches torch.bincount's internal order, 

60 yielding bit-identical float32 results. 

61 """ 

62 bin_id = tl.program_id(0).to(tl.int64) 

63 acc = 0.0 # float32 in Triton JIT (consistent with fp32 weight loads) 

64 for i in range(n_elements): 

65 val = tl.load(input_ptr + i).to(tl.int64) 

66 if val == bin_id: 

67 w = tl.load(weights_ptr + i).to(tl.float32) 

68 acc = acc + w 

69 tl.store(output_ptr + bin_id, acc) 

70 

71 

72@triton.jit(do_not_specialize=["n_elements"]) 

73def _bincount_weights_fp64_kernel( 

74 input_ptr, 

75 weights_ptr, 

76 output_ptr, 

77 n_elements, 

78): 

79 """Weighted bincount with fp64 accumulation.""" 

80 bin_id = tl.program_id(0).to(tl.int64) 

81 # Explicit fp64 init to keep the loop-carried type consistent. 

82 acc = tl.zeros([1], dtype=tl.float64)[0] 

83 for i in range(n_elements): 

84 val = tl.load(input_ptr + i).to(tl.int64) 

85 if val == bin_id: 

86 w = tl.load(weights_ptr + i).to(tl.float64) 

87 acc = acc + w 

88 tl.store(output_ptr + bin_id, acc) 

89 

90 

91def bincount(input, weights=None, minlength=0): 

92 logger.debug("GEMS_KUNLUNXIN BINCOUNT") 

93 

94 assert input.dim() == 1, "input must be a 1-D tensor" 

95 assert minlength >= 0, "minlength must be non-negative" 

96 

97 if weights is not None: 

98 assert weights.shape == input.shape, "weights must have the same shape as input" 

99 

100 n = input.numel() 

101 

102 if n == 0: 

103 if weights is not None: 

104 return torch.zeros(minlength, dtype=weights.dtype, device=input.device) 

105 return torch.zeros(minlength, dtype=torch.int64, device=input.device) 

106 

107 input_contig = input.contiguous() 

108 

109 # Determine output size; use PyTorch max to avoid tl.atomic_max with int64 

110 # (incomplete support on XPU). 

111 max_val = int(input_contig.max().item()) 

112 output_size = max(max_val + 1, minlength) 

113 

114 grid = (output_size,) 

115 

116 if weights is None: 

117 output = torch.zeros(output_size, dtype=torch.int64, device=input.device) 

118 _bincount_kernel[grid]( 

119 input_contig, 

120 output, 

121 n, 

122 isCloseUnrollControl=True, 

123 ) 

124 return output 

125 

126 weights_contig = weights.contiguous() 

127 out_dtype = weights.dtype 

128 

129 if out_dtype == torch.float64: 

130 output = torch.zeros(output_size, dtype=torch.float64, device=input.device) 

131 _bincount_weights_fp64_kernel[grid]( 

132 input_contig, 

133 weights_contig, 

134 output, 

135 n, 

136 isCloseUnrollControl=True, 

137 ) 

138 else: 

139 # Accumulate in fp32 for fp16 / bf16 / fp32 weights 

140 output = torch.zeros(output_size, dtype=torch.float32, device=input.device) 

141 _bincount_weights_fp32_kernel[grid]( 

142 input_contig, 

143 weights_contig, 

144 output, 

145 n, 

146 isCloseUnrollControl=True, 

147 ) 

148 if out_dtype != torch.float32: 

149 output = output.to(out_dtype) 

150 

151 return output