Coverage for src/flag_gems/ops/median.py: 46%
799 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
2import math
3from collections import namedtuple
5import torch
6import triton
7import triton.language as tl
9from flag_gems.ops.topk import _get_iinfo_val
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils import libentry
13logger = logging.getLogger(__name__)
15MedianResult = namedtuple("median", ["values", "indices"])
17_DIRECT_REDUCTION_LIMIT = 256
18_DIRECT_FLAT_LIMIT = 256
19_BOOL_FLAT_BLOCK = 1024
20_BOOL_COUNT_REDUCE_BLOCK = 1024
21_DIRECT_REDUCTION_DTYPES = {
22 torch.bool,
23 torch.float16,
24 torch.bfloat16,
25 torch.float32,
26 torch.float64,
27 torch.int8,
28 torch.uint8,
29 torch.int16,
30 torch.int32,
31 torch.int64,
32}
33_FLAT_SORT_LIMIT = 1024
34_LASTDIM_SORT_LIMIT = 1024
35_BF16_LASTDIM_SORT_LIMIT = 2048
36_LASTDIM_SORT_DTYPES = {torch.float16, torch.bfloat16}
37_FLAT_SORT_DTYPES = _LASTDIM_SORT_DTYPES | {torch.float32}
38_F16_KEY_SELECT_MIN = 2
39_F16_KEY_SELECT_LIMIT = 16384
40_F16_KEY_SELECT_DTYPES = {torch.float16, torch.bfloat16}
41_FP32_KEY_SELECT_MIN = 2
42_FP32_KEY_SELECT_LIMIT = 16384
43_FP64_KEY_SELECT_MIN = 2
44_FP64_KEY_SELECT_LIMIT = 8192
45_INT_LASTDIM_SELECT_LIMIT = 16384
46_INT_LASTDIM_SELECT_DTYPES = {
47 torch.int8,
48 torch.uint8,
49 torch.int16,
50 torch.int32,
51 torch.int64,
52}
53_STRIDED_SELECT_MIN = _DIRECT_REDUCTION_LIMIT + 1
54_STRIDED_SELECT_LIMIT = 4096
57@libentry()
58@triton.jit
59def median_small_dim_kernel(
60 inp,
61 values,
62 indices,
63 total_outputs,
64 reduction_size,
65 inner_size,
66 BLOCK_N: tl.constexpr,
67 BLOCK_OUT: tl.constexpr,
68):
69 out_offsets = tl.program_id(0) * BLOCK_OUT + tl.arange(0, BLOCK_OUT)
70 out_mask = out_offsets < total_outputs
71 inner_offsets = out_offsets % inner_size
72 outer_offsets = out_offsets // inner_size
74 reduction_offsets = tl.arange(0, BLOCK_N)
75 sample_mask = (reduction_offsets[None, :] < reduction_size) & out_mask[:, None]
76 sample_ptrs = (
77 inp
78 + outer_offsets[:, None] * reduction_size * inner_size
79 + reduction_offsets[None, :] * inner_size
80 + inner_offsets[:, None]
81 )
83 if inp.dtype.element_ty.is_floating():
84 high = float("inf")
85 else:
86 high = _get_iinfo_val(inp.dtype.element_ty, return_max=True)
88 samples = tl.load(sample_ptrs, mask=sample_mask, other=high)
89 sortable = samples
91 if inp.dtype.element_ty.is_floating():
92 nan_mask = sample_mask & (samples != samples)
93 sortable = tl.where(nan_mask, high, samples)
95 ordered = tl.sort(sortable, dim=1, descending=False)
96 rank = (reduction_size - 1) // 2
97 rank_mask = reduction_offsets[None, :] == rank
98 median_values = tl.sum(tl.where(rank_mask, ordered, tl.zeros_like(ordered)), axis=1)
100 first_match = tl.argmax(
101 (sample_mask & (samples == median_values[:, None])).to(tl.int32), axis=1
102 )
104 if inp.dtype.element_ty.is_floating():
105 nan_i32 = nan_mask.to(tl.int32)
106 has_nan = tl.max(nan_i32, axis=1) != 0
107 first_nan = tl.argmax(nan_i32, axis=1)
108 nan_values = tl.load(
109 inp
110 + outer_offsets * reduction_size * inner_size
111 + first_nan * inner_size
112 + inner_offsets,
113 mask=out_mask,
114 other=0.0,
115 )
116 median_values = tl.where(has_nan, nan_values, median_values)
117 first_match = tl.where(has_nan, first_nan, first_match)
119 tl.store(values + out_offsets, median_values, mask=out_mask)
120 tl.store(indices + out_offsets, first_match.to(tl.int64), mask=out_mask)
123@libentry()
124@triton.jit
125def median_small_flat_kernel(
126 inp,
127 value,
128 WIDTH: tl.constexpr,
129 BLOCK: tl.constexpr,
130):
131 offsets = tl.arange(0, BLOCK)
132 valid = offsets < WIDTH
134 if inp.dtype.element_ty.is_floating():
135 high = float("inf")
136 elif inp.dtype.element_ty is tl.int1:
137 high = True
138 else:
139 high = _get_iinfo_val(inp.dtype.element_ty, return_max=True)
141 data = tl.load(inp + offsets, mask=valid, other=high)
142 sortable = data
144 if inp.dtype.element_ty.is_floating():
145 nan_mask = valid & (data != data)
146 sortable = tl.where(nan_mask, high, data)
148 ordered = tl.sort(sortable, descending=False)
149 rank = (WIDTH - 1) // 2
150 median_value = tl.sum(
151 tl.where(offsets == rank, ordered, tl.zeros_like(ordered)), axis=0
152 )
154 if inp.dtype.element_ty.is_floating():
155 nan_i32 = nan_mask.to(tl.int32)
156 has_nan = tl.max(nan_i32, axis=0) != 0
157 first_nan = tl.argmax(nan_i32, axis=0)
158 nan_value = tl.load(inp + first_nan, mask=has_nan, other=0.0)
159 median_value = tl.where(has_nan, nan_value, median_value)
161 tl.store(value, median_value)
164@libentry()
165@triton.jit
166def median_bool_count_kernel(
167 inp,
168 counts,
169 WIDTH: tl.constexpr,
170 BLOCK: tl.constexpr,
171):
172 block_id = tl.program_id(0)
173 offsets = block_id * BLOCK + tl.arange(0, BLOCK)
174 valid = offsets < WIDTH
175 data = tl.load(inp + offsets, mask=valid, other=False)
176 true_count = tl.sum((valid & data).to(tl.int64), axis=0)
177 tl.store(counts + block_id, true_count)
180@libentry()
181@triton.jit
182def median_bool_from_counts_kernel(
183 counts,
184 value,
185 WIDTH: tl.constexpr,
186 NUM_BLOCKS: tl.constexpr,
187 BLOCK: tl.constexpr,
188):
189 offsets = tl.arange(0, BLOCK)
190 valid = offsets < NUM_BLOCKS
191 block_counts = tl.load(counts + offsets, mask=valid, other=0)
192 true_count = tl.sum(block_counts, axis=0)
193 rank = (WIDTH - 1) // 2
194 false_count = WIDTH - true_count
195 median_value = rank >= false_count
196 tl.store(value, median_value)
199@libentry()
200@triton.jit
201def median_bool_reduce_counts_kernel(
202 counts_in,
203 counts_out,
204 WIDTH: tl.constexpr,
205 BLOCK: tl.constexpr,
206):
207 block_id = tl.program_id(0)
208 offsets = block_id * BLOCK + tl.arange(0, BLOCK)
209 valid = offsets < WIDTH
210 block_counts = tl.load(counts_in + offsets, mask=valid, other=0)
211 count = tl.sum(block_counts, axis=0)
212 tl.store(counts_out + block_id, count)
215@libentry()
216@triton.jit
217def median_bool_dim_count_chunks_kernel(
218 inp,
219 counts,
220 first_false,
221 first_true,
222 total_outputs,
223 reduction_size,
224 inner_size,
225 chunks_per_output: tl.constexpr,
226 BLOCK: tl.constexpr,
227):
228 pid = tl.program_id(0)
229 out_offset = pid // chunks_per_output
230 chunk_id = pid - out_offset * chunks_per_output
231 out_mask = out_offset < total_outputs
232 inner_offset = out_offset % inner_size
233 outer_offset = out_offset // inner_size
235 cols = chunk_id * BLOCK + tl.arange(0, BLOCK)
236 valid = (cols < reduction_size) & out_mask
237 ptrs = (
238 inp
239 + outer_offset * reduction_size * inner_size
240 + cols * inner_size
241 + inner_offset
242 )
243 data = tl.load(ptrs, mask=valid, other=False)
245 true_mask = valid & data
246 false_mask = valid & ~data
247 true_count = tl.sum(true_mask.to(tl.int64), axis=0)
248 first_false_idx = tl.min(tl.where(false_mask, cols, reduction_size), axis=0)
249 first_true_idx = tl.min(tl.where(true_mask, cols, reduction_size), axis=0)
251 tl.store(counts + pid, true_count)
252 tl.store(first_false + pid, first_false_idx.to(tl.int64))
253 tl.store(first_true + pid, first_true_idx.to(tl.int64))
256@libentry()
257@triton.jit
258def median_bool_dim_reduce_chunks_kernel(
259 counts_in,
260 first_false_in,
261 first_true_in,
262 counts_out,
263 first_false_out,
264 first_true_out,
265 input_chunks: tl.constexpr,
266 output_chunks: tl.constexpr,
267 BLOCK: tl.constexpr,
268):
269 row = tl.program_id(0)
270 out_chunk = tl.program_id(1)
271 chunk_offsets = out_chunk * BLOCK + tl.arange(0, BLOCK)
272 valid = chunk_offsets < input_chunks
273 in_base = row * input_chunks + chunk_offsets
275 counts = tl.load(counts_in + in_base, mask=valid, other=0)
276 first_false = tl.load(
277 first_false_in + in_base, mask=valid, other=9223372036854775807
278 )
279 first_true = tl.load(first_true_in + in_base, mask=valid, other=9223372036854775807)
281 true_count = tl.sum(counts, axis=0)
282 first_false_idx = tl.min(first_false, axis=0)
283 first_true_idx = tl.min(first_true, axis=0)
284 out_base = row * output_chunks + out_chunk
285 tl.store(counts_out + out_base, true_count)
286 tl.store(first_false_out + out_base, first_false_idx)
287 tl.store(first_true_out + out_base, first_true_idx)
290@libentry()
291@triton.jit
292def median_bool_dim_finish_kernel(
293 counts,
294 first_false,
295 first_true,
296 values,
297 indices,
298 reduction_size,
299 chunks_per_output: tl.constexpr,
300 BLOCK: tl.constexpr,
301):
302 row = tl.program_id(0)
303 chunk_offsets = tl.arange(0, BLOCK)
304 valid = chunk_offsets < chunks_per_output
305 base = row * chunks_per_output + chunk_offsets
307 block_counts = tl.load(counts + base, mask=valid, other=0)
308 true_count = tl.sum(block_counts, axis=0)
309 false_count = reduction_size - true_count
310 rank = (reduction_size - 1) // 2
311 median_value = rank >= false_count
313 false_indices = tl.load(first_false + base, mask=valid, other=9223372036854775807)
314 true_indices = tl.load(first_true + base, mask=valid, other=9223372036854775807)
315 first_false_idx = tl.min(false_indices, axis=0)
316 first_true_idx = tl.min(true_indices, axis=0)
317 first_match = tl.where(median_value, first_true_idx, first_false_idx)
319 tl.store(values + row, median_value)
320 tl.store(indices + row, first_match)
323@libentry()
324@triton.jit
325def median_lastdim_sort_kernel(
326 row_data,
327 values,
328 indices,
329 WIDTH: tl.constexpr,
330 BLOCK: tl.constexpr,
331):
332 row = tl.program_id(0)
333 cols = tl.arange(0, BLOCK)
334 valid = cols < WIDTH
335 base = row_data + row * WIDTH
336 data = tl.load(base + cols, mask=valid, other=float("inf"))
338 nan_mask = valid & (data != data)
339 sortable = tl.where(nan_mask, float("inf"), data)
340 ordered = tl.sort(sortable, descending=False)
341 rank = (WIDTH - 1) // 2
342 median_value = tl.sum(
343 tl.where(cols == rank, ordered, tl.zeros_like(ordered)), axis=0
344 )
346 first_match = tl.argmax((valid & (data == median_value)).to(tl.int32), axis=0)
347 nan_i32 = nan_mask.to(tl.int32)
348 has_nan = tl.max(nan_i32, axis=0) != 0
349 first_nan = tl.argmax(nan_i32, axis=0)
350 nan_value = tl.load(base + first_nan, mask=has_nan, other=0.0)
351 median_value = tl.where(has_nan, nan_value, median_value)
352 first_match = tl.where(has_nan, first_nan, first_match)
354 tl.store(values + row, median_value)
355 tl.store(indices + row, first_match.to(tl.int64))
358@libentry()
359@triton.jit
360def median_int_lastdim_select_kernel(
361 row_data,
362 values,
363 indices,
364 WIDTH: tl.constexpr,
365 BLOCK: tl.constexpr,
366 SEARCH_STEPS: tl.constexpr,
367):
368 row = tl.program_id(0)
369 cols = tl.arange(0, BLOCK)
370 valid = cols < WIDTH
371 base = row_data + row * WIDTH
372 data = tl.load(base + cols, mask=valid, other=0)
374 dtype = row_data.dtype.element_ty
375 high = _get_iinfo_val(dtype, return_max=True)
376 low = _get_iinfo_val(dtype, return_max=False)
377 row_min = tl.min(tl.where(valid, data, high), axis=0).to(tl.int64)
378 row_max = tl.max(tl.where(valid, data, low), axis=0).to(tl.int64)
380 lo = row_min
381 hi = row_max
382 rank = (WIDTH - 1) // 2
383 for _ in tl.static_range(0, SEARCH_STEPS):
384 mid = lo + ((hi - lo) // 2)
385 le_count = tl.sum((valid & (data <= mid.to(dtype))).to(tl.int32), axis=0)
386 take_left = le_count > rank
387 hi = tl.where(take_left, mid, hi)
388 lo = tl.where(take_left, lo, mid + 1)
390 median_value = lo.to(dtype)
391 first_match = tl.argmax((valid & (data == median_value)).to(tl.int32), axis=0)
392 tl.store(values + row, median_value)
393 tl.store(indices + row, first_match.to(tl.int64))
396@triton.jit
397def _fp32_order_key(x):
398 bits = x.to(tl.uint32, bitcast=True)
399 signed = x.to(tl.int32, bitcast=True)
400 sign = signed >> 31
401 sign_mask = tl.full((), 0x80000000, dtype=tl.uint32)
402 mask = sign_mask | sign.to(tl.uint32, bitcast=True)
403 return bits ^ mask
406@triton.jit
407def _fp64_order_key(x):
408 bits = x.to(tl.uint64, bitcast=True)
409 signed = x.to(tl.int64, bitcast=True)
410 sign = signed >> 63
411 sign_mask = tl.full((), 1, dtype=tl.uint64) << 63
412 mask = sign_mask | sign.to(tl.uint64, bitcast=True)
413 return bits ^ mask
416@triton.jit
417def _f16_order_key(x):
418 bits = x.to(tl.uint16, bitcast=True)
419 signed = x.to(tl.int16, bitcast=True)
420 sign = signed >> 15
421 sign_mask = tl.full((), 0x8000, dtype=tl.uint16)
422 mask = sign_mask | sign.to(tl.uint16, bitcast=True)
423 return bits ^ mask
426@libentry()
427@triton.jit
428def median_f16_key_select_kernel(
429 row_data,
430 values,
431 indices,
432 WIDTH: tl.constexpr,
433 BLOCK: tl.constexpr,
434):
435 row = tl.program_id(0)
436 cols = tl.arange(0, BLOCK)
437 valid = cols < WIDTH
438 base = row_data + row * WIDTH
439 data = tl.load(base + cols, mask=valid, other=0.0)
441 nan_mask = valid & (data != data)
442 nan_i32 = nan_mask.to(tl.int32)
443 has_nan = tl.max(nan_i32, axis=0) != 0
444 first_nan = tl.argmax(nan_i32, axis=0)
445 nan_value = tl.load(base + first_nan, mask=has_nan, other=0.0)
447 finite = valid & ~nan_mask
448 neg_inf_mask = finite & (data == -float("inf"))
449 pos_inf_mask = finite & (data == float("inf"))
450 real_finite = finite & ~(neg_inf_mask | pos_inf_mask)
451 neg_inf_count = tl.sum(neg_inf_mask.to(tl.int32), axis=0)
452 real_finite_count = tl.sum(real_finite.to(tl.int32), axis=0)
454 rank = (WIDTH - 1) // 2
455 search_rank = rank - neg_inf_count
456 take_neg_inf = rank < neg_inf_count
457 take_pos_inf = search_rank >= real_finite_count
459 keys = _f16_order_key(data).to(tl.uint32)
460 key_min_fill = tl.full((), 0xFFFF, dtype=tl.uint32)
461 key_max_fill = tl.full((), 0, dtype=tl.uint32)
462 row_min = tl.min(tl.where(real_finite, keys, key_min_fill), axis=0)
463 row_max = tl.max(tl.where(real_finite, keys, key_max_fill), axis=0)
464 has_real_finite = real_finite_count != 0
465 row_min = tl.where(has_real_finite, row_min, 0)
466 row_max = tl.where(has_real_finite, row_max, 0)
468 lo = row_min
469 hi = row_max
470 for _ in tl.static_range(0, 16):
471 mid = lo + ((hi - lo) >> 1)
472 le_count = tl.sum((real_finite & (keys <= mid)).to(tl.int32), axis=0)
473 take_left = le_count > search_rank
474 hi = tl.where(take_left, mid, hi)
475 lo = tl.where(take_left, lo, mid + 1)
477 selected_key = lo
478 key_match = real_finite & (keys == selected_key)
479 selected_key_first = tl.argmax(key_match.to(tl.int32), axis=0)
480 selected_value = tl.load(base + selected_key_first)
482 first_neg_inf = tl.argmax(neg_inf_mask.to(tl.int32), axis=0)
483 neg_inf_value = tl.load(base + first_neg_inf, mask=take_neg_inf, other=0.0)
484 first_pos_inf = tl.argmax(pos_inf_mask.to(tl.int32), axis=0)
485 pos_inf_value = tl.load(base + first_pos_inf, mask=take_pos_inf, other=0.0)
486 selected_value = tl.where(take_neg_inf, neg_inf_value, selected_value)
487 selected_value = tl.where(take_pos_inf, pos_inf_value, selected_value)
488 selected_key_first = tl.where(take_neg_inf, first_neg_inf, selected_key_first)
489 selected_key_first = tl.where(take_pos_inf, first_pos_inf, selected_key_first)
491 selected_value = tl.where(has_nan, nan_value, selected_value)
492 first_match = tl.where(has_nan, first_nan, selected_key_first)
493 tl.store(values + row, selected_value)
494 tl.store(indices + row, first_match.to(tl.int64))
497@libentry()
498@triton.jit
499def median_fp32_key_select_kernel(
500 row_data,
501 values,
502 indices,
503 WIDTH: tl.constexpr,
504 BLOCK: tl.constexpr,
505):
506 row = tl.program_id(0)
507 cols = tl.arange(0, BLOCK)
508 valid = cols < WIDTH
509 base = row_data + row * WIDTH
510 data = tl.load(base + cols, mask=valid, other=0.0)
512 nan_mask = valid & (data != data)
513 nan_i32 = nan_mask.to(tl.int32)
514 has_nan = tl.max(nan_i32, axis=0) != 0
515 first_nan = tl.argmax(nan_i32, axis=0)
516 nan_value = tl.load(base + first_nan, mask=has_nan, other=0.0)
518 keys = _fp32_order_key(data)
519 finite = valid & ~nan_mask
520 key_min_fill = tl.full((), 0xFFFFFFFF, dtype=tl.uint32)
521 key_max_fill = tl.full((), 0, dtype=tl.uint32)
522 row_min = tl.min(tl.where(finite, keys, key_min_fill), axis=0)
523 row_max = tl.max(tl.where(finite, keys, key_max_fill), axis=0)
525 lo = row_min
526 hi = row_max
527 rank = (WIDTH - 1) // 2
528 for _ in tl.static_range(0, 32):
529 mid = lo + ((hi - lo) >> 1)
530 le_count = tl.sum((finite & (keys <= mid)).to(tl.int32), axis=0)
531 take_left = le_count > rank
532 hi = tl.where(take_left, mid, hi)
533 lo = tl.where(take_left, lo, mid + 1)
535 selected_key = lo
536 key_match = finite & (keys == selected_key)
537 selected_key_first = tl.argmax(key_match.to(tl.int32), axis=0)
538 selected_value = tl.load(base + selected_key_first)
540 selected_value = tl.where(has_nan, nan_value, selected_value)
541 first_match = tl.where(has_nan, first_nan, selected_key_first)
542 tl.store(values + row, selected_value)
543 tl.store(indices + row, first_match.to(tl.int64))
546@libentry()
547@triton.jit
548def median_fp64_key_select_kernel(
549 row_data,
550 values,
551 indices,
552 WIDTH: tl.constexpr,
553 BLOCK: tl.constexpr,
554):
555 row = tl.program_id(0)
556 cols = tl.arange(0, BLOCK)
557 valid = cols < WIDTH
558 base = row_data + row * WIDTH
559 data = tl.load(base + cols, mask=valid, other=0.0)
561 nan_mask = valid & (data != data)
562 nan_i64 = nan_mask.to(tl.int64)
563 has_nan = tl.max(nan_i64, axis=0) != 0
564 first_nan = tl.argmax(nan_i64, axis=0)
565 nan_value = tl.load(base + first_nan, mask=has_nan, other=0.0)
567 keys = _fp64_order_key(data)
568 finite = valid & ~nan_mask
569 key_min_fill = tl.full((), 0xFFFFFFFFFFFFFFFF, dtype=tl.uint64)
570 key_max_fill = tl.full((), 0, dtype=tl.uint64)
571 row_min = tl.min(tl.where(finite, keys, key_min_fill), axis=0)
572 row_max = tl.max(tl.where(finite, keys, key_max_fill), axis=0)
574 lo = row_min
575 hi = row_max
576 rank = (WIDTH - 1) // 2
577 for _ in tl.static_range(0, 64):
578 mid = lo + ((hi - lo) >> 1)
579 le_count = tl.sum((finite & (keys <= mid)).to(tl.int32), axis=0)
580 take_left = le_count > rank
581 hi = tl.where(take_left, mid, hi)
582 lo = tl.where(take_left, lo, mid + 1)
584 selected_key = lo
585 key_match = finite & (keys == selected_key)
586 selected_key_first = tl.argmax(key_match.to(tl.int32), axis=0)
587 selected_value = tl.load(base + selected_key_first)
589 selected_value = tl.where(has_nan, nan_value, selected_value)
590 first_match = tl.where(has_nan, first_nan, selected_key_first)
591 tl.store(values + row, selected_value)
592 tl.store(indices + row, first_match.to(tl.int64))
595@libentry()
596@triton.jit
597def median_f16_strided_key_select_kernel(
598 inp,
599 values,
600 indices,
601 total_outputs,
602 reduction_size,
603 inner_size,
604 BLOCK: tl.constexpr,
605 BLOCK_OUT: tl.constexpr,
606):
607 out_offsets = tl.program_id(0) * BLOCK_OUT + tl.arange(0, BLOCK_OUT)
608 out_mask = out_offsets < total_outputs
609 inner_offsets = out_offsets % inner_size
610 outer_offsets = out_offsets // inner_size
611 cols = tl.arange(0, BLOCK)
612 valid = (cols[None, :] < reduction_size) & out_mask[:, None]
613 ptrs = (
614 inp
615 + outer_offsets[:, None] * reduction_size * inner_size
616 + cols[None, :] * inner_size
617 + inner_offsets[:, None]
618 )
619 data = tl.load(ptrs, mask=valid, other=0.0)
621 nan_mask = valid & (data != data)
622 nan_i32 = nan_mask.to(tl.int32)
623 has_nan = tl.max(nan_i32, axis=1) != 0
624 first_nan = tl.argmax(nan_i32, axis=1)
625 nan_value = tl.load(
626 inp
627 + outer_offsets * reduction_size * inner_size
628 + first_nan * inner_size
629 + inner_offsets,
630 mask=out_mask,
631 other=0.0,
632 )
634 keys = _f16_order_key(data).to(tl.uint32)
635 finite = valid & ~nan_mask
636 key_min_fill = tl.full((), 0xFFFF, dtype=tl.uint32)
637 key_max_fill = tl.full((), 0, dtype=tl.uint32)
638 row_min = tl.min(tl.where(finite, keys, key_min_fill), axis=1)
639 row_max = tl.max(tl.where(finite, keys, key_max_fill), axis=1)
641 lo = row_min
642 hi = row_max
643 rank = (reduction_size - 1) // 2
644 for _ in tl.static_range(0, 16):
645 mid = lo + ((hi - lo) >> 1)
646 le_count = tl.sum((finite & (keys <= mid[:, None])).to(tl.int32), axis=1)
647 take_left = le_count > rank
648 hi = tl.where(take_left, mid, hi)
649 lo = tl.where(take_left, lo, mid + 1)
651 selected_key = lo
652 key_match = finite & (keys == selected_key[:, None])
653 selected_key_first = tl.argmax(key_match.to(tl.int32), axis=1)
654 selected_value = tl.load(
655 inp
656 + outer_offsets * reduction_size * inner_size
657 + selected_key_first * inner_size
658 + inner_offsets,
659 mask=out_mask,
660 other=0.0,
661 )
663 selected_value = tl.where(has_nan, nan_value, selected_value)
664 first_match = tl.where(has_nan, first_nan, selected_key_first)
665 tl.store(values + out_offsets, selected_value, mask=out_mask)
666 tl.store(indices + out_offsets, first_match.to(tl.int64), mask=out_mask)
669@libentry()
670@triton.jit
671def median_fp32_strided_key_select_kernel(
672 inp,
673 values,
674 indices,
675 total_outputs,
676 reduction_size,
677 inner_size,
678 BLOCK: tl.constexpr,
679 BLOCK_OUT: tl.constexpr,
680):
681 out_offsets = tl.program_id(0) * BLOCK_OUT + tl.arange(0, BLOCK_OUT)
682 out_mask = out_offsets < total_outputs
683 inner_offsets = out_offsets % inner_size
684 outer_offsets = out_offsets // inner_size
685 cols = tl.arange(0, BLOCK)
686 valid = (cols[None, :] < reduction_size) & out_mask[:, None]
687 ptrs = (
688 inp
689 + outer_offsets[:, None] * reduction_size * inner_size
690 + cols[None, :] * inner_size
691 + inner_offsets[:, None]
692 )
693 data = tl.load(ptrs, mask=valid, other=0.0)
695 nan_mask = valid & (data != data)
696 nan_i32 = nan_mask.to(tl.int32)
697 has_nan = tl.max(nan_i32, axis=1) != 0
698 first_nan = tl.argmax(nan_i32, axis=1)
699 nan_value = tl.load(
700 inp
701 + outer_offsets * reduction_size * inner_size
702 + first_nan * inner_size
703 + inner_offsets,
704 mask=out_mask,
705 other=0.0,
706 )
708 keys = _fp32_order_key(data)
709 finite = valid & ~nan_mask
710 key_min_fill = tl.full((), 0xFFFFFFFF, dtype=tl.uint32)
711 key_max_fill = tl.full((), 0, dtype=tl.uint32)
712 row_min = tl.min(tl.where(finite, keys, key_min_fill), axis=1)
713 row_max = tl.max(tl.where(finite, keys, key_max_fill), axis=1)
715 lo = row_min
716 hi = row_max
717 rank = (reduction_size - 1) // 2
718 for _ in tl.static_range(0, 32):
719 mid = lo + ((hi - lo) >> 1)
720 le_count = tl.sum((finite & (keys <= mid[:, None])).to(tl.int32), axis=1)
721 take_left = le_count > rank
722 hi = tl.where(take_left, mid, hi)
723 lo = tl.where(take_left, lo, mid + 1)
725 selected_key = lo
726 key_match = finite & (keys == selected_key[:, None])
727 selected_key_first = tl.argmax(key_match.to(tl.int32), axis=1)
728 selected_value = tl.load(
729 inp
730 + outer_offsets * reduction_size * inner_size
731 + selected_key_first * inner_size
732 + inner_offsets,
733 mask=out_mask,
734 other=0.0,
735 )
737 selected_value = tl.where(has_nan, nan_value, selected_value)
738 first_match = tl.where(has_nan, first_nan, selected_key_first)
739 tl.store(values + out_offsets, selected_value, mask=out_mask)
740 tl.store(indices + out_offsets, first_match.to(tl.int64), mask=out_mask)
743def _has_names(inp):
744 return any(name is not None for name in inp.names)
747def _anonymous(inp):
748 return inp.rename(None) if _has_names(inp) else inp
751def _canonical_dim(ndim, dim):
752 lower = -1 if ndim == 0 else -ndim
753 upper = 0 if ndim == 0 else ndim - 1
754 if dim < lower or dim > upper:
755 raise IndexError(
756 f"Dimension out of range (expected to be in range of "
757 f"[{lower}, {upper}], but got {dim})"
758 )
759 return 0 if ndim == 0 else dim % ndim
762def _name_to_dim(inp, dim):
763 if dim not in inp.names:
764 raise RuntimeError(f"Name '{dim}' not found in Tensor{inp.names}.")
765 return inp.names.index(dim)
768def _kept_names(names, dim, keepdim):
769 if names is None:
770 return None
771 if keepdim:
772 return names
773 return names[:dim] + names[dim + 1 :]
776def _empty_result_value(inp):
777 if inp.dtype.is_complex:
778 out = torch.empty((), dtype=inp.dtype, device=inp.device)
779 out.real.fill_(float("nan"))
780 out.imag.zero_()
781 return out
782 if inp.dtype.is_floating_point:
783 return torch.full((), float("nan"), dtype=inp.dtype, device=inp.device)
784 if inp.dtype == torch.bool:
785 return torch.ones((), dtype=inp.dtype, device=inp.device)
786 if inp.dtype in (torch.int32, torch.int64):
787 return torch.full(
788 (), torch.iinfo(inp.dtype).min, dtype=inp.dtype, device=inp.device
789 )
790 return torch.zeros((), dtype=inp.dtype, device=inp.device)
793def _raise_dim_dtype(dtype):
794 dtype_names = {
795 torch.bool: "Bool",
796 torch.complex64: "ComplexFloat",
797 torch.complex128: "ComplexDouble",
798 }
799 dtype_name = dtype_names.get(dtype, str(dtype).removeprefix("torch."))
800 raise NotImplementedError(f'"median_out_impl" not implemented for {dtype_name!r}')
803def _int_search_steps(dtype):
804 if dtype in (torch.int8, torch.uint8):
805 return 8
806 if dtype == torch.int16:
807 return 16
808 if dtype == torch.int32:
809 return 32
810 if dtype == torch.int64:
811 return 64
812 raise NotImplementedError(f"median integer selection not implemented for {dtype}")
815def _unsupported_width(dtype, width):
816 raise NotImplementedError(
817 f"median Triton selection not implemented for dtype {dtype} "
818 f"with reduction width {width}"
819 )
822def _median_from_rows(row_data, output_shape):
823 width = row_data.shape[-1]
824 if _use_f16_key_select(row_data.dtype, width):
825 return _median_f16_key_select(row_data, output_shape)
826 if _use_lastdim_sort(row_data.dtype, width):
827 return _median_lastdim_sort(row_data, output_shape)
828 if _use_fp32_key_select(row_data.dtype, width):
829 return _median_fp32_key_select(row_data, output_shape)
830 if _use_fp64_key_select(row_data.dtype, width):
831 return _median_fp64_key_select(row_data, output_shape)
832 if (
833 width <= _INT_LASTDIM_SELECT_LIMIT
834 and row_data.dtype in _INT_LASTDIM_SELECT_DTYPES
835 ):
836 return _median_int_lastdim_select(row_data, output_shape)
837 _unsupported_width(row_data.dtype, width)
840def _median_small_flat(inp):
841 value = torch.empty((), dtype=inp.dtype, device=inp.device)
842 block = triton.next_power_of_2(inp.numel())
843 with torch_device_fn.device(inp.device):
844 median_small_flat_kernel[(1,)](
845 inp.reshape(-1),
846 value,
847 WIDTH=inp.numel(),
848 BLOCK=block,
849 num_warps=min(8, max(4, block // 32)),
850 )
851 return value
854def _median_bool_flat(inp):
855 width = inp.numel()
856 block = _BOOL_FLAT_BLOCK
857 num_blocks = triton.cdiv(width, block)
858 counts = torch.empty((num_blocks,), dtype=torch.int64, device=inp.device)
859 value = torch.empty((), dtype=inp.dtype, device=inp.device)
860 with torch_device_fn.device(inp.device):
861 median_bool_count_kernel[(num_blocks,)](
862 inp.reshape(-1),
863 counts,
864 WIDTH=width,
865 BLOCK=block,
866 num_warps=4,
867 )
868 while counts.numel() > _BOOL_COUNT_REDUCE_BLOCK:
869 reduced_blocks = triton.cdiv(counts.numel(), _BOOL_COUNT_REDUCE_BLOCK)
870 reduced = torch.empty(
871 (reduced_blocks,), dtype=torch.int64, device=inp.device
872 )
873 median_bool_reduce_counts_kernel[(reduced_blocks,)](
874 counts,
875 reduced,
876 WIDTH=counts.numel(),
877 BLOCK=_BOOL_COUNT_REDUCE_BLOCK,
878 num_warps=4,
879 )
880 counts = reduced
881 count_block = triton.next_power_of_2(counts.numel())
882 median_bool_from_counts_kernel[(1,)](
883 counts,
884 value,
885 WIDTH=width,
886 NUM_BLOCKS=counts.numel(),
887 BLOCK=count_block,
888 num_warps=min(8, max(1, count_block // 32)),
889 )
890 return value
893def _median_bool_dim(inp, dim, output_shape):
894 reduction_size = inp.shape[dim]
895 inner_size = math.prod(inp.shape[dim + 1 :])
896 total_outputs = math.prod(output_shape)
897 values = torch.empty(output_shape, dtype=inp.dtype, device=inp.device)
898 indices = torch.empty(output_shape, dtype=torch.int64, device=inp.device)
899 block = _BOOL_FLAT_BLOCK
900 chunks = triton.cdiv(reduction_size, block)
901 chunk_shape = (total_outputs, chunks)
902 counts = torch.empty(chunk_shape, dtype=torch.int64, device=inp.device)
903 first_false = torch.empty(chunk_shape, dtype=torch.int64, device=inp.device)
904 first_true = torch.empty(chunk_shape, dtype=torch.int64, device=inp.device)
906 with torch_device_fn.device(inp.device):
907 median_bool_dim_count_chunks_kernel[(total_outputs * chunks,)](
908 inp,
909 counts.reshape(-1),
910 first_false.reshape(-1),
911 first_true.reshape(-1),
912 total_outputs,
913 reduction_size,
914 inner_size,
915 chunks_per_output=chunks,
916 BLOCK=block,
917 num_warps=4,
918 )
919 while chunks > _BOOL_COUNT_REDUCE_BLOCK:
920 reduced_chunks = triton.cdiv(chunks, _BOOL_COUNT_REDUCE_BLOCK)
921 reduced_shape = (total_outputs, reduced_chunks)
922 reduced_counts = torch.empty(
923 reduced_shape, dtype=torch.int64, device=inp.device
924 )
925 reduced_first_false = torch.empty(
926 reduced_shape, dtype=torch.int64, device=inp.device
927 )
928 reduced_first_true = torch.empty(
929 reduced_shape, dtype=torch.int64, device=inp.device
930 )
931 median_bool_dim_reduce_chunks_kernel[(total_outputs, reduced_chunks)](
932 counts.reshape(-1),
933 first_false.reshape(-1),
934 first_true.reshape(-1),
935 reduced_counts.reshape(-1),
936 reduced_first_false.reshape(-1),
937 reduced_first_true.reshape(-1),
938 input_chunks=chunks,
939 output_chunks=reduced_chunks,
940 BLOCK=_BOOL_COUNT_REDUCE_BLOCK,
941 num_warps=4,
942 )
943 counts = reduced_counts
944 first_false = reduced_first_false
945 first_true = reduced_first_true
946 chunks = reduced_chunks
948 finish_block = triton.next_power_of_2(chunks)
949 median_bool_dim_finish_kernel[(total_outputs,)](
950 counts.reshape(-1),
951 first_false.reshape(-1),
952 first_true.reshape(-1),
953 values.reshape(-1),
954 indices.reshape(-1),
955 reduction_size,
956 chunks_per_output=chunks,
957 BLOCK=finish_block,
958 num_warps=min(8, max(1, finish_block // 32)),
959 )
960 return values, indices
963def _median_lastdim_sort(row_data, output_shape):
964 width = row_data.shape[-1]
965 rows = row_data.numel() // width
966 values = torch.empty(output_shape, dtype=row_data.dtype, device=row_data.device)
967 indices = torch.empty(output_shape, dtype=torch.int64, device=row_data.device)
968 block = triton.next_power_of_2(width)
969 num_warps = 8 if rows == 1 and block >= 1024 else min(8, max(4, block // 512))
970 with torch_device_fn.device(row_data.device):
971 median_lastdim_sort_kernel[(rows,)](
972 row_data.reshape(rows, width),
973 values.reshape(rows),
974 indices.reshape(rows),
975 WIDTH=width,
976 BLOCK=block,
977 num_warps=num_warps,
978 )
979 return values, indices
982def _use_lastdim_sort(dtype, width):
983 if dtype == torch.bfloat16:
984 return width <= _BF16_LASTDIM_SORT_LIMIT
985 if dtype == torch.float16:
986 return width <= _LASTDIM_SORT_LIMIT
987 return False
990def _use_f16_key_select(dtype, width):
991 return (
992 dtype in _F16_KEY_SELECT_DTYPES
993 and _F16_KEY_SELECT_MIN <= width <= _F16_KEY_SELECT_LIMIT
994 )
997def _use_fp32_key_select(dtype, width):
998 return (
999 dtype == torch.float32
1000 and _FP32_KEY_SELECT_MIN <= width <= _FP32_KEY_SELECT_LIMIT
1001 )
1004def _use_fp64_key_select(dtype, width):
1005 return (
1006 dtype == torch.float64
1007 and _FP64_KEY_SELECT_MIN <= width <= _FP64_KEY_SELECT_LIMIT
1008 )
1011def _use_strided_select(dtype, width):
1012 return _STRIDED_SELECT_MIN <= width <= _STRIDED_SELECT_LIMIT and dtype in (
1013 _F16_KEY_SELECT_DTYPES | {torch.float32}
1014 )
1017def _use_float_key_select(dtype, width):
1018 return (
1019 _use_f16_key_select(dtype, width)
1020 or _use_fp32_key_select(dtype, width)
1021 or _use_fp64_key_select(dtype, width)
1022 )
1025def _median_float_key_select_rows(row_data, output_shape):
1026 if _use_f16_key_select(row_data.dtype, row_data.shape[-1]):
1027 return _median_f16_key_select(row_data, output_shape)
1028 if _use_fp32_key_select(row_data.dtype, row_data.shape[-1]):
1029 return _median_fp32_key_select(row_data, output_shape)
1030 return _median_fp64_key_select(row_data, output_shape)
1033def _median_float_key_select_dim(work, dim, output_shape, keepdim):
1034 if dim == work.ndim - 1:
1035 return _median_float_key_select_rows(work.contiguous(), output_shape)
1036 if work.is_contiguous() and work.dtype in (
1037 _F16_KEY_SELECT_DTYPES | {torch.float32}
1038 ):
1039 if work.dtype in _F16_KEY_SELECT_DTYPES:
1040 return _median_f16_strided_key_select(work, dim, output_shape)
1041 return _median_fp32_strided_key_select(work, dim, output_shape)
1043 rows = torch.movedim(work, dim, -1).contiguous()
1044 row_output_shape = rows.shape[:-1]
1045 values, indices = _median_float_key_select_rows(rows, row_output_shape)
1046 if keepdim:
1047 values = torch.movedim(values.unsqueeze(-1), -1, dim)
1048 indices = torch.movedim(indices.unsqueeze(-1), -1, dim)
1049 return values, indices
1052def _median_int_lastdim_select(row_data, output_shape):
1053 width = row_data.shape[-1]
1054 rows = row_data.numel() // width
1055 values = torch.empty(output_shape, dtype=row_data.dtype, device=row_data.device)
1056 indices = torch.empty(output_shape, dtype=torch.int64, device=row_data.device)
1057 block = triton.next_power_of_2(width)
1058 search_steps = _int_search_steps(row_data.dtype)
1059 with torch_device_fn.device(row_data.device):
1060 median_int_lastdim_select_kernel[(rows,)](
1061 row_data.reshape(rows, width),
1062 values.reshape(rows),
1063 indices.reshape(rows),
1064 WIDTH=width,
1065 BLOCK=block,
1066 SEARCH_STEPS=search_steps,
1067 num_warps=min(8, max(4, block // 512)),
1068 )
1069 return values, indices
1072def _median_f16_key_select(row_data, output_shape):
1073 width = row_data.shape[-1]
1074 rows = row_data.numel() // width
1075 values = torch.empty(output_shape, dtype=row_data.dtype, device=row_data.device)
1076 indices = torch.empty(output_shape, dtype=torch.int64, device=row_data.device)
1077 block = triton.next_power_of_2(width)
1078 num_warps = 1 if block <= 1024 else 2 if block <= 2048 else 4
1079 with torch_device_fn.device(row_data.device):
1080 median_f16_key_select_kernel[(rows,)](
1081 row_data.reshape(rows, width),
1082 values.reshape(rows),
1083 indices.reshape(rows),
1084 WIDTH=width,
1085 BLOCK=block,
1086 num_warps=num_warps,
1087 )
1088 return values, indices
1091def _median_f16_strided_key_select(inp, dim, output_shape):
1092 reduction_size = inp.shape[dim]
1093 inner_size = math.prod(inp.shape[dim + 1 :])
1094 total_outputs = math.prod(output_shape)
1095 values = torch.empty(output_shape, dtype=inp.dtype, device=inp.device)
1096 indices = torch.empty(output_shape, dtype=torch.int64, device=inp.device)
1097 block = triton.next_power_of_2(reduction_size)
1098 block_out = 2
1099 num_warps = 1 if block <= 1024 else 2 if block <= 2048 else 4
1100 with torch_device_fn.device(inp.device):
1101 median_f16_strided_key_select_kernel[(triton.cdiv(total_outputs, block_out),)](
1102 inp,
1103 values.reshape(-1),
1104 indices.reshape(-1),
1105 total_outputs,
1106 reduction_size,
1107 inner_size,
1108 BLOCK=block,
1109 BLOCK_OUT=block_out,
1110 num_warps=num_warps,
1111 )
1112 return values, indices
1115def _median_fp32_key_select(row_data, output_shape):
1116 width = row_data.shape[-1]
1117 rows = row_data.numel() // width
1118 values = torch.empty(output_shape, dtype=row_data.dtype, device=row_data.device)
1119 indices = torch.empty(output_shape, dtype=torch.int64, device=row_data.device)
1120 block = triton.next_power_of_2(width)
1121 num_warps = 2 if block <= 1024 else 8
1122 with torch_device_fn.device(row_data.device):
1123 median_fp32_key_select_kernel[(rows,)](
1124 row_data.reshape(rows, width),
1125 values.reshape(rows),
1126 indices.reshape(rows),
1127 WIDTH=width,
1128 BLOCK=block,
1129 num_warps=num_warps,
1130 )
1131 return values, indices
1134def _median_fp64_key_select(row_data, output_shape):
1135 width = row_data.shape[-1]
1136 rows = row_data.numel() // width
1137 values = torch.empty(output_shape, dtype=row_data.dtype, device=row_data.device)
1138 indices = torch.empty(output_shape, dtype=torch.int64, device=row_data.device)
1139 block = triton.next_power_of_2(width)
1140 num_warps = 2 if block <= 1024 else 8
1141 with torch_device_fn.device(row_data.device):
1142 median_fp64_key_select_kernel[(rows,)](
1143 row_data.reshape(rows, width),
1144 values.reshape(rows),
1145 indices.reshape(rows),
1146 WIDTH=width,
1147 BLOCK=block,
1148 num_warps=num_warps,
1149 )
1150 return values, indices
1153def _median_fp32_strided_key_select(inp, dim, output_shape):
1154 reduction_size = inp.shape[dim]
1155 inner_size = math.prod(inp.shape[dim + 1 :])
1156 total_outputs = math.prod(output_shape)
1157 values = torch.empty(output_shape, dtype=inp.dtype, device=inp.device)
1158 indices = torch.empty(output_shape, dtype=torch.int64, device=inp.device)
1159 block = triton.next_power_of_2(reduction_size)
1160 block_out = 2
1161 num_warps = 2 if block <= 1024 else 8
1162 with torch_device_fn.device(inp.device):
1163 median_fp32_strided_key_select_kernel[(triton.cdiv(total_outputs, block_out),)](
1164 inp,
1165 values.reshape(-1),
1166 indices.reshape(-1),
1167 total_outputs,
1168 reduction_size,
1169 inner_size,
1170 BLOCK=block,
1171 BLOCK_OUT=block_out,
1172 num_warps=num_warps,
1173 )
1174 return values, indices
1177def _median_direct_dim(inp, dim, output_shape):
1178 reduction_size = inp.shape[dim]
1179 inner_size = math.prod(inp.shape[dim + 1 :])
1180 total_outputs = math.prod(output_shape)
1181 values = torch.empty(output_shape, dtype=inp.dtype, device=inp.device)
1182 indices = torch.empty(output_shape, dtype=torch.int64, device=inp.device)
1183 block_n = triton.next_power_of_2(reduction_size)
1184 block_out = 2 if block_n >= 128 else 16
1185 if block_n >= 128:
1186 num_warps = 8 if inp.dtype in (torch.int32, torch.int64) else 4
1187 else:
1188 num_warps = 1
1189 with torch_device_fn.device(inp.device):
1190 median_small_dim_kernel[(triton.cdiv(total_outputs, block_out),)](
1191 inp,
1192 values.reshape(-1),
1193 indices.reshape(-1),
1194 total_outputs,
1195 reduction_size,
1196 inner_size,
1197 BLOCK_N=block_n,
1198 BLOCK_OUT=block_out,
1199 num_warps=num_warps,
1200 )
1201 return values, indices
1204def _copy_out(src, out, name):
1205 if out.device != src.device:
1206 raise RuntimeError(
1207 f"Expected {name} tensor to have device {src.device}, "
1208 f"but got {out.device} instead"
1209 )
1210 if out.dtype != src.dtype:
1211 raise RuntimeError(
1212 f"Expected out tensor to have dtype {src.dtype}, but got {out.dtype}"
1213 )
1214 out.resize_as_(src)
1215 out.copy_(src)
1216 return out
1219def median(inp):
1220 logger.debug("GEMS MEDIAN")
1222 inp = _anonymous(inp)
1223 if inp.numel() == 0:
1224 return _empty_result_value(inp)
1225 if inp.dtype.is_complex:
1226 raise RuntimeError("Sort does not support complex dtypes on CPU")
1227 if inp.numel() == 1:
1228 return inp.reshape(()).clone()
1230 flat = inp.contiguous().reshape(-1)
1231 row_data = flat.reshape(1, inp.numel())
1232 if _use_float_key_select(inp.dtype, inp.numel()):
1233 values, _ = _median_float_key_select_rows(row_data, ())
1234 return values.reshape(())
1235 if inp.dtype in _DIRECT_REDUCTION_DTYPES and inp.numel() <= _DIRECT_FLAT_LIMIT:
1236 return _median_small_flat(flat)
1237 if inp.dtype == torch.bool:
1238 return _median_bool_flat(flat)
1240 if inp.dtype in _FLAT_SORT_DTYPES and inp.numel() <= _FLAT_SORT_LIMIT:
1241 values, _ = _median_lastdim_sort(row_data, ())
1242 elif _use_fp32_key_select(inp.dtype, inp.numel()):
1243 values, _ = _median_fp32_key_select(row_data, ())
1244 elif _use_fp64_key_select(inp.dtype, inp.numel()):
1245 values, _ = _median_fp64_key_select(row_data, ())
1246 elif (
1247 inp.numel() <= _INT_LASTDIM_SELECT_LIMIT
1248 and inp.dtype in _INT_LASTDIM_SELECT_DTYPES
1249 ):
1250 values, _ = _median_int_lastdim_select(row_data, ())
1251 else:
1252 values, _ = _median_from_rows(row_data, ())
1253 return values.reshape(())
1256def median_out(inp, *, out):
1257 logger.debug("GEMS MEDIAN.OUT")
1258 return _copy_out(median(inp), out, "out")
1261def median_dim(inp, dim=0, keepdim=False):
1262 logger.debug("GEMS MEDIAN.DIM")
1264 if isinstance(dim, str):
1265 dim = _name_to_dim(inp, dim)
1266 dim = _canonical_dim(inp.ndim, dim)
1267 names = inp.names if _has_names(inp) else None
1268 work = _anonymous(inp)
1270 if work.ndim == 0:
1271 if work.dtype.is_complex:
1272 _raise_dim_dtype(work.dtype)
1273 return MedianResult(
1274 values=work.clone(),
1275 indices=torch.zeros((), dtype=torch.int64, device=work.device),
1276 )
1278 if work.shape[dim] == 0:
1279 raise IndexError(
1280 f"median(): Expected reduction dim {dim} to have non-zero size."
1281 )
1283 output_shape = list(work.shape)
1284 if keepdim:
1285 output_shape[dim] = 1
1286 else:
1287 del output_shape[dim]
1288 output_names = _kept_names(names, dim, keepdim)
1290 if work.numel() == 0:
1291 values = torch.empty(output_shape, dtype=work.dtype, device=work.device)
1292 indices = torch.empty(output_shape, dtype=torch.int64, device=work.device)
1293 else:
1294 if work.dtype.is_complex:
1295 _raise_dim_dtype(work.dtype)
1296 if work.dtype == torch.bool:
1297 values, indices = _median_bool_dim(work.contiguous(), dim, output_shape)
1298 elif _use_float_key_select(work.dtype, work.shape[dim]):
1299 values, indices = _median_float_key_select_dim(
1300 work, dim, output_shape, keepdim
1301 )
1302 elif (
1303 work.shape[dim] <= _DIRECT_REDUCTION_LIMIT
1304 and work.dtype in _DIRECT_REDUCTION_DTYPES
1305 ):
1306 values, indices = _median_direct_dim(work.contiguous(), dim, output_shape)
1307 elif (
1308 dim != work.ndim - 1
1309 and work.is_contiguous()
1310 and _use_strided_select(work.dtype, work.shape[dim])
1311 ):
1312 if work.dtype in _F16_KEY_SELECT_DTYPES:
1313 values, indices = _median_f16_strided_key_select(
1314 work, dim, output_shape
1315 )
1316 elif work.dtype == torch.float32:
1317 values, indices = _median_fp32_strided_key_select(
1318 work, dim, output_shape
1319 )
1320 elif dim == work.ndim - 1 and _use_f16_key_select(work.dtype, work.shape[dim]):
1321 values, indices = _median_f16_key_select(work.contiguous(), output_shape)
1322 elif dim == work.ndim - 1 and _use_lastdim_sort(work.dtype, work.shape[dim]):
1323 values, indices = _median_lastdim_sort(work.contiguous(), output_shape)
1324 elif dim == work.ndim - 1 and _use_fp32_key_select(work.dtype, work.shape[dim]):
1325 values, indices = _median_fp32_key_select(work.contiguous(), output_shape)
1326 elif dim == work.ndim - 1 and _use_fp64_key_select(work.dtype, work.shape[dim]):
1327 values, indices = _median_fp64_key_select(work.contiguous(), output_shape)
1328 elif (
1329 dim == work.ndim - 1
1330 and work.shape[dim] <= _INT_LASTDIM_SELECT_LIMIT
1331 and work.dtype in _INT_LASTDIM_SELECT_DTYPES
1332 ):
1333 values, indices = _median_int_lastdim_select(
1334 work.contiguous(), output_shape
1335 )
1336 else:
1337 rows = torch.movedim(work, dim, -1).contiguous()
1338 row_output_shape = rows.shape[:-1]
1339 row_width = rows.shape[-1]
1340 if _use_f16_key_select(rows.dtype, row_width):
1341 values, indices = _median_f16_key_select(rows, row_output_shape)
1342 elif _use_lastdim_sort(rows.dtype, row_width):
1343 values, indices = _median_lastdim_sort(rows, row_output_shape)
1344 elif _use_fp32_key_select(rows.dtype, row_width):
1345 values, indices = _median_fp32_key_select(rows, row_output_shape)
1346 elif _use_fp64_key_select(rows.dtype, row_width):
1347 values, indices = _median_fp64_key_select(rows, row_output_shape)
1348 elif (
1349 row_width <= _INT_LASTDIM_SELECT_LIMIT
1350 and rows.dtype in _INT_LASTDIM_SELECT_DTYPES
1351 ):
1352 values, indices = _median_int_lastdim_select(rows, row_output_shape)
1353 else:
1354 values, indices = _median_from_rows(rows, row_output_shape)
1355 if keepdim:
1356 values = torch.movedim(values.unsqueeze(-1), -1, dim)
1357 indices = torch.movedim(indices.unsqueeze(-1), -1, dim)
1359 if output_names is not None:
1360 values = values.refine_names(*output_names)
1361 indices = indices.refine_names(*output_names)
1363 return MedianResult(values=values, indices=indices)
1366def median_dim_values(inp, dim=0, keepdim=False, *, values, indices):
1367 logger.debug("GEMS MEDIAN.DIM_VALUES")
1368 result = median_dim(inp, dim=dim, keepdim=keepdim)
1369 _copy_out(result.values, values, "values")
1370 _copy_out(result.indices, indices, "indices")
1371 return MedianResult(values=values, indices=indices)