Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/cummin.py: 0%
242 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
2import math
3from typing import List, Tuple, Union
5import torch
6import triton
7import triton.language as tl
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as ext
12from flag_gems.utils.limits import get_dtype_max
14Tensor = torch.Tensor
16logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
19@triton.jit
20def tl_cummin(input, index, axis=0):
21 return tl.associative_scan(
22 (input, index), axis, ext.minimum_with_index_tie_break_right
23 )
26@triton.jit
27def tl_min_tie_break_right(input, index, axis=None, keep_dims=False):
28 return tl.reduce(
29 (input, index),
30 axis,
31 ext.minimum_with_index_tie_break_right,
32 keep_dims=keep_dims,
33 )
36@libentry()
37@triton.jit(do_not_specialize=["n_elements"])
38def add_base_min_kernel(
39 out,
40 out_indices,
41 partial_min,
42 partial_min_indices,
43 n_elements,
44 BLOCK_SIZE: tl.constexpr,
45):
46 pid = ext.program_id(0)
47 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
48 mask = offset < n_elements
50 out_ptrs = out + offset
51 out_indices_ptrs = out_indices + offset
52 out_vals = tl.load(out_ptrs, mask=mask)
53 out_indices = tl.load(out_indices_ptrs, mask=mask)
55 if pid > 0:
56 partial_min_ptrs = partial_min + pid - 1
57 last_part_min_via_min = tl.load(partial_min_ptrs)
58 partial_min_indices_ptrs = partial_min_indices + pid - 1
59 last_part_min_index_via_min = tl.load(partial_min_indices_ptrs)
61 final_vals = tl.minimum(out_vals, last_part_min_via_min)
62 final_indices = tl.where(
63 out_vals <= last_part_min_via_min, out_indices, last_part_min_index_via_min
64 )
65 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask)
66 tl.store(out_indices_ptrs, final_indices, mask=mask)
69@libentry()
70@triton.jit(do_not_specialize=["n_elements"])
71def scan_part_min_kernel(
72 inp,
73 out,
74 in_indices,
75 out_indices,
76 partial_min,
77 partial_min_indices,
78 n_elements,
79 BLOCK_SIZE: tl.constexpr,
80 NEED_PARTIAL: tl.constexpr,
81 USE_OUT_INDICES: tl.constexpr,
82):
83 pid = ext.program_id(0)
84 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
85 mask = offset < n_elements
87 max_value = get_dtype_max(inp.type.element_ty)
88 inp_ptrs = inp + offset
89 inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value)
90 if (
91 tl.constexpr(inp_vals.dtype.is_int64())
92 or tl.constexpr(inp_vals.dtype.is_uint64())
93 ) or tl.constexpr(inp_vals.dtype.is_fp64()):
94 inp_vals = inp_vals
95 elif tl.constexpr(inp_vals.dtype.is_int()):
96 inp_vals = inp_vals.to(tl.int32)
97 else:
98 inp_vals = inp_vals.to(tl.float32)
99 if tl.constexpr(USE_OUT_INDICES):
100 in_indices_ptrs = out_indices + offset
101 in_indices_vals = tl.load(in_indices_ptrs, mask=mask)
102 else:
103 in_indices_vals = offset
104 result, cummin_indices = tl_cummin(inp_vals, in_indices_vals, axis=0)
106 if tl.constexpr(NEED_PARTIAL):
107 part_min_via_min, part_min_indices_via_min = tl_min_tie_break_right(
108 inp_vals, in_indices_vals, axis=0
109 )
111 out_ptrs = out + offset
112 tl.store(out_ptrs, result, mask=mask)
114 out_indices_ptrs = out_indices + offset
115 tl.store(out_indices_ptrs, cummin_indices, mask=mask)
117 if tl.constexpr(NEED_PARTIAL):
118 partial_min_ptrs = partial_min + pid
119 tl.store(partial_min_ptrs, part_min_via_min)
121 partial_min_indices_ptrs = partial_min_indices + pid
122 tl.store(partial_min_indices_ptrs, part_min_indices_via_min)
125def scan_then_fan_col(inp, out, out_indices, n_ele, dtype, use_out_indices=False):
126 BLOCK_SIZE = 1024
127 if n_ele <= 1024 * 4:
128 BLOCK_SIZE = triton.next_power_of_2(n_ele)
129 part_num = math.ceil(n_ele / BLOCK_SIZE)
130 need_partial = True if part_num >= 2 else False
131 if need_partial:
132 partial_min = torch.empty(part_num, dtype=dtype, device=inp.device)
133 partial_min_indices = torch.empty(
134 part_num, dtype=torch.int64, device=inp.device
135 )
136 else:
137 partial_min = None
138 partial_min_indices = None
140 grid = (part_num,)
141 with torch_device_fn.device(inp.device):
142 scan_part_min_kernel[grid](
143 inp,
144 out,
145 out_indices,
146 out_indices,
147 partial_min,
148 partial_min_indices,
149 n_ele,
150 BLOCK_SIZE,
151 need_partial,
152 use_out_indices,
153 )
155 if part_num >= 2:
156 scan_then_fan_col(
157 partial_min,
158 partial_min,
159 partial_min_indices,
160 part_num,
161 dtype,
162 use_out_indices=True,
163 )
164 with torch_device_fn.device(inp.device):
165 add_base_min_kernel[grid](
166 out, out_indices, partial_min, partial_min_indices, n_ele, BLOCK_SIZE
167 )
170@libentry()
171@triton.jit(do_not_specialize=["part_num"])
172def scan_part_min_abc_kernel(
173 inp,
174 out,
175 in_indices,
176 out_indices,
177 partial_min,
178 partial_min_indices,
179 B,
180 C,
181 part_num,
182 BLOCK_SIZE: tl.constexpr,
183 NEED_PARTIAL: tl.constexpr,
184 USE_OUT_INDICES: tl.constexpr,
185):
186 pid_a = ext.program_id(0)
187 pid_b = ext.program_id(1)
188 pid_c = ext.program_id(2)
190 a_idx = pid_a
191 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
192 c_idx = pid_c
194 offset = a_idx * B * C + b_idx * C + c_idx
195 base_part_offset = a_idx * part_num * C + c_idx
196 part_offset = base_part_offset + pid_b * C
198 mask = b_idx < B
199 inp_ptrs = inp + offset
200 max_value = get_dtype_max(inp.type.element_ty)
201 inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value)
202 if (
203 tl.constexpr(inp_vals.dtype.is_int64())
204 or tl.constexpr(inp_vals.dtype.is_uint64())
205 ) or tl.constexpr(inp_vals.dtype.is_fp64()):
206 inp_vals = inp_vals
207 elif tl.constexpr(inp_vals.dtype.is_int()):
208 inp_vals = inp_vals.to(tl.int32)
209 else:
210 inp_vals = inp_vals.to(tl.float32)
211 if tl.constexpr(USE_OUT_INDICES):
212 in_indices_ptrs = out_indices + offset
213 in_indices_vals = tl.load(in_indices_ptrs, mask=mask)
214 else:
215 in_indices_vals = b_idx
216 result, cummin_indices = tl_cummin(inp_vals, in_indices_vals, axis=0)
218 if tl.constexpr(NEED_PARTIAL):
219 part_min_via_min, part_min_indices_via_min = tl_min_tie_break_right(
220 inp_vals, in_indices_vals, axis=0
221 )
223 out_ptrs = out + offset
224 tl.store(out_ptrs, result, mask=mask)
226 out_indices_ptrs = out_indices + offset
227 tl.store(out_indices_ptrs, cummin_indices, mask=mask)
229 if tl.constexpr(NEED_PARTIAL):
230 partial_min_ptrs = partial_min + part_offset
231 tl.store(partial_min_ptrs, part_min_via_min)
233 partial_min_indices_ptrs = partial_min_indices + part_offset
234 tl.store(partial_min_indices_ptrs, part_min_indices_via_min)
237@libentry()
238@triton.jit(do_not_specialize=["part_num"])
239def add_base_min_abc_kernel(
240 out,
241 out_indices,
242 partial_min,
243 partial_min_indices,
244 B,
245 C,
246 part_num,
247 BLOCK_SIZE: tl.constexpr,
248):
249 pid_a = ext.program_id(0)
250 pid_b = ext.program_id(1)
251 pid_c = ext.program_id(2)
253 a_idx = pid_a
254 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
255 c_idx = pid_c
257 base_offset = a_idx * B * C + c_idx
258 offset = base_offset + b_idx * C
259 base_part_offset = a_idx * part_num * C + c_idx
260 last_part_offset = base_part_offset + (pid_b - 1) * C
262 mask = b_idx < B
263 out_ptrs = out + offset
264 out_vals = tl.load(out_ptrs, mask=mask)
265 out_indices_ptrs = out_indices + offset
266 out_indices = tl.load(out_indices_ptrs, mask=mask)
268 if pid_b > 0:
269 partial_min_ptrs = partial_min + last_part_offset
270 last_part_min_via_min = tl.load(partial_min_ptrs)
271 partial_min_index_ptrs = partial_min_indices + last_part_offset
272 last_part_min_index_via_min = tl.load(partial_min_index_ptrs)
274 final_vals = tl.minimum(out_vals, last_part_min_via_min)
275 final_indices = tl.where(
276 out_vals <= last_part_min_via_min, out_indices, last_part_min_index_via_min
277 )
278 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask)
279 tl.store(out_indices_ptrs, final_indices, mask=mask)
282def scan_then_fan(inp, out, out_indices, A, B, C, dtype, use_out_indices=False):
283 BLOCK_SIZE = 1024
284 if B <= 1024 * 4:
285 BLOCK_SIZE = triton.next_power_of_2(B)
286 part_num = math.ceil(B / BLOCK_SIZE)
287 need_partial = True if part_num >= 2 else False
288 if need_partial:
289 partial_min = torch.empty(A, part_num, C, dtype=dtype, device=inp.device)
290 partial_min_indices = torch.empty(
291 A, part_num, C, dtype=torch.int64, device=inp.device
292 )
293 else:
294 partial_min = None
295 partial_min_indices = None
297 grid = (A, part_num, C)
298 with torch_device_fn.device(inp.device):
299 scan_part_min_abc_kernel[grid](
300 inp,
301 out,
302 out_indices,
303 out_indices,
304 partial_min,
305 partial_min_indices,
306 B,
307 C,
308 part_num,
309 BLOCK_SIZE,
310 need_partial,
311 use_out_indices,
312 )
314 if part_num >= 2:
315 scan_then_fan(
316 partial_min,
317 partial_min,
318 partial_min_indices,
319 A,
320 part_num,
321 C,
322 dtype,
323 use_out_indices=True,
324 )
325 with torch_device_fn.device(inp.device):
326 add_base_min_abc_kernel[grid](
327 out,
328 out_indices,
329 partial_min,
330 partial_min_indices,
331 B,
332 C,
333 part_num,
334 BLOCK_SIZE,
335 )
338@libentry()
339@triton.jit()
340def scan_part_min_abc_loop_kernel(
341 inp,
342 out,
343 out_indices,
344 B,
345 C,
346 loop_num,
347 BLOCK_SIZE: tl.constexpr,
348):
349 pid_a = ext.program_id(0)
350 pid_c = ext.program_id(1)
352 a_idx = pid_a
353 c_idx = pid_c
354 t_idx = tl.arange(0, BLOCK_SIZE)
355 ac_offset = a_idx * B * C + c_idx
357 max_value = get_dtype_max(inp.type.element_ty)
358 if tl.constexpr(inp.type.element_ty.is_fp16()) or tl.constexpr(
359 inp.type.element_ty.is_bf16()
360 ):
361 compute_dtype = tl.float32
362 elif tl.constexpr(inp.type.element_ty.is_int8()) or tl.constexpr(
363 inp.type.element_ty.is_int16()
364 ):
365 compute_dtype = tl.int32
366 else:
367 compute_dtype = inp.type.element_ty
369 prev_min_val = tl.full([], max_value, dtype=compute_dtype)
370 prev_min_val_idx = tl.full([], 0, dtype=tl.int64)
371 last_mask = t_idx == (BLOCK_SIZE - 1)
373 for l_idx in tl.range(loop_num):
374 b_idx = l_idx * BLOCK_SIZE + t_idx
375 mask = b_idx < B
376 offset = ac_offset + b_idx * C
378 inp_vals = tl.load(inp + offset, mask=mask, other=max_value)
379 if tl.constexpr(compute_dtype != inp.type.element_ty):
380 vals = inp_vals.to(compute_dtype)
381 else:
382 vals = inp_vals
383 idxs = b_idx
385 result, cummin_indices = tl_cummin(vals, idxs, axis=0)
387 prev_min_val_b = tl.broadcast_to(prev_min_val, (BLOCK_SIZE,))
388 prev_min_val_idx_b = tl.broadcast_to(prev_min_val_idx, (BLOCK_SIZE,))
390 if tl.constexpr(compute_dtype.is_floating()):
391 prev_is_nan = prev_min_val != prev_min_val
392 result_is_nan = result != result
393 prev_nan_mask = tl.broadcast_to(prev_is_nan, (BLOCK_SIZE,))
394 use_result = result_is_nan | (~prev_nan_mask & (result <= prev_min_val_b))
395 else:
396 use_result = result <= prev_min_val_b
398 final_vals = tl.where(use_result, result, prev_min_val_b)
399 final_indices = tl.where(use_result, cummin_indices, prev_min_val_idx_b)
401 prev_min_val = tl.sum(tl.where(last_mask, final_vals, 0), axis=0)
402 prev_min_val_idx = tl.sum(tl.where(last_mask, final_indices, 0), axis=0)
404 tl.store(out + offset, final_vals.to(out.type.element_ty), mask=mask)
405 tl.store(out_indices + offset, final_indices, mask=mask)
408def scan_then_fan_loop(inp, out, out_indices, A, B, C, dtype):
409 BLOCK_SIZE = 1024
410 if B < 1024 * 4:
411 BLOCK_SIZE = triton.next_power_of_2(B)
412 loop_num = math.ceil(B / BLOCK_SIZE)
414 grid = (A, C)
415 with torch_device_fn.device(inp.device):
416 scan_part_min_abc_loop_kernel[grid](
417 inp,
418 out,
419 out_indices,
420 B,
421 C,
422 loop_num,
423 BLOCK_SIZE,
424 is_use_mask_zero=True,
425 )
428def cummin(
429 input: Tensor,
430 dim: int,
431 *,
432 out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None,
433) -> torch.return_types.cummin:
434 logger.debug("GEMS_KUNLUNXIN CUMMIN")
435 assert dim >= -input.ndim and dim < input.ndim, "Invalid dim"
436 shape = input.shape
437 dim = dim % input.ndim
438 M = 1
439 N = shape[dim]
440 for i in range(dim):
441 M *= shape[i]
442 input = input.contiguous()
443 K = input.numel() // M // N
445 dtype = input.dtype
446 if dtype is torch.bool:
447 dtype = torch.int64
448 out = torch.empty_like(input, dtype=dtype)
449 out_indices = torch.empty_like(input, dtype=torch.int64)
451 compute_dtype = out.dtype
452 if input.dtype == torch.float16 or input.dtype == torch.bfloat16:
453 compute_dtype = torch.float32
455 if M == 1 and K == 1:
456 scan_then_fan_col(input, out, out_indices, N, compute_dtype)
457 elif M * K <= 16:
458 scan_then_fan(input, out, out_indices, M, N, K, compute_dtype)
459 else:
460 scan_then_fan_loop(input, out, out_indices, M, N, K, compute_dtype)
461 return torch.return_types.cummin((out, out_indices))