Coverage for src/flag_gems/runtime/backend/_ascend/ops/cummin.py: 0%
260 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
2import math
3from typing import List, Tuple, Union
5import torch
6import triton
7import triton.language as tl
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as tle
12from flag_gems.utils.limits import get_dtype_max
14Tensor = torch.Tensor
16logger = logging.getLogger(__name__)
19@triton.jit
20def tl_cummin(input, index, axis=0):
21 return tl.associative_scan(
22 (input, index), axis, tle.minimum_with_index_tie_break_right
23 )
26@triton.jit
27def tl_min_tie_break_right(input, index, axis=None, keep_dims=False):
28 return tl.reduce(
29 (input, index),
30 axis,
31 tle.minimum_with_index_tie_break_right,
32 keep_dims=keep_dims,
33 )
36@libentry()
37@triton.jit(do_not_specialize=["n_elements"])
38def add_base_min_kernel(
39 out,
40 out_indices,
41 partial_min,
42 partial_min_indices,
43 n_elements,
44 BLOCK_SIZE: tl.constexpr,
45):
46 pid = tle.program_id(0)
47 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
48 mask = offset < n_elements
50 out_ptrs = out + offset
51 out_indices_ptrs = out_indices + offset
52 out_vals = tl.load(out_ptrs, mask=mask)
53 out_indices = tl.load(out_indices_ptrs, mask=mask)
55 if pid > 0:
56 partial_min_ptrs = partial_min + pid - 1
57 last_part_min_via_min = tl.load(partial_min_ptrs)
58 partial_min_indices_ptrs = partial_min_indices + pid - 1
59 last_part_min_index_via_min = tl.load(partial_min_indices_ptrs)
61 use_cur = out_vals < last_part_min_via_min
62 equal = out_vals == last_part_min_via_min
63 cur_is_nan = out_vals != out_vals
64 prev_is_nan = last_part_min_via_min != last_part_min_via_min
65 use_cur |= cur_is_nan & ~prev_is_nan
66 equal |= cur_is_nan & prev_is_nan
67 use_cur |= equal & (out_indices > last_part_min_index_via_min)
69 final_vals = tl.where(use_cur, out_vals, last_part_min_via_min)
70 final_indices = tl.where(use_cur, out_indices, last_part_min_index_via_min)
71 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask)
72 tl.store(out_indices_ptrs, final_indices, mask=mask)
75@libentry()
76@triton.jit(do_not_specialize=["n_elements"])
77def scan_part_min_kernel(
78 inp,
79 out,
80 in_indices,
81 out_indices,
82 partial_min,
83 partial_min_indices,
84 n_elements,
85 BLOCK_SIZE: tl.constexpr,
86 NEED_PARTIAL: tl.constexpr,
87 USE_OUT_INDICES: tl.constexpr,
88):
89 pid = tle.program_id(0)
90 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
91 mask = offset < n_elements
93 max_value = get_dtype_max(inp.type.element_ty)
94 inp_ptrs = inp + offset
95 inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value)
96 if (
97 tl.constexpr(inp_vals.dtype.is_int64())
98 or tl.constexpr(inp_vals.dtype.is_uint64())
99 ) or tl.constexpr(inp_vals.dtype.is_fp64()):
100 inp_vals = inp_vals
101 elif tl.constexpr(inp_vals.dtype.is_int()):
102 inp_vals = inp_vals.to(tl.int32)
103 else:
104 inp_vals = inp_vals.to(tl.float32)
105 if tl.constexpr(USE_OUT_INDICES):
106 in_indices_ptrs = out_indices + offset
107 in_indices_vals = tl.load(in_indices_ptrs, mask=mask)
108 else:
109 in_indices_vals = offset
110 result, cummin_indices = tl_cummin(inp_vals, in_indices_vals, axis=0)
112 if tl.constexpr(NEED_PARTIAL):
113 part_min_via_min, part_min_indices_via_min = tl_min_tie_break_right(
114 inp_vals, in_indices_vals, axis=0
115 )
116 if tl.constexpr(not USE_OUT_INDICES):
117 part_min_indices_via_min = pid * BLOCK_SIZE + part_min_indices_via_min
119 out_ptrs = out + offset
120 tl.store(out_ptrs, result, mask=mask)
122 out_indices_ptrs = out_indices + offset
123 tl.store(out_indices_ptrs, cummin_indices, mask=mask)
125 if tl.constexpr(NEED_PARTIAL):
126 partial_min_ptrs = partial_min + pid
127 tl.store(partial_min_ptrs, part_min_via_min)
129 partial_min_indices_ptrs = partial_min_indices + pid
130 tl.store(partial_min_indices_ptrs, part_min_indices_via_min)
133def scan_then_fan_col(inp, out, out_indices, n_ele, dtype, use_out_indices=False):
134 BLOCK_SIZE = 512
135 if n_ele <= BLOCK_SIZE:
136 BLOCK_SIZE = triton.next_power_of_2(n_ele)
137 part_num = math.ceil(n_ele / BLOCK_SIZE)
138 need_partial = True if part_num >= 2 else False
139 if need_partial:
140 partial_min = torch.empty(part_num, dtype=dtype, device=inp.device)
141 partial_min_indices = torch.empty(
142 part_num, dtype=torch.int64, device=inp.device
143 )
144 else:
145 partial_min = None
146 partial_min_indices = None
148 grid = (part_num,)
149 with torch_device_fn.device(inp.device):
150 scan_part_min_kernel[grid](
151 inp,
152 out,
153 out_indices,
154 out_indices,
155 partial_min,
156 partial_min_indices,
157 n_ele,
158 BLOCK_SIZE,
159 need_partial,
160 use_out_indices,
161 )
163 if part_num >= 2:
164 scan_then_fan_col(
165 partial_min,
166 partial_min,
167 partial_min_indices,
168 part_num,
169 dtype,
170 use_out_indices=True,
171 )
172 with torch_device_fn.device(inp.device):
173 add_base_min_kernel[grid](
174 out, out_indices, partial_min, partial_min_indices, n_ele, BLOCK_SIZE
175 )
178@libentry()
179@triton.jit(do_not_specialize=["part_num"])
180def scan_part_min_abc_kernel(
181 inp,
182 out,
183 in_indices,
184 out_indices,
185 partial_min,
186 partial_min_indices,
187 B,
188 C,
189 part_num,
190 BLOCK_SIZE: tl.constexpr,
191 NEED_PARTIAL: tl.constexpr,
192 USE_OUT_INDICES: tl.constexpr,
193):
194 pid_a = tle.program_id(0)
195 pid_b = tle.program_id(1)
196 pid_c = tle.program_id(2)
198 a_idx = pid_a
199 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
200 c_idx = pid_c
202 offset = a_idx * B * C + b_idx * C + c_idx
203 base_part_offset = a_idx * part_num * C + c_idx
204 part_offset = base_part_offset + pid_b * C
206 mask = b_idx < B
207 inp_ptrs = inp + offset
208 max_value = get_dtype_max(inp.type.element_ty)
209 inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value)
210 if (
211 tl.constexpr(inp_vals.dtype.is_int64())
212 or tl.constexpr(inp_vals.dtype.is_uint64())
213 ) or tl.constexpr(inp_vals.dtype.is_fp64()):
214 inp_vals = inp_vals
215 elif tl.constexpr(inp_vals.dtype.is_int()):
216 inp_vals = inp_vals.to(tl.int32)
217 else:
218 inp_vals = inp_vals.to(tl.float32)
219 if tl.constexpr(USE_OUT_INDICES):
220 in_indices_ptrs = out_indices + offset
221 in_indices_vals = tl.load(in_indices_ptrs, mask=mask)
222 else:
223 in_indices_vals = b_idx
224 result, cummin_indices = tl_cummin(inp_vals, in_indices_vals, axis=0)
226 if tl.constexpr(NEED_PARTIAL):
227 part_min_via_min, part_min_indices_via_min = tl_min_tie_break_right(
228 inp_vals, in_indices_vals, axis=0
229 )
230 if tl.constexpr(not USE_OUT_INDICES):
231 part_min_indices_via_min = pid_b * BLOCK_SIZE + part_min_indices_via_min
233 out_ptrs = out + offset
234 tl.store(out_ptrs, result, mask=mask)
236 out_indices_ptrs = out_indices + offset
237 tl.store(out_indices_ptrs, cummin_indices, mask=mask)
239 if tl.constexpr(NEED_PARTIAL):
240 partial_min_ptrs = partial_min + part_offset
241 tl.store(partial_min_ptrs, part_min_via_min)
243 partial_min_indices_ptrs = partial_min_indices + part_offset
244 tl.store(partial_min_indices_ptrs, part_min_indices_via_min)
247@libentry()
248@triton.jit(do_not_specialize=["part_num"])
249def add_base_min_abc_kernel(
250 out,
251 out_indices,
252 partial_min,
253 partial_min_indices,
254 B,
255 C,
256 part_num,
257 BLOCK_SIZE: tl.constexpr,
258):
259 pid_a = tle.program_id(0)
260 pid_b = tle.program_id(1)
261 pid_c = tle.program_id(2)
263 a_idx = pid_a
264 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
265 c_idx = pid_c
267 base_offset = a_idx * B * C + c_idx
268 offset = base_offset + b_idx * C
269 base_part_offset = a_idx * part_num * C + c_idx
270 last_part_offset = base_part_offset + (pid_b - 1) * C
272 mask = b_idx < B
273 out_ptrs = out + offset
274 out_vals = tl.load(out_ptrs, mask=mask)
275 out_indices_ptrs = out_indices + offset
276 out_indices = tl.load(out_indices_ptrs, mask=mask)
278 if pid_b > 0:
279 partial_min_ptrs = partial_min + last_part_offset
280 last_part_min_via_min = tl.load(partial_min_ptrs)
281 partial_min_index_ptrs = partial_min_indices + last_part_offset
282 last_part_min_index_via_min = tl.load(partial_min_index_ptrs)
284 use_cur = out_vals < last_part_min_via_min
285 equal = out_vals == last_part_min_via_min
286 cur_is_nan = out_vals != out_vals
287 prev_is_nan = last_part_min_via_min != last_part_min_via_min
288 use_cur |= cur_is_nan & ~prev_is_nan
289 equal |= cur_is_nan & prev_is_nan
290 use_cur |= equal & (out_indices > last_part_min_index_via_min)
292 final_vals = tl.where(use_cur, out_vals, last_part_min_via_min)
293 final_indices = tl.where(use_cur, out_indices, last_part_min_index_via_min)
294 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask)
295 tl.store(out_indices_ptrs, final_indices, mask=mask)
298def scan_then_fan(inp, out, out_indices, A, B, C, dtype, use_out_indices=False):
299 BLOCK_SIZE = 512
300 if B <= BLOCK_SIZE:
301 BLOCK_SIZE = triton.next_power_of_2(B)
302 part_num = math.ceil(B / BLOCK_SIZE)
303 need_partial = True if part_num >= 2 else False
304 if need_partial:
305 partial_min = torch.empty(A, part_num, C, dtype=dtype, device=inp.device)
306 partial_min_indices = torch.empty(
307 A, part_num, C, dtype=torch.int64, device=inp.device
308 )
309 else:
310 partial_min = None
311 partial_min_indices = None
313 grid = (A, part_num, C)
314 with torch_device_fn.device(inp.device):
315 scan_part_min_abc_kernel[grid](
316 inp,
317 out,
318 out_indices,
319 out_indices,
320 partial_min,
321 partial_min_indices,
322 B,
323 C,
324 part_num,
325 BLOCK_SIZE,
326 need_partial,
327 use_out_indices,
328 )
330 if part_num >= 2:
331 scan_then_fan(
332 partial_min,
333 partial_min,
334 partial_min_indices,
335 A,
336 part_num,
337 C,
338 dtype,
339 use_out_indices=True,
340 )
341 with torch_device_fn.device(inp.device):
342 add_base_min_abc_kernel[grid](
343 out,
344 out_indices,
345 partial_min,
346 partial_min_indices,
347 B,
348 C,
349 part_num,
350 BLOCK_SIZE,
351 )
354@libentry()
355@triton.jit()
356def scan_part_min_abc_loop_kernel(
357 inp,
358 out,
359 out_indices,
360 B,
361 C,
362 loop_num,
363 BLOCK_SIZE: tl.constexpr,
364):
365 pid_a = tle.program_id(0)
366 pid_c = tle.program_id(1)
368 a_idx = pid_a
369 c_idx = pid_c
370 t_idx = tl.arange(0, BLOCK_SIZE)
371 ac_offset = a_idx * B * C + c_idx
373 # init
374 max_value = get_dtype_max(inp.type.element_ty)
375 if tl.constexpr(inp.type.element_ty.is_fp16()) or tl.constexpr(
376 inp.type.element_ty.is_bf16()
377 ):
378 compute_dtype = tl.float32
379 elif tl.constexpr(inp.type.element_ty.is_int8()) or tl.constexpr(
380 inp.type.element_ty.is_int16()
381 ):
382 compute_dtype = tl.int32
383 else:
384 compute_dtype = inp.type.element_ty
386 prev_min_val = tl.full([], max_value, dtype=compute_dtype)
387 prev_min_val_idx = tl.full([], 0, dtype=tl.int64)
388 last_mask = t_idx == (BLOCK_SIZE - 1)
390 for l_idx in tl.range(loop_num):
391 b_idx = l_idx * BLOCK_SIZE + t_idx
392 mask = b_idx < B
393 offset = ac_offset + b_idx * C
395 inp_vals = tl.load(inp + offset, mask=mask, other=max_value)
396 # Only promote if necessary
397 if tl.constexpr(compute_dtype != inp.type.element_ty):
398 vals = inp_vals.to(compute_dtype)
399 else:
400 vals = inp_vals
401 idxs = b_idx
403 # cummin
404 result, cummin_indices = tl_cummin(vals, idxs, axis=0)
406 # broadcast
407 prev_min_val_b = tl.broadcast_to(prev_min_val, (BLOCK_SIZE,))
408 prev_min_val_idx_b = tl.broadcast_to(prev_min_val_idx, (BLOCK_SIZE,))
410 # Handle NaN and tie-breaking logic
411 if tl.constexpr(compute_dtype.is_floating()):
412 # For floats: handle NaN propagation + tie-break right
413 prev_is_nan = prev_min_val != prev_min_val
414 result_is_nan = result != result
415 prev_nan_mask = tl.broadcast_to(prev_is_nan, (BLOCK_SIZE,))
417 use_result = result_is_nan | (~prev_nan_mask & (result <= prev_min_val_b))
418 else:
419 # For integers: simple tie-break right
420 use_result = result <= prev_min_val_b
422 final_vals = tl.where(use_result, result, prev_min_val_b)
423 final_indices = tl.where(use_result, cummin_indices, prev_min_val_idx_b)
425 # update global min val and idx
426 prev_min_val = tl.sum(tl.where(last_mask, final_vals, 0), axis=0)
427 prev_min_val_idx = tl.sum(tl.where(last_mask, final_indices, 0), axis=0)
429 # store result
430 tl.store(out + offset, final_vals.to(out.type.element_ty), mask=mask)
431 tl.store(out_indices + offset, final_indices, mask=mask)
434def scan_then_fan_loop(inp, out, out_indices, A, B, C, dtype):
435 # TODO(all): tune on target board
436 BLOCK_SIZE = 512
437 if B < 1024 * 4:
438 BLOCK_SIZE = triton.next_power_of_2(B)
439 loop_num = math.ceil(B / BLOCK_SIZE)
441 grid = (A, C)
442 with torch_device_fn.device(inp.device):
443 scan_part_min_abc_loop_kernel[grid](
444 inp,
445 out,
446 out_indices,
447 B,
448 C,
449 loop_num,
450 BLOCK_SIZE,
451 )
454def cummin(
455 input: Tensor,
456 dim: int,
457 *,
458 out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None,
459) -> torch.return_types.cummin:
460 logger.debug("GEMS_ASCEND CUMMIN")
461 assert dim >= -input.ndim and dim < input.ndim, "Invalid dim"
462 shape = input.shape
463 dim = dim % input.ndim
464 M = 1
465 N = shape[dim]
466 for i in range(dim):
467 M *= shape[i]
468 input = input.contiguous()
469 K = input.numel() // M // N
471 dtype = input.dtype
472 if dtype is torch.bool:
473 dtype = torch.int64
474 out = torch.empty_like(input, dtype=dtype)
475 out_indices = torch.empty_like(input, dtype=torch.int64)
477 compute_dtype = out.dtype
478 if input.dtype == torch.float16 or input.dtype == torch.bfloat16:
479 compute_dtype = torch.float32
481 if M == 1 and K == 1:
482 scan_then_fan_col(input, out, out_indices, N, compute_dtype)
483 elif M * K <= 16:
484 scan_then_fan(input, out, out_indices, M, N, K, compute_dtype)
485 else:
486 scan_then_fan_loop(input, out, out_indices, M, N, K, compute_dtype)
487 return out, out_indices