Coverage for src/flag_gems/ops/median.py: 46%

799 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2import math 

3from collections import namedtuple 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems.ops.topk import _get_iinfo_val 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils import libentry 

12 

13logger = logging.getLogger(__name__) 

14 

15MedianResult = namedtuple("median", ["values", "indices"]) 

16 

17_DIRECT_REDUCTION_LIMIT = 256 

18_DIRECT_FLAT_LIMIT = 256 

19_BOOL_FLAT_BLOCK = 1024 

20_BOOL_COUNT_REDUCE_BLOCK = 1024 

21_DIRECT_REDUCTION_DTYPES = { 

22 torch.bool, 

23 torch.float16, 

24 torch.bfloat16, 

25 torch.float32, 

26 torch.float64, 

27 torch.int8, 

28 torch.uint8, 

29 torch.int16, 

30 torch.int32, 

31 torch.int64, 

32} 

33_FLAT_SORT_LIMIT = 1024 

34_LASTDIM_SORT_LIMIT = 1024 

35_BF16_LASTDIM_SORT_LIMIT = 2048 

36_LASTDIM_SORT_DTYPES = {torch.float16, torch.bfloat16} 

37_FLAT_SORT_DTYPES = _LASTDIM_SORT_DTYPES | {torch.float32} 

38_F16_KEY_SELECT_MIN = 2 

39_F16_KEY_SELECT_LIMIT = 16384 

40_F16_KEY_SELECT_DTYPES = {torch.float16, torch.bfloat16} 

41_FP32_KEY_SELECT_MIN = 2 

42_FP32_KEY_SELECT_LIMIT = 16384 

43_FP64_KEY_SELECT_MIN = 2 

44_FP64_KEY_SELECT_LIMIT = 8192 

45_INT_LASTDIM_SELECT_LIMIT = 16384 

46_INT_LASTDIM_SELECT_DTYPES = { 

47 torch.int8, 

48 torch.uint8, 

49 torch.int16, 

50 torch.int32, 

51 torch.int64, 

52} 

53_STRIDED_SELECT_MIN = _DIRECT_REDUCTION_LIMIT + 1 

54_STRIDED_SELECT_LIMIT = 4096 

55 

56 

57@libentry() 

58@triton.jit 

59def median_small_dim_kernel( 

60 inp, 

61 values, 

62 indices, 

63 total_outputs, 

64 reduction_size, 

65 inner_size, 

66 BLOCK_N: tl.constexpr, 

67 BLOCK_OUT: tl.constexpr, 

68): 

69 out_offsets = tl.program_id(0) * BLOCK_OUT + tl.arange(0, BLOCK_OUT) 

70 out_mask = out_offsets < total_outputs 

71 inner_offsets = out_offsets % inner_size 

72 outer_offsets = out_offsets // inner_size 

73 

74 reduction_offsets = tl.arange(0, BLOCK_N) 

75 sample_mask = (reduction_offsets[None, :] < reduction_size) & out_mask[:, None] 

76 sample_ptrs = ( 

77 inp 

78 + outer_offsets[:, None] * reduction_size * inner_size 

79 + reduction_offsets[None, :] * inner_size 

80 + inner_offsets[:, None] 

81 ) 

82 

83 if inp.dtype.element_ty.is_floating(): 

84 high = float("inf") 

85 else: 

86 high = _get_iinfo_val(inp.dtype.element_ty, return_max=True) 

87 

88 samples = tl.load(sample_ptrs, mask=sample_mask, other=high) 

89 sortable = samples 

90 

91 if inp.dtype.element_ty.is_floating(): 

92 nan_mask = sample_mask & (samples != samples) 

93 sortable = tl.where(nan_mask, high, samples) 

94 

95 ordered = tl.sort(sortable, dim=1, descending=False) 

96 rank = (reduction_size - 1) // 2 

97 rank_mask = reduction_offsets[None, :] == rank 

98 median_values = tl.sum(tl.where(rank_mask, ordered, tl.zeros_like(ordered)), axis=1) 

99 

100 first_match = tl.argmax( 

101 (sample_mask & (samples == median_values[:, None])).to(tl.int32), axis=1 

102 ) 

103 

104 if inp.dtype.element_ty.is_floating(): 

105 nan_i32 = nan_mask.to(tl.int32) 

106 has_nan = tl.max(nan_i32, axis=1) != 0 

107 first_nan = tl.argmax(nan_i32, axis=1) 

108 nan_values = tl.load( 

109 inp 

110 + outer_offsets * reduction_size * inner_size 

111 + first_nan * inner_size 

112 + inner_offsets, 

113 mask=out_mask, 

114 other=0.0, 

115 ) 

116 median_values = tl.where(has_nan, nan_values, median_values) 

117 first_match = tl.where(has_nan, first_nan, first_match) 

118 

119 tl.store(values + out_offsets, median_values, mask=out_mask) 

120 tl.store(indices + out_offsets, first_match.to(tl.int64), mask=out_mask) 

121 

122 

123@libentry() 

124@triton.jit 

125def median_small_flat_kernel( 

126 inp, 

127 value, 

128 WIDTH: tl.constexpr, 

129 BLOCK: tl.constexpr, 

130): 

131 offsets = tl.arange(0, BLOCK) 

132 valid = offsets < WIDTH 

133 

134 if inp.dtype.element_ty.is_floating(): 

135 high = float("inf") 

136 elif inp.dtype.element_ty is tl.int1: 

137 high = True 

138 else: 

139 high = _get_iinfo_val(inp.dtype.element_ty, return_max=True) 

140 

141 data = tl.load(inp + offsets, mask=valid, other=high) 

142 sortable = data 

143 

144 if inp.dtype.element_ty.is_floating(): 

145 nan_mask = valid & (data != data) 

146 sortable = tl.where(nan_mask, high, data) 

147 

148 ordered = tl.sort(sortable, descending=False) 

149 rank = (WIDTH - 1) // 2 

150 median_value = tl.sum( 

151 tl.where(offsets == rank, ordered, tl.zeros_like(ordered)), axis=0 

152 ) 

153 

154 if inp.dtype.element_ty.is_floating(): 

155 nan_i32 = nan_mask.to(tl.int32) 

156 has_nan = tl.max(nan_i32, axis=0) != 0 

157 first_nan = tl.argmax(nan_i32, axis=0) 

158 nan_value = tl.load(inp + first_nan, mask=has_nan, other=0.0) 

159 median_value = tl.where(has_nan, nan_value, median_value) 

160 

161 tl.store(value, median_value) 

162 

163 

164@libentry() 

165@triton.jit 

166def median_bool_count_kernel( 

167 inp, 

168 counts, 

169 WIDTH: tl.constexpr, 

170 BLOCK: tl.constexpr, 

171): 

172 block_id = tl.program_id(0) 

173 offsets = block_id * BLOCK + tl.arange(0, BLOCK) 

174 valid = offsets < WIDTH 

175 data = tl.load(inp + offsets, mask=valid, other=False) 

176 true_count = tl.sum((valid & data).to(tl.int64), axis=0) 

177 tl.store(counts + block_id, true_count) 

178 

179 

180@libentry() 

181@triton.jit 

182def median_bool_from_counts_kernel( 

183 counts, 

184 value, 

185 WIDTH: tl.constexpr, 

186 NUM_BLOCKS: tl.constexpr, 

187 BLOCK: tl.constexpr, 

188): 

189 offsets = tl.arange(0, BLOCK) 

190 valid = offsets < NUM_BLOCKS 

191 block_counts = tl.load(counts + offsets, mask=valid, other=0) 

192 true_count = tl.sum(block_counts, axis=0) 

193 rank = (WIDTH - 1) // 2 

194 false_count = WIDTH - true_count 

195 median_value = rank >= false_count 

196 tl.store(value, median_value) 

197 

198 

199@libentry() 

200@triton.jit 

201def median_bool_reduce_counts_kernel( 

202 counts_in, 

203 counts_out, 

204 WIDTH: tl.constexpr, 

205 BLOCK: tl.constexpr, 

206): 

207 block_id = tl.program_id(0) 

208 offsets = block_id * BLOCK + tl.arange(0, BLOCK) 

209 valid = offsets < WIDTH 

210 block_counts = tl.load(counts_in + offsets, mask=valid, other=0) 

211 count = tl.sum(block_counts, axis=0) 

212 tl.store(counts_out + block_id, count) 

213 

214 

215@libentry() 

216@triton.jit 

217def median_bool_dim_count_chunks_kernel( 

218 inp, 

219 counts, 

220 first_false, 

221 first_true, 

222 total_outputs, 

223 reduction_size, 

224 inner_size, 

225 chunks_per_output: tl.constexpr, 

226 BLOCK: tl.constexpr, 

227): 

228 pid = tl.program_id(0) 

229 out_offset = pid // chunks_per_output 

230 chunk_id = pid - out_offset * chunks_per_output 

231 out_mask = out_offset < total_outputs 

232 inner_offset = out_offset % inner_size 

233 outer_offset = out_offset // inner_size 

234 

235 cols = chunk_id * BLOCK + tl.arange(0, BLOCK) 

236 valid = (cols < reduction_size) & out_mask 

237 ptrs = ( 

238 inp 

239 + outer_offset * reduction_size * inner_size 

240 + cols * inner_size 

241 + inner_offset 

242 ) 

243 data = tl.load(ptrs, mask=valid, other=False) 

244 

245 true_mask = valid & data 

246 false_mask = valid & ~data 

247 true_count = tl.sum(true_mask.to(tl.int64), axis=0) 

248 first_false_idx = tl.min(tl.where(false_mask, cols, reduction_size), axis=0) 

249 first_true_idx = tl.min(tl.where(true_mask, cols, reduction_size), axis=0) 

250 

251 tl.store(counts + pid, true_count) 

252 tl.store(first_false + pid, first_false_idx.to(tl.int64)) 

253 tl.store(first_true + pid, first_true_idx.to(tl.int64)) 

254 

255 

256@libentry() 

257@triton.jit 

258def median_bool_dim_reduce_chunks_kernel( 

259 counts_in, 

260 first_false_in, 

261 first_true_in, 

262 counts_out, 

263 first_false_out, 

264 first_true_out, 

265 input_chunks: tl.constexpr, 

266 output_chunks: tl.constexpr, 

267 BLOCK: tl.constexpr, 

268): 

269 row = tl.program_id(0) 

270 out_chunk = tl.program_id(1) 

271 chunk_offsets = out_chunk * BLOCK + tl.arange(0, BLOCK) 

272 valid = chunk_offsets < input_chunks 

273 in_base = row * input_chunks + chunk_offsets 

274 

275 counts = tl.load(counts_in + in_base, mask=valid, other=0) 

276 first_false = tl.load( 

277 first_false_in + in_base, mask=valid, other=9223372036854775807 

278 ) 

279 first_true = tl.load(first_true_in + in_base, mask=valid, other=9223372036854775807) 

280 

281 true_count = tl.sum(counts, axis=0) 

282 first_false_idx = tl.min(first_false, axis=0) 

283 first_true_idx = tl.min(first_true, axis=0) 

284 out_base = row * output_chunks + out_chunk 

285 tl.store(counts_out + out_base, true_count) 

286 tl.store(first_false_out + out_base, first_false_idx) 

287 tl.store(first_true_out + out_base, first_true_idx) 

288 

289 

290@libentry() 

291@triton.jit 

292def median_bool_dim_finish_kernel( 

293 counts, 

294 first_false, 

295 first_true, 

296 values, 

297 indices, 

298 reduction_size, 

299 chunks_per_output: tl.constexpr, 

300 BLOCK: tl.constexpr, 

301): 

302 row = tl.program_id(0) 

303 chunk_offsets = tl.arange(0, BLOCK) 

304 valid = chunk_offsets < chunks_per_output 

305 base = row * chunks_per_output + chunk_offsets 

306 

307 block_counts = tl.load(counts + base, mask=valid, other=0) 

308 true_count = tl.sum(block_counts, axis=0) 

309 false_count = reduction_size - true_count 

310 rank = (reduction_size - 1) // 2 

311 median_value = rank >= false_count 

312 

313 false_indices = tl.load(first_false + base, mask=valid, other=9223372036854775807) 

314 true_indices = tl.load(first_true + base, mask=valid, other=9223372036854775807) 

315 first_false_idx = tl.min(false_indices, axis=0) 

316 first_true_idx = tl.min(true_indices, axis=0) 

317 first_match = tl.where(median_value, first_true_idx, first_false_idx) 

318 

319 tl.store(values + row, median_value) 

320 tl.store(indices + row, first_match) 

321 

322 

323@libentry() 

324@triton.jit 

325def median_lastdim_sort_kernel( 

326 row_data, 

327 values, 

328 indices, 

329 WIDTH: tl.constexpr, 

330 BLOCK: tl.constexpr, 

331): 

332 row = tl.program_id(0) 

333 cols = tl.arange(0, BLOCK) 

334 valid = cols < WIDTH 

335 base = row_data + row * WIDTH 

336 data = tl.load(base + cols, mask=valid, other=float("inf")) 

337 

338 nan_mask = valid & (data != data) 

339 sortable = tl.where(nan_mask, float("inf"), data) 

340 ordered = tl.sort(sortable, descending=False) 

341 rank = (WIDTH - 1) // 2 

342 median_value = tl.sum( 

343 tl.where(cols == rank, ordered, tl.zeros_like(ordered)), axis=0 

344 ) 

345 

346 first_match = tl.argmax((valid & (data == median_value)).to(tl.int32), axis=0) 

347 nan_i32 = nan_mask.to(tl.int32) 

348 has_nan = tl.max(nan_i32, axis=0) != 0 

349 first_nan = tl.argmax(nan_i32, axis=0) 

350 nan_value = tl.load(base + first_nan, mask=has_nan, other=0.0) 

351 median_value = tl.where(has_nan, nan_value, median_value) 

352 first_match = tl.where(has_nan, first_nan, first_match) 

353 

354 tl.store(values + row, median_value) 

355 tl.store(indices + row, first_match.to(tl.int64)) 

356 

357 

358@libentry() 

359@triton.jit 

360def median_int_lastdim_select_kernel( 

361 row_data, 

362 values, 

363 indices, 

364 WIDTH: tl.constexpr, 

365 BLOCK: tl.constexpr, 

366 SEARCH_STEPS: tl.constexpr, 

367): 

368 row = tl.program_id(0) 

369 cols = tl.arange(0, BLOCK) 

370 valid = cols < WIDTH 

371 base = row_data + row * WIDTH 

372 data = tl.load(base + cols, mask=valid, other=0) 

373 

374 dtype = row_data.dtype.element_ty 

375 high = _get_iinfo_val(dtype, return_max=True) 

376 low = _get_iinfo_val(dtype, return_max=False) 

377 row_min = tl.min(tl.where(valid, data, high), axis=0).to(tl.int64) 

378 row_max = tl.max(tl.where(valid, data, low), axis=0).to(tl.int64) 

379 

380 lo = row_min 

381 hi = row_max 

382 rank = (WIDTH - 1) // 2 

383 for _ in tl.static_range(0, SEARCH_STEPS): 

384 mid = lo + ((hi - lo) // 2) 

385 le_count = tl.sum((valid & (data <= mid.to(dtype))).to(tl.int32), axis=0) 

386 take_left = le_count > rank 

387 hi = tl.where(take_left, mid, hi) 

388 lo = tl.where(take_left, lo, mid + 1) 

389 

390 median_value = lo.to(dtype) 

391 first_match = tl.argmax((valid & (data == median_value)).to(tl.int32), axis=0) 

392 tl.store(values + row, median_value) 

393 tl.store(indices + row, first_match.to(tl.int64)) 

394 

395 

396@triton.jit 

397def _fp32_order_key(x): 

398 bits = x.to(tl.uint32, bitcast=True) 

399 signed = x.to(tl.int32, bitcast=True) 

400 sign = signed >> 31 

401 sign_mask = tl.full((), 0x80000000, dtype=tl.uint32) 

402 mask = sign_mask | sign.to(tl.uint32, bitcast=True) 

403 return bits ^ mask 

404 

405 

406@triton.jit 

407def _fp64_order_key(x): 

408 bits = x.to(tl.uint64, bitcast=True) 

409 signed = x.to(tl.int64, bitcast=True) 

410 sign = signed >> 63 

411 sign_mask = tl.full((), 1, dtype=tl.uint64) << 63 

412 mask = sign_mask | sign.to(tl.uint64, bitcast=True) 

413 return bits ^ mask 

414 

415 

416@triton.jit 

417def _f16_order_key(x): 

418 bits = x.to(tl.uint16, bitcast=True) 

419 signed = x.to(tl.int16, bitcast=True) 

420 sign = signed >> 15 

421 sign_mask = tl.full((), 0x8000, dtype=tl.uint16) 

422 mask = sign_mask | sign.to(tl.uint16, bitcast=True) 

423 return bits ^ mask 

424 

425 

426@libentry() 

427@triton.jit 

428def median_f16_key_select_kernel( 

429 row_data, 

430 values, 

431 indices, 

432 WIDTH: tl.constexpr, 

433 BLOCK: tl.constexpr, 

434): 

435 row = tl.program_id(0) 

436 cols = tl.arange(0, BLOCK) 

437 valid = cols < WIDTH 

438 base = row_data + row * WIDTH 

439 data = tl.load(base + cols, mask=valid, other=0.0) 

440 

441 nan_mask = valid & (data != data) 

442 nan_i32 = nan_mask.to(tl.int32) 

443 has_nan = tl.max(nan_i32, axis=0) != 0 

444 first_nan = tl.argmax(nan_i32, axis=0) 

445 nan_value = tl.load(base + first_nan, mask=has_nan, other=0.0) 

446 

447 finite = valid & ~nan_mask 

448 neg_inf_mask = finite & (data == -float("inf")) 

449 pos_inf_mask = finite & (data == float("inf")) 

450 real_finite = finite & ~(neg_inf_mask | pos_inf_mask) 

451 neg_inf_count = tl.sum(neg_inf_mask.to(tl.int32), axis=0) 

452 real_finite_count = tl.sum(real_finite.to(tl.int32), axis=0) 

453 

454 rank = (WIDTH - 1) // 2 

455 search_rank = rank - neg_inf_count 

456 take_neg_inf = rank < neg_inf_count 

457 take_pos_inf = search_rank >= real_finite_count 

458 

459 keys = _f16_order_key(data).to(tl.uint32) 

460 key_min_fill = tl.full((), 0xFFFF, dtype=tl.uint32) 

461 key_max_fill = tl.full((), 0, dtype=tl.uint32) 

462 row_min = tl.min(tl.where(real_finite, keys, key_min_fill), axis=0) 

463 row_max = tl.max(tl.where(real_finite, keys, key_max_fill), axis=0) 

464 has_real_finite = real_finite_count != 0 

465 row_min = tl.where(has_real_finite, row_min, 0) 

466 row_max = tl.where(has_real_finite, row_max, 0) 

467 

468 lo = row_min 

469 hi = row_max 

470 for _ in tl.static_range(0, 16): 

471 mid = lo + ((hi - lo) >> 1) 

472 le_count = tl.sum((real_finite & (keys <= mid)).to(tl.int32), axis=0) 

473 take_left = le_count > search_rank 

474 hi = tl.where(take_left, mid, hi) 

475 lo = tl.where(take_left, lo, mid + 1) 

476 

477 selected_key = lo 

478 key_match = real_finite & (keys == selected_key) 

479 selected_key_first = tl.argmax(key_match.to(tl.int32), axis=0) 

480 selected_value = tl.load(base + selected_key_first) 

481 

482 first_neg_inf = tl.argmax(neg_inf_mask.to(tl.int32), axis=0) 

483 neg_inf_value = tl.load(base + first_neg_inf, mask=take_neg_inf, other=0.0) 

484 first_pos_inf = tl.argmax(pos_inf_mask.to(tl.int32), axis=0) 

485 pos_inf_value = tl.load(base + first_pos_inf, mask=take_pos_inf, other=0.0) 

486 selected_value = tl.where(take_neg_inf, neg_inf_value, selected_value) 

487 selected_value = tl.where(take_pos_inf, pos_inf_value, selected_value) 

488 selected_key_first = tl.where(take_neg_inf, first_neg_inf, selected_key_first) 

489 selected_key_first = tl.where(take_pos_inf, first_pos_inf, selected_key_first) 

490 

491 selected_value = tl.where(has_nan, nan_value, selected_value) 

492 first_match = tl.where(has_nan, first_nan, selected_key_first) 

493 tl.store(values + row, selected_value) 

494 tl.store(indices + row, first_match.to(tl.int64)) 

495 

496 

497@libentry() 

498@triton.jit 

499def median_fp32_key_select_kernel( 

500 row_data, 

501 values, 

502 indices, 

503 WIDTH: tl.constexpr, 

504 BLOCK: tl.constexpr, 

505): 

506 row = tl.program_id(0) 

507 cols = tl.arange(0, BLOCK) 

508 valid = cols < WIDTH 

509 base = row_data + row * WIDTH 

510 data = tl.load(base + cols, mask=valid, other=0.0) 

511 

512 nan_mask = valid & (data != data) 

513 nan_i32 = nan_mask.to(tl.int32) 

514 has_nan = tl.max(nan_i32, axis=0) != 0 

515 first_nan = tl.argmax(nan_i32, axis=0) 

516 nan_value = tl.load(base + first_nan, mask=has_nan, other=0.0) 

517 

518 keys = _fp32_order_key(data) 

519 finite = valid & ~nan_mask 

520 key_min_fill = tl.full((), 0xFFFFFFFF, dtype=tl.uint32) 

521 key_max_fill = tl.full((), 0, dtype=tl.uint32) 

522 row_min = tl.min(tl.where(finite, keys, key_min_fill), axis=0) 

523 row_max = tl.max(tl.where(finite, keys, key_max_fill), axis=0) 

524 

525 lo = row_min 

526 hi = row_max 

527 rank = (WIDTH - 1) // 2 

528 for _ in tl.static_range(0, 32): 

529 mid = lo + ((hi - lo) >> 1) 

530 le_count = tl.sum((finite & (keys <= mid)).to(tl.int32), axis=0) 

531 take_left = le_count > rank 

532 hi = tl.where(take_left, mid, hi) 

533 lo = tl.where(take_left, lo, mid + 1) 

534 

535 selected_key = lo 

536 key_match = finite & (keys == selected_key) 

537 selected_key_first = tl.argmax(key_match.to(tl.int32), axis=0) 

538 selected_value = tl.load(base + selected_key_first) 

539 

540 selected_value = tl.where(has_nan, nan_value, selected_value) 

541 first_match = tl.where(has_nan, first_nan, selected_key_first) 

542 tl.store(values + row, selected_value) 

543 tl.store(indices + row, first_match.to(tl.int64)) 

544 

545 

546@libentry() 

547@triton.jit 

548def median_fp64_key_select_kernel( 

549 row_data, 

550 values, 

551 indices, 

552 WIDTH: tl.constexpr, 

553 BLOCK: tl.constexpr, 

554): 

555 row = tl.program_id(0) 

556 cols = tl.arange(0, BLOCK) 

557 valid = cols < WIDTH 

558 base = row_data + row * WIDTH 

559 data = tl.load(base + cols, mask=valid, other=0.0) 

560 

561 nan_mask = valid & (data != data) 

562 nan_i64 = nan_mask.to(tl.int64) 

563 has_nan = tl.max(nan_i64, axis=0) != 0 

564 first_nan = tl.argmax(nan_i64, axis=0) 

565 nan_value = tl.load(base + first_nan, mask=has_nan, other=0.0) 

566 

567 keys = _fp64_order_key(data) 

568 finite = valid & ~nan_mask 

569 key_min_fill = tl.full((), 0xFFFFFFFFFFFFFFFF, dtype=tl.uint64) 

570 key_max_fill = tl.full((), 0, dtype=tl.uint64) 

571 row_min = tl.min(tl.where(finite, keys, key_min_fill), axis=0) 

572 row_max = tl.max(tl.where(finite, keys, key_max_fill), axis=0) 

573 

574 lo = row_min 

575 hi = row_max 

576 rank = (WIDTH - 1) // 2 

577 for _ in tl.static_range(0, 64): 

578 mid = lo + ((hi - lo) >> 1) 

579 le_count = tl.sum((finite & (keys <= mid)).to(tl.int32), axis=0) 

580 take_left = le_count > rank 

581 hi = tl.where(take_left, mid, hi) 

582 lo = tl.where(take_left, lo, mid + 1) 

583 

584 selected_key = lo 

585 key_match = finite & (keys == selected_key) 

586 selected_key_first = tl.argmax(key_match.to(tl.int32), axis=0) 

587 selected_value = tl.load(base + selected_key_first) 

588 

589 selected_value = tl.where(has_nan, nan_value, selected_value) 

590 first_match = tl.where(has_nan, first_nan, selected_key_first) 

591 tl.store(values + row, selected_value) 

592 tl.store(indices + row, first_match.to(tl.int64)) 

593 

594 

595@libentry() 

596@triton.jit 

597def median_f16_strided_key_select_kernel( 

598 inp, 

599 values, 

600 indices, 

601 total_outputs, 

602 reduction_size, 

603 inner_size, 

604 BLOCK: tl.constexpr, 

605 BLOCK_OUT: tl.constexpr, 

606): 

607 out_offsets = tl.program_id(0) * BLOCK_OUT + tl.arange(0, BLOCK_OUT) 

608 out_mask = out_offsets < total_outputs 

609 inner_offsets = out_offsets % inner_size 

610 outer_offsets = out_offsets // inner_size 

611 cols = tl.arange(0, BLOCK) 

612 valid = (cols[None, :] < reduction_size) & out_mask[:, None] 

613 ptrs = ( 

614 inp 

615 + outer_offsets[:, None] * reduction_size * inner_size 

616 + cols[None, :] * inner_size 

617 + inner_offsets[:, None] 

618 ) 

619 data = tl.load(ptrs, mask=valid, other=0.0) 

620 

621 nan_mask = valid & (data != data) 

622 nan_i32 = nan_mask.to(tl.int32) 

623 has_nan = tl.max(nan_i32, axis=1) != 0 

624 first_nan = tl.argmax(nan_i32, axis=1) 

625 nan_value = tl.load( 

626 inp 

627 + outer_offsets * reduction_size * inner_size 

628 + first_nan * inner_size 

629 + inner_offsets, 

630 mask=out_mask, 

631 other=0.0, 

632 ) 

633 

634 keys = _f16_order_key(data).to(tl.uint32) 

635 finite = valid & ~nan_mask 

636 key_min_fill = tl.full((), 0xFFFF, dtype=tl.uint32) 

637 key_max_fill = tl.full((), 0, dtype=tl.uint32) 

638 row_min = tl.min(tl.where(finite, keys, key_min_fill), axis=1) 

639 row_max = tl.max(tl.where(finite, keys, key_max_fill), axis=1) 

640 

641 lo = row_min 

642 hi = row_max 

643 rank = (reduction_size - 1) // 2 

644 for _ in tl.static_range(0, 16): 

645 mid = lo + ((hi - lo) >> 1) 

646 le_count = tl.sum((finite & (keys <= mid[:, None])).to(tl.int32), axis=1) 

647 take_left = le_count > rank 

648 hi = tl.where(take_left, mid, hi) 

649 lo = tl.where(take_left, lo, mid + 1) 

650 

651 selected_key = lo 

652 key_match = finite & (keys == selected_key[:, None]) 

653 selected_key_first = tl.argmax(key_match.to(tl.int32), axis=1) 

654 selected_value = tl.load( 

655 inp 

656 + outer_offsets * reduction_size * inner_size 

657 + selected_key_first * inner_size 

658 + inner_offsets, 

659 mask=out_mask, 

660 other=0.0, 

661 ) 

662 

663 selected_value = tl.where(has_nan, nan_value, selected_value) 

664 first_match = tl.where(has_nan, first_nan, selected_key_first) 

665 tl.store(values + out_offsets, selected_value, mask=out_mask) 

666 tl.store(indices + out_offsets, first_match.to(tl.int64), mask=out_mask) 

667 

668 

669@libentry() 

670@triton.jit 

671def median_fp32_strided_key_select_kernel( 

672 inp, 

673 values, 

674 indices, 

675 total_outputs, 

676 reduction_size, 

677 inner_size, 

678 BLOCK: tl.constexpr, 

679 BLOCK_OUT: tl.constexpr, 

680): 

681 out_offsets = tl.program_id(0) * BLOCK_OUT + tl.arange(0, BLOCK_OUT) 

682 out_mask = out_offsets < total_outputs 

683 inner_offsets = out_offsets % inner_size 

684 outer_offsets = out_offsets // inner_size 

685 cols = tl.arange(0, BLOCK) 

686 valid = (cols[None, :] < reduction_size) & out_mask[:, None] 

687 ptrs = ( 

688 inp 

689 + outer_offsets[:, None] * reduction_size * inner_size 

690 + cols[None, :] * inner_size 

691 + inner_offsets[:, None] 

692 ) 

693 data = tl.load(ptrs, mask=valid, other=0.0) 

694 

695 nan_mask = valid & (data != data) 

696 nan_i32 = nan_mask.to(tl.int32) 

697 has_nan = tl.max(nan_i32, axis=1) != 0 

698 first_nan = tl.argmax(nan_i32, axis=1) 

699 nan_value = tl.load( 

700 inp 

701 + outer_offsets * reduction_size * inner_size 

702 + first_nan * inner_size 

703 + inner_offsets, 

704 mask=out_mask, 

705 other=0.0, 

706 ) 

707 

708 keys = _fp32_order_key(data) 

709 finite = valid & ~nan_mask 

710 key_min_fill = tl.full((), 0xFFFFFFFF, dtype=tl.uint32) 

711 key_max_fill = tl.full((), 0, dtype=tl.uint32) 

712 row_min = tl.min(tl.where(finite, keys, key_min_fill), axis=1) 

713 row_max = tl.max(tl.where(finite, keys, key_max_fill), axis=1) 

714 

715 lo = row_min 

716 hi = row_max 

717 rank = (reduction_size - 1) // 2 

718 for _ in tl.static_range(0, 32): 

719 mid = lo + ((hi - lo) >> 1) 

720 le_count = tl.sum((finite & (keys <= mid[:, None])).to(tl.int32), axis=1) 

721 take_left = le_count > rank 

722 hi = tl.where(take_left, mid, hi) 

723 lo = tl.where(take_left, lo, mid + 1) 

724 

725 selected_key = lo 

726 key_match = finite & (keys == selected_key[:, None]) 

727 selected_key_first = tl.argmax(key_match.to(tl.int32), axis=1) 

728 selected_value = tl.load( 

729 inp 

730 + outer_offsets * reduction_size * inner_size 

731 + selected_key_first * inner_size 

732 + inner_offsets, 

733 mask=out_mask, 

734 other=0.0, 

735 ) 

736 

737 selected_value = tl.where(has_nan, nan_value, selected_value) 

738 first_match = tl.where(has_nan, first_nan, selected_key_first) 

739 tl.store(values + out_offsets, selected_value, mask=out_mask) 

740 tl.store(indices + out_offsets, first_match.to(tl.int64), mask=out_mask) 

741 

742 

743def _has_names(inp): 

744 return any(name is not None for name in inp.names) 

745 

746 

747def _anonymous(inp): 

748 return inp.rename(None) if _has_names(inp) else inp 

749 

750 

751def _canonical_dim(ndim, dim): 

752 lower = -1 if ndim == 0 else -ndim 

753 upper = 0 if ndim == 0 else ndim - 1 

754 if dim < lower or dim > upper: 

755 raise IndexError( 

756 f"Dimension out of range (expected to be in range of " 

757 f"[{lower}, {upper}], but got {dim})" 

758 ) 

759 return 0 if ndim == 0 else dim % ndim 

760 

761 

762def _name_to_dim(inp, dim): 

763 if dim not in inp.names: 

764 raise RuntimeError(f"Name '{dim}' not found in Tensor{inp.names}.") 

765 return inp.names.index(dim) 

766 

767 

768def _kept_names(names, dim, keepdim): 

769 if names is None: 

770 return None 

771 if keepdim: 

772 return names 

773 return names[:dim] + names[dim + 1 :] 

774 

775 

776def _empty_result_value(inp): 

777 if inp.dtype.is_complex: 

778 out = torch.empty((), dtype=inp.dtype, device=inp.device) 

779 out.real.fill_(float("nan")) 

780 out.imag.zero_() 

781 return out 

782 if inp.dtype.is_floating_point: 

783 return torch.full((), float("nan"), dtype=inp.dtype, device=inp.device) 

784 if inp.dtype == torch.bool: 

785 return torch.ones((), dtype=inp.dtype, device=inp.device) 

786 if inp.dtype in (torch.int32, torch.int64): 

787 return torch.full( 

788 (), torch.iinfo(inp.dtype).min, dtype=inp.dtype, device=inp.device 

789 ) 

790 return torch.zeros((), dtype=inp.dtype, device=inp.device) 

791 

792 

793def _raise_dim_dtype(dtype): 

794 dtype_names = { 

795 torch.bool: "Bool", 

796 torch.complex64: "ComplexFloat", 

797 torch.complex128: "ComplexDouble", 

798 } 

799 dtype_name = dtype_names.get(dtype, str(dtype).removeprefix("torch.")) 

800 raise NotImplementedError(f'"median_out_impl" not implemented for {dtype_name!r}') 

801 

802 

803def _int_search_steps(dtype): 

804 if dtype in (torch.int8, torch.uint8): 

805 return 8 

806 if dtype == torch.int16: 

807 return 16 

808 if dtype == torch.int32: 

809 return 32 

810 if dtype == torch.int64: 

811 return 64 

812 raise NotImplementedError(f"median integer selection not implemented for {dtype}") 

813 

814 

815def _unsupported_width(dtype, width): 

816 raise NotImplementedError( 

817 f"median Triton selection not implemented for dtype {dtype} " 

818 f"with reduction width {width}" 

819 ) 

820 

821 

822def _median_from_rows(row_data, output_shape): 

823 width = row_data.shape[-1] 

824 if _use_f16_key_select(row_data.dtype, width): 

825 return _median_f16_key_select(row_data, output_shape) 

826 if _use_lastdim_sort(row_data.dtype, width): 

827 return _median_lastdim_sort(row_data, output_shape) 

828 if _use_fp32_key_select(row_data.dtype, width): 

829 return _median_fp32_key_select(row_data, output_shape) 

830 if _use_fp64_key_select(row_data.dtype, width): 

831 return _median_fp64_key_select(row_data, output_shape) 

832 if ( 

833 width <= _INT_LASTDIM_SELECT_LIMIT 

834 and row_data.dtype in _INT_LASTDIM_SELECT_DTYPES 

835 ): 

836 return _median_int_lastdim_select(row_data, output_shape) 

837 _unsupported_width(row_data.dtype, width) 

838 

839 

840def _median_small_flat(inp): 

841 value = torch.empty((), dtype=inp.dtype, device=inp.device) 

842 block = triton.next_power_of_2(inp.numel()) 

843 with torch_device_fn.device(inp.device): 

844 median_small_flat_kernel[(1,)]( 

845 inp.reshape(-1), 

846 value, 

847 WIDTH=inp.numel(), 

848 BLOCK=block, 

849 num_warps=min(8, max(4, block // 32)), 

850 ) 

851 return value 

852 

853 

854def _median_bool_flat(inp): 

855 width = inp.numel() 

856 block = _BOOL_FLAT_BLOCK 

857 num_blocks = triton.cdiv(width, block) 

858 counts = torch.empty((num_blocks,), dtype=torch.int64, device=inp.device) 

859 value = torch.empty((), dtype=inp.dtype, device=inp.device) 

860 with torch_device_fn.device(inp.device): 

861 median_bool_count_kernel[(num_blocks,)]( 

862 inp.reshape(-1), 

863 counts, 

864 WIDTH=width, 

865 BLOCK=block, 

866 num_warps=4, 

867 ) 

868 while counts.numel() > _BOOL_COUNT_REDUCE_BLOCK: 

869 reduced_blocks = triton.cdiv(counts.numel(), _BOOL_COUNT_REDUCE_BLOCK) 

870 reduced = torch.empty( 

871 (reduced_blocks,), dtype=torch.int64, device=inp.device 

872 ) 

873 median_bool_reduce_counts_kernel[(reduced_blocks,)]( 

874 counts, 

875 reduced, 

876 WIDTH=counts.numel(), 

877 BLOCK=_BOOL_COUNT_REDUCE_BLOCK, 

878 num_warps=4, 

879 ) 

880 counts = reduced 

881 count_block = triton.next_power_of_2(counts.numel()) 

882 median_bool_from_counts_kernel[(1,)]( 

883 counts, 

884 value, 

885 WIDTH=width, 

886 NUM_BLOCKS=counts.numel(), 

887 BLOCK=count_block, 

888 num_warps=min(8, max(1, count_block // 32)), 

889 ) 

890 return value 

891 

892 

893def _median_bool_dim(inp, dim, output_shape): 

894 reduction_size = inp.shape[dim] 

895 inner_size = math.prod(inp.shape[dim + 1 :]) 

896 total_outputs = math.prod(output_shape) 

897 values = torch.empty(output_shape, dtype=inp.dtype, device=inp.device) 

898 indices = torch.empty(output_shape, dtype=torch.int64, device=inp.device) 

899 block = _BOOL_FLAT_BLOCK 

900 chunks = triton.cdiv(reduction_size, block) 

901 chunk_shape = (total_outputs, chunks) 

902 counts = torch.empty(chunk_shape, dtype=torch.int64, device=inp.device) 

903 first_false = torch.empty(chunk_shape, dtype=torch.int64, device=inp.device) 

904 first_true = torch.empty(chunk_shape, dtype=torch.int64, device=inp.device) 

905 

906 with torch_device_fn.device(inp.device): 

907 median_bool_dim_count_chunks_kernel[(total_outputs * chunks,)]( 

908 inp, 

909 counts.reshape(-1), 

910 first_false.reshape(-1), 

911 first_true.reshape(-1), 

912 total_outputs, 

913 reduction_size, 

914 inner_size, 

915 chunks_per_output=chunks, 

916 BLOCK=block, 

917 num_warps=4, 

918 ) 

919 while chunks > _BOOL_COUNT_REDUCE_BLOCK: 

920 reduced_chunks = triton.cdiv(chunks, _BOOL_COUNT_REDUCE_BLOCK) 

921 reduced_shape = (total_outputs, reduced_chunks) 

922 reduced_counts = torch.empty( 

923 reduced_shape, dtype=torch.int64, device=inp.device 

924 ) 

925 reduced_first_false = torch.empty( 

926 reduced_shape, dtype=torch.int64, device=inp.device 

927 ) 

928 reduced_first_true = torch.empty( 

929 reduced_shape, dtype=torch.int64, device=inp.device 

930 ) 

931 median_bool_dim_reduce_chunks_kernel[(total_outputs, reduced_chunks)]( 

932 counts.reshape(-1), 

933 first_false.reshape(-1), 

934 first_true.reshape(-1), 

935 reduced_counts.reshape(-1), 

936 reduced_first_false.reshape(-1), 

937 reduced_first_true.reshape(-1), 

938 input_chunks=chunks, 

939 output_chunks=reduced_chunks, 

940 BLOCK=_BOOL_COUNT_REDUCE_BLOCK, 

941 num_warps=4, 

942 ) 

943 counts = reduced_counts 

944 first_false = reduced_first_false 

945 first_true = reduced_first_true 

946 chunks = reduced_chunks 

947 

948 finish_block = triton.next_power_of_2(chunks) 

949 median_bool_dim_finish_kernel[(total_outputs,)]( 

950 counts.reshape(-1), 

951 first_false.reshape(-1), 

952 first_true.reshape(-1), 

953 values.reshape(-1), 

954 indices.reshape(-1), 

955 reduction_size, 

956 chunks_per_output=chunks, 

957 BLOCK=finish_block, 

958 num_warps=min(8, max(1, finish_block // 32)), 

959 ) 

960 return values, indices 

961 

962 

963def _median_lastdim_sort(row_data, output_shape): 

964 width = row_data.shape[-1] 

965 rows = row_data.numel() // width 

966 values = torch.empty(output_shape, dtype=row_data.dtype, device=row_data.device) 

967 indices = torch.empty(output_shape, dtype=torch.int64, device=row_data.device) 

968 block = triton.next_power_of_2(width) 

969 num_warps = 8 if rows == 1 and block >= 1024 else min(8, max(4, block // 512)) 

970 with torch_device_fn.device(row_data.device): 

971 median_lastdim_sort_kernel[(rows,)]( 

972 row_data.reshape(rows, width), 

973 values.reshape(rows), 

974 indices.reshape(rows), 

975 WIDTH=width, 

976 BLOCK=block, 

977 num_warps=num_warps, 

978 ) 

979 return values, indices 

980 

981 

982def _use_lastdim_sort(dtype, width): 

983 if dtype == torch.bfloat16: 

984 return width <= _BF16_LASTDIM_SORT_LIMIT 

985 if dtype == torch.float16: 

986 return width <= _LASTDIM_SORT_LIMIT 

987 return False 

988 

989 

990def _use_f16_key_select(dtype, width): 

991 return ( 

992 dtype in _F16_KEY_SELECT_DTYPES 

993 and _F16_KEY_SELECT_MIN <= width <= _F16_KEY_SELECT_LIMIT 

994 ) 

995 

996 

997def _use_fp32_key_select(dtype, width): 

998 return ( 

999 dtype == torch.float32 

1000 and _FP32_KEY_SELECT_MIN <= width <= _FP32_KEY_SELECT_LIMIT 

1001 ) 

1002 

1003 

1004def _use_fp64_key_select(dtype, width): 

1005 return ( 

1006 dtype == torch.float64 

1007 and _FP64_KEY_SELECT_MIN <= width <= _FP64_KEY_SELECT_LIMIT 

1008 ) 

1009 

1010 

1011def _use_strided_select(dtype, width): 

1012 return _STRIDED_SELECT_MIN <= width <= _STRIDED_SELECT_LIMIT and dtype in ( 

1013 _F16_KEY_SELECT_DTYPES | {torch.float32} 

1014 ) 

1015 

1016 

1017def _use_float_key_select(dtype, width): 

1018 return ( 

1019 _use_f16_key_select(dtype, width) 

1020 or _use_fp32_key_select(dtype, width) 

1021 or _use_fp64_key_select(dtype, width) 

1022 ) 

1023 

1024 

1025def _median_float_key_select_rows(row_data, output_shape): 

1026 if _use_f16_key_select(row_data.dtype, row_data.shape[-1]): 

1027 return _median_f16_key_select(row_data, output_shape) 

1028 if _use_fp32_key_select(row_data.dtype, row_data.shape[-1]): 

1029 return _median_fp32_key_select(row_data, output_shape) 

1030 return _median_fp64_key_select(row_data, output_shape) 

1031 

1032 

1033def _median_float_key_select_dim(work, dim, output_shape, keepdim): 

1034 if dim == work.ndim - 1: 

1035 return _median_float_key_select_rows(work.contiguous(), output_shape) 

1036 if work.is_contiguous() and work.dtype in ( 

1037 _F16_KEY_SELECT_DTYPES | {torch.float32} 

1038 ): 

1039 if work.dtype in _F16_KEY_SELECT_DTYPES: 

1040 return _median_f16_strided_key_select(work, dim, output_shape) 

1041 return _median_fp32_strided_key_select(work, dim, output_shape) 

1042 

1043 rows = torch.movedim(work, dim, -1).contiguous() 

1044 row_output_shape = rows.shape[:-1] 

1045 values, indices = _median_float_key_select_rows(rows, row_output_shape) 

1046 if keepdim: 

1047 values = torch.movedim(values.unsqueeze(-1), -1, dim) 

1048 indices = torch.movedim(indices.unsqueeze(-1), -1, dim) 

1049 return values, indices 

1050 

1051 

1052def _median_int_lastdim_select(row_data, output_shape): 

1053 width = row_data.shape[-1] 

1054 rows = row_data.numel() // width 

1055 values = torch.empty(output_shape, dtype=row_data.dtype, device=row_data.device) 

1056 indices = torch.empty(output_shape, dtype=torch.int64, device=row_data.device) 

1057 block = triton.next_power_of_2(width) 

1058 search_steps = _int_search_steps(row_data.dtype) 

1059 with torch_device_fn.device(row_data.device): 

1060 median_int_lastdim_select_kernel[(rows,)]( 

1061 row_data.reshape(rows, width), 

1062 values.reshape(rows), 

1063 indices.reshape(rows), 

1064 WIDTH=width, 

1065 BLOCK=block, 

1066 SEARCH_STEPS=search_steps, 

1067 num_warps=min(8, max(4, block // 512)), 

1068 ) 

1069 return values, indices 

1070 

1071 

1072def _median_f16_key_select(row_data, output_shape): 

1073 width = row_data.shape[-1] 

1074 rows = row_data.numel() // width 

1075 values = torch.empty(output_shape, dtype=row_data.dtype, device=row_data.device) 

1076 indices = torch.empty(output_shape, dtype=torch.int64, device=row_data.device) 

1077 block = triton.next_power_of_2(width) 

1078 num_warps = 1 if block <= 1024 else 2 if block <= 2048 else 4 

1079 with torch_device_fn.device(row_data.device): 

1080 median_f16_key_select_kernel[(rows,)]( 

1081 row_data.reshape(rows, width), 

1082 values.reshape(rows), 

1083 indices.reshape(rows), 

1084 WIDTH=width, 

1085 BLOCK=block, 

1086 num_warps=num_warps, 

1087 ) 

1088 return values, indices 

1089 

1090 

1091def _median_f16_strided_key_select(inp, dim, output_shape): 

1092 reduction_size = inp.shape[dim] 

1093 inner_size = math.prod(inp.shape[dim + 1 :]) 

1094 total_outputs = math.prod(output_shape) 

1095 values = torch.empty(output_shape, dtype=inp.dtype, device=inp.device) 

1096 indices = torch.empty(output_shape, dtype=torch.int64, device=inp.device) 

1097 block = triton.next_power_of_2(reduction_size) 

1098 block_out = 2 

1099 num_warps = 1 if block <= 1024 else 2 if block <= 2048 else 4 

1100 with torch_device_fn.device(inp.device): 

1101 median_f16_strided_key_select_kernel[(triton.cdiv(total_outputs, block_out),)]( 

1102 inp, 

1103 values.reshape(-1), 

1104 indices.reshape(-1), 

1105 total_outputs, 

1106 reduction_size, 

1107 inner_size, 

1108 BLOCK=block, 

1109 BLOCK_OUT=block_out, 

1110 num_warps=num_warps, 

1111 ) 

1112 return values, indices 

1113 

1114 

1115def _median_fp32_key_select(row_data, output_shape): 

1116 width = row_data.shape[-1] 

1117 rows = row_data.numel() // width 

1118 values = torch.empty(output_shape, dtype=row_data.dtype, device=row_data.device) 

1119 indices = torch.empty(output_shape, dtype=torch.int64, device=row_data.device) 

1120 block = triton.next_power_of_2(width) 

1121 num_warps = 2 if block <= 1024 else 8 

1122 with torch_device_fn.device(row_data.device): 

1123 median_fp32_key_select_kernel[(rows,)]( 

1124 row_data.reshape(rows, width), 

1125 values.reshape(rows), 

1126 indices.reshape(rows), 

1127 WIDTH=width, 

1128 BLOCK=block, 

1129 num_warps=num_warps, 

1130 ) 

1131 return values, indices 

1132 

1133 

1134def _median_fp64_key_select(row_data, output_shape): 

1135 width = row_data.shape[-1] 

1136 rows = row_data.numel() // width 

1137 values = torch.empty(output_shape, dtype=row_data.dtype, device=row_data.device) 

1138 indices = torch.empty(output_shape, dtype=torch.int64, device=row_data.device) 

1139 block = triton.next_power_of_2(width) 

1140 num_warps = 2 if block <= 1024 else 8 

1141 with torch_device_fn.device(row_data.device): 

1142 median_fp64_key_select_kernel[(rows,)]( 

1143 row_data.reshape(rows, width), 

1144 values.reshape(rows), 

1145 indices.reshape(rows), 

1146 WIDTH=width, 

1147 BLOCK=block, 

1148 num_warps=num_warps, 

1149 ) 

1150 return values, indices 

1151 

1152 

1153def _median_fp32_strided_key_select(inp, dim, output_shape): 

1154 reduction_size = inp.shape[dim] 

1155 inner_size = math.prod(inp.shape[dim + 1 :]) 

1156 total_outputs = math.prod(output_shape) 

1157 values = torch.empty(output_shape, dtype=inp.dtype, device=inp.device) 

1158 indices = torch.empty(output_shape, dtype=torch.int64, device=inp.device) 

1159 block = triton.next_power_of_2(reduction_size) 

1160 block_out = 2 

1161 num_warps = 2 if block <= 1024 else 8 

1162 with torch_device_fn.device(inp.device): 

1163 median_fp32_strided_key_select_kernel[(triton.cdiv(total_outputs, block_out),)]( 

1164 inp, 

1165 values.reshape(-1), 

1166 indices.reshape(-1), 

1167 total_outputs, 

1168 reduction_size, 

1169 inner_size, 

1170 BLOCK=block, 

1171 BLOCK_OUT=block_out, 

1172 num_warps=num_warps, 

1173 ) 

1174 return values, indices 

1175 

1176 

1177def _median_direct_dim(inp, dim, output_shape): 

1178 reduction_size = inp.shape[dim] 

1179 inner_size = math.prod(inp.shape[dim + 1 :]) 

1180 total_outputs = math.prod(output_shape) 

1181 values = torch.empty(output_shape, dtype=inp.dtype, device=inp.device) 

1182 indices = torch.empty(output_shape, dtype=torch.int64, device=inp.device) 

1183 block_n = triton.next_power_of_2(reduction_size) 

1184 block_out = 2 if block_n >= 128 else 16 

1185 if block_n >= 128: 

1186 num_warps = 8 if inp.dtype in (torch.int32, torch.int64) else 4 

1187 else: 

1188 num_warps = 1 

1189 with torch_device_fn.device(inp.device): 

1190 median_small_dim_kernel[(triton.cdiv(total_outputs, block_out),)]( 

1191 inp, 

1192 values.reshape(-1), 

1193 indices.reshape(-1), 

1194 total_outputs, 

1195 reduction_size, 

1196 inner_size, 

1197 BLOCK_N=block_n, 

1198 BLOCK_OUT=block_out, 

1199 num_warps=num_warps, 

1200 ) 

1201 return values, indices 

1202 

1203 

1204def _copy_out(src, out, name): 

1205 if out.device != src.device: 

1206 raise RuntimeError( 

1207 f"Expected {name} tensor to have device {src.device}, " 

1208 f"but got {out.device} instead" 

1209 ) 

1210 if out.dtype != src.dtype: 

1211 raise RuntimeError( 

1212 f"Expected out tensor to have dtype {src.dtype}, but got {out.dtype}" 

1213 ) 

1214 out.resize_as_(src) 

1215 out.copy_(src) 

1216 return out 

1217 

1218 

1219def median(inp): 

1220 logger.debug("GEMS MEDIAN") 

1221 

1222 inp = _anonymous(inp) 

1223 if inp.numel() == 0: 

1224 return _empty_result_value(inp) 

1225 if inp.dtype.is_complex: 

1226 raise RuntimeError("Sort does not support complex dtypes on CPU") 

1227 if inp.numel() == 1: 

1228 return inp.reshape(()).clone() 

1229 

1230 flat = inp.contiguous().reshape(-1) 

1231 row_data = flat.reshape(1, inp.numel()) 

1232 if _use_float_key_select(inp.dtype, inp.numel()): 

1233 values, _ = _median_float_key_select_rows(row_data, ()) 

1234 return values.reshape(()) 

1235 if inp.dtype in _DIRECT_REDUCTION_DTYPES and inp.numel() <= _DIRECT_FLAT_LIMIT: 

1236 return _median_small_flat(flat) 

1237 if inp.dtype == torch.bool: 

1238 return _median_bool_flat(flat) 

1239 

1240 if inp.dtype in _FLAT_SORT_DTYPES and inp.numel() <= _FLAT_SORT_LIMIT: 

1241 values, _ = _median_lastdim_sort(row_data, ()) 

1242 elif _use_fp32_key_select(inp.dtype, inp.numel()): 

1243 values, _ = _median_fp32_key_select(row_data, ()) 

1244 elif _use_fp64_key_select(inp.dtype, inp.numel()): 

1245 values, _ = _median_fp64_key_select(row_data, ()) 

1246 elif ( 

1247 inp.numel() <= _INT_LASTDIM_SELECT_LIMIT 

1248 and inp.dtype in _INT_LASTDIM_SELECT_DTYPES 

1249 ): 

1250 values, _ = _median_int_lastdim_select(row_data, ()) 

1251 else: 

1252 values, _ = _median_from_rows(row_data, ()) 

1253 return values.reshape(()) 

1254 

1255 

1256def median_out(inp, *, out): 

1257 logger.debug("GEMS MEDIAN.OUT") 

1258 return _copy_out(median(inp), out, "out") 

1259 

1260 

1261def median_dim(inp, dim=0, keepdim=False): 

1262 logger.debug("GEMS MEDIAN.DIM") 

1263 

1264 if isinstance(dim, str): 

1265 dim = _name_to_dim(inp, dim) 

1266 dim = _canonical_dim(inp.ndim, dim) 

1267 names = inp.names if _has_names(inp) else None 

1268 work = _anonymous(inp) 

1269 

1270 if work.ndim == 0: 

1271 if work.dtype.is_complex: 

1272 _raise_dim_dtype(work.dtype) 

1273 return MedianResult( 

1274 values=work.clone(), 

1275 indices=torch.zeros((), dtype=torch.int64, device=work.device), 

1276 ) 

1277 

1278 if work.shape[dim] == 0: 

1279 raise IndexError( 

1280 f"median(): Expected reduction dim {dim} to have non-zero size." 

1281 ) 

1282 

1283 output_shape = list(work.shape) 

1284 if keepdim: 

1285 output_shape[dim] = 1 

1286 else: 

1287 del output_shape[dim] 

1288 output_names = _kept_names(names, dim, keepdim) 

1289 

1290 if work.numel() == 0: 

1291 values = torch.empty(output_shape, dtype=work.dtype, device=work.device) 

1292 indices = torch.empty(output_shape, dtype=torch.int64, device=work.device) 

1293 else: 

1294 if work.dtype.is_complex: 

1295 _raise_dim_dtype(work.dtype) 

1296 if work.dtype == torch.bool: 

1297 values, indices = _median_bool_dim(work.contiguous(), dim, output_shape) 

1298 elif _use_float_key_select(work.dtype, work.shape[dim]): 

1299 values, indices = _median_float_key_select_dim( 

1300 work, dim, output_shape, keepdim 

1301 ) 

1302 elif ( 

1303 work.shape[dim] <= _DIRECT_REDUCTION_LIMIT 

1304 and work.dtype in _DIRECT_REDUCTION_DTYPES 

1305 ): 

1306 values, indices = _median_direct_dim(work.contiguous(), dim, output_shape) 

1307 elif ( 

1308 dim != work.ndim - 1 

1309 and work.is_contiguous() 

1310 and _use_strided_select(work.dtype, work.shape[dim]) 

1311 ): 

1312 if work.dtype in _F16_KEY_SELECT_DTYPES: 

1313 values, indices = _median_f16_strided_key_select( 

1314 work, dim, output_shape 

1315 ) 

1316 elif work.dtype == torch.float32: 

1317 values, indices = _median_fp32_strided_key_select( 

1318 work, dim, output_shape 

1319 ) 

1320 elif dim == work.ndim - 1 and _use_f16_key_select(work.dtype, work.shape[dim]): 

1321 values, indices = _median_f16_key_select(work.contiguous(), output_shape) 

1322 elif dim == work.ndim - 1 and _use_lastdim_sort(work.dtype, work.shape[dim]): 

1323 values, indices = _median_lastdim_sort(work.contiguous(), output_shape) 

1324 elif dim == work.ndim - 1 and _use_fp32_key_select(work.dtype, work.shape[dim]): 

1325 values, indices = _median_fp32_key_select(work.contiguous(), output_shape) 

1326 elif dim == work.ndim - 1 and _use_fp64_key_select(work.dtype, work.shape[dim]): 

1327 values, indices = _median_fp64_key_select(work.contiguous(), output_shape) 

1328 elif ( 

1329 dim == work.ndim - 1 

1330 and work.shape[dim] <= _INT_LASTDIM_SELECT_LIMIT 

1331 and work.dtype in _INT_LASTDIM_SELECT_DTYPES 

1332 ): 

1333 values, indices = _median_int_lastdim_select( 

1334 work.contiguous(), output_shape 

1335 ) 

1336 else: 

1337 rows = torch.movedim(work, dim, -1).contiguous() 

1338 row_output_shape = rows.shape[:-1] 

1339 row_width = rows.shape[-1] 

1340 if _use_f16_key_select(rows.dtype, row_width): 

1341 values, indices = _median_f16_key_select(rows, row_output_shape) 

1342 elif _use_lastdim_sort(rows.dtype, row_width): 

1343 values, indices = _median_lastdim_sort(rows, row_output_shape) 

1344 elif _use_fp32_key_select(rows.dtype, row_width): 

1345 values, indices = _median_fp32_key_select(rows, row_output_shape) 

1346 elif _use_fp64_key_select(rows.dtype, row_width): 

1347 values, indices = _median_fp64_key_select(rows, row_output_shape) 

1348 elif ( 

1349 row_width <= _INT_LASTDIM_SELECT_LIMIT 

1350 and rows.dtype in _INT_LASTDIM_SELECT_DTYPES 

1351 ): 

1352 values, indices = _median_int_lastdim_select(rows, row_output_shape) 

1353 else: 

1354 values, indices = _median_from_rows(rows, row_output_shape) 

1355 if keepdim: 

1356 values = torch.movedim(values.unsqueeze(-1), -1, dim) 

1357 indices = torch.movedim(indices.unsqueeze(-1), -1, dim) 

1358 

1359 if output_names is not None: 

1360 values = values.refine_names(*output_names) 

1361 indices = indices.refine_names(*output_names) 

1362 

1363 return MedianResult(values=values, indices=indices) 

1364 

1365 

1366def median_dim_values(inp, dim=0, keepdim=False, *, values, indices): 

1367 logger.debug("GEMS MEDIAN.DIM_VALUES") 

1368 result = median_dim(inp, dim=dim, keepdim=keepdim) 

1369 _copy_out(result.values, values, "values") 

1370 _copy_out(result.indices, indices, "indices") 

1371 return MedianResult(values=values, indices=indices)