Coverage for src/flag_gems/ops/nanmedian.py: 34%

825 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2import math 

3from collections import namedtuple 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import dim_compress, libentry, tl_extra_shim 

11from flag_gems.utils import triton_lang_extension as tle 

12from flag_gems.utils.limits import get_dtype_max, get_dtype_min 

13 

14from .topk import _get_finfo_val 

15 

16logger = logging.getLogger(__name__) 

17 

18NanMedian = namedtuple("nanmedian", ["values", "indices"]) 

19INT32_MAX = torch.iinfo(torch.int32).max 

20MAX_BLOCK_N = 128 

21RADIX_BLOCK_N = 1024 

22RADIX_BITS = 2 

23MEDIUM_REDUCTION_N = 1024 

24LARGE_FLOAT_REDUCTION_N = 4096 

25LONG_RADIX_REDUCTION_N = 131072 

26ASCEND_FLAT_SORT_MIN_N = 1 << 20 

27FLAT_RADIX_BLOCK_N = 4096 

28FLAT_RADIX_BITS = 8 

29RADIX_SELECT_DTYPES = ( 

30 torch.float16, 

31 torch.bfloat16, 

32 torch.float32, 

33 torch.int8, 

34 torch.uint8, 

35 torch.int16, 

36 torch.int32, 

37) 

38ASCEND_HISTOGRAM_SELECT_DTYPES = ( 

39 torch.int8, 

40 torch.uint8, 

41) 

42ASCEND_BYTE_HISTOGRAM_SELECT_DTYPES = ( 

43 torch.int16, 

44 torch.int32, 

45) 

46ASCEND_FLOAT_SELECT_DTYPES = ( 

47 torch.float16, 

48 torch.float32, 

49) 

50ASCEND_HISTOGRAM_BINS = 256 

51ASCEND_MULTI_HISTOGRAM_MIN_N = 8192 

52ASCEND_FLAT_SORT_DTYPES = ( 

53 torch.float16, 

54 torch.float32, 

55 torch.int8, 

56 torch.uint8, 

57 torch.int16, 

58 torch.int32, 

59) 

60 

61 

62def _triton_version_at_least(major, minor): 

63 version = getattr(triton, "__version__", "0.0").split("+", 1)[0] 

64 parts = [] 

65 for token in version.split(".")[:2]: 

66 digits = [] 

67 for char in token: 

68 if not char.isdigit(): 

69 break 

70 digits.append(char) 

71 parts.append(int("".join(digits) or 0)) 

72 parts.extend([0] * (2 - len(parts))) 

73 return tuple(parts[:2]) >= (major, minor) 

74 

75 

76# Triton added tl.histogram(..., mask) in 3.4. 

77CUDA_SUPPORTS_MASKED_HISTOGRAM = _triton_version_at_least(3, 4) 

78 

79 

80@triton.jit 

81def _is_not_nan(vals, USE_ISNAN: tl.constexpr): 

82 vals_fp32 = vals.to(tl.float32) 

83 if USE_ISNAN: 

84 return ~tl_extra_shim.isnan(vals_fp32) 

85 return vals_fp32 == vals_fp32 

86 

87 

88@triton.jit 

89def _to_order_key(vals, valid): 

90 dtype = vals.dtype 

91 nbits: tl.constexpr = dtype.primitive_bitwidth 

92 utype = tl.dtype(f"uint{nbits}") 

93 top_mask: tl.constexpr = 1 << (nbits - 1) 

94 full_mask: tl.constexpr = (1 << nbits) - 1 

95 full = tl.full(vals.shape, full_mask, dtype=utype) 

96 

97 if dtype.is_floating(): 

98 bits = vals.to(utype, bitcast=True) 

99 sign_mask = tl.where((bits & top_mask) != 0, full_mask, top_mask) 

100 key = bits ^ sign_mask 

101 elif dtype.is_int_signed(): 

102 bits = vals.to(utype, bitcast=True) 

103 key = bits ^ top_mask 

104 else: 

105 key = vals.to(utype) 

106 return tl.where(valid, key, full) 

107 

108 

109@libentry() 

110@triton.jit 

111def count_valid_kernel( 

112 inp, 

113 valid_counts, 

114 M, 

115 N: tl.constexpr, 

116 BLOCK_N: tl.constexpr, 

117 USE_ISNAN: tl.constexpr, 

118): 

119 pid = tle.program_id(0) 

120 offsets = tl.arange(0, BLOCK_N) 

121 count = tl.full((), 0, dtype=tl.int32) 

122 for start in tl.range(0, N, BLOCK_N): 

123 cols = start + offsets 

124 mask = cols < N 

125 vals = tl.load(inp + pid * N + cols, mask=mask, other=float("nan")) 

126 valid = mask & _is_not_nan(vals, USE_ISNAN) 

127 count += tl.sum(valid.to(tl.int32), axis=0) 

128 tl.store(valid_counts + pid, count) 

129 

130 

131@libentry() 

132@triton.jit 

133def nanmedian_select_kernel( 

134 inp, 

135 out_values, 

136 out_indices, 

137 M, 

138 N: tl.constexpr, 

139 BLOCK_N: tl.constexpr, 

140 USE_ISNAN: tl.constexpr, 

141): 

142 pid = tle.program_id(0) 

143 offsets = tl.arange(0, BLOCK_N) 

144 mask = offsets < N 

145 dtype = inp.dtype.element_ty 

146 if dtype.is_floating(): 

147 max_value = _get_finfo_val(dtype, return_max=True) 

148 fallback_value = _get_finfo_val(dtype, return_max=False) 

149 else: 

150 max_value = get_dtype_max(dtype) 

151 fallback_value = get_dtype_min(dtype) 

152 vals = tl.load(inp + pid * N + offsets, mask=mask, other=max_value) 

153 

154 if dtype.is_floating(): 

155 valid = mask & _is_not_nan(vals, USE_ISNAN) 

156 else: 

157 valid = mask 

158 valid_count = tl.sum(valid.to(tl.int32), axis=0) 

159 median_rank = (valid_count - 1) // 2 

160 

161 active = valid 

162 median_val = tl.full((), fallback_value, dtype=vals.dtype) 

163 median_idx = tl.full((), 0, dtype=tl.int32) 

164 for select_iter in tl.static_range(0, BLOCK_N): 

165 select_vals = tl.where(active, vals, max_value) 

166 cur_val = tl.min(select_vals, axis=0) 

167 cur_idx = tl.min(tl.where(active & (vals == cur_val), offsets, BLOCK_N), axis=0) 

168 take = select_iter == median_rank 

169 median_val = tl.where(take, cur_val, median_val) 

170 median_idx = tl.where(take, cur_idx, median_idx) 

171 active = active & (offsets != cur_idx) 

172 

173 if dtype.is_floating(): 

174 all_nan = valid_count == 0 

175 median_val = tl.where(all_nan, float("nan"), median_val) 

176 median_idx = tl.where(all_nan, 0, median_idx) 

177 

178 tl.store(out_values + pid, median_val) 

179 tl.store(out_indices + pid, median_idx) 

180 

181 

182@libentry() 

183@triton.jit 

184def nanmedian_float_clean_count_kernel( 

185 inp, 

186 cleaned, 

187 valid_counts, 

188 N: tl.constexpr, 

189 BLOCK_N: tl.constexpr, 

190): 

191 pid = tle.program_id(0) 

192 offsets = tl.arange(0, BLOCK_N) 

193 dtype = inp.dtype.element_ty 

194 max_value = _get_finfo_val(dtype, return_max=True) 

195 count = tl.full((), 0, dtype=tl.int32) 

196 

197 for start in tl.range(0, N, BLOCK_N): 

198 cols = start + offsets 

199 mask = cols < N 

200 vals = tl.load(inp + pid * N + cols, mask=mask, other=max_value) 

201 valid = mask & _is_not_nan(vals, False) 

202 cleaned_vals = tl.where(valid, vals, max_value) 

203 tl.store(cleaned + pid * N + cols, cleaned_vals, mask=mask) 

204 count += tl.sum(valid.to(tl.int32), axis=0) 

205 

206 tl.store(valid_counts + pid, count) 

207 

208 

209@libentry() 

210@triton.jit 

211def nanmedian_float_sorted_gather_kernel( 

212 sorted_values, 

213 sorted_indices, 

214 valid_counts, 

215 out_values, 

216 out_indices, 

217 N: tl.constexpr, 

218): 

219 pid = tle.program_id(0) 

220 count = tl.load(valid_counts + pid) 

221 rank = tl.where(count > 0, (count - 1) // 2, 0) 

222 result_val = tl.load( 

223 sorted_values + pid * N + rank, mask=count > 0, other=float("nan") 

224 ) 

225 result_idx = tl.load(sorted_indices + pid * N + rank, mask=count > 0, other=0) 

226 result_val = tl.where(count > 0, result_val, float("nan")) 

227 result_idx = tl.where(count > 0, result_idx, 0) 

228 

229 tl.store(out_values + pid, result_val) 

230 tl.store(out_indices + pid, result_idx) 

231 

232 

233@libentry() 

234@triton.jit 

235def nanmedian_ascend_histogram_select_kernel( 

236 inp, 

237 out_values, 

238 out_indices, 

239 M, 

240 N: tl.constexpr, 

241 BLOCK_N: tl.constexpr, 

242 HISTOGRAM_BINS: tl.constexpr, 

243): 

244 pid = tle.program_id(0) 

245 offsets = tl.arange(0, BLOCK_N) 

246 bins = tl.arange(0, HISTOGRAM_BINS) 

247 counts = tl.zeros((HISTOGRAM_BINS,), dtype=tl.int32) 

248 

249 for start in tl.range(0, N, BLOCK_N): 

250 cols = start + offsets 

251 mask = cols < N 

252 vals = tl.load(inp + pid * N + cols, mask=mask, other=0) 

253 keys = _to_order_key(vals, mask).to(tl.int32) 

254 keys = tl.where(mask, keys, 0) 

255 chunk_counts = tl.histogram(keys, HISTOGRAM_BINS).to(tl.int32) 

256 invalid_count = tl.sum((~mask).to(tl.int32), axis=0) 

257 counts += chunk_counts - tl.where(bins == 0, invalid_count, 0) 

258 

259 k_to_find: tl.constexpr = (N + 1) // 2 

260 cumsum = tl.cumsum(counts, axis=0) 

261 prev = cumsum - counts 

262 take = (k_to_find <= cumsum) & (k_to_find > prev) 

263 selected_key = tl.min(tl.where(take, bins, HISTOGRAM_BINS - 1), axis=0) 

264 

265 result_idx = tl.full((), N, dtype=tl.int32) 

266 for start in tl.range(0, N, BLOCK_N): 

267 cols = start + offsets 

268 mask = cols < N 

269 vals = tl.load(inp + pid * N + cols, mask=mask, other=0) 

270 keys = _to_order_key(vals, mask).to(tl.int32) 

271 local_idx = tl.min(tl.where(mask & (keys == selected_key), cols, N), axis=0) 

272 result_idx = tl.where(local_idx < result_idx, local_idx, result_idx) 

273 

274 result_val = tl.load(inp + pid * N + result_idx) 

275 tl.store(out_values + pid, result_val) 

276 tl.store(out_indices + pid, result_idx) 

277 

278 

279@libentry() 

280@triton.jit 

281def nanmedian_ascend_histogram_count_kernel( 

282 inp, 

283 partial_counts, 

284 M, 

285 N: tl.constexpr, 

286 BLOCK_N: tl.constexpr, 

287 NUM_CHUNKS: tl.constexpr, 

288 HISTOGRAM_BINS: tl.constexpr, 

289): 

290 pid_m = tle.program_id(0) 

291 pid_chunk = tle.program_id(1) 

292 offsets = pid_chunk * BLOCK_N + tl.arange(0, BLOCK_N) 

293 bins = tl.arange(0, HISTOGRAM_BINS) 

294 mask = offsets < N 

295 vals = tl.load(inp + pid_m * N + offsets, mask=mask, other=0) 

296 keys = _to_order_key(vals, mask).to(tl.int32) 

297 keys = tl.where(mask, keys, 0) 

298 counts = tl.histogram(keys, HISTOGRAM_BINS).to(tl.int32) 

299 invalid_count = tl.sum((~mask).to(tl.int32), axis=0) 

300 counts = counts - tl.where(bins == 0, invalid_count, 0) 

301 count_offsets = (pid_m * NUM_CHUNKS + pid_chunk) * HISTOGRAM_BINS + bins 

302 tl.store(partial_counts + count_offsets, counts) 

303 

304 

305@libentry() 

306@triton.jit 

307def nanmedian_ascend_histogram_reduce_kernel( 

308 inp, 

309 partial_counts, 

310 out_values, 

311 out_indices, 

312 M, 

313 N: tl.constexpr, 

314 BLOCK_N: tl.constexpr, 

315 NUM_CHUNKS: tl.constexpr, 

316 HISTOGRAM_BINS: tl.constexpr, 

317): 

318 pid = tle.program_id(0) 

319 offsets = tl.arange(0, BLOCK_N) 

320 bins = tl.arange(0, HISTOGRAM_BINS) 

321 counts = tl.zeros((HISTOGRAM_BINS,), dtype=tl.int32) 

322 

323 for chunk in tl.range(0, NUM_CHUNKS): 

324 count_offsets = (pid * NUM_CHUNKS + chunk) * HISTOGRAM_BINS + bins 

325 counts += tl.load(partial_counts + count_offsets) 

326 

327 k_to_find: tl.constexpr = (N + 1) // 2 

328 cumsum = tl.cumsum(counts, axis=0) 

329 prev = cumsum - counts 

330 take = (k_to_find <= cumsum) & (k_to_find > prev) 

331 selected_key = tl.min(tl.where(take, bins, HISTOGRAM_BINS - 1), axis=0) 

332 

333 result_idx = tl.full((), N, dtype=tl.int32) 

334 for start in tl.range(0, N, BLOCK_N): 

335 cols = start + offsets 

336 mask = cols < N 

337 vals = tl.load(inp + pid * N + cols, mask=mask, other=0) 

338 keys = _to_order_key(vals, mask).to(tl.int32) 

339 local_idx = tl.min(tl.where(mask & (keys == selected_key), cols, N), axis=0) 

340 result_idx = tl.where(local_idx < result_idx, local_idx, result_idx) 

341 

342 result_val = tl.load(inp + pid * N + result_idx) 

343 tl.store(out_values + pid, result_val) 

344 tl.store(out_indices + pid, result_idx) 

345 

346 

347@libentry() 

348@triton.jit 

349def nanmedian_ascend_byte_histogram_select_kernel( 

350 inp, 

351 out_values, 

352 out_indices, 

353 M, 

354 N: tl.constexpr, 

355 BLOCK_N: tl.constexpr, 

356 HISTOGRAM_BINS: tl.constexpr, 

357): 

358 pid = tle.program_id(0) 

359 offsets = tl.arange(0, BLOCK_N) 

360 bins = tl.arange(0, HISTOGRAM_BINS) 

361 dtype = inp.dtype.element_ty 

362 nbits: tl.constexpr = dtype.primitive_bitwidth 

363 utype = tl.dtype(f"uint{nbits}") 

364 byte_mask_val = tl.full((), HISTOGRAM_BINS - 1, dtype=utype) 

365 

366 k_to_find = tl.full((), (N + 1) // 2, dtype=tl.int32) 

367 desired = tl.full((), 0, dtype=utype) 

368 desired_mask = tl.full((), 0, dtype=utype) 

369 

370 for digit_pos in tl.static_range(nbits - 8, -1, -8): 

371 counts = tl.zeros((HISTOGRAM_BINS,), dtype=tl.int32) 

372 

373 for start in tl.range(0, N, BLOCK_N): 

374 cols = start + offsets 

375 mask = cols < N 

376 vals = tl.load(inp + pid * N + cols, mask=mask, other=0) 

377 keys = _to_order_key(vals, mask) 

378 active = mask & ((keys & desired_mask) == desired) 

379 digit = ((keys >> digit_pos) & byte_mask_val).to(tl.int32) 

380 digit = tl.where(active, digit, 0) 

381 chunk_counts = tl.histogram(digit, HISTOGRAM_BINS).to(tl.int32) 

382 inactive_count = tl.sum((~active).to(tl.int32), axis=0) 

383 counts += chunk_counts - tl.where(bins == 0, inactive_count, 0) 

384 

385 cumsum = tl.cumsum(counts, axis=0) 

386 prev = cumsum - counts 

387 take = (k_to_find <= cumsum) & (k_to_find > prev) 

388 selected_bin = tl.min(tl.where(take, bins, HISTOGRAM_BINS - 1), axis=0) 

389 counts_before = tl.max(tl.where(take, prev, 0), axis=0) 

390 

391 selected_bin = selected_bin.to(utype) 

392 desired = desired | (selected_bin << digit_pos) 

393 desired_mask = desired_mask | (byte_mask_val << digit_pos) 

394 k_to_find = k_to_find - counts_before 

395 

396 result_idx = tl.full((), N, dtype=tl.int32) 

397 for start in tl.range(0, N, BLOCK_N): 

398 cols = start + offsets 

399 mask = cols < N 

400 vals = tl.load(inp + pid * N + cols, mask=mask, other=0) 

401 keys = _to_order_key(vals, mask) 

402 local_idx = tl.min(tl.where(mask & (keys == desired), cols, N), axis=0) 

403 result_idx = tl.where(local_idx < result_idx, local_idx, result_idx) 

404 

405 result_val = tl.load(inp + pid * N + result_idx) 

406 tl.store(out_values + pid, result_val) 

407 tl.store(out_indices + pid, result_idx) 

408 

409 

410@libentry() 

411@triton.jit 

412def nanmedian_ascend_byte_histogram_init_kernel( 

413 state, 

414 M, 

415 N: tl.constexpr, 

416): 

417 pid = tle.program_id(0) 

418 base = pid * 3 

419 tl.store(state + base + 0, 0) 

420 tl.store(state + base + 1, 0) 

421 tl.store(state + base + 2, (N + 1) // 2) 

422 

423 

424@libentry() 

425@triton.jit 

426def nanmedian_ascend_byte_histogram_count_kernel( 

427 inp, 

428 state, 

429 partial_counts, 

430 M, 

431 N: tl.constexpr, 

432 BLOCK_N: tl.constexpr, 

433 NUM_CHUNKS: tl.constexpr, 

434 HISTOGRAM_BINS: tl.constexpr, 

435 DIGIT_POS: tl.constexpr, 

436): 

437 pid_m = tle.program_id(0) 

438 pid_chunk = tle.program_id(1) 

439 offsets = pid_chunk * BLOCK_N + tl.arange(0, BLOCK_N) 

440 bins = tl.arange(0, HISTOGRAM_BINS) 

441 mask = offsets < N 

442 

443 dtype = inp.dtype.element_ty 

444 nbits: tl.constexpr = dtype.primitive_bitwidth 

445 utype = tl.dtype(f"uint{nbits}") 

446 byte_mask_val = tl.full((), HISTOGRAM_BINS - 1, dtype=utype) 

447 state_base = pid_m * 3 

448 desired = tl.load(state + state_base + 0).to(utype) 

449 desired_mask = tl.load(state + state_base + 1).to(utype) 

450 

451 vals = tl.load(inp + pid_m * N + offsets, mask=mask, other=0) 

452 keys = _to_order_key(vals, mask) 

453 active = mask & ((keys & desired_mask) == desired) 

454 digit = ((keys >> DIGIT_POS) & byte_mask_val).to(tl.int32) 

455 digit = tl.where(active, digit, 0) 

456 counts = tl.histogram(digit, HISTOGRAM_BINS).to(tl.int32) 

457 inactive_count = tl.sum((~active).to(tl.int32), axis=0) 

458 counts = counts - tl.where(bins == 0, inactive_count, 0) 

459 

460 count_offsets = (pid_m * NUM_CHUNKS + pid_chunk) * HISTOGRAM_BINS + bins 

461 tl.store(partial_counts + count_offsets, counts) 

462 

463 

464@libentry() 

465@triton.jit 

466def nanmedian_ascend_byte_histogram_update_kernel( 

467 inp, 

468 partial_counts, 

469 state, 

470 M, 

471 NUM_CHUNKS: tl.constexpr, 

472 HISTOGRAM_BINS: tl.constexpr, 

473 DIGIT_POS: tl.constexpr, 

474): 

475 pid = tle.program_id(0) 

476 bins = tl.arange(0, HISTOGRAM_BINS) 

477 counts = tl.zeros((HISTOGRAM_BINS,), dtype=tl.int32) 

478 

479 for chunk in tl.range(0, NUM_CHUNKS): 

480 count_offsets = (pid * NUM_CHUNKS + chunk) * HISTOGRAM_BINS + bins 

481 counts += tl.load(partial_counts + count_offsets) 

482 

483 state_base = pid * 3 

484 k_to_find = tl.load(state + state_base + 2).to(tl.int32) 

485 cumsum = tl.cumsum(counts, axis=0) 

486 prev = cumsum - counts 

487 take = (k_to_find <= cumsum) & (k_to_find > prev) 

488 selected_bin = tl.min(tl.where(take, bins, HISTOGRAM_BINS - 1), axis=0) 

489 counts_before = tl.max(tl.where(take, prev, 0), axis=0) 

490 

491 dtype = inp.dtype.element_ty 

492 nbits: tl.constexpr = dtype.primitive_bitwidth 

493 utype = tl.dtype(f"uint{nbits}") 

494 byte_mask_val = tl.full((), HISTOGRAM_BINS - 1, dtype=utype) 

495 desired = tl.load(state + state_base + 0).to(utype) 

496 desired_mask = tl.load(state + state_base + 1).to(utype) 

497 selected_bin = selected_bin.to(utype) 

498 

499 desired = desired | (selected_bin << DIGIT_POS) 

500 desired_mask = desired_mask | (byte_mask_val << DIGIT_POS) 

501 tl.store(state + state_base + 0, desired) 

502 tl.store(state + state_base + 1, desired_mask) 

503 tl.store(state + state_base + 2, k_to_find - counts_before) 

504 

505 

506@libentry() 

507@triton.jit 

508def nanmedian_ascend_byte_histogram_find_index_kernel( 

509 inp, 

510 state, 

511 out_values, 

512 out_indices, 

513 M, 

514 N: tl.constexpr, 

515 BLOCK_N: tl.constexpr, 

516): 

517 pid = tle.program_id(0) 

518 offsets = tl.arange(0, BLOCK_N) 

519 dtype = inp.dtype.element_ty 

520 nbits: tl.constexpr = dtype.primitive_bitwidth 

521 utype = tl.dtype(f"uint{nbits}") 

522 desired = tl.load(state + pid * 3 + 0).to(utype) 

523 

524 result_idx = tl.full((), N, dtype=tl.int32) 

525 for start in tl.range(0, N, BLOCK_N): 

526 cols = start + offsets 

527 mask = cols < N 

528 vals = tl.load(inp + pid * N + cols, mask=mask, other=0) 

529 keys = _to_order_key(vals, mask) 

530 local_idx = tl.min(tl.where(mask & (keys == desired), cols, N), axis=0) 

531 result_idx = tl.where(local_idx < result_idx, local_idx, result_idx) 

532 

533 result_val = tl.load(inp + pid * N + result_idx) 

534 tl.store(out_values + pid, result_val) 

535 tl.store(out_indices + pid, result_idx) 

536 

537 

538@libentry() 

539@triton.jit 

540def nanmedian_radix_select_kernel( 

541 inp, 

542 out_values, 

543 out_indices, 

544 M, 

545 N: tl.constexpr, 

546 BLOCK_N: tl.constexpr, 

547 RADIX_BITS_: tl.constexpr, 

548 USE_ISNAN: tl.constexpr, 

549 USE_HISTOGRAM: tl.constexpr, 

550): 

551 pid = tle.program_id(0) 

552 offsets = tl.arange(0, BLOCK_N) 

553 dtype = inp.dtype.element_ty 

554 nbits: tl.constexpr = dtype.primitive_bitwidth 

555 utype = tl.dtype(f"uint{nbits}") 

556 radix_size: tl.constexpr = 1 << RADIX_BITS_ 

557 radix_mask: tl.constexpr = radix_size - 1 

558 radix_bins = tl.arange(0, radix_size) 

559 

560 valid_count = tl.full((), 0, dtype=tl.int32) 

561 for start in tl.range(0, N, BLOCK_N): 

562 cols = start + offsets 

563 mask = cols < N 

564 vals = tl.load(inp + pid * N + cols, mask=mask, other=0.0) 

565 if dtype.is_floating(): 

566 valid = mask & _is_not_nan(vals, USE_ISNAN) 

567 else: 

568 valid = mask 

569 valid_count += tl.sum(valid.to(tl.int32), axis=0) 

570 

571 k_to_find = (valid_count + 1) // 2 

572 desired = tl.full((), 0, dtype=utype) 

573 desired_mask = tl.full((), 0, dtype=utype) 

574 radix_mask_val = tl.full((), radix_mask, dtype=utype) 

575 

576 for digit_pos in tl.static_range(nbits - RADIX_BITS_, -1, -RADIX_BITS_): 

577 counts = tl.zeros((radix_size,), dtype=tl.int32) 

578 for start in tl.range(0, N, BLOCK_N): 

579 cols = start + offsets 

580 mask = cols < N 

581 vals = tl.load(inp + pid * N + cols, mask=mask, other=0.0) 

582 if dtype.is_floating(): 

583 valid = mask & _is_not_nan(vals, USE_ISNAN) 

584 else: 

585 valid = mask 

586 keys = _to_order_key(vals, valid) 

587 matches = (keys & desired_mask) == desired 

588 digit = ((keys >> digit_pos) & radix_mask_val).to(tl.int32) 

589 active = valid & matches 

590 if USE_HISTOGRAM: 

591 counts += tl.histogram(digit, radix_size, active) 

592 else: 

593 for radix_bin in tl.static_range(0, radix_size): 

594 bin_count = tl.sum( 

595 (active & (digit == radix_bin)).to(tl.int32), axis=0 

596 ) 

597 counts += tl.where(radix_bins == radix_bin, bin_count, 0) 

598 

599 cumsum = tl.cumsum(counts, axis=0) 

600 prev = cumsum - counts 

601 take = (cumsum >= k_to_find) & (prev < k_to_find) 

602 selected_bin = tl.min(tl.where(take, radix_bins, radix_size - 1), axis=0) 

603 counts_before = tl.max(tl.where(take, prev, 0), axis=0) 

604 

605 selected_bin = selected_bin.to(utype) 

606 desired = desired | (selected_bin << digit_pos) 

607 desired_mask = desired_mask | (radix_mask_val << digit_pos) 

608 k_to_find = k_to_find - counts_before 

609 

610 result_idx = tl.full((), N, dtype=tl.int32) 

611 for start in tl.range(0, N, BLOCK_N): 

612 cols = start + offsets 

613 mask = cols < N 

614 vals = tl.load(inp + pid * N + cols, mask=mask, other=0.0) 

615 if dtype.is_floating(): 

616 valid = mask & _is_not_nan(vals, USE_ISNAN) 

617 else: 

618 valid = mask 

619 keys = _to_order_key(vals, valid) 

620 local_idx = tl.min(tl.where(valid & (keys == desired), cols, N), axis=0) 

621 result_idx = tl.where(local_idx < result_idx, local_idx, result_idx) 

622 

623 if dtype.is_floating(): 

624 fallback_value = _get_finfo_val(dtype, return_max=False) 

625 else: 

626 fallback_value = get_dtype_min(dtype) 

627 result_val = tl.load( 

628 inp + pid * N + result_idx, mask=valid_count > 0, other=fallback_value 

629 ) 

630 

631 if dtype.is_floating(): 

632 all_nan = valid_count == 0 

633 result_val = tl.where(all_nan, float("nan"), result_val) 

634 result_idx = tl.where(all_nan, 0, result_idx) 

635 

636 tl.store(out_values + pid, result_val) 

637 tl.store(out_indices + pid, result_idx) 

638 

639 

640@libentry() 

641@triton.jit 

642def flat_radix_init_kernel( 

643 valid_count, 

644 state, 

645 result_idx, 

646 N: tl.constexpr, 

647 IS_FLOAT: tl.constexpr, 

648): 

649 tl.store(valid_count, 0 if IS_FLOAT else N) 

650 tl.store(state + 0, 0) 

651 tl.store(state + 1, 0) 

652 tl.store(state + 2, 0) 

653 tl.store(result_idx, N) 

654 

655 

656@libentry() 

657@triton.jit 

658def flat_radix_count_valid_kernel( 

659 inp, 

660 valid_count, 

661 N: tl.constexpr, 

662 BLOCK_N: tl.constexpr, 

663 USE_ISNAN: tl.constexpr, 

664): 

665 pid = tle.program_id(0) 

666 offsets = pid * BLOCK_N + tl.arange(0, BLOCK_N) 

667 mask = offsets < N 

668 vals = tl.load(inp + offsets, mask=mask, other=0.0) 

669 valid = mask & _is_not_nan(vals, USE_ISNAN) 

670 count = tl.sum(valid.to(tl.int64), axis=0) 

671 tl.atomic_add(valid_count, count, sem="relaxed") 

672 

673 

674@libentry() 

675@triton.jit 

676def flat_radix_init_rank_kernel(valid_count, state): 

677 count = tl.load(valid_count) 

678 tl.store(state + 2, (count + 1) // 2) 

679 

680 

681@libentry() 

682@triton.jit 

683def flat_radix_count_kernel( 

684 inp, 

685 bin_counts, 

686 state, 

687 N: tl.constexpr, 

688 BLOCK_N: tl.constexpr, 

689 DIGIT_POS: tl.constexpr, 

690 RADIX_BITS_: tl.constexpr, 

691 RADIX_SIZE: tl.constexpr, 

692 USE_ISNAN: tl.constexpr, 

693 USE_HISTOGRAM: tl.constexpr, 

694): 

695 pid = tle.program_id(0) 

696 offsets = pid * BLOCK_N + tl.arange(0, BLOCK_N) 

697 mask = offsets < N 

698 vals = tl.load(inp + offsets, mask=mask, other=0.0) 

699 dtype = inp.dtype.element_ty 

700 nbits: tl.constexpr = dtype.primitive_bitwidth 

701 utype = tl.dtype(f"uint{nbits}") 

702 radix_mask: tl.constexpr = (1 << RADIX_BITS_) - 1 

703 radix_mask_val = tl.full((), radix_mask, dtype=utype) 

704 

705 if dtype.is_floating(): 

706 valid = mask & _is_not_nan(vals, USE_ISNAN) 

707 else: 

708 valid = mask 

709 

710 desired = tl.load(state + 0).to(utype) 

711 desired_mask = tl.load(state + 1).to(utype) 

712 keys = _to_order_key(vals, valid) 

713 active = valid & ((keys & desired_mask) == desired) 

714 digit = ((keys >> DIGIT_POS) & radix_mask_val).to(tl.int32) 

715 bins = tl.arange(0, RADIX_SIZE) 

716 counts = tl.zeros((RADIX_SIZE,), dtype=tl.int64) 

717 if USE_HISTOGRAM: 

718 counts = tl.histogram(digit, RADIX_SIZE, active).to(tl.int64) 

719 else: 

720 for radix_bin in tl.static_range(0, RADIX_SIZE): 

721 bin_count = tl.sum((active & (digit == radix_bin)).to(tl.int64), axis=0) 

722 counts += tl.where(bins == radix_bin, bin_count, 0) 

723 tl.atomic_add(bin_counts + bins, counts, sem="relaxed") 

724 

725 

726@libentry() 

727@triton.jit 

728def flat_radix_update_kernel( 

729 bin_counts, 

730 state, 

731 DIGIT_POS: tl.constexpr, 

732 RADIX_BITS_: tl.constexpr, 

733 RADIX_SIZE: tl.constexpr, 

734): 

735 bins = tl.arange(0, RADIX_SIZE) 

736 counts = tl.load(bin_counts + bins) 

737 k_to_find = tl.load(state + 2) 

738 cumsum = tl.cumsum(counts, axis=0) 

739 prev = cumsum - counts 

740 take = (k_to_find <= cumsum) & (k_to_find > prev) 

741 selected_bin = tl.min(tl.where(take, bins, RADIX_SIZE - 1), axis=0).to(tl.int64) 

742 counts_before = tl.max(tl.where(take, prev, 0), axis=0) 

743 

744 desired = tl.load(state + 0) 

745 desired_mask = tl.load(state + 1) 

746 radix_mask: tl.constexpr = (1 << RADIX_BITS_) - 1 

747 desired = desired | (selected_bin << DIGIT_POS) 

748 desired_mask = desired_mask | (radix_mask << DIGIT_POS) 

749 tl.store(state + 0, desired) 

750 tl.store(state + 1, desired_mask) 

751 tl.store(state + 2, k_to_find - counts_before) 

752 

753 

754@libentry() 

755@triton.jit 

756def flat_radix_find_index_kernel( 

757 inp, 

758 state, 

759 valid_count, 

760 result_idx, 

761 N: tl.constexpr, 

762 BLOCK_N: tl.constexpr, 

763 USE_ISNAN: tl.constexpr, 

764): 

765 if tl.load(valid_count) > 0: 

766 pid = tle.program_id(0) 

767 offsets = pid * BLOCK_N + tl.arange(0, BLOCK_N) 

768 mask = offsets < N 

769 vals = tl.load(inp + offsets, mask=mask, other=0.0) 

770 dtype = inp.dtype.element_ty 

771 nbits: tl.constexpr = dtype.primitive_bitwidth 

772 utype = tl.dtype(f"uint{nbits}") 

773 

774 if dtype.is_floating(): 

775 valid = mask & _is_not_nan(vals, USE_ISNAN) 

776 else: 

777 valid = mask 

778 

779 desired = tl.load(state + 0).to(utype) 

780 keys = _to_order_key(vals, valid) 

781 local_idx = tl.min(tl.where(valid & (keys == desired), offsets, N), axis=0) 

782 tl.atomic_min(result_idx, local_idx, sem="relaxed") 

783 

784 

785@libentry() 

786@triton.jit 

787def flat_radix_store_result_kernel(inp, out, valid_count, result_idx): 

788 dtype = inp.dtype.element_ty 

789 idx = tl.load(result_idx) 

790 if dtype.is_floating(): 

791 result = tl.load(inp + idx, mask=tl.load(valid_count) > 0, other=float("nan")) 

792 else: 

793 result = tl.load(inp + idx) 

794 tl.store(out, result) 

795 

796 

797def _check_supported_dtype(inp): 

798 if inp.dtype is torch.bool: 

799 raise NotImplementedError("\"median_out_impl\" not implemented for 'Bool'") 

800 

801 

802def _normalize_dim(dim, ndim): 

803 if ndim == 0: 

804 if dim in (0, -1): 

805 return 0 

806 elif -ndim <= dim < ndim: 

807 return dim % ndim 

808 raise IndexError( 

809 f"Dimension out of range (expected to be in range of [{-ndim}, {ndim - 1}], but got {dim})" 

810 ) 

811 

812 

813def _empty_flat_value(inp): 

814 out = torch.empty((), dtype=inp.dtype, device=inp.device) 

815 if torch.is_floating_point(inp): 

816 out.fill_(float("nan")) 

817 elif inp.is_cuda: 

818 out.fill_(torch.iinfo(inp.dtype).min) 

819 else: 

820 out.zero_() 

821 return out 

822 

823 

824def _radix_block_n(inp, n): 

825 block_n = triton.next_power_of_2(n) 

826 if inp.is_cuda: 

827 if n > LARGE_FLOAT_REDUCTION_N: 

828 return min(block_n, 8192) 

829 if n > MEDIUM_REDUCTION_N: 

830 return min(block_n, 4096) 

831 if inp.dtype is torch.uint8: 

832 return min(block_n, 512) 

833 return min(block_n, RADIX_BLOCK_N) 

834 if inp.dtype in (torch.float16, torch.bfloat16): 

835 if n > LARGE_FLOAT_REDUCTION_N: 

836 return 2048 

837 return min(block_n, 2048) 

838 if inp.dtype is torch.float32 or inp.dtype is torch.int32: 

839 if n > MEDIUM_REDUCTION_N: 

840 return 512 

841 return min(block_n, RADIX_BLOCK_N) 

842 if inp.dtype in (torch.int8, torch.uint8): 

843 if n > MEDIUM_REDUCTION_N: 

844 return RADIX_BLOCK_N 

845 return min(block_n, 512) 

846 return min(block_n, RADIX_BLOCK_N) 

847 

848 

849def _radix_bits(inp, n): 

850 if inp.is_cuda: 

851 if n > LARGE_FLOAT_REDUCTION_N: 

852 return 8 

853 if n > MEDIUM_REDUCTION_N: 

854 return 4 

855 return RADIX_BITS 

856 

857 

858def _full_nan_result(shape, dtype, device): 

859 values = torch.full(shape, float("nan"), dtype=dtype, device=device) 

860 indices = torch.zeros(shape, dtype=torch.long, device=device) 

861 return NanMedian(values=values, indices=indices) 

862 

863 

864def _count_block_n(inp, n): 

865 block_n = triton.next_power_of_2(n) 

866 if inp.is_cuda and n >= LONG_RADIX_REDUCTION_N: 

867 return min(block_n, 16384) 

868 if n >= LONG_RADIX_REDUCTION_N: 

869 return min(block_n, 4096) 

870 if n >= LARGE_FLOAT_REDUCTION_N: 

871 return min(block_n, 2048) 

872 return min(block_n, RADIX_BLOCK_N) 

873 

874 

875def _nanmedian_kthvalue_fallback(inp, M, N): 

876 inp = inp.reshape(M, N) 

877 if torch.is_floating_point(inp): 

878 valid_count = torch.empty((M,), dtype=torch.long, device=inp.device) 

879 block_n = _count_block_n(inp, N) 

880 with torch_device_fn.device(inp.device): 

881 count_valid_kernel[(M,)](inp, valid_count, M, N, block_n, inp.is_cuda) 

882 min_count = int(torch.min(valid_count).item()) 

883 max_count = int(torch.max(valid_count).item()) 

884 if min_count == max_count: 

885 if max_count == 0: 

886 return _full_nan_result((M,), inp.dtype, inp.device) 

887 values, indices = torch.kthvalue(inp, (max_count + 1) // 2, dim=1) 

888 return NanMedian(values=values, indices=indices) 

889 

890 if max_count - min_count <= 1: 

891 min_k = (min_count + 1) // 2 if min_count > 0 else 0 

892 max_k = (max_count + 1) // 2 

893 

894 if min_k == max_k: 

895 values, indices = torch.kthvalue(inp, max_k, dim=1) 

896 if min_count > 0: 

897 return NanMedian(values=values, indices=indices) 

898 fallback = _full_nan_result((M,), inp.dtype, inp.device) 

899 positive = valid_count > 0 

900 return NanMedian( 

901 values=torch.where(positive, values, fallback.values), 

902 indices=torch.where(positive, indices, fallback.indices), 

903 ) 

904 

905 result = _full_nan_result((M,), inp.dtype, inp.device) 

906 

907 if min_count > 0: 

908 values, indices = torch.kthvalue(inp, min_k, dim=1) 

909 mask = valid_count == min_count 

910 result = NanMedian( 

911 values=torch.where(mask, values, result.values), 

912 indices=torch.where(mask, indices, result.indices), 

913 ) 

914 

915 values, indices = torch.kthvalue(inp, max_k, dim=1) 

916 mask = valid_count == max_count 

917 return NanMedian( 

918 values=torch.where(mask, values, result.values), 

919 indices=torch.where(mask, indices, result.indices), 

920 ) 

921 

922 result = _full_nan_result((M,), inp.dtype, inp.device) 

923 for count in torch.unique(valid_count).tolist(): 

924 count = int(count) 

925 if count == 0: 

926 continue 

927 row_indices = torch.nonzero(valid_count == count).flatten() 

928 rows = torch.index_select(inp, 0, row_indices) 

929 values, indices = torch.kthvalue(rows, (count + 1) // 2, dim=1) 

930 result.values[row_indices] = values 

931 result.indices[row_indices] = indices 

932 return result 

933 else: 

934 if inp.device.type == "npu" and inp.dtype in (torch.int32, torch.int64): 

935 sorted_values, sorted_indices = torch.sort(inp, dim=1) 

936 kth = (N + 1) // 2 - 1 

937 values = sorted_values[:, kth] 

938 indices = sorted_indices[:, kth] 

939 return NanMedian(values=values, indices=indices) 

940 values, indices = torch.kthvalue(inp, (N + 1) // 2, dim=1) 

941 return NanMedian(values=values, indices=indices) 

942 

943 

944def _nanmedian_ascend_float_sort_select(inp, M, N, values, indices): 

945 inp = inp.reshape(M, N) 

946 flat_values = values.reshape(M) 

947 flat_indices = indices.reshape(M) 

948 if N <= LARGE_FLOAT_REDUCTION_N: 

949 cleaned = torch.empty_like(inp) 

950 valid_counts = torch.empty((M,), dtype=torch.int32, device=inp.device) 

951 block_n = min(triton.next_power_of_2(N), RADIX_BLOCK_N) 

952 num_warps = 4 if block_n <= 512 else 8 

953 with torch_device_fn.device(inp.device): 

954 nanmedian_float_clean_count_kernel[(M,)]( 

955 inp, 

956 cleaned, 

957 valid_counts, 

958 N, 

959 block_n, 

960 num_warps=num_warps, 

961 num_stages=1, 

962 ) 

963 sorted_values, sorted_indices = torch.sort(cleaned, dim=1) 

964 else: 

965 sorted_values, sorted_indices = torch.sort(inp, dim=1) 

966 valid_counts = torch.sum( 

967 (sorted_values == sorted_values).to(torch.int32), dim=1 

968 ) 

969 

970 with torch_device_fn.device(inp.device): 

971 nanmedian_float_sorted_gather_kernel[(M,)]( 

972 sorted_values, 

973 sorted_indices, 

974 valid_counts, 

975 flat_values, 

976 flat_indices, 

977 N, 

978 num_warps=1, 

979 num_stages=1, 

980 ) 

981 

982 

983def _nanmedian_dim_impl(inp, dim, keepdim, out=None, use_ascend_float_select=True): 

984 dim = _normalize_dim(dim, inp.ndim) 

985 

986 if inp.ndim == 0: 

987 if out is None: 

988 values = inp.clone() 

989 indices = torch.zeros((), dtype=torch.long, device=inp.device) 

990 else: 

991 values, indices = out 

992 values.copy_(inp) 

993 indices.zero_() 

994 return NanMedian(values=values, indices=indices) 

995 

996 shape = list(inp.shape) 

997 N = shape[dim] 

998 out_shape = shape[:dim] + shape[dim + 1 :] 

999 M = math.prod(out_shape) 

1000 

1001 keepdim_shape = shape.copy() 

1002 keepdim_shape[dim] = 1 

1003 output_shape = keepdim_shape if keepdim else out_shape 

1004 compute_shape = output_shape if out is not None else keepdim_shape 

1005 

1006 if N == 0: 

1007 if M != 0: 

1008 raise IndexError( 

1009 f"median(): Expected reduction dim {dim} to have non-zero size." 

1010 ) 

1011 if out is None: 

1012 values = torch.empty(compute_shape, dtype=inp.dtype, device=inp.device) 

1013 indices = torch.empty(compute_shape, dtype=torch.long, device=inp.device) 

1014 if not keepdim: 

1015 values = torch.squeeze(values, dim) 

1016 indices = torch.squeeze(indices, dim) 

1017 else: 

1018 values, indices = out 

1019 return NanMedian(values=values, indices=indices) 

1020 

1021 if out is None: 

1022 values = torch.empty(compute_shape, dtype=inp.dtype, device=inp.device) 

1023 indices = torch.empty(compute_shape, dtype=torch.long, device=inp.device) 

1024 else: 

1025 values, indices = out 

1026 

1027 if M == 0: 

1028 if out is None and not keepdim: 

1029 values = torch.squeeze(values, dim) 

1030 indices = torch.squeeze(indices, dim) 

1031 return NanMedian(values=values, indices=indices) 

1032 

1033 inp = dim_compress(inp, dim) 

1034 is_cuda = inp.is_cuda 

1035 is_ascend = inp.device.type == "npu" 

1036 in_radix_range = MAX_BLOCK_N < N <= LONG_RADIX_REDUCTION_N 

1037 use_cuda_histogram = ( 

1038 is_cuda 

1039 and CUDA_SUPPORTS_MASKED_HISTOGRAM 

1040 and N > MAX_BLOCK_N 

1041 and N == triton.next_power_of_2(N) 

1042 ) 

1043 use_ascend_float_select_path = ( 

1044 use_ascend_float_select 

1045 and is_ascend 

1046 and inp.dtype in ASCEND_FLOAT_SELECT_DTYPES 

1047 and in_radix_range 

1048 ) 

1049 use_ascend_histogram = ( 

1050 is_ascend and inp.dtype in ASCEND_HISTOGRAM_SELECT_DTYPES and in_radix_range 

1051 ) 

1052 use_ascend_byte_histogram = ( 

1053 is_ascend 

1054 and inp.dtype in ASCEND_BYTE_HISTOGRAM_SELECT_DTYPES 

1055 and in_radix_range 

1056 ) 

1057 

1058 if is_cuda and inp.dtype in RADIX_SELECT_DTYPES and in_radix_range: 

1059 flat_values = values.reshape(M) 

1060 flat_indices = indices.reshape(M) 

1061 block_n = _radix_block_n(inp, N) 

1062 num_warps = 4 if block_n <= 512 else 8 

1063 with torch_device_fn.device(inp.device): 

1064 nanmedian_radix_select_kernel[(M,)]( 

1065 inp, 

1066 flat_values, 

1067 flat_indices, 

1068 M, 

1069 N, 

1070 block_n, 

1071 _radix_bits(inp, N) if use_cuda_histogram else RADIX_BITS, 

1072 is_cuda, 

1073 use_cuda_histogram, 

1074 num_warps=num_warps, 

1075 num_stages=1, 

1076 ) 

1077 elif use_ascend_float_select_path: 

1078 _nanmedian_ascend_float_sort_select(inp, M, N, values, indices) 

1079 elif use_ascend_histogram and N >= ASCEND_MULTI_HISTOGRAM_MIN_N: 

1080 flat_values = values.reshape(M) 

1081 flat_indices = indices.reshape(M) 

1082 block_n = _radix_block_n(inp, N) 

1083 num_chunks = triton.cdiv(N, block_n) 

1084 partial_counts = torch.empty( 

1085 (M, num_chunks, ASCEND_HISTOGRAM_BINS), 

1086 dtype=torch.int32, 

1087 device=inp.device, 

1088 ) 

1089 num_warps = 4 if block_n <= 512 else 8 

1090 with torch_device_fn.device(inp.device): 

1091 nanmedian_ascend_histogram_count_kernel[(M, num_chunks)]( 

1092 inp, 

1093 partial_counts, 

1094 M, 

1095 N, 

1096 block_n, 

1097 num_chunks, 

1098 ASCEND_HISTOGRAM_BINS, 

1099 num_warps=num_warps, 

1100 num_stages=1, 

1101 ) 

1102 nanmedian_ascend_histogram_reduce_kernel[(M,)]( 

1103 inp, 

1104 partial_counts, 

1105 flat_values, 

1106 flat_indices, 

1107 M, 

1108 N, 

1109 block_n, 

1110 num_chunks, 

1111 ASCEND_HISTOGRAM_BINS, 

1112 num_warps=num_warps, 

1113 num_stages=1, 

1114 ) 

1115 elif use_ascend_histogram: 

1116 flat_values = values.reshape(M) 

1117 flat_indices = indices.reshape(M) 

1118 block_n = _radix_block_n(inp, N) 

1119 num_warps = 4 if block_n <= 512 else 8 

1120 with torch_device_fn.device(inp.device): 

1121 nanmedian_ascend_histogram_select_kernel[(M,)]( 

1122 inp, 

1123 flat_values, 

1124 flat_indices, 

1125 M, 

1126 N, 

1127 block_n, 

1128 ASCEND_HISTOGRAM_BINS, 

1129 num_warps=num_warps, 

1130 num_stages=1, 

1131 ) 

1132 elif use_ascend_byte_histogram and N >= ASCEND_MULTI_HISTOGRAM_MIN_N: 

1133 flat_values = values.reshape(M) 

1134 flat_indices = indices.reshape(M) 

1135 block_n = _radix_block_n(inp, N) 

1136 num_chunks = triton.cdiv(N, block_n) 

1137 partial_counts = torch.empty( 

1138 (M, num_chunks, ASCEND_HISTOGRAM_BINS), 

1139 dtype=torch.int32, 

1140 device=inp.device, 

1141 ) 

1142 state = torch.empty((M, 3), dtype=torch.int64, device=inp.device) 

1143 num_warps = 4 if block_n <= 512 else 8 

1144 nbits = inp.element_size() * 8 

1145 with torch_device_fn.device(inp.device): 

1146 nanmedian_ascend_byte_histogram_init_kernel[(M,)]( 

1147 state, 

1148 M, 

1149 N, 

1150 num_warps=1, 

1151 num_stages=1, 

1152 ) 

1153 for digit_pos in range(nbits - 8, -1, -8): 

1154 nanmedian_ascend_byte_histogram_count_kernel[(M, num_chunks)]( 

1155 inp, 

1156 state, 

1157 partial_counts, 

1158 M, 

1159 N, 

1160 block_n, 

1161 num_chunks, 

1162 ASCEND_HISTOGRAM_BINS, 

1163 digit_pos, 

1164 num_warps=num_warps, 

1165 num_stages=1, 

1166 ) 

1167 nanmedian_ascend_byte_histogram_update_kernel[(M,)]( 

1168 inp, 

1169 partial_counts, 

1170 state, 

1171 M, 

1172 num_chunks, 

1173 ASCEND_HISTOGRAM_BINS, 

1174 digit_pos, 

1175 num_warps=num_warps, 

1176 num_stages=1, 

1177 ) 

1178 nanmedian_ascend_byte_histogram_find_index_kernel[(M,)]( 

1179 inp, 

1180 state, 

1181 flat_values, 

1182 flat_indices, 

1183 M, 

1184 N, 

1185 block_n, 

1186 num_warps=num_warps, 

1187 num_stages=1, 

1188 ) 

1189 elif use_ascend_byte_histogram: 

1190 flat_values = values.reshape(M) 

1191 flat_indices = indices.reshape(M) 

1192 block_n = _radix_block_n(inp, N) 

1193 num_warps = 4 if block_n <= 512 else 8 

1194 with torch_device_fn.device(inp.device): 

1195 nanmedian_ascend_byte_histogram_select_kernel[(M,)]( 

1196 inp, 

1197 flat_values, 

1198 flat_indices, 

1199 M, 

1200 N, 

1201 block_n, 

1202 ASCEND_HISTOGRAM_BINS, 

1203 num_warps=num_warps, 

1204 num_stages=1, 

1205 ) 

1206 elif N <= MAX_BLOCK_N and inp.dtype is not torch.float64: 

1207 flat_values = values.reshape(M) 

1208 flat_indices = indices.reshape(M) 

1209 block_n = triton.next_power_of_2(N) 

1210 with torch_device_fn.device(inp.device): 

1211 nanmedian_select_kernel[(M,)]( 

1212 inp, 

1213 flat_values, 

1214 flat_indices, 

1215 M, 

1216 N, 

1217 block_n, 

1218 is_cuda, 

1219 ) 

1220 else: 

1221 result = _nanmedian_kthvalue_fallback(inp, M, N) 

1222 computed_values = result.values.reshape(compute_shape) 

1223 computed_indices = result.indices.reshape(compute_shape) 

1224 if out is None: 

1225 values = computed_values 

1226 indices = computed_indices 

1227 else: 

1228 values.copy_(computed_values) 

1229 indices.copy_(computed_indices) 

1230 

1231 if out is None and not keepdim: 

1232 values = torch.squeeze(values, dim) 

1233 indices = torch.squeeze(indices, dim) 

1234 

1235 return NanMedian(values=values, indices=indices) 

1236 

1237 

1238def _nanmedian_ascend_flat_sort(inp): 

1239 flat = inp.reshape(-1).contiguous() 

1240 sorted_values = torch.sort(flat).values 

1241 if torch.is_floating_point(flat): 

1242 valid_count = (sorted_values == sorted_values).sum() 

1243 rank = (valid_count - 1) // 2 

1244 else: 

1245 rank = (flat.numel() - 1) // 2 

1246 return sorted_values[rank] 

1247 

1248 

1249def _nanmedian_cuda_flat_radix_select(inp, out=None): 

1250 flat = inp.reshape(-1).contiguous() 

1251 n = flat.numel() 

1252 if out is None: 

1253 out = torch.empty((), dtype=flat.dtype, device=flat.device) 

1254 valid_count = torch.empty((), dtype=torch.int64, device=flat.device) 

1255 state = torch.empty((3,), dtype=torch.int64, device=flat.device) 

1256 result_idx = torch.empty((), dtype=torch.int64, device=flat.device) 

1257 block_n = min(triton.next_power_of_2(n), FLAT_RADIX_BLOCK_N) 

1258 grid = (triton.cdiv(n, block_n),) 

1259 nbits = flat.element_size() * 8 

1260 use_histogram = CUDA_SUPPORTS_MASKED_HISTOGRAM and n % block_n == 0 

1261 radix_bits = FLAT_RADIX_BITS if use_histogram else RADIX_BITS 

1262 radix_size = 1 << radix_bits 

1263 bin_counts = torch.empty((radix_size,), dtype=torch.int64, device=flat.device) 

1264 

1265 with torch_device_fn.device(flat.device): 

1266 flat_radix_init_kernel[(1,)]( 

1267 valid_count, 

1268 state, 

1269 result_idx, 

1270 n, 

1271 torch.is_floating_point(flat), 

1272 ) 

1273 if torch.is_floating_point(flat): 

1274 flat_radix_count_valid_kernel[grid]( 

1275 flat, 

1276 valid_count, 

1277 n, 

1278 block_n, 

1279 True, 

1280 num_warps=8, 

1281 num_stages=1, 

1282 ) 

1283 flat_radix_init_rank_kernel[(1,)](valid_count, state) 

1284 for digit_pos in range(nbits - radix_bits, -1, -radix_bits): 

1285 bin_counts.zero_() 

1286 flat_radix_count_kernel[grid]( 

1287 flat, 

1288 bin_counts, 

1289 state, 

1290 n, 

1291 block_n, 

1292 digit_pos, 

1293 radix_bits, 

1294 radix_size, 

1295 True, 

1296 use_histogram, 

1297 num_warps=8, 

1298 num_stages=1, 

1299 ) 

1300 flat_radix_update_kernel[(1,)]( 

1301 bin_counts, 

1302 state, 

1303 digit_pos, 

1304 radix_bits, 

1305 radix_size, 

1306 num_warps=8, 

1307 num_stages=1, 

1308 ) 

1309 flat_radix_find_index_kernel[grid]( 

1310 flat, 

1311 state, 

1312 valid_count, 

1313 result_idx, 

1314 n, 

1315 block_n, 

1316 True, 

1317 num_warps=8, 

1318 num_stages=1, 

1319 ) 

1320 flat_radix_store_result_kernel[(1,)](flat, out, valid_count, result_idx) 

1321 return out 

1322 

1323 

1324def _nanmedian_flat_impl(inp, out=None): 

1325 n = inp.numel() 

1326 if n == 0: 

1327 result = _empty_flat_value(inp) 

1328 if out is not None: 

1329 out.copy_(result) 

1330 return out 

1331 return result 

1332 

1333 if ( 

1334 inp.is_cuda 

1335 and inp.dtype in RADIX_SELECT_DTYPES 

1336 and LONG_RADIX_REDUCTION_N < n <= INT32_MAX 

1337 ): 

1338 return _nanmedian_cuda_flat_radix_select(inp, out=out) 

1339 

1340 if ( 

1341 inp.device.type == "npu" 

1342 and inp.dtype in ASCEND_FLAT_SORT_DTYPES 

1343 and n >= ASCEND_FLAT_SORT_MIN_N 

1344 ): 

1345 result = _nanmedian_ascend_flat_sort(inp) 

1346 if out is not None: 

1347 out.copy_(result) 

1348 return out 

1349 return result 

1350 

1351 flat = inp.reshape(-1) 

1352 if out is None: 

1353 return _nanmedian_dim_impl(flat, 0, False, use_ascend_float_select=False).values 

1354 

1355 indices = torch.empty((), dtype=torch.long, device=inp.device) 

1356 _nanmedian_dim_impl( 

1357 flat, 

1358 0, 

1359 False, 

1360 out=(out, indices), 

1361 use_ascend_float_select=False, 

1362 ) 

1363 return out 

1364 

1365 

1366def nanmedian(inp): 

1367 logger.debug("GEMS NANMEDIAN") 

1368 _check_supported_dtype(inp) 

1369 return _nanmedian_flat_impl(inp) 

1370 

1371 

1372def nanmedian_out(inp, *, out): 

1373 logger.debug("GEMS NANMEDIAN OUT") 

1374 _check_supported_dtype(inp) 

1375 return _nanmedian_flat_impl(inp, out=out) 

1376 

1377 

1378def nanmedian_dim(inp, dim=-1, keepdim=False): 

1379 logger.debug("GEMS NANMEDIAN DIM") 

1380 _check_supported_dtype(inp) 

1381 return _nanmedian_dim_impl(inp, dim, keepdim) 

1382 

1383 

1384def nanmedian_dim_values(inp, dim=-1, keepdim=False, *, values, indices): 

1385 logger.debug("GEMS NANMEDIAN DIM VALUES") 

1386 return _nanmedian_dim_impl(inp, dim, keepdim, out=(values, indices))