Coverage for src/flag_gems/runtime/backend/_sunrise/fused/bincount.py: 0%

228 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13def _select_params(n): 

14 if n <= 256: 

15 return 256, 2 

16 if n <= 1024: 

17 return 256, 4 

18 if n <= 4096: 

19 return 512, 4 

20 return 1024, 4 

21 

22 

23def _estimate_output_size(n, minlength): 

24 estimate = max(8192, n * 4, minlength) 

25 estimate = min(estimate, 65536) 

26 return max(estimate, minlength) 

27 

28 

29def _select_max_block_size(n): 

30 return triton.next_power_of_2(max(1, math.ceil(math.sqrt(n)))) 

31 

32 

33def _select_bins_block(output_size): 

34 return min(128, triton.next_power_of_2(max(1, output_size))) 

35 

36 

37@triton.jit 

38def fused_max_bincount_kernel( 

39 input_ptr, 

40 max_ptr, 

41 output_ptr, 

42 n_elements, 

43 output_size, 

44 BLOCK_SIZE: tl.constexpr, 

45): 

46 pid = tl.program_id(0) 

47 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

48 mask = offsets < n_elements 

49 vals = tl.load(input_ptr + offsets, mask=mask, other=0) 

50 

51 local_max = tl.max(vals, axis=0) 

52 tl.atomic_max(max_ptr, local_max) 

53 

54 safe_mask = mask & (vals < output_size) 

55 tl.atomic_add(output_ptr + vals, 1, mask=safe_mask) 

56 

57 

58@triton.jit 

59def bincount_kernel( 

60 input_ptr, 

61 output_ptr, 

62 n_elements, 

63 BLOCK_SIZE: tl.constexpr, 

64): 

65 pid = tl.program_id(0) 

66 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

67 mask = offsets < n_elements 

68 vals = tl.load(input_ptr + offsets, mask=mask, other=0) 

69 tl.atomic_add(output_ptr + vals, 1, mask=mask) 

70 

71 

72@triton.jit 

73def fused_max_bincount_weights_fp32_kernel( 

74 input_ptr, 

75 weights_ptr, 

76 max_ptr, 

77 output_ptr, 

78 n_elements, 

79 output_size, 

80 BLOCK_SIZE: tl.constexpr, 

81): 

82 pid = tl.program_id(0) 

83 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

84 mask = offsets < n_elements 

85 vals = tl.load(input_ptr + offsets, mask=mask, other=0) 

86 w = tl.load(weights_ptr + offsets, mask=mask, other=0.0) 

87 w_fp32 = w.to(tl.float32) 

88 

89 local_max = tl.max(vals, axis=0) 

90 tl.atomic_max(max_ptr, local_max) 

91 

92 safe_mask = mask & (vals < output_size) 

93 tl.atomic_add(output_ptr + vals, w_fp32, mask=safe_mask) 

94 

95 

96@triton.jit 

97def fused_max_bincount_weights_fp64_kernel( 

98 input_ptr, 

99 weights_ptr, 

100 max_ptr, 

101 output_ptr, 

102 n_elements, 

103 output_size, 

104 BLOCK_SIZE: tl.constexpr, 

105): 

106 pid = tl.program_id(0) 

107 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

108 mask = offsets < n_elements 

109 vals = tl.load(input_ptr + offsets, mask=mask, other=0) 

110 w = tl.load(weights_ptr + offsets, mask=mask, other=0.0) 

111 w_fp64 = w.to(tl.float64) 

112 

113 local_max = tl.max(vals, axis=0) 

114 tl.atomic_max(max_ptr, local_max) 

115 

116 safe_mask = mask & (vals < output_size) 

117 tl.atomic_add(output_ptr + vals, w_fp64, mask=safe_mask) 

118 

119 

120@triton.jit 

121def bincount_max_kernel_1( 

122 input_ptr, 

123 mid_ptr, 

124 n_elements, 

125 BLOCK_SIZE: tl.constexpr, 

126): 

127 pid = tl.program_id(0) 

128 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

129 mask = offsets < n_elements 

130 vals = tl.load(input_ptr + offsets, mask=mask, other=0) 

131 local_max = tl.max(vals, axis=0) 

132 tl.store(mid_ptr + pid, local_max) 

133 

134 

135@triton.jit 

136def bincount_max_kernel_2( 

137 mid_ptr, 

138 max_ptr, 

139 mid_size, 

140 BLOCK_MID: tl.constexpr, 

141): 

142 offsets = tl.arange(0, BLOCK_MID) 

143 mask = offsets < mid_size 

144 mid_vals = tl.load(mid_ptr + offsets, mask=mask, other=0) 

145 max_val = tl.max(mid_vals, axis=0) 

146 tl.store(max_ptr, max_val) 

147 

148 

149@triton.jit 

150def bincount_weights_fp32_kernel( 

151 input_ptr, 

152 weights_ptr, 

153 output_ptr, 

154 n_elements, 

155 BLOCK_SIZE: tl.constexpr, 

156): 

157 pid = tl.program_id(0) 

158 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

159 mask = offsets < n_elements 

160 vals = tl.load(input_ptr + offsets, mask=mask, other=0) 

161 w = tl.load(weights_ptr + offsets, mask=mask, other=0.0) 

162 w_fp32 = w.to(tl.float32) 

163 tl.atomic_add(output_ptr + vals, w_fp32, mask=mask) 

164 

165 

166@triton.jit 

167def bincount_weights_fp64_kernel( 

168 input_ptr, 

169 weights_ptr, 

170 output_ptr, 

171 n_elements, 

172 BLOCK_SIZE: tl.constexpr, 

173): 

174 pid = tl.program_id(0) 

175 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

176 mask = offsets < n_elements 

177 vals = tl.load(input_ptr + offsets, mask=mask, other=0) 

178 w = tl.load(weights_ptr + offsets, mask=mask, other=0.0) 

179 w_fp64 = w.to(tl.float64) 

180 tl.atomic_add(output_ptr + vals, w_fp64, mask=mask) 

181 

182 

183@triton.jit 

184def bincount_partial_int64_kernel( 

185 input_ptr, 

186 partial_ptr, 

187 n_elements, 

188 output_size, 

189 BLOCK_SIZE: tl.constexpr, 

190 BLOCK_BINS: tl.constexpr, 

191 TILE_INPUT: tl.constexpr, 

192): 

193 pid_block = tl.program_id(0) 

194 pid_bin = tl.program_id(1) 

195 

196 block_start = pid_block * BLOCK_SIZE 

197 bin_offsets = pid_bin * BLOCK_BINS + tl.arange(0, BLOCK_BINS) 

198 bin_mask = bin_offsets < output_size 

199 acc = tl.zeros([BLOCK_BINS], dtype=tl.int32) 

200 

201 for tile_start in range(0, BLOCK_SIZE, TILE_INPUT): 

202 input_offsets = block_start + tile_start + tl.arange(0, TILE_INPUT) 

203 input_mask = input_offsets < n_elements 

204 vals = tl.load(input_ptr + input_offsets, mask=input_mask, other=0) 

205 bins = bin_offsets.to(vals.dtype) 

206 matches = ( 

207 bin_mask[:, None] & input_mask[None, :] & (bins[:, None] == vals[None, :]) 

208 ) 

209 acc += tl.sum(matches.to(tl.int32), axis=1) 

210 

211 partial_offsets = pid_block * output_size + bin_offsets 

212 tl.store(partial_ptr + partial_offsets, acc.to(tl.int64), mask=bin_mask) 

213 

214 

215@triton.jit 

216def bincount_partial_weights_fp64_kernel( 

217 input_ptr, 

218 weights_ptr, 

219 partial_ptr, 

220 n_elements, 

221 output_size, 

222 BLOCK_SIZE: tl.constexpr, 

223 BLOCK_BINS: tl.constexpr, 

224 TILE_INPUT: tl.constexpr, 

225): 

226 pid_block = tl.program_id(0) 

227 pid_bin = tl.program_id(1) 

228 

229 block_start = pid_block * BLOCK_SIZE 

230 bin_offsets = pid_bin * BLOCK_BINS + tl.arange(0, BLOCK_BINS) 

231 bin_mask = bin_offsets < output_size 

232 acc = tl.zeros([BLOCK_BINS], dtype=tl.float64) 

233 

234 for tile_start in range(0, BLOCK_SIZE, TILE_INPUT): 

235 input_offsets = block_start + tile_start + tl.arange(0, TILE_INPUT) 

236 input_mask = input_offsets < n_elements 

237 vals = tl.load(input_ptr + input_offsets, mask=input_mask, other=0) 

238 w = tl.load(weights_ptr + input_offsets, mask=input_mask, other=0.0).to( 

239 tl.float64 

240 ) 

241 bins = bin_offsets.to(vals.dtype) 

242 matches = ( 

243 bin_mask[:, None] & input_mask[None, :] & (bins[:, None] == vals[None, :]) 

244 ) 

245 acc += tl.sum(tl.where(matches, w[None, :], 0.0), axis=1) 

246 

247 partial_offsets = pid_block * output_size + bin_offsets 

248 tl.store(partial_ptr + partial_offsets, acc, mask=bin_mask) 

249 

250 

251@triton.jit 

252def bincount_reduce_partial_kernel( 

253 partial_ptr, 

254 output_ptr, 

255 num_partials, 

256 output_size, 

257 BLOCK_PARTIAL: tl.constexpr, 

258 BLOCK_BINS: tl.constexpr, 

259): 

260 pid_bin = tl.program_id(0) 

261 bin_offsets = pid_bin * BLOCK_BINS + tl.arange(0, BLOCK_BINS) 

262 bin_mask = bin_offsets < output_size 

263 acc = tl.zeros([BLOCK_BINS], dtype=output_ptr.dtype.element_ty) 

264 

265 for partial_start in range(0, num_partials, BLOCK_PARTIAL): 

266 partial_rows = partial_start + tl.arange(0, BLOCK_PARTIAL) 

267 partial_ptrs = ( 

268 partial_ptr + partial_rows[:, None] * output_size + bin_offsets[None, :] 

269 ) 

270 partial_mask = (partial_rows[:, None] < num_partials) & bin_mask[None, :] 

271 partial_vals = tl.load(partial_ptrs, mask=partial_mask, other=0) 

272 acc += tl.sum(partial_vals, axis=0) 

273 

274 tl.store(output_ptr + bin_offsets, acc, mask=bin_mask) 

275 

276 

277def _compute_output_size(input_contig, n, minlength): 

278 max_block_size = _select_max_block_size(n) 

279 mid_size = triton.cdiv(n, max_block_size) 

280 block_mid = triton.next_power_of_2(mid_size) 

281 

282 mid = torch.empty((mid_size,), dtype=input_contig.dtype, device=input_contig.device) 

283 max_tensor = torch.empty([], dtype=input_contig.dtype, device=input_contig.device) 

284 

285 with torch_device_fn.device(input_contig.device): 

286 bincount_max_kernel_1[(mid_size, 1, 1)]( 

287 input_contig, 

288 mid, 

289 n, 

290 BLOCK_SIZE=max_block_size, 

291 ) 

292 bincount_max_kernel_2[(1, 1, 1)]( 

293 mid, 

294 max_tensor, 

295 mid_size, 

296 BLOCK_MID=block_mid, 

297 ) 

298 

299 return max(int(max_tensor.item()) + 1, minlength) 

300 

301 

302def _bincount_atomic_launch( 

303 input_contig, 

304 weights_contig, 

305 n, 

306 output_size, 

307 BLOCK_SIZE, 

308 num_warps, 

309): 

310 output = torch.zeros(output_size, dtype=torch.float32, device=input_contig.device) 

311 grid = (triton.cdiv(n, BLOCK_SIZE),) 

312 

313 with torch_device_fn.device(input_contig.device): 

314 bincount_weights_fp32_kernel[grid]( 

315 input_contig, 

316 weights_contig, 

317 output, 

318 n, 

319 BLOCK_SIZE=BLOCK_SIZE, 

320 num_warps=num_warps, 

321 ) 

322 

323 return output 

324 

325 

326def _fused_bincount_atomic_launch( 

327 input_contig, 

328 weights_contig, 

329 n, 

330 pre_size, 

331 minlength, 

332 out_dtype, 

333 grid, 

334 BLOCK_SIZE, 

335 num_warps, 

336): 

337 max_tensor = torch.zeros(1, dtype=input_contig.dtype, device=input_contig.device) 

338 is_fp64 = out_dtype == torch.float64 

339 compute_dtype = torch.float64 if is_fp64 else torch.float32 

340 output = torch.zeros(pre_size, dtype=compute_dtype, device=input_contig.device) 

341 

342 with torch_device_fn.device(input_contig.device): 

343 if is_fp64: 

344 fused_max_bincount_weights_fp64_kernel[grid]( 

345 input_contig, 

346 weights_contig, 

347 max_tensor, 

348 output, 

349 n, 

350 pre_size, 

351 BLOCK_SIZE=BLOCK_SIZE, 

352 num_warps=num_warps, 

353 ) 

354 else: 

355 fused_max_bincount_weights_fp32_kernel[grid]( 

356 input_contig, 

357 weights_contig, 

358 max_tensor, 

359 output, 

360 n, 

361 pre_size, 

362 BLOCK_SIZE=BLOCK_SIZE, 

363 num_warps=num_warps, 

364 ) 

365 

366 max_val = int(max_tensor.item()) 

367 needed_size = max(max_val + 1, minlength) 

368 

369 if needed_size <= pre_size: 

370 return output[:needed_size] 

371 

372 if is_fp64: 

373 output = torch.zeros( 

374 needed_size, dtype=torch.float64, device=input_contig.device 

375 ) 

376 else: 

377 output = torch.zeros( 

378 needed_size, dtype=torch.float32, device=input_contig.device 

379 ) 

380 

381 with torch_device_fn.device(input_contig.device): 

382 if is_fp64: 

383 bincount_weights_fp64_kernel[grid]( 

384 input_contig, 

385 weights_contig, 

386 output, 

387 n, 

388 BLOCK_SIZE=BLOCK_SIZE, 

389 num_warps=num_warps, 

390 ) 

391 else: 

392 bincount_weights_fp32_kernel[grid]( 

393 input_contig, 

394 weights_contig, 

395 output, 

396 n, 

397 BLOCK_SIZE=BLOCK_SIZE, 

398 num_warps=num_warps, 

399 ) 

400 

401 return output 

402 

403 

404def _bincount_no_atomic_launch( 

405 input_contig, 

406 weights_contig, 

407 n, 

408 output_size, 

409 out_dtype, 

410 BLOCK_SIZE, 

411 num_warps, 

412): 

413 block_bins = _select_bins_block(output_size) 

414 tile_input = min(64, BLOCK_SIZE) 

415 num_partials = triton.cdiv(n, BLOCK_SIZE) 

416 grid = (num_partials, triton.cdiv(output_size, block_bins)) 

417 

418 partial = torch.empty( 

419 (num_partials, output_size), dtype=out_dtype, device=input_contig.device 

420 ) 

421 output = torch.empty(output_size, dtype=out_dtype, device=input_contig.device) 

422 

423 with torch_device_fn.device(input_contig.device): 

424 if weights_contig is None: 

425 bincount_partial_int64_kernel[grid]( 

426 input_contig, 

427 partial, 

428 n, 

429 output_size, 

430 BLOCK_SIZE=BLOCK_SIZE, 

431 BLOCK_BINS=block_bins, 

432 TILE_INPUT=tile_input, 

433 num_warps=num_warps, 

434 ) 

435 else: 

436 bincount_partial_weights_fp64_kernel[grid]( 

437 input_contig, 

438 weights_contig, 

439 partial, 

440 n, 

441 output_size, 

442 BLOCK_SIZE=BLOCK_SIZE, 

443 BLOCK_BINS=block_bins, 

444 TILE_INPUT=tile_input, 

445 num_warps=num_warps, 

446 ) 

447 

448 bincount_reduce_partial_kernel[(triton.cdiv(output_size, block_bins), 1, 1)]( 

449 partial, 

450 output, 

451 num_partials, 

452 output_size, 

453 BLOCK_PARTIAL=8, 

454 BLOCK_BINS=block_bins, 

455 num_warps=4, 

456 ) 

457 

458 return output 

459 

460 

461def _supports_atomic_accumulate(out_dtype): 

462 return out_dtype not in (torch.int64, torch.float64) 

463 

464 

465def _supports_fused_atomic(input_dtype, out_dtype): 

466 return _supports_atomic_accumulate(out_dtype) and input_dtype == torch.int32 

467 

468 

469def bincount(input, weights=None, minlength=0): 

470 logger.debug("GEMS BINCOUNT") 

471 

472 assert input.dim() == 1, "input must be a 1-D tensor" 

473 assert minlength >= 0, "minlength must be non-negative" 

474 

475 if weights is not None: 

476 assert weights.shape == input.shape, "weights must have the same shape as input" 

477 

478 n = input.numel() 

479 

480 if n == 0: 

481 if weights is not None: 

482 return torch.zeros(minlength, dtype=weights.dtype, device=input.device) 

483 return torch.zeros(minlength, dtype=torch.int64, device=input.device) 

484 

485 input_contig = input.contiguous() 

486 weights_contig = weights.contiguous() if weights is not None else None 

487 

488 if weights is not None and weights.dtype == torch.float64: 

489 return torch.bincount( 

490 input_contig.cpu(), 

491 weights=weights_contig.cpu(), 

492 minlength=minlength, 

493 ).to(input.device) 

494 

495 BLOCK_SIZE, num_warps = _select_params(n) 

496 grid = (triton.cdiv(n, BLOCK_SIZE),) 

497 

498 out_dtype = weights.dtype if weights is not None else torch.int64 

499 

500 if _supports_fused_atomic(input_contig.dtype, out_dtype): 

501 pre_size = _estimate_output_size(n, minlength) 

502 output = _fused_bincount_atomic_launch( 

503 input_contig, 

504 weights_contig, 

505 n, 

506 pre_size, 

507 minlength, 

508 out_dtype, 

509 grid, 

510 BLOCK_SIZE, 

511 num_warps, 

512 ) 

513 elif _supports_atomic_accumulate(out_dtype): 

514 output_size = _compute_output_size(input_contig, n, minlength) 

515 output = _bincount_atomic_launch( 

516 input_contig, 

517 weights_contig, 

518 n, 

519 output_size, 

520 BLOCK_SIZE, 

521 num_warps, 

522 ) 

523 else: 

524 output_size = _compute_output_size(input_contig, n, minlength) 

525 output = _bincount_no_atomic_launch( 

526 input_contig, 

527 weights_contig, 

528 n, 

529 output_size, 

530 out_dtype, 

531 BLOCK_SIZE, 

532 num_warps, 

533 ) 

534 

535 if ( 

536 weights is not None 

537 and weights.dtype != torch.float64 

538 and weights.dtype != torch.float32 

539 ): 

540 output = output.to(dtype=weights.dtype) 

541 

542 return output