Coverage for src/flag_gems/ops/nanmedian.py: 34%
825 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
2import math
3from collections import namedtuple
5import torch
6import triton
7import triton.language as tl
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import dim_compress, libentry, tl_extra_shim
11from flag_gems.utils import triton_lang_extension as tle
12from flag_gems.utils.limits import get_dtype_max, get_dtype_min
14from .topk import _get_finfo_val
16logger = logging.getLogger(__name__)
18NanMedian = namedtuple("nanmedian", ["values", "indices"])
19INT32_MAX = torch.iinfo(torch.int32).max
20MAX_BLOCK_N = 128
21RADIX_BLOCK_N = 1024
22RADIX_BITS = 2
23MEDIUM_REDUCTION_N = 1024
24LARGE_FLOAT_REDUCTION_N = 4096
25LONG_RADIX_REDUCTION_N = 131072
26ASCEND_FLAT_SORT_MIN_N = 1 << 20
27FLAT_RADIX_BLOCK_N = 4096
28FLAT_RADIX_BITS = 8
29RADIX_SELECT_DTYPES = (
30 torch.float16,
31 torch.bfloat16,
32 torch.float32,
33 torch.int8,
34 torch.uint8,
35 torch.int16,
36 torch.int32,
37)
38ASCEND_HISTOGRAM_SELECT_DTYPES = (
39 torch.int8,
40 torch.uint8,
41)
42ASCEND_BYTE_HISTOGRAM_SELECT_DTYPES = (
43 torch.int16,
44 torch.int32,
45)
46ASCEND_FLOAT_SELECT_DTYPES = (
47 torch.float16,
48 torch.float32,
49)
50ASCEND_HISTOGRAM_BINS = 256
51ASCEND_MULTI_HISTOGRAM_MIN_N = 8192
52ASCEND_FLAT_SORT_DTYPES = (
53 torch.float16,
54 torch.float32,
55 torch.int8,
56 torch.uint8,
57 torch.int16,
58 torch.int32,
59)
62def _triton_version_at_least(major, minor):
63 version = getattr(triton, "__version__", "0.0").split("+", 1)[0]
64 parts = []
65 for token in version.split(".")[:2]:
66 digits = []
67 for char in token:
68 if not char.isdigit():
69 break
70 digits.append(char)
71 parts.append(int("".join(digits) or 0))
72 parts.extend([0] * (2 - len(parts)))
73 return tuple(parts[:2]) >= (major, minor)
76# Triton added tl.histogram(..., mask) in 3.4.
77CUDA_SUPPORTS_MASKED_HISTOGRAM = _triton_version_at_least(3, 4)
80@triton.jit
81def _is_not_nan(vals, USE_ISNAN: tl.constexpr):
82 vals_fp32 = vals.to(tl.float32)
83 if USE_ISNAN:
84 return ~tl_extra_shim.isnan(vals_fp32)
85 return vals_fp32 == vals_fp32
88@triton.jit
89def _to_order_key(vals, valid):
90 dtype = vals.dtype
91 nbits: tl.constexpr = dtype.primitive_bitwidth
92 utype = tl.dtype(f"uint{nbits}")
93 top_mask: tl.constexpr = 1 << (nbits - 1)
94 full_mask: tl.constexpr = (1 << nbits) - 1
95 full = tl.full(vals.shape, full_mask, dtype=utype)
97 if dtype.is_floating():
98 bits = vals.to(utype, bitcast=True)
99 sign_mask = tl.where((bits & top_mask) != 0, full_mask, top_mask)
100 key = bits ^ sign_mask
101 elif dtype.is_int_signed():
102 bits = vals.to(utype, bitcast=True)
103 key = bits ^ top_mask
104 else:
105 key = vals.to(utype)
106 return tl.where(valid, key, full)
109@libentry()
110@triton.jit
111def count_valid_kernel(
112 inp,
113 valid_counts,
114 M,
115 N: tl.constexpr,
116 BLOCK_N: tl.constexpr,
117 USE_ISNAN: tl.constexpr,
118):
119 pid = tle.program_id(0)
120 offsets = tl.arange(0, BLOCK_N)
121 count = tl.full((), 0, dtype=tl.int32)
122 for start in tl.range(0, N, BLOCK_N):
123 cols = start + offsets
124 mask = cols < N
125 vals = tl.load(inp + pid * N + cols, mask=mask, other=float("nan"))
126 valid = mask & _is_not_nan(vals, USE_ISNAN)
127 count += tl.sum(valid.to(tl.int32), axis=0)
128 tl.store(valid_counts + pid, count)
131@libentry()
132@triton.jit
133def nanmedian_select_kernel(
134 inp,
135 out_values,
136 out_indices,
137 M,
138 N: tl.constexpr,
139 BLOCK_N: tl.constexpr,
140 USE_ISNAN: tl.constexpr,
141):
142 pid = tle.program_id(0)
143 offsets = tl.arange(0, BLOCK_N)
144 mask = offsets < N
145 dtype = inp.dtype.element_ty
146 if dtype.is_floating():
147 max_value = _get_finfo_val(dtype, return_max=True)
148 fallback_value = _get_finfo_val(dtype, return_max=False)
149 else:
150 max_value = get_dtype_max(dtype)
151 fallback_value = get_dtype_min(dtype)
152 vals = tl.load(inp + pid * N + offsets, mask=mask, other=max_value)
154 if dtype.is_floating():
155 valid = mask & _is_not_nan(vals, USE_ISNAN)
156 else:
157 valid = mask
158 valid_count = tl.sum(valid.to(tl.int32), axis=0)
159 median_rank = (valid_count - 1) // 2
161 active = valid
162 median_val = tl.full((), fallback_value, dtype=vals.dtype)
163 median_idx = tl.full((), 0, dtype=tl.int32)
164 for select_iter in tl.static_range(0, BLOCK_N):
165 select_vals = tl.where(active, vals, max_value)
166 cur_val = tl.min(select_vals, axis=0)
167 cur_idx = tl.min(tl.where(active & (vals == cur_val), offsets, BLOCK_N), axis=0)
168 take = select_iter == median_rank
169 median_val = tl.where(take, cur_val, median_val)
170 median_idx = tl.where(take, cur_idx, median_idx)
171 active = active & (offsets != cur_idx)
173 if dtype.is_floating():
174 all_nan = valid_count == 0
175 median_val = tl.where(all_nan, float("nan"), median_val)
176 median_idx = tl.where(all_nan, 0, median_idx)
178 tl.store(out_values + pid, median_val)
179 tl.store(out_indices + pid, median_idx)
182@libentry()
183@triton.jit
184def nanmedian_float_clean_count_kernel(
185 inp,
186 cleaned,
187 valid_counts,
188 N: tl.constexpr,
189 BLOCK_N: tl.constexpr,
190):
191 pid = tle.program_id(0)
192 offsets = tl.arange(0, BLOCK_N)
193 dtype = inp.dtype.element_ty
194 max_value = _get_finfo_val(dtype, return_max=True)
195 count = tl.full((), 0, dtype=tl.int32)
197 for start in tl.range(0, N, BLOCK_N):
198 cols = start + offsets
199 mask = cols < N
200 vals = tl.load(inp + pid * N + cols, mask=mask, other=max_value)
201 valid = mask & _is_not_nan(vals, False)
202 cleaned_vals = tl.where(valid, vals, max_value)
203 tl.store(cleaned + pid * N + cols, cleaned_vals, mask=mask)
204 count += tl.sum(valid.to(tl.int32), axis=0)
206 tl.store(valid_counts + pid, count)
209@libentry()
210@triton.jit
211def nanmedian_float_sorted_gather_kernel(
212 sorted_values,
213 sorted_indices,
214 valid_counts,
215 out_values,
216 out_indices,
217 N: tl.constexpr,
218):
219 pid = tle.program_id(0)
220 count = tl.load(valid_counts + pid)
221 rank = tl.where(count > 0, (count - 1) // 2, 0)
222 result_val = tl.load(
223 sorted_values + pid * N + rank, mask=count > 0, other=float("nan")
224 )
225 result_idx = tl.load(sorted_indices + pid * N + rank, mask=count > 0, other=0)
226 result_val = tl.where(count > 0, result_val, float("nan"))
227 result_idx = tl.where(count > 0, result_idx, 0)
229 tl.store(out_values + pid, result_val)
230 tl.store(out_indices + pid, result_idx)
233@libentry()
234@triton.jit
235def nanmedian_ascend_histogram_select_kernel(
236 inp,
237 out_values,
238 out_indices,
239 M,
240 N: tl.constexpr,
241 BLOCK_N: tl.constexpr,
242 HISTOGRAM_BINS: tl.constexpr,
243):
244 pid = tle.program_id(0)
245 offsets = tl.arange(0, BLOCK_N)
246 bins = tl.arange(0, HISTOGRAM_BINS)
247 counts = tl.zeros((HISTOGRAM_BINS,), dtype=tl.int32)
249 for start in tl.range(0, N, BLOCK_N):
250 cols = start + offsets
251 mask = cols < N
252 vals = tl.load(inp + pid * N + cols, mask=mask, other=0)
253 keys = _to_order_key(vals, mask).to(tl.int32)
254 keys = tl.where(mask, keys, 0)
255 chunk_counts = tl.histogram(keys, HISTOGRAM_BINS).to(tl.int32)
256 invalid_count = tl.sum((~mask).to(tl.int32), axis=0)
257 counts += chunk_counts - tl.where(bins == 0, invalid_count, 0)
259 k_to_find: tl.constexpr = (N + 1) // 2
260 cumsum = tl.cumsum(counts, axis=0)
261 prev = cumsum - counts
262 take = (k_to_find <= cumsum) & (k_to_find > prev)
263 selected_key = tl.min(tl.where(take, bins, HISTOGRAM_BINS - 1), axis=0)
265 result_idx = tl.full((), N, dtype=tl.int32)
266 for start in tl.range(0, N, BLOCK_N):
267 cols = start + offsets
268 mask = cols < N
269 vals = tl.load(inp + pid * N + cols, mask=mask, other=0)
270 keys = _to_order_key(vals, mask).to(tl.int32)
271 local_idx = tl.min(tl.where(mask & (keys == selected_key), cols, N), axis=0)
272 result_idx = tl.where(local_idx < result_idx, local_idx, result_idx)
274 result_val = tl.load(inp + pid * N + result_idx)
275 tl.store(out_values + pid, result_val)
276 tl.store(out_indices + pid, result_idx)
279@libentry()
280@triton.jit
281def nanmedian_ascend_histogram_count_kernel(
282 inp,
283 partial_counts,
284 M,
285 N: tl.constexpr,
286 BLOCK_N: tl.constexpr,
287 NUM_CHUNKS: tl.constexpr,
288 HISTOGRAM_BINS: tl.constexpr,
289):
290 pid_m = tle.program_id(0)
291 pid_chunk = tle.program_id(1)
292 offsets = pid_chunk * BLOCK_N + tl.arange(0, BLOCK_N)
293 bins = tl.arange(0, HISTOGRAM_BINS)
294 mask = offsets < N
295 vals = tl.load(inp + pid_m * N + offsets, mask=mask, other=0)
296 keys = _to_order_key(vals, mask).to(tl.int32)
297 keys = tl.where(mask, keys, 0)
298 counts = tl.histogram(keys, HISTOGRAM_BINS).to(tl.int32)
299 invalid_count = tl.sum((~mask).to(tl.int32), axis=0)
300 counts = counts - tl.where(bins == 0, invalid_count, 0)
301 count_offsets = (pid_m * NUM_CHUNKS + pid_chunk) * HISTOGRAM_BINS + bins
302 tl.store(partial_counts + count_offsets, counts)
305@libentry()
306@triton.jit
307def nanmedian_ascend_histogram_reduce_kernel(
308 inp,
309 partial_counts,
310 out_values,
311 out_indices,
312 M,
313 N: tl.constexpr,
314 BLOCK_N: tl.constexpr,
315 NUM_CHUNKS: tl.constexpr,
316 HISTOGRAM_BINS: tl.constexpr,
317):
318 pid = tle.program_id(0)
319 offsets = tl.arange(0, BLOCK_N)
320 bins = tl.arange(0, HISTOGRAM_BINS)
321 counts = tl.zeros((HISTOGRAM_BINS,), dtype=tl.int32)
323 for chunk in tl.range(0, NUM_CHUNKS):
324 count_offsets = (pid * NUM_CHUNKS + chunk) * HISTOGRAM_BINS + bins
325 counts += tl.load(partial_counts + count_offsets)
327 k_to_find: tl.constexpr = (N + 1) // 2
328 cumsum = tl.cumsum(counts, axis=0)
329 prev = cumsum - counts
330 take = (k_to_find <= cumsum) & (k_to_find > prev)
331 selected_key = tl.min(tl.where(take, bins, HISTOGRAM_BINS - 1), axis=0)
333 result_idx = tl.full((), N, dtype=tl.int32)
334 for start in tl.range(0, N, BLOCK_N):
335 cols = start + offsets
336 mask = cols < N
337 vals = tl.load(inp + pid * N + cols, mask=mask, other=0)
338 keys = _to_order_key(vals, mask).to(tl.int32)
339 local_idx = tl.min(tl.where(mask & (keys == selected_key), cols, N), axis=0)
340 result_idx = tl.where(local_idx < result_idx, local_idx, result_idx)
342 result_val = tl.load(inp + pid * N + result_idx)
343 tl.store(out_values + pid, result_val)
344 tl.store(out_indices + pid, result_idx)
347@libentry()
348@triton.jit
349def nanmedian_ascend_byte_histogram_select_kernel(
350 inp,
351 out_values,
352 out_indices,
353 M,
354 N: tl.constexpr,
355 BLOCK_N: tl.constexpr,
356 HISTOGRAM_BINS: tl.constexpr,
357):
358 pid = tle.program_id(0)
359 offsets = tl.arange(0, BLOCK_N)
360 bins = tl.arange(0, HISTOGRAM_BINS)
361 dtype = inp.dtype.element_ty
362 nbits: tl.constexpr = dtype.primitive_bitwidth
363 utype = tl.dtype(f"uint{nbits}")
364 byte_mask_val = tl.full((), HISTOGRAM_BINS - 1, dtype=utype)
366 k_to_find = tl.full((), (N + 1) // 2, dtype=tl.int32)
367 desired = tl.full((), 0, dtype=utype)
368 desired_mask = tl.full((), 0, dtype=utype)
370 for digit_pos in tl.static_range(nbits - 8, -1, -8):
371 counts = tl.zeros((HISTOGRAM_BINS,), dtype=tl.int32)
373 for start in tl.range(0, N, BLOCK_N):
374 cols = start + offsets
375 mask = cols < N
376 vals = tl.load(inp + pid * N + cols, mask=mask, other=0)
377 keys = _to_order_key(vals, mask)
378 active = mask & ((keys & desired_mask) == desired)
379 digit = ((keys >> digit_pos) & byte_mask_val).to(tl.int32)
380 digit = tl.where(active, digit, 0)
381 chunk_counts = tl.histogram(digit, HISTOGRAM_BINS).to(tl.int32)
382 inactive_count = tl.sum((~active).to(tl.int32), axis=0)
383 counts += chunk_counts - tl.where(bins == 0, inactive_count, 0)
385 cumsum = tl.cumsum(counts, axis=0)
386 prev = cumsum - counts
387 take = (k_to_find <= cumsum) & (k_to_find > prev)
388 selected_bin = tl.min(tl.where(take, bins, HISTOGRAM_BINS - 1), axis=0)
389 counts_before = tl.max(tl.where(take, prev, 0), axis=0)
391 selected_bin = selected_bin.to(utype)
392 desired = desired | (selected_bin << digit_pos)
393 desired_mask = desired_mask | (byte_mask_val << digit_pos)
394 k_to_find = k_to_find - counts_before
396 result_idx = tl.full((), N, dtype=tl.int32)
397 for start in tl.range(0, N, BLOCK_N):
398 cols = start + offsets
399 mask = cols < N
400 vals = tl.load(inp + pid * N + cols, mask=mask, other=0)
401 keys = _to_order_key(vals, mask)
402 local_idx = tl.min(tl.where(mask & (keys == desired), cols, N), axis=0)
403 result_idx = tl.where(local_idx < result_idx, local_idx, result_idx)
405 result_val = tl.load(inp + pid * N + result_idx)
406 tl.store(out_values + pid, result_val)
407 tl.store(out_indices + pid, result_idx)
410@libentry()
411@triton.jit
412def nanmedian_ascend_byte_histogram_init_kernel(
413 state,
414 M,
415 N: tl.constexpr,
416):
417 pid = tle.program_id(0)
418 base = pid * 3
419 tl.store(state + base + 0, 0)
420 tl.store(state + base + 1, 0)
421 tl.store(state + base + 2, (N + 1) // 2)
424@libentry()
425@triton.jit
426def nanmedian_ascend_byte_histogram_count_kernel(
427 inp,
428 state,
429 partial_counts,
430 M,
431 N: tl.constexpr,
432 BLOCK_N: tl.constexpr,
433 NUM_CHUNKS: tl.constexpr,
434 HISTOGRAM_BINS: tl.constexpr,
435 DIGIT_POS: tl.constexpr,
436):
437 pid_m = tle.program_id(0)
438 pid_chunk = tle.program_id(1)
439 offsets = pid_chunk * BLOCK_N + tl.arange(0, BLOCK_N)
440 bins = tl.arange(0, HISTOGRAM_BINS)
441 mask = offsets < N
443 dtype = inp.dtype.element_ty
444 nbits: tl.constexpr = dtype.primitive_bitwidth
445 utype = tl.dtype(f"uint{nbits}")
446 byte_mask_val = tl.full((), HISTOGRAM_BINS - 1, dtype=utype)
447 state_base = pid_m * 3
448 desired = tl.load(state + state_base + 0).to(utype)
449 desired_mask = tl.load(state + state_base + 1).to(utype)
451 vals = tl.load(inp + pid_m * N + offsets, mask=mask, other=0)
452 keys = _to_order_key(vals, mask)
453 active = mask & ((keys & desired_mask) == desired)
454 digit = ((keys >> DIGIT_POS) & byte_mask_val).to(tl.int32)
455 digit = tl.where(active, digit, 0)
456 counts = tl.histogram(digit, HISTOGRAM_BINS).to(tl.int32)
457 inactive_count = tl.sum((~active).to(tl.int32), axis=0)
458 counts = counts - tl.where(bins == 0, inactive_count, 0)
460 count_offsets = (pid_m * NUM_CHUNKS + pid_chunk) * HISTOGRAM_BINS + bins
461 tl.store(partial_counts + count_offsets, counts)
464@libentry()
465@triton.jit
466def nanmedian_ascend_byte_histogram_update_kernel(
467 inp,
468 partial_counts,
469 state,
470 M,
471 NUM_CHUNKS: tl.constexpr,
472 HISTOGRAM_BINS: tl.constexpr,
473 DIGIT_POS: tl.constexpr,
474):
475 pid = tle.program_id(0)
476 bins = tl.arange(0, HISTOGRAM_BINS)
477 counts = tl.zeros((HISTOGRAM_BINS,), dtype=tl.int32)
479 for chunk in tl.range(0, NUM_CHUNKS):
480 count_offsets = (pid * NUM_CHUNKS + chunk) * HISTOGRAM_BINS + bins
481 counts += tl.load(partial_counts + count_offsets)
483 state_base = pid * 3
484 k_to_find = tl.load(state + state_base + 2).to(tl.int32)
485 cumsum = tl.cumsum(counts, axis=0)
486 prev = cumsum - counts
487 take = (k_to_find <= cumsum) & (k_to_find > prev)
488 selected_bin = tl.min(tl.where(take, bins, HISTOGRAM_BINS - 1), axis=0)
489 counts_before = tl.max(tl.where(take, prev, 0), axis=0)
491 dtype = inp.dtype.element_ty
492 nbits: tl.constexpr = dtype.primitive_bitwidth
493 utype = tl.dtype(f"uint{nbits}")
494 byte_mask_val = tl.full((), HISTOGRAM_BINS - 1, dtype=utype)
495 desired = tl.load(state + state_base + 0).to(utype)
496 desired_mask = tl.load(state + state_base + 1).to(utype)
497 selected_bin = selected_bin.to(utype)
499 desired = desired | (selected_bin << DIGIT_POS)
500 desired_mask = desired_mask | (byte_mask_val << DIGIT_POS)
501 tl.store(state + state_base + 0, desired)
502 tl.store(state + state_base + 1, desired_mask)
503 tl.store(state + state_base + 2, k_to_find - counts_before)
506@libentry()
507@triton.jit
508def nanmedian_ascend_byte_histogram_find_index_kernel(
509 inp,
510 state,
511 out_values,
512 out_indices,
513 M,
514 N: tl.constexpr,
515 BLOCK_N: tl.constexpr,
516):
517 pid = tle.program_id(0)
518 offsets = tl.arange(0, BLOCK_N)
519 dtype = inp.dtype.element_ty
520 nbits: tl.constexpr = dtype.primitive_bitwidth
521 utype = tl.dtype(f"uint{nbits}")
522 desired = tl.load(state + pid * 3 + 0).to(utype)
524 result_idx = tl.full((), N, dtype=tl.int32)
525 for start in tl.range(0, N, BLOCK_N):
526 cols = start + offsets
527 mask = cols < N
528 vals = tl.load(inp + pid * N + cols, mask=mask, other=0)
529 keys = _to_order_key(vals, mask)
530 local_idx = tl.min(tl.where(mask & (keys == desired), cols, N), axis=0)
531 result_idx = tl.where(local_idx < result_idx, local_idx, result_idx)
533 result_val = tl.load(inp + pid * N + result_idx)
534 tl.store(out_values + pid, result_val)
535 tl.store(out_indices + pid, result_idx)
538@libentry()
539@triton.jit
540def nanmedian_radix_select_kernel(
541 inp,
542 out_values,
543 out_indices,
544 M,
545 N: tl.constexpr,
546 BLOCK_N: tl.constexpr,
547 RADIX_BITS_: tl.constexpr,
548 USE_ISNAN: tl.constexpr,
549 USE_HISTOGRAM: tl.constexpr,
550):
551 pid = tle.program_id(0)
552 offsets = tl.arange(0, BLOCK_N)
553 dtype = inp.dtype.element_ty
554 nbits: tl.constexpr = dtype.primitive_bitwidth
555 utype = tl.dtype(f"uint{nbits}")
556 radix_size: tl.constexpr = 1 << RADIX_BITS_
557 radix_mask: tl.constexpr = radix_size - 1
558 radix_bins = tl.arange(0, radix_size)
560 valid_count = tl.full((), 0, dtype=tl.int32)
561 for start in tl.range(0, N, BLOCK_N):
562 cols = start + offsets
563 mask = cols < N
564 vals = tl.load(inp + pid * N + cols, mask=mask, other=0.0)
565 if dtype.is_floating():
566 valid = mask & _is_not_nan(vals, USE_ISNAN)
567 else:
568 valid = mask
569 valid_count += tl.sum(valid.to(tl.int32), axis=0)
571 k_to_find = (valid_count + 1) // 2
572 desired = tl.full((), 0, dtype=utype)
573 desired_mask = tl.full((), 0, dtype=utype)
574 radix_mask_val = tl.full((), radix_mask, dtype=utype)
576 for digit_pos in tl.static_range(nbits - RADIX_BITS_, -1, -RADIX_BITS_):
577 counts = tl.zeros((radix_size,), dtype=tl.int32)
578 for start in tl.range(0, N, BLOCK_N):
579 cols = start + offsets
580 mask = cols < N
581 vals = tl.load(inp + pid * N + cols, mask=mask, other=0.0)
582 if dtype.is_floating():
583 valid = mask & _is_not_nan(vals, USE_ISNAN)
584 else:
585 valid = mask
586 keys = _to_order_key(vals, valid)
587 matches = (keys & desired_mask) == desired
588 digit = ((keys >> digit_pos) & radix_mask_val).to(tl.int32)
589 active = valid & matches
590 if USE_HISTOGRAM:
591 counts += tl.histogram(digit, radix_size, active)
592 else:
593 for radix_bin in tl.static_range(0, radix_size):
594 bin_count = tl.sum(
595 (active & (digit == radix_bin)).to(tl.int32), axis=0
596 )
597 counts += tl.where(radix_bins == radix_bin, bin_count, 0)
599 cumsum = tl.cumsum(counts, axis=0)
600 prev = cumsum - counts
601 take = (cumsum >= k_to_find) & (prev < k_to_find)
602 selected_bin = tl.min(tl.where(take, radix_bins, radix_size - 1), axis=0)
603 counts_before = tl.max(tl.where(take, prev, 0), axis=0)
605 selected_bin = selected_bin.to(utype)
606 desired = desired | (selected_bin << digit_pos)
607 desired_mask = desired_mask | (radix_mask_val << digit_pos)
608 k_to_find = k_to_find - counts_before
610 result_idx = tl.full((), N, dtype=tl.int32)
611 for start in tl.range(0, N, BLOCK_N):
612 cols = start + offsets
613 mask = cols < N
614 vals = tl.load(inp + pid * N + cols, mask=mask, other=0.0)
615 if dtype.is_floating():
616 valid = mask & _is_not_nan(vals, USE_ISNAN)
617 else:
618 valid = mask
619 keys = _to_order_key(vals, valid)
620 local_idx = tl.min(tl.where(valid & (keys == desired), cols, N), axis=0)
621 result_idx = tl.where(local_idx < result_idx, local_idx, result_idx)
623 if dtype.is_floating():
624 fallback_value = _get_finfo_val(dtype, return_max=False)
625 else:
626 fallback_value = get_dtype_min(dtype)
627 result_val = tl.load(
628 inp + pid * N + result_idx, mask=valid_count > 0, other=fallback_value
629 )
631 if dtype.is_floating():
632 all_nan = valid_count == 0
633 result_val = tl.where(all_nan, float("nan"), result_val)
634 result_idx = tl.where(all_nan, 0, result_idx)
636 tl.store(out_values + pid, result_val)
637 tl.store(out_indices + pid, result_idx)
640@libentry()
641@triton.jit
642def flat_radix_init_kernel(
643 valid_count,
644 state,
645 result_idx,
646 N: tl.constexpr,
647 IS_FLOAT: tl.constexpr,
648):
649 tl.store(valid_count, 0 if IS_FLOAT else N)
650 tl.store(state + 0, 0)
651 tl.store(state + 1, 0)
652 tl.store(state + 2, 0)
653 tl.store(result_idx, N)
656@libentry()
657@triton.jit
658def flat_radix_count_valid_kernel(
659 inp,
660 valid_count,
661 N: tl.constexpr,
662 BLOCK_N: tl.constexpr,
663 USE_ISNAN: tl.constexpr,
664):
665 pid = tle.program_id(0)
666 offsets = pid * BLOCK_N + tl.arange(0, BLOCK_N)
667 mask = offsets < N
668 vals = tl.load(inp + offsets, mask=mask, other=0.0)
669 valid = mask & _is_not_nan(vals, USE_ISNAN)
670 count = tl.sum(valid.to(tl.int64), axis=0)
671 tl.atomic_add(valid_count, count, sem="relaxed")
674@libentry()
675@triton.jit
676def flat_radix_init_rank_kernel(valid_count, state):
677 count = tl.load(valid_count)
678 tl.store(state + 2, (count + 1) // 2)
681@libentry()
682@triton.jit
683def flat_radix_count_kernel(
684 inp,
685 bin_counts,
686 state,
687 N: tl.constexpr,
688 BLOCK_N: tl.constexpr,
689 DIGIT_POS: tl.constexpr,
690 RADIX_BITS_: tl.constexpr,
691 RADIX_SIZE: tl.constexpr,
692 USE_ISNAN: tl.constexpr,
693 USE_HISTOGRAM: tl.constexpr,
694):
695 pid = tle.program_id(0)
696 offsets = pid * BLOCK_N + tl.arange(0, BLOCK_N)
697 mask = offsets < N
698 vals = tl.load(inp + offsets, mask=mask, other=0.0)
699 dtype = inp.dtype.element_ty
700 nbits: tl.constexpr = dtype.primitive_bitwidth
701 utype = tl.dtype(f"uint{nbits}")
702 radix_mask: tl.constexpr = (1 << RADIX_BITS_) - 1
703 radix_mask_val = tl.full((), radix_mask, dtype=utype)
705 if dtype.is_floating():
706 valid = mask & _is_not_nan(vals, USE_ISNAN)
707 else:
708 valid = mask
710 desired = tl.load(state + 0).to(utype)
711 desired_mask = tl.load(state + 1).to(utype)
712 keys = _to_order_key(vals, valid)
713 active = valid & ((keys & desired_mask) == desired)
714 digit = ((keys >> DIGIT_POS) & radix_mask_val).to(tl.int32)
715 bins = tl.arange(0, RADIX_SIZE)
716 counts = tl.zeros((RADIX_SIZE,), dtype=tl.int64)
717 if USE_HISTOGRAM:
718 counts = tl.histogram(digit, RADIX_SIZE, active).to(tl.int64)
719 else:
720 for radix_bin in tl.static_range(0, RADIX_SIZE):
721 bin_count = tl.sum((active & (digit == radix_bin)).to(tl.int64), axis=0)
722 counts += tl.where(bins == radix_bin, bin_count, 0)
723 tl.atomic_add(bin_counts + bins, counts, sem="relaxed")
726@libentry()
727@triton.jit
728def flat_radix_update_kernel(
729 bin_counts,
730 state,
731 DIGIT_POS: tl.constexpr,
732 RADIX_BITS_: tl.constexpr,
733 RADIX_SIZE: tl.constexpr,
734):
735 bins = tl.arange(0, RADIX_SIZE)
736 counts = tl.load(bin_counts + bins)
737 k_to_find = tl.load(state + 2)
738 cumsum = tl.cumsum(counts, axis=0)
739 prev = cumsum - counts
740 take = (k_to_find <= cumsum) & (k_to_find > prev)
741 selected_bin = tl.min(tl.where(take, bins, RADIX_SIZE - 1), axis=0).to(tl.int64)
742 counts_before = tl.max(tl.where(take, prev, 0), axis=0)
744 desired = tl.load(state + 0)
745 desired_mask = tl.load(state + 1)
746 radix_mask: tl.constexpr = (1 << RADIX_BITS_) - 1
747 desired = desired | (selected_bin << DIGIT_POS)
748 desired_mask = desired_mask | (radix_mask << DIGIT_POS)
749 tl.store(state + 0, desired)
750 tl.store(state + 1, desired_mask)
751 tl.store(state + 2, k_to_find - counts_before)
754@libentry()
755@triton.jit
756def flat_radix_find_index_kernel(
757 inp,
758 state,
759 valid_count,
760 result_idx,
761 N: tl.constexpr,
762 BLOCK_N: tl.constexpr,
763 USE_ISNAN: tl.constexpr,
764):
765 if tl.load(valid_count) > 0:
766 pid = tle.program_id(0)
767 offsets = pid * BLOCK_N + tl.arange(0, BLOCK_N)
768 mask = offsets < N
769 vals = tl.load(inp + offsets, mask=mask, other=0.0)
770 dtype = inp.dtype.element_ty
771 nbits: tl.constexpr = dtype.primitive_bitwidth
772 utype = tl.dtype(f"uint{nbits}")
774 if dtype.is_floating():
775 valid = mask & _is_not_nan(vals, USE_ISNAN)
776 else:
777 valid = mask
779 desired = tl.load(state + 0).to(utype)
780 keys = _to_order_key(vals, valid)
781 local_idx = tl.min(tl.where(valid & (keys == desired), offsets, N), axis=0)
782 tl.atomic_min(result_idx, local_idx, sem="relaxed")
785@libentry()
786@triton.jit
787def flat_radix_store_result_kernel(inp, out, valid_count, result_idx):
788 dtype = inp.dtype.element_ty
789 idx = tl.load(result_idx)
790 if dtype.is_floating():
791 result = tl.load(inp + idx, mask=tl.load(valid_count) > 0, other=float("nan"))
792 else:
793 result = tl.load(inp + idx)
794 tl.store(out, result)
797def _check_supported_dtype(inp):
798 if inp.dtype is torch.bool:
799 raise NotImplementedError("\"median_out_impl\" not implemented for 'Bool'")
802def _normalize_dim(dim, ndim):
803 if ndim == 0:
804 if dim in (0, -1):
805 return 0
806 elif -ndim <= dim < ndim:
807 return dim % ndim
808 raise IndexError(
809 f"Dimension out of range (expected to be in range of [{-ndim}, {ndim - 1}], but got {dim})"
810 )
813def _empty_flat_value(inp):
814 out = torch.empty((), dtype=inp.dtype, device=inp.device)
815 if torch.is_floating_point(inp):
816 out.fill_(float("nan"))
817 elif inp.is_cuda:
818 out.fill_(torch.iinfo(inp.dtype).min)
819 else:
820 out.zero_()
821 return out
824def _radix_block_n(inp, n):
825 block_n = triton.next_power_of_2(n)
826 if inp.is_cuda:
827 if n > LARGE_FLOAT_REDUCTION_N:
828 return min(block_n, 8192)
829 if n > MEDIUM_REDUCTION_N:
830 return min(block_n, 4096)
831 if inp.dtype is torch.uint8:
832 return min(block_n, 512)
833 return min(block_n, RADIX_BLOCK_N)
834 if inp.dtype in (torch.float16, torch.bfloat16):
835 if n > LARGE_FLOAT_REDUCTION_N:
836 return 2048
837 return min(block_n, 2048)
838 if inp.dtype is torch.float32 or inp.dtype is torch.int32:
839 if n > MEDIUM_REDUCTION_N:
840 return 512
841 return min(block_n, RADIX_BLOCK_N)
842 if inp.dtype in (torch.int8, torch.uint8):
843 if n > MEDIUM_REDUCTION_N:
844 return RADIX_BLOCK_N
845 return min(block_n, 512)
846 return min(block_n, RADIX_BLOCK_N)
849def _radix_bits(inp, n):
850 if inp.is_cuda:
851 if n > LARGE_FLOAT_REDUCTION_N:
852 return 8
853 if n > MEDIUM_REDUCTION_N:
854 return 4
855 return RADIX_BITS
858def _full_nan_result(shape, dtype, device):
859 values = torch.full(shape, float("nan"), dtype=dtype, device=device)
860 indices = torch.zeros(shape, dtype=torch.long, device=device)
861 return NanMedian(values=values, indices=indices)
864def _count_block_n(inp, n):
865 block_n = triton.next_power_of_2(n)
866 if inp.is_cuda and n >= LONG_RADIX_REDUCTION_N:
867 return min(block_n, 16384)
868 if n >= LONG_RADIX_REDUCTION_N:
869 return min(block_n, 4096)
870 if n >= LARGE_FLOAT_REDUCTION_N:
871 return min(block_n, 2048)
872 return min(block_n, RADIX_BLOCK_N)
875def _nanmedian_kthvalue_fallback(inp, M, N):
876 inp = inp.reshape(M, N)
877 if torch.is_floating_point(inp):
878 valid_count = torch.empty((M,), dtype=torch.long, device=inp.device)
879 block_n = _count_block_n(inp, N)
880 with torch_device_fn.device(inp.device):
881 count_valid_kernel[(M,)](inp, valid_count, M, N, block_n, inp.is_cuda)
882 min_count = int(torch.min(valid_count).item())
883 max_count = int(torch.max(valid_count).item())
884 if min_count == max_count:
885 if max_count == 0:
886 return _full_nan_result((M,), inp.dtype, inp.device)
887 values, indices = torch.kthvalue(inp, (max_count + 1) // 2, dim=1)
888 return NanMedian(values=values, indices=indices)
890 if max_count - min_count <= 1:
891 min_k = (min_count + 1) // 2 if min_count > 0 else 0
892 max_k = (max_count + 1) // 2
894 if min_k == max_k:
895 values, indices = torch.kthvalue(inp, max_k, dim=1)
896 if min_count > 0:
897 return NanMedian(values=values, indices=indices)
898 fallback = _full_nan_result((M,), inp.dtype, inp.device)
899 positive = valid_count > 0
900 return NanMedian(
901 values=torch.where(positive, values, fallback.values),
902 indices=torch.where(positive, indices, fallback.indices),
903 )
905 result = _full_nan_result((M,), inp.dtype, inp.device)
907 if min_count > 0:
908 values, indices = torch.kthvalue(inp, min_k, dim=1)
909 mask = valid_count == min_count
910 result = NanMedian(
911 values=torch.where(mask, values, result.values),
912 indices=torch.where(mask, indices, result.indices),
913 )
915 values, indices = torch.kthvalue(inp, max_k, dim=1)
916 mask = valid_count == max_count
917 return NanMedian(
918 values=torch.where(mask, values, result.values),
919 indices=torch.where(mask, indices, result.indices),
920 )
922 result = _full_nan_result((M,), inp.dtype, inp.device)
923 for count in torch.unique(valid_count).tolist():
924 count = int(count)
925 if count == 0:
926 continue
927 row_indices = torch.nonzero(valid_count == count).flatten()
928 rows = torch.index_select(inp, 0, row_indices)
929 values, indices = torch.kthvalue(rows, (count + 1) // 2, dim=1)
930 result.values[row_indices] = values
931 result.indices[row_indices] = indices
932 return result
933 else:
934 if inp.device.type == "npu" and inp.dtype in (torch.int32, torch.int64):
935 sorted_values, sorted_indices = torch.sort(inp, dim=1)
936 kth = (N + 1) // 2 - 1
937 values = sorted_values[:, kth]
938 indices = sorted_indices[:, kth]
939 return NanMedian(values=values, indices=indices)
940 values, indices = torch.kthvalue(inp, (N + 1) // 2, dim=1)
941 return NanMedian(values=values, indices=indices)
944def _nanmedian_ascend_float_sort_select(inp, M, N, values, indices):
945 inp = inp.reshape(M, N)
946 flat_values = values.reshape(M)
947 flat_indices = indices.reshape(M)
948 if N <= LARGE_FLOAT_REDUCTION_N:
949 cleaned = torch.empty_like(inp)
950 valid_counts = torch.empty((M,), dtype=torch.int32, device=inp.device)
951 block_n = min(triton.next_power_of_2(N), RADIX_BLOCK_N)
952 num_warps = 4 if block_n <= 512 else 8
953 with torch_device_fn.device(inp.device):
954 nanmedian_float_clean_count_kernel[(M,)](
955 inp,
956 cleaned,
957 valid_counts,
958 N,
959 block_n,
960 num_warps=num_warps,
961 num_stages=1,
962 )
963 sorted_values, sorted_indices = torch.sort(cleaned, dim=1)
964 else:
965 sorted_values, sorted_indices = torch.sort(inp, dim=1)
966 valid_counts = torch.sum(
967 (sorted_values == sorted_values).to(torch.int32), dim=1
968 )
970 with torch_device_fn.device(inp.device):
971 nanmedian_float_sorted_gather_kernel[(M,)](
972 sorted_values,
973 sorted_indices,
974 valid_counts,
975 flat_values,
976 flat_indices,
977 N,
978 num_warps=1,
979 num_stages=1,
980 )
983def _nanmedian_dim_impl(inp, dim, keepdim, out=None, use_ascend_float_select=True):
984 dim = _normalize_dim(dim, inp.ndim)
986 if inp.ndim == 0:
987 if out is None:
988 values = inp.clone()
989 indices = torch.zeros((), dtype=torch.long, device=inp.device)
990 else:
991 values, indices = out
992 values.copy_(inp)
993 indices.zero_()
994 return NanMedian(values=values, indices=indices)
996 shape = list(inp.shape)
997 N = shape[dim]
998 out_shape = shape[:dim] + shape[dim + 1 :]
999 M = math.prod(out_shape)
1001 keepdim_shape = shape.copy()
1002 keepdim_shape[dim] = 1
1003 output_shape = keepdim_shape if keepdim else out_shape
1004 compute_shape = output_shape if out is not None else keepdim_shape
1006 if N == 0:
1007 if M != 0:
1008 raise IndexError(
1009 f"median(): Expected reduction dim {dim} to have non-zero size."
1010 )
1011 if out is None:
1012 values = torch.empty(compute_shape, dtype=inp.dtype, device=inp.device)
1013 indices = torch.empty(compute_shape, dtype=torch.long, device=inp.device)
1014 if not keepdim:
1015 values = torch.squeeze(values, dim)
1016 indices = torch.squeeze(indices, dim)
1017 else:
1018 values, indices = out
1019 return NanMedian(values=values, indices=indices)
1021 if out is None:
1022 values = torch.empty(compute_shape, dtype=inp.dtype, device=inp.device)
1023 indices = torch.empty(compute_shape, dtype=torch.long, device=inp.device)
1024 else:
1025 values, indices = out
1027 if M == 0:
1028 if out is None and not keepdim:
1029 values = torch.squeeze(values, dim)
1030 indices = torch.squeeze(indices, dim)
1031 return NanMedian(values=values, indices=indices)
1033 inp = dim_compress(inp, dim)
1034 is_cuda = inp.is_cuda
1035 is_ascend = inp.device.type == "npu"
1036 in_radix_range = MAX_BLOCK_N < N <= LONG_RADIX_REDUCTION_N
1037 use_cuda_histogram = (
1038 is_cuda
1039 and CUDA_SUPPORTS_MASKED_HISTOGRAM
1040 and N > MAX_BLOCK_N
1041 and N == triton.next_power_of_2(N)
1042 )
1043 use_ascend_float_select_path = (
1044 use_ascend_float_select
1045 and is_ascend
1046 and inp.dtype in ASCEND_FLOAT_SELECT_DTYPES
1047 and in_radix_range
1048 )
1049 use_ascend_histogram = (
1050 is_ascend and inp.dtype in ASCEND_HISTOGRAM_SELECT_DTYPES and in_radix_range
1051 )
1052 use_ascend_byte_histogram = (
1053 is_ascend
1054 and inp.dtype in ASCEND_BYTE_HISTOGRAM_SELECT_DTYPES
1055 and in_radix_range
1056 )
1058 if is_cuda and inp.dtype in RADIX_SELECT_DTYPES and in_radix_range:
1059 flat_values = values.reshape(M)
1060 flat_indices = indices.reshape(M)
1061 block_n = _radix_block_n(inp, N)
1062 num_warps = 4 if block_n <= 512 else 8
1063 with torch_device_fn.device(inp.device):
1064 nanmedian_radix_select_kernel[(M,)](
1065 inp,
1066 flat_values,
1067 flat_indices,
1068 M,
1069 N,
1070 block_n,
1071 _radix_bits(inp, N) if use_cuda_histogram else RADIX_BITS,
1072 is_cuda,
1073 use_cuda_histogram,
1074 num_warps=num_warps,
1075 num_stages=1,
1076 )
1077 elif use_ascend_float_select_path:
1078 _nanmedian_ascend_float_sort_select(inp, M, N, values, indices)
1079 elif use_ascend_histogram and N >= ASCEND_MULTI_HISTOGRAM_MIN_N:
1080 flat_values = values.reshape(M)
1081 flat_indices = indices.reshape(M)
1082 block_n = _radix_block_n(inp, N)
1083 num_chunks = triton.cdiv(N, block_n)
1084 partial_counts = torch.empty(
1085 (M, num_chunks, ASCEND_HISTOGRAM_BINS),
1086 dtype=torch.int32,
1087 device=inp.device,
1088 )
1089 num_warps = 4 if block_n <= 512 else 8
1090 with torch_device_fn.device(inp.device):
1091 nanmedian_ascend_histogram_count_kernel[(M, num_chunks)](
1092 inp,
1093 partial_counts,
1094 M,
1095 N,
1096 block_n,
1097 num_chunks,
1098 ASCEND_HISTOGRAM_BINS,
1099 num_warps=num_warps,
1100 num_stages=1,
1101 )
1102 nanmedian_ascend_histogram_reduce_kernel[(M,)](
1103 inp,
1104 partial_counts,
1105 flat_values,
1106 flat_indices,
1107 M,
1108 N,
1109 block_n,
1110 num_chunks,
1111 ASCEND_HISTOGRAM_BINS,
1112 num_warps=num_warps,
1113 num_stages=1,
1114 )
1115 elif use_ascend_histogram:
1116 flat_values = values.reshape(M)
1117 flat_indices = indices.reshape(M)
1118 block_n = _radix_block_n(inp, N)
1119 num_warps = 4 if block_n <= 512 else 8
1120 with torch_device_fn.device(inp.device):
1121 nanmedian_ascend_histogram_select_kernel[(M,)](
1122 inp,
1123 flat_values,
1124 flat_indices,
1125 M,
1126 N,
1127 block_n,
1128 ASCEND_HISTOGRAM_BINS,
1129 num_warps=num_warps,
1130 num_stages=1,
1131 )
1132 elif use_ascend_byte_histogram and N >= ASCEND_MULTI_HISTOGRAM_MIN_N:
1133 flat_values = values.reshape(M)
1134 flat_indices = indices.reshape(M)
1135 block_n = _radix_block_n(inp, N)
1136 num_chunks = triton.cdiv(N, block_n)
1137 partial_counts = torch.empty(
1138 (M, num_chunks, ASCEND_HISTOGRAM_BINS),
1139 dtype=torch.int32,
1140 device=inp.device,
1141 )
1142 state = torch.empty((M, 3), dtype=torch.int64, device=inp.device)
1143 num_warps = 4 if block_n <= 512 else 8
1144 nbits = inp.element_size() * 8
1145 with torch_device_fn.device(inp.device):
1146 nanmedian_ascend_byte_histogram_init_kernel[(M,)](
1147 state,
1148 M,
1149 N,
1150 num_warps=1,
1151 num_stages=1,
1152 )
1153 for digit_pos in range(nbits - 8, -1, -8):
1154 nanmedian_ascend_byte_histogram_count_kernel[(M, num_chunks)](
1155 inp,
1156 state,
1157 partial_counts,
1158 M,
1159 N,
1160 block_n,
1161 num_chunks,
1162 ASCEND_HISTOGRAM_BINS,
1163 digit_pos,
1164 num_warps=num_warps,
1165 num_stages=1,
1166 )
1167 nanmedian_ascend_byte_histogram_update_kernel[(M,)](
1168 inp,
1169 partial_counts,
1170 state,
1171 M,
1172 num_chunks,
1173 ASCEND_HISTOGRAM_BINS,
1174 digit_pos,
1175 num_warps=num_warps,
1176 num_stages=1,
1177 )
1178 nanmedian_ascend_byte_histogram_find_index_kernel[(M,)](
1179 inp,
1180 state,
1181 flat_values,
1182 flat_indices,
1183 M,
1184 N,
1185 block_n,
1186 num_warps=num_warps,
1187 num_stages=1,
1188 )
1189 elif use_ascend_byte_histogram:
1190 flat_values = values.reshape(M)
1191 flat_indices = indices.reshape(M)
1192 block_n = _radix_block_n(inp, N)
1193 num_warps = 4 if block_n <= 512 else 8
1194 with torch_device_fn.device(inp.device):
1195 nanmedian_ascend_byte_histogram_select_kernel[(M,)](
1196 inp,
1197 flat_values,
1198 flat_indices,
1199 M,
1200 N,
1201 block_n,
1202 ASCEND_HISTOGRAM_BINS,
1203 num_warps=num_warps,
1204 num_stages=1,
1205 )
1206 elif N <= MAX_BLOCK_N and inp.dtype is not torch.float64:
1207 flat_values = values.reshape(M)
1208 flat_indices = indices.reshape(M)
1209 block_n = triton.next_power_of_2(N)
1210 with torch_device_fn.device(inp.device):
1211 nanmedian_select_kernel[(M,)](
1212 inp,
1213 flat_values,
1214 flat_indices,
1215 M,
1216 N,
1217 block_n,
1218 is_cuda,
1219 )
1220 else:
1221 result = _nanmedian_kthvalue_fallback(inp, M, N)
1222 computed_values = result.values.reshape(compute_shape)
1223 computed_indices = result.indices.reshape(compute_shape)
1224 if out is None:
1225 values = computed_values
1226 indices = computed_indices
1227 else:
1228 values.copy_(computed_values)
1229 indices.copy_(computed_indices)
1231 if out is None and not keepdim:
1232 values = torch.squeeze(values, dim)
1233 indices = torch.squeeze(indices, dim)
1235 return NanMedian(values=values, indices=indices)
1238def _nanmedian_ascend_flat_sort(inp):
1239 flat = inp.reshape(-1).contiguous()
1240 sorted_values = torch.sort(flat).values
1241 if torch.is_floating_point(flat):
1242 valid_count = (sorted_values == sorted_values).sum()
1243 rank = (valid_count - 1) // 2
1244 else:
1245 rank = (flat.numel() - 1) // 2
1246 return sorted_values[rank]
1249def _nanmedian_cuda_flat_radix_select(inp, out=None):
1250 flat = inp.reshape(-1).contiguous()
1251 n = flat.numel()
1252 if out is None:
1253 out = torch.empty((), dtype=flat.dtype, device=flat.device)
1254 valid_count = torch.empty((), dtype=torch.int64, device=flat.device)
1255 state = torch.empty((3,), dtype=torch.int64, device=flat.device)
1256 result_idx = torch.empty((), dtype=torch.int64, device=flat.device)
1257 block_n = min(triton.next_power_of_2(n), FLAT_RADIX_BLOCK_N)
1258 grid = (triton.cdiv(n, block_n),)
1259 nbits = flat.element_size() * 8
1260 use_histogram = CUDA_SUPPORTS_MASKED_HISTOGRAM and n % block_n == 0
1261 radix_bits = FLAT_RADIX_BITS if use_histogram else RADIX_BITS
1262 radix_size = 1 << radix_bits
1263 bin_counts = torch.empty((radix_size,), dtype=torch.int64, device=flat.device)
1265 with torch_device_fn.device(flat.device):
1266 flat_radix_init_kernel[(1,)](
1267 valid_count,
1268 state,
1269 result_idx,
1270 n,
1271 torch.is_floating_point(flat),
1272 )
1273 if torch.is_floating_point(flat):
1274 flat_radix_count_valid_kernel[grid](
1275 flat,
1276 valid_count,
1277 n,
1278 block_n,
1279 True,
1280 num_warps=8,
1281 num_stages=1,
1282 )
1283 flat_radix_init_rank_kernel[(1,)](valid_count, state)
1284 for digit_pos in range(nbits - radix_bits, -1, -radix_bits):
1285 bin_counts.zero_()
1286 flat_radix_count_kernel[grid](
1287 flat,
1288 bin_counts,
1289 state,
1290 n,
1291 block_n,
1292 digit_pos,
1293 radix_bits,
1294 radix_size,
1295 True,
1296 use_histogram,
1297 num_warps=8,
1298 num_stages=1,
1299 )
1300 flat_radix_update_kernel[(1,)](
1301 bin_counts,
1302 state,
1303 digit_pos,
1304 radix_bits,
1305 radix_size,
1306 num_warps=8,
1307 num_stages=1,
1308 )
1309 flat_radix_find_index_kernel[grid](
1310 flat,
1311 state,
1312 valid_count,
1313 result_idx,
1314 n,
1315 block_n,
1316 True,
1317 num_warps=8,
1318 num_stages=1,
1319 )
1320 flat_radix_store_result_kernel[(1,)](flat, out, valid_count, result_idx)
1321 return out
1324def _nanmedian_flat_impl(inp, out=None):
1325 n = inp.numel()
1326 if n == 0:
1327 result = _empty_flat_value(inp)
1328 if out is not None:
1329 out.copy_(result)
1330 return out
1331 return result
1333 if (
1334 inp.is_cuda
1335 and inp.dtype in RADIX_SELECT_DTYPES
1336 and LONG_RADIX_REDUCTION_N < n <= INT32_MAX
1337 ):
1338 return _nanmedian_cuda_flat_radix_select(inp, out=out)
1340 if (
1341 inp.device.type == "npu"
1342 and inp.dtype in ASCEND_FLAT_SORT_DTYPES
1343 and n >= ASCEND_FLAT_SORT_MIN_N
1344 ):
1345 result = _nanmedian_ascend_flat_sort(inp)
1346 if out is not None:
1347 out.copy_(result)
1348 return out
1349 return result
1351 flat = inp.reshape(-1)
1352 if out is None:
1353 return _nanmedian_dim_impl(flat, 0, False, use_ascend_float_select=False).values
1355 indices = torch.empty((), dtype=torch.long, device=inp.device)
1356 _nanmedian_dim_impl(
1357 flat,
1358 0,
1359 False,
1360 out=(out, indices),
1361 use_ascend_float_select=False,
1362 )
1363 return out
1366def nanmedian(inp):
1367 logger.debug("GEMS NANMEDIAN")
1368 _check_supported_dtype(inp)
1369 return _nanmedian_flat_impl(inp)
1372def nanmedian_out(inp, *, out):
1373 logger.debug("GEMS NANMEDIAN OUT")
1374 _check_supported_dtype(inp)
1375 return _nanmedian_flat_impl(inp, out=out)
1378def nanmedian_dim(inp, dim=-1, keepdim=False):
1379 logger.debug("GEMS NANMEDIAN DIM")
1380 _check_supported_dtype(inp)
1381 return _nanmedian_dim_impl(inp, dim, keepdim)
1384def nanmedian_dim_values(inp, dim=-1, keepdim=False, *, values, indices):
1385 logger.debug("GEMS NANMEDIAN DIM VALUES")
1386 return _nanmedian_dim_impl(inp, dim, keepdim, out=(values, indices))