Coverage for src/flag_gems/runtime/backend/_ascend/ops/cummax.py: 0%
260 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
2import math
3from typing import List, Tuple, Union
5import torch
6import triton
7import triton.language as tl
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as tle
12from flag_gems.utils.limits import get_dtype_min
14Tensor = torch.Tensor
16logger = logging.getLogger(__name__)
19@triton.jit
20def tl_cummax(input, index, axis=0):
21 return tl.associative_scan(
22 (input, index), axis, tle.maximum_with_index_tie_break_right
23 )
26@triton.jit
27def tl_max_tie_break_right(input, index, axis=None, keep_dims=False):
28 return tl.reduce(
29 (input, index),
30 axis,
31 tle.maximum_with_index_tie_break_right,
32 keep_dims=keep_dims,
33 )
36@libentry()
37@triton.jit(do_not_specialize=["n_elements"])
38def add_base_max_kernel(
39 out,
40 out_indices,
41 partial_max,
42 partial_max_indices,
43 n_elements,
44 BLOCK_SIZE: tl.constexpr,
45):
46 pid = tle.program_id(0)
47 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
48 mask = offset < n_elements
50 out_ptrs = out + offset
51 out_indices_ptrs = out_indices + offset
52 out_vals = tl.load(out_ptrs, mask=mask)
53 out_indices = tl.load(out_indices_ptrs, mask=mask)
55 if pid > 0:
56 partial_max_ptrs = partial_max + pid - 1
57 last_part_max_via_max = tl.load(partial_max_ptrs)
58 partial_max_indices_ptrs = partial_max_indices + pid - 1
59 last_part_max_index_via_max = tl.load(partial_max_indices_ptrs)
61 # NaN-aware maximum (same semantics as maximum_with_index_tie_break_right)
62 use_cur = out_vals > last_part_max_via_max
63 equal = out_vals == last_part_max_via_max
64 cur_is_nan = out_vals != out_vals
65 prev_is_nan = last_part_max_via_max != last_part_max_via_max
66 use_cur |= cur_is_nan & ~prev_is_nan
67 equal |= cur_is_nan & prev_is_nan
68 use_cur |= equal & (out_indices > last_part_max_index_via_max)
70 final_vals = tl.where(use_cur, out_vals, last_part_max_via_max)
71 final_indices = tl.where(use_cur, out_indices, last_part_max_index_via_max)
72 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask)
73 tl.store(out_indices_ptrs, final_indices, mask=mask)
76@libentry()
77@triton.jit(do_not_specialize=["n_elements"])
78def scan_part_max_kernel(
79 inp,
80 out,
81 in_indices,
82 out_indices,
83 partial_max,
84 partial_max_indices,
85 n_elements,
86 BLOCK_SIZE: tl.constexpr,
87 NEED_PARTIAL: tl.constexpr,
88 USE_OUT_INDICES: tl.constexpr,
89):
90 pid = tle.program_id(0)
91 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
92 mask = offset < n_elements
94 min_value = get_dtype_min(inp.type.element_ty)
95 inp_ptrs = inp + offset
96 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value)
97 if (
98 tl.constexpr(inp_vals.dtype.is_int64())
99 or tl.constexpr(inp_vals.dtype.is_uint64())
100 ) or tl.constexpr(inp_vals.dtype.is_fp64()):
101 inp_vals = inp_vals
102 elif tl.constexpr(inp_vals.dtype.is_int()):
103 inp_vals = inp_vals.to(tl.int32)
104 else:
105 inp_vals = inp_vals.to(tl.float32)
106 if tl.constexpr(USE_OUT_INDICES):
107 in_indices_ptrs = out_indices + offset
108 in_indices_vals = tl.load(in_indices_ptrs, mask=mask)
109 else:
110 in_indices_vals = offset
111 result, cummax_indices = tl_cummax(inp_vals, in_indices_vals, axis=0)
113 if tl.constexpr(NEED_PARTIAL):
114 part_max_via_max, part_max_indices_via_max = tl_max_tie_break_right(
115 inp_vals, in_indices_vals, axis=0
116 )
117 if tl.constexpr(not USE_OUT_INDICES):
118 part_max_indices_via_max = pid * BLOCK_SIZE + part_max_indices_via_max
120 out_ptrs = out + offset
121 tl.store(out_ptrs, result, mask=mask)
123 out_indices_ptrs = out_indices + offset
124 tl.store(out_indices_ptrs, cummax_indices, mask=mask)
126 if tl.constexpr(NEED_PARTIAL):
127 partial_max_ptrs = partial_max + pid
128 tl.store(partial_max_ptrs, part_max_via_max)
130 partial_max_indices_ptrs = partial_max_indices + pid
131 tl.store(partial_max_indices_ptrs, part_max_indices_via_max)
134def scan_then_fan_col(inp, out, out_indices, n_ele, dtype, use_out_indices=False):
135 BLOCK_SIZE = 512
136 if n_ele <= BLOCK_SIZE:
137 BLOCK_SIZE = triton.next_power_of_2(n_ele)
138 part_num = math.ceil(n_ele / BLOCK_SIZE)
139 need_partial = True if part_num >= 2 else False
140 if need_partial:
141 partial_max = torch.empty(part_num, dtype=dtype, device=inp.device)
142 partial_max_indices = torch.empty(
143 part_num, dtype=torch.int64, device=inp.device
144 )
145 else:
146 partial_max = None
147 partial_max_indices = None
149 grid = (part_num,)
150 with torch_device_fn.device(inp.device):
151 scan_part_max_kernel[grid](
152 inp,
153 out,
154 out_indices,
155 out_indices,
156 partial_max,
157 partial_max_indices,
158 n_ele,
159 BLOCK_SIZE,
160 need_partial,
161 use_out_indices,
162 )
164 if part_num >= 2:
165 scan_then_fan_col(
166 partial_max,
167 partial_max,
168 partial_max_indices,
169 part_num,
170 dtype,
171 use_out_indices=True,
172 )
173 with torch_device_fn.device(inp.device):
174 add_base_max_kernel[grid](
175 out, out_indices, partial_max, partial_max_indices, n_ele, BLOCK_SIZE
176 )
179@libentry()
180@triton.jit(do_not_specialize=["part_num"])
181def scan_part_max_abc_kernel(
182 inp,
183 out,
184 in_indices,
185 out_indices,
186 partial_max,
187 partial_max_indices,
188 B,
189 C,
190 part_num,
191 BLOCK_SIZE: tl.constexpr,
192 NEED_PARTIAL: tl.constexpr,
193 USE_OUT_INDICES: tl.constexpr,
194):
195 pid_a = tle.program_id(0)
196 pid_b = tle.program_id(1)
197 pid_c = tle.program_id(2)
199 a_idx = pid_a
200 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
201 c_idx = pid_c
203 offset = a_idx * B * C + b_idx * C + c_idx
204 base_part_offset = a_idx * part_num * C + c_idx
205 part_offset = base_part_offset + pid_b * C
207 mask = b_idx < B
208 inp_ptrs = inp + offset
209 min_value = get_dtype_min(inp.type.element_ty)
210 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value)
211 if (
212 tl.constexpr(inp_vals.dtype.is_int64())
213 or tl.constexpr(inp_vals.dtype.is_uint64())
214 ) or tl.constexpr(inp_vals.dtype.is_fp64()):
215 inp_vals = inp_vals
216 elif tl.constexpr(inp_vals.dtype.is_int()):
217 inp_vals = inp_vals.to(tl.int32)
218 else:
219 inp_vals = inp_vals.to(tl.float32)
220 if tl.constexpr(USE_OUT_INDICES):
221 in_indices_ptrs = out_indices + offset
222 in_indices_vals = tl.load(in_indices_ptrs, mask=mask)
223 else:
224 in_indices_vals = b_idx
225 result, cummax_indices = tl_cummax(inp_vals, in_indices_vals, axis=0)
227 if tl.constexpr(NEED_PARTIAL):
228 part_max_via_max, part_max_indices_via_max = tl_max_tie_break_right(
229 inp_vals, in_indices_vals, axis=0
230 )
231 if tl.constexpr(not USE_OUT_INDICES):
232 part_max_indices_via_max = pid_b * BLOCK_SIZE + part_max_indices_via_max
234 out_ptrs = out + offset
235 tl.store(out_ptrs, result, mask=mask)
237 out_indices_ptrs = out_indices + offset
238 tl.store(out_indices_ptrs, cummax_indices, mask=mask)
240 if tl.constexpr(NEED_PARTIAL):
241 partial_max_ptrs = partial_max + part_offset
242 tl.store(partial_max_ptrs, part_max_via_max)
244 partial_max_indices_ptrs = partial_max_indices + part_offset
245 tl.store(partial_max_indices_ptrs, part_max_indices_via_max)
248@libentry()
249@triton.jit(do_not_specialize=["part_num"])
250def add_base_max_abc_kernel(
251 out,
252 out_indices,
253 partial_max,
254 partial_max_indices,
255 B,
256 C,
257 part_num,
258 BLOCK_SIZE: tl.constexpr,
259):
260 pid_a = tle.program_id(0)
261 pid_b = tle.program_id(1)
262 pid_c = tle.program_id(2)
264 a_idx = pid_a
265 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
266 c_idx = pid_c
268 base_offset = a_idx * B * C + c_idx
269 offset = base_offset + b_idx * C
270 base_part_offset = a_idx * part_num * C + c_idx
271 last_part_offset = base_part_offset + (pid_b - 1) * C
273 mask = b_idx < B
274 out_ptrs = out + offset
275 out_vals = tl.load(out_ptrs, mask=mask)
276 out_indices_ptrs = out_indices + offset
277 out_indices = tl.load(out_indices_ptrs, mask=mask)
279 if pid_b > 0:
280 partial_max_ptrs = partial_max + last_part_offset
281 last_part_max_via_max = tl.load(partial_max_ptrs)
282 partial_max_index_ptrs = partial_max_indices + last_part_offset
283 last_part_max_index_via_max = tl.load(partial_max_index_ptrs)
285 use_cur = out_vals > last_part_max_via_max
286 equal = out_vals == last_part_max_via_max
287 cur_is_nan = out_vals != out_vals
288 prev_is_nan = last_part_max_via_max != last_part_max_via_max
289 use_cur |= cur_is_nan & ~prev_is_nan
290 equal |= cur_is_nan & prev_is_nan
291 use_cur |= equal & (out_indices > last_part_max_index_via_max)
293 final_vals = tl.where(use_cur, out_vals, last_part_max_via_max)
294 final_indices = tl.where(use_cur, out_indices, last_part_max_index_via_max)
295 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask)
296 tl.store(out_indices_ptrs, final_indices, mask=mask)
299def scan_then_fan(inp, out, out_indices, A, B, C, dtype, use_out_indices=False):
300 BLOCK_SIZE = 512
301 if B <= BLOCK_SIZE:
302 BLOCK_SIZE = triton.next_power_of_2(B)
303 part_num = math.ceil(B / BLOCK_SIZE)
304 need_partial = True if part_num >= 2 else False
305 if need_partial:
306 partial_max = torch.empty(A, part_num, C, dtype=dtype, device=inp.device)
307 partial_max_indices = torch.empty(
308 A, part_num, C, dtype=torch.int64, device=inp.device
309 )
310 else:
311 partial_max = None
312 partial_max_indices = None
314 grid = (A, part_num, C)
315 with torch_device_fn.device(inp.device):
316 scan_part_max_abc_kernel[grid](
317 inp,
318 out,
319 out_indices,
320 out_indices,
321 partial_max,
322 partial_max_indices,
323 B,
324 C,
325 part_num,
326 BLOCK_SIZE,
327 need_partial,
328 use_out_indices,
329 )
331 if part_num >= 2:
332 scan_then_fan(
333 partial_max,
334 partial_max,
335 partial_max_indices,
336 A,
337 part_num,
338 C,
339 dtype,
340 use_out_indices=True,
341 )
342 with torch_device_fn.device(inp.device):
343 add_base_max_abc_kernel[grid](
344 out,
345 out_indices,
346 partial_max,
347 partial_max_indices,
348 B,
349 C,
350 part_num,
351 BLOCK_SIZE,
352 )
355@libentry()
356@triton.jit()
357def scan_part_max_abc_loop_kernel(
358 inp,
359 out,
360 out_indices,
361 B,
362 C,
363 loop_num,
364 BLOCK_SIZE: tl.constexpr,
365):
366 pid_a = tle.program_id(0)
367 pid_c = tle.program_id(1)
369 a_idx = pid_a
370 c_idx = pid_c
371 t_idx = tl.arange(0, BLOCK_SIZE)
372 ac_offset = a_idx * B * C + c_idx
374 # init, promote low precision types
375 min_value = get_dtype_min(inp.type.element_ty)
376 if tl.constexpr(inp.type.element_ty.is_fp16()) or tl.constexpr(
377 inp.type.element_ty.is_bf16()
378 ):
379 compute_dtype = tl.float32
380 elif tl.constexpr(inp.type.element_ty.is_int8()) or tl.constexpr(
381 inp.type.element_ty.is_int16()
382 ):
383 compute_dtype = tl.int32
384 else:
385 compute_dtype = inp.type.element_ty
387 prev_max_val = tl.full([], min_value, dtype=compute_dtype)
388 prev_max_val_idx = tl.full([], 0, dtype=tl.int64)
389 last_mask = t_idx == (BLOCK_SIZE - 1)
391 for l_idx in tl.range(loop_num):
392 b_idx = l_idx * BLOCK_SIZE + t_idx
393 mask = b_idx < B
394 offset = ac_offset + b_idx * C
396 inp_vals = tl.load(inp + offset, mask=mask, other=min_value)
397 # Only promote if necessary
398 if tl.constexpr(compute_dtype != inp.type.element_ty):
399 vals = inp_vals.to(compute_dtype)
400 else:
401 vals = inp_vals
402 idxs = b_idx
404 # cummax
405 result, cummax_indices = tl_cummax(vals, idxs, axis=0)
407 # broadcast
408 prev_max_val_b = tl.broadcast_to(prev_max_val, (BLOCK_SIZE,))
409 prev_max_val_idx_b = tl.broadcast_to(prev_max_val_idx, (BLOCK_SIZE,))
411 # Handle NaN and tie-breaking logic
412 if tl.constexpr(compute_dtype.is_floating()):
413 # For floats: handle NaN propagation + tie-break right
414 prev_is_nan = prev_max_val != prev_max_val
415 result_is_nan = result != result
416 prev_nan_mask = tl.broadcast_to(prev_is_nan, (BLOCK_SIZE,))
418 use_result = result_is_nan | (~prev_nan_mask & (result >= prev_max_val_b))
419 else:
420 # For integers: simple tie-break right
421 use_result = result >= prev_max_val_b
423 final_vals = tl.where(use_result, result, prev_max_val_b)
424 final_indices = tl.where(use_result, cummax_indices, prev_max_val_idx_b)
426 # update global max val and idx
427 prev_max_val = tl.sum(tl.where(last_mask, final_vals, 0), axis=0)
428 prev_max_val_idx = tl.sum(tl.where(last_mask, final_indices, 0), axis=0)
430 # store result
431 tl.store(out + offset, final_vals.to(out.type.element_ty), mask=mask)
432 tl.store(out_indices + offset, final_indices, mask=mask)
435def scan_then_fan_loop(inp, out, out_indices, A, B, C, dtype):
436 BLOCK_SIZE = 512
437 if B < BLOCK_SIZE:
438 BLOCK_SIZE = triton.next_power_of_2(B)
439 loop_num = math.ceil(B / BLOCK_SIZE)
441 grid = (A, C)
442 with torch_device_fn.device(inp.device):
443 scan_part_max_abc_loop_kernel[grid](
444 inp,
445 out,
446 out_indices,
447 B,
448 C,
449 loop_num,
450 BLOCK_SIZE,
451 )
454def cummax(
455 input: Tensor,
456 dim: int,
457 *,
458 out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None,
459) -> torch.return_types.cummax:
460 logger.debug("GEMS_ASCEND CUMMAX")
461 assert dim >= -input.ndim and dim < input.ndim, "Invalid dim"
462 shape = input.shape
463 dim = dim % input.ndim
464 M = 1
465 N = shape[dim]
466 for i in range(dim):
467 M *= shape[i]
468 input = input.contiguous()
469 K = input.numel() // M // N
471 dtype = input.dtype
472 if dtype is torch.bool:
473 dtype = torch.int64
474 out = torch.empty_like(input, dtype=dtype)
475 out_indices = torch.empty_like(input, dtype=torch.int64)
477 compute_dtype = out.dtype
478 if input.dtype == torch.float16 or input.dtype == torch.bfloat16:
479 compute_dtype = torch.float32
481 if M == 1 and K == 1:
482 scan_then_fan_col(input, out, out_indices, N, compute_dtype)
483 elif M * K <= 16:
484 scan_then_fan(input, out, out_indices, M, N, K, compute_dtype)
485 else:
486 scan_then_fan_loop(input, out, out_indices, M, N, K, compute_dtype)
487 return out, out_indices