Coverage for src/flag_gems/runtime/backend/_ascend/ops/cumsum.py: 0%
258 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
7import triton.runtime.driver as driver
9from flag_gems.runtime import device, torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as ext
13logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}')
16def get_npu_properties():
17 device = torch.npu.current_device()
18 return driver.active.utils.get_device_properties(device)
21device = device.name
24@libentry()
25@triton.jit(do_not_specialize=["n_elements", "part_num"])
26def scan_part_sum_kernel(
27 inp,
28 out,
29 partial_sum,
30 n_elements,
31 part_num,
32 BLOCK_SIZE: tl.constexpr,
33):
34 pid = ext.program_id(0)
35 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
36 mask = offset < n_elements
38 inp_ptrs = inp + offset
39 inp_vals = tl.load(inp_ptrs, mask=mask)
40 if (
41 tl.constexpr(inp_vals.dtype.is_int64())
42 or tl.constexpr(inp_vals.dtype.is_uint64())
43 ) or tl.constexpr(inp_vals.dtype.is_fp64()):
44 inp_vals = inp_vals
45 elif tl.constexpr(inp_vals.dtype.is_int()):
46 inp_vals = inp_vals.to(tl.int32)
47 else:
48 inp_vals = inp_vals.to(tl.float32)
49 result = tl.cumsum(inp_vals, axis=0)
51 part_sum_via_sum = tl.sum(inp_vals)
53 out_ptrs = out + offset
54 tl.store(out_ptrs, result, mask=mask)
56 partial_sum_ptrs = partial_sum + pid
57 tl.store(partial_sum_ptrs, part_sum_via_sum)
60@libentry()
61@triton.jit(do_not_specialize=["n_elements", "part_num"])
62def add_base_sum_kernel(
63 out,
64 partial_sum,
65 n_elements,
66 part_num,
67 BLOCK_SIZE: tl.constexpr,
68):
69 pid = ext.program_id(0)
70 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
71 mask = offset < n_elements
73 out_ptrs = out + offset
74 out_vals = tl.load(out_ptrs, mask=mask)
76 if pid > 0:
77 partial_sum_ptrs = partial_sum + pid - 1
78 last_part_sum_via_sum = tl.load(partial_sum_ptrs)
80 final_vals = out_vals + last_part_sum_via_sum
81 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask)
84@libentry()
85@triton.jit(do_not_specialize=["part_num"])
86def scan_part_sum_abc_kernel(
87 inp,
88 out,
89 partial_sum,
90 B,
91 C,
92 part_num,
93 BLOCK_SIZE: tl.constexpr,
94):
95 pid_a = ext.program_id(0)
96 pid_b = ext.program_id(1)
97 pid_c = ext.program_id(2)
99 a_idx = pid_a
100 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
101 c_idx = pid_c
103 offset = a_idx * B * C + b_idx * C + c_idx
104 base_part_offset = a_idx * part_num * C + c_idx
105 part_offset = base_part_offset + pid_b * C
107 mask = b_idx < B
108 inp_ptrs = inp + offset
109 inp_vals = tl.load(inp_ptrs, mask=mask)
110 if (
111 tl.constexpr(inp_vals.dtype.is_int64())
112 or tl.constexpr(inp_vals.dtype.is_uint64())
113 ) or tl.constexpr(inp_vals.dtype.is_fp64()):
114 inp_vals = inp_vals
115 elif tl.constexpr(inp_vals.dtype.is_int()):
116 inp_vals = inp_vals.to(tl.int32)
117 else:
118 inp_vals = inp_vals.to(tl.float32)
119 result = tl.cumsum(inp_vals, axis=0)
121 part_sum_via_sum = tl.sum(inp_vals)
123 out_ptrs = out + offset
124 tl.store(out_ptrs, result, mask=mask)
126 partial_sum_ptrs = partial_sum + part_offset
127 tl.store(partial_sum_ptrs, part_sum_via_sum)
130@libentry()
131@triton.jit(do_not_specialize=["part_num"])
132def add_base_sum_abc_kernel(
133 out,
134 partial_sum,
135 B,
136 C,
137 part_num,
138 BLOCK_SIZE: tl.constexpr,
139):
140 pid_a = ext.program_id(0)
141 pid_b = ext.program_id(1)
142 pid_c = ext.program_id(2)
144 a_idx = pid_a
145 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
146 c_idx = pid_c
148 base_offset = a_idx * B * C + c_idx
149 offset = base_offset + b_idx * C
150 base_part_offset = a_idx * part_num * C + c_idx
151 last_part_offset = base_part_offset + (pid_b - 1) * C
153 mask = b_idx < B
154 out_ptrs = out + offset
155 out_vals = tl.load(out_ptrs, mask=mask)
157 if pid_b > 0:
158 partial_sum_ptrs = partial_sum + last_part_offset
159 last_part_sum_via_sum = tl.load(partial_sum_ptrs)
161 final_vals = out_vals + last_part_sum_via_sum
162 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask)
165def scan_then_fan_col(inp, out, n_ele, dtype):
166 BLOCK_SIZE = 1024
167 if n_ele <= 1024 * 4:
168 BLOCK_SIZE = triton.next_power_of_2(n_ele)
169 part_num = math.ceil(n_ele / BLOCK_SIZE)
170 partial_sum = torch.empty(part_num, dtype=dtype, device=inp.device)
172 grid = (part_num,)
173 with torch_device_fn.device(inp.device):
174 scan_part_sum_kernel[grid](inp, out, partial_sum, n_ele, part_num, BLOCK_SIZE)
176 if part_num >= 2:
177 scan_then_fan_col(partial_sum, partial_sum, part_num, dtype)
178 with torch_device_fn.device(inp.device):
179 add_base_sum_kernel[grid](out, partial_sum, n_ele, part_num, BLOCK_SIZE)
182def scan_then_fan(inp, out, A, B, C, dtype):
183 BLOCK_SIZE = 1024
184 if B <= 1024 * 4:
185 BLOCK_SIZE = triton.next_power_of_2(B)
186 part_num = math.ceil(B / BLOCK_SIZE)
187 partial_sum = torch.empty(A, part_num, C, dtype=dtype, device=inp.device)
189 grid = (A, part_num, C)
190 with torch_device_fn.device(inp.device):
191 scan_part_sum_abc_kernel[grid](
192 inp, out, partial_sum, B, C, part_num, BLOCK_SIZE
193 )
195 if part_num >= 2:
196 scan_then_fan(partial_sum, partial_sum, A, part_num, C, dtype)
197 with torch_device_fn.device(inp.device):
198 add_base_sum_abc_kernel[grid](out, partial_sum, B, C, part_num, BLOCK_SIZE)
201def cumsum(inp, dim=1, *, dtype=None):
202 logger.debug("GEMS_ASCEND CUMSUM")
203 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
204 shape = inp.shape
205 dim = dim % inp.ndim
206 M = 1
207 N = shape[dim]
208 for i in range(dim):
209 M *= shape[i]
210 inp = inp.contiguous()
211 K = inp.numel() // M // N
213 if dtype is None:
214 dtype = inp.dtype
215 if dtype is torch.bool:
216 dtype = torch.int64
217 if inp.dtype in (torch.int8, torch.int16, torch.int32, torch.uint8):
218 dtype = torch.int64
219 out = torch.empty_like(inp, dtype=dtype)
221 compute_dtype = out.dtype
222 if inp.dtype == torch.float16 or inp.dtype == torch.bfloat16:
223 compute_dtype = torch.float32
225 if M == 1 and K == 1:
226 scan_then_fan_col(inp, out, N, compute_dtype)
227 else:
228 scan_then_fan(inp, out, M, N, K, compute_dtype)
229 return out
232@libentry()
233@triton.jit(do_not_specialize=["K"])
234def normed_cumsum_kernel(inp, out, K, BLOCK: tl.constexpr):
235 row_start = ext.program_id(0) * K
236 row_off = tl.arange(0, BLOCK)
237 x = tl.load(inp + row_start + row_off, mask=row_off < K, other=0)
238 if x.dtype.is_fp16():
239 x = x.to(tl.float32)
240 y_sum = tl.sum(x, 0)
241 y = tl.cumsum(x, 0)
242 y = y / y_sum
243 tl.store(out + row_start + row_off, y, mask=row_off < K)
246@libentry()
247@triton.jit(
248 do_not_specialize=[
249 "r",
250 "t",
251 "R",
252 "K",
253 "r_stride",
254 "out_r_stride",
255 ]
256)
257def block_cumsum_kernel(
258 inp,
259 out,
260 sums,
261 r,
262 t,
263 R,
264 K,
265 r_stride,
266 k_stride,
267 out_r_stride,
268 out_k_stride,
269 OUTPUT_SUMS: tl.constexpr,
270 NORMALIZE: tl.constexpr,
271 HAS_OUT_LAYOUT: tl.constexpr,
272 TILE: tl.constexpr,
273):
274 # One CTA processes a (r, t*tile) chunk
275 # rows = [ grid.y, grid.y + r )
276 # cols = [ grid.x * t * tile, (grid.x + 1) * t * tile )
277 gridx = ext.program_id(0).to(tl.int64)
278 gridy = ext.program_id(1).to(tl.int64)
279 n_chunks = ext.num_programs(0)
281 for row in range(gridy * r, min((gridy + 1) * r, R)):
282 curr_cumsum = tl.zeros((1,), tl.float32)
283 row_offset = row * r_stride
284 cols_base = gridx * t * TILE + tl.arange(0, TILE)
285 for ti in range(0, t):
286 cols = cols_base + ti * TILE
287 cols_offset = cols * k_stride
288 x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0)
289 if x.dtype.is_fp16() | x.dtype.is_bf16():
290 x = x.to(tl.float32)
291 tile_sum = tl.sum(x, 0)[None]
292 tile_cumsum = tl.cumsum(x, 0) + curr_cumsum
293 curr_cumsum += tile_sum
294 if HAS_OUT_LAYOUT:
295 cols_offset = cols * out_k_stride
296 row_offset = row * out_r_stride
297 tl.store(out + row_offset + cols_offset, tile_cumsum, mask=cols < K)
298 if OUTPUT_SUMS:
299 tl.store(sums + row * n_chunks + gridx[None], curr_cumsum)
300 if NORMALIZE:
301 cols_base = gridx * t * TILE + tl.arange(0, TILE)
302 for ti in range(0, t):
303 cols = cols_base + ti * TILE
304 cols_offset = cols * k_stride
305 if HAS_OUT_LAYOUT:
306 cols_offset = cols * out_k_stride
307 row_offset = row * out_r_stride
308 x = tl.load(out + row_offset + cols_offset, mask=cols < K, other=0)
309 if x.dtype.is_fp16() | x.dtype.is_bf16():
310 x = x.to(tl.float32)
311 x = x / curr_cumsum
312 tl.store(out + row_offset + cols_offset, x, mask=cols < K)
315@libentry()
316@triton.jit(
317 do_not_specialize=[
318 "r",
319 "t",
320 "R",
321 "K",
322 "r_stride",
323 "out_r_stride",
324 ]
325)
326def block_update_kernel(
327 inp,
328 base,
329 rscale_ptr,
330 out,
331 r,
332 t,
333 R,
334 K,
335 r_stride,
336 k_stride,
337 out_r_stride,
338 out_k_stride,
339 rscale_stride,
340 HAS_OUT_LAYOUT: tl.constexpr,
341 TILE: tl.constexpr,
342):
343 # One CTA processes a (r, t*tile) chunk
344 # rows = [ grid.y, grid.y + r )
345 # cols = [ grid.x * t * tile, (grid.x + 1) * t * tile )
346 gridx = ext.program_id(0).to(tl.int64)
347 gridy = ext.program_id(1).to(tl.int64)
348 n_gridx = ext.num_programs(1)
350 base += gridy * n_gridx + gridx
351 rscale_ptr += gridy * rscale_stride
353 for row in range(gridy, min(gridy + r, R)):
354 d = tl.load(base)
355 rscale = tl.load(rscale_ptr)
356 base += gridx
357 rscale_ptr += rscale_stride
358 row_offset = row * r_stride
359 cols = gridx * t * TILE + tl.arange(0, TILE)
360 for _ in range(0, t):
361 cols_offset = cols * k_stride
362 x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0)
363 x += d
364 x /= rscale
365 if HAS_OUT_LAYOUT:
366 cols_offset = cols * out_k_stride
367 row_offset = row * out_r_stride
368 tl.store(out + row_offset + cols_offset, x, mask=cols < K)
369 cols += TILE
372GRID_Y_LIMIT = 65535
375def normed_cumsum(inp, dim=-1):
376 logger.debug("GEMS_ASCEND NORMED_CUMSUM")
377 assert inp.dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64)
378 dim = dim % inp.ndim
379 N = inp.numel()
380 K = inp.size(dim)
381 # inp = inp.contiguous()
382 # First and last dims are easier to handle, but transpose the middle dim to the last
383 ranked_dims = sorted(range(inp.ndim), key=lambda i: inp.stride(i), reverse=True)
384 is_mid_dim = dim not in (ranked_dims[0], ranked_dims[-1])
385 if is_mid_dim:
386 inp = inp.transpose(dim, -1).contiguous()
387 dim = -1
388 out = torch.empty_like(inp)
389 with torch_device_fn.device(inp.device.index):
390 # Pass one, scan a (batch, n_tiles * TILE) sized block within each cta
391 num_sms = get_npu_properties()["num_vectorcore"]
392 TILE = 2048
393 # Each row is split into n_chunks of chunks where each chunk is compised of
394 # n_tiles of tiles. Different chunks are assigned to different ctas.
395 n_rows = N // K
396 n_chunks = min(triton.cdiv(num_sms, n_rows), triton.cdiv(K, TILE))
397 n_tiles = triton.cdiv(triton.cdiv(K, TILE), n_chunks)
398 k_stride = inp.stride(dim)
399 r_stride = inp.size(dim) if k_stride == 1 else 1
400 if n_rows > GRID_Y_LIMIT:
401 batch = triton.cdiv(n_rows, GRID_Y_LIMIT)
402 n_batch = triton.cdiv(n_rows, batch)
403 else:
404 batch = 1
405 n_batch = n_rows
407 grid = (n_chunks, n_batch)
408 if n_chunks == 1:
409 block_cumsum_kernel[grid](
410 inp,
411 out,
412 0,
413 batch,
414 n_tiles,
415 n_rows,
416 K,
417 r_stride,
418 k_stride,
419 r_stride,
420 k_stride,
421 OUTPUT_SUMS=False,
422 NORMALIZE=True,
423 HAS_OUT_LAYOUT=False,
424 TILE=TILE,
425 )
426 return out
428 if inp.dtype != torch.float64:
429 acc_dtype = torch.float32
430 sums = torch.empty((n_rows, n_chunks), dtype=acc_dtype, device=device.name)
431 cumsums = torch.empty_like(sums)
432 block_cumsum_kernel[grid](
433 inp,
434 out,
435 sums,
436 batch,
437 n_tiles,
438 n_rows,
439 K,
440 r_stride,
441 k_stride,
442 r_stride,
443 k_stride,
444 OUTPUT_SUMS=True,
445 NORMALIZE=False,
446 HAS_OUT_LAYOUT=False,
447 TILE=TILE,
448 )
449 # Pass two, scan partial cumsums
450 block_cumsum_kernel[(1, n_batch)](
451 sums,
452 cumsums,
453 0,
454 batch,
455 1,
456 n_rows,
457 n_chunks,
458 n_chunks,
459 1,
460 n_chunks,
461 1,
462 OUTPUT_SUMS=False,
463 NORMALIZE=False,
464 HAS_OUT_LAYOUT=True,
465 TILE=TILE,
466 )
467 # logger.debug(sums)
468 rscale = cumsums[..., -1]
469 block_update_kernel[grid](
470 out,
471 cumsums - sums,
472 rscale,
473 out,
474 batch,
475 n_tiles,
476 n_rows,
477 K,
478 r_stride,
479 k_stride,
480 r_stride,
481 k_stride,
482 n_chunks,
483 HAS_OUT_LAYOUT=False,
484 TILE=TILE,
485 )
486 return out