Coverage for src/flag_gems/runtime/backend/_arm/ops/cumsum.py: 0%
242 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import device, torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as tle
12device = device.name
15# @libentry()
16@triton.jit(do_not_specialize=["n_elements", "part_num"])
17def scan_part_sum_kernel(
18 inp,
19 out,
20 partial_sum,
21 n_elements,
22 part_num,
23 BLOCK_SIZE: tl.constexpr,
24):
25 pid = tle.program_id(0)
26 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
27 mask = offset < n_elements
29 inp_ptrs = inp + offset
30 inp_vals = tl.load(inp_ptrs, mask=mask)
31 if (
32 tl.constexpr(inp_vals.dtype.is_int64())
33 or tl.constexpr(inp_vals.dtype.is_uint64())
34 ) or tl.constexpr(inp_vals.dtype.is_fp64()):
35 inp_vals = inp_vals
36 elif tl.constexpr(inp_vals.dtype.is_int()):
37 inp_vals = inp_vals.to(tl.int32)
38 else:
39 inp_vals = inp_vals.to(tl.float32)
40 result = tl.cumsum(inp_vals, axis=0)
42 part_sum_via_sum = tl.sum(inp_vals)
44 out_ptrs = out + offset
45 tl.store(out_ptrs, result, mask=mask)
47 partial_sum_ptrs = partial_sum + pid
48 tl.store(partial_sum_ptrs, part_sum_via_sum)
51# @libentry()
52@triton.jit(do_not_specialize=["n_elements", "part_num"])
53def add_base_sum_kernel(
54 out,
55 partial_sum,
56 n_elements,
57 part_num,
58 BLOCK_SIZE: tl.constexpr,
59):
60 pid = tle.program_id(0)
61 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
62 mask = offset < n_elements
64 out_ptrs = out + offset
65 out_vals = tl.load(out_ptrs, mask=mask)
67 if pid > 0:
68 partial_sum_ptrs = partial_sum + pid - 1
69 last_part_sum_via_sum = tl.load(partial_sum_ptrs)
71 final_vals = out_vals + last_part_sum_via_sum
72 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask)
75# @libentry()
76@triton.jit(do_not_specialize=["part_num"])
77def scan_part_sum_abc_kernel(
78 inp,
79 out,
80 partial_sum,
81 B,
82 C,
83 part_num,
84 BLOCK_SIZE: tl.constexpr,
85):
86 pid_a = tle.program_id(0)
87 pid_b = tle.program_id(1)
88 pid_c = tle.program_id(2)
90 a_idx = pid_a
91 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
92 c_idx = pid_c
94 offset = a_idx * B * C + b_idx * C + c_idx
95 base_part_offset = a_idx * part_num * C + c_idx
96 part_offset = base_part_offset + pid_b * C
98 mask = b_idx < B
99 inp_ptrs = inp + offset
100 inp_vals = tl.load(inp_ptrs, mask=mask)
101 if (
102 tl.constexpr(inp_vals.dtype.is_int64())
103 or tl.constexpr(inp_vals.dtype.is_uint64())
104 ) or tl.constexpr(inp_vals.dtype.is_fp64()):
105 inp_vals = inp_vals
106 elif tl.constexpr(inp_vals.dtype.is_int()):
107 inp_vals = inp_vals.to(tl.int32)
108 else:
109 inp_vals = inp_vals.to(tl.float32)
110 result = tl.cumsum(inp_vals, axis=0)
112 part_sum_via_sum = tl.sum(inp_vals)
114 out_ptrs = out + offset
115 tl.store(out_ptrs, result, mask=mask)
117 partial_sum_ptrs = partial_sum + part_offset
118 tl.store(partial_sum_ptrs, part_sum_via_sum)
121# @libentry()
122@triton.jit(do_not_specialize=["part_num"])
123def add_base_sum_abc_kernel(
124 out,
125 partial_sum,
126 B,
127 C,
128 part_num,
129 BLOCK_SIZE: tl.constexpr,
130):
131 pid_a = tle.program_id(0)
132 pid_b = tle.program_id(1)
133 pid_c = tle.program_id(2)
135 a_idx = pid_a
136 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
137 c_idx = pid_c
139 base_offset = a_idx * B * C + c_idx
140 offset = base_offset + b_idx * C
141 base_part_offset = a_idx * part_num * C + c_idx
142 last_part_offset = base_part_offset + (pid_b - 1) * C
144 mask = b_idx < B
145 out_ptrs = out + offset
146 out_vals = tl.load(out_ptrs, mask=mask)
148 if pid_b > 0:
149 partial_sum_ptrs = partial_sum + last_part_offset
150 last_part_sum_via_sum = tl.load(partial_sum_ptrs)
152 final_vals = out_vals + last_part_sum_via_sum
153 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask)
156def scan_then_fan_col(inp, out, n_ele, dtype):
157 # TODO(all): tune on target board
158 BLOCK_SIZE = 64
159 # if n_ele <= 1024 * 4:
160 # BLOCK_SIZE = triton.next_power_of_2(n_ele)
161 part_num = math.ceil(n_ele / BLOCK_SIZE)
162 partial_sum = torch.empty(part_num, dtype=dtype, device=inp.device)
164 grid = (part_num,)
165 # with torch_device_fn.device(inp.device):
166 scan_part_sum_kernel[grid](inp, out, partial_sum, n_ele, part_num, BLOCK_SIZE)
168 if part_num >= 2:
169 scan_then_fan_col(partial_sum, partial_sum, part_num, dtype)
170 # with torch_device_fn.device(inp.device):
171 add_base_sum_kernel[grid](out, partial_sum, n_ele, part_num, BLOCK_SIZE)
174def scan_then_fan(inp, out, A, B, C, dtype):
175 # TODO(all): tune on target board
176 BLOCK_SIZE = 64
177 # if B <= 1024 * 4:
178 # BLOCK_SIZE = triton.next_power_of_2(B)
179 part_num = math.ceil(B / BLOCK_SIZE)
180 partial_sum = torch.empty(A, part_num, C, dtype=dtype, device=inp.device)
182 grid = (A, part_num, C)
183 # with torch_device_fn.device(inp.device):
184 scan_part_sum_abc_kernel[grid](inp, out, partial_sum, B, C, part_num, BLOCK_SIZE)
186 if part_num >= 2:
187 scan_then_fan(partial_sum, partial_sum, A, part_num, C, dtype)
188 # with torch_device_fn.device(inp.device):
189 add_base_sum_abc_kernel[grid](out, partial_sum, B, C, part_num, BLOCK_SIZE)
192def cumsum(inp, dim=1, *, dtype=None):
193 logging.debug("GEMS CUMSUM")
194 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
195 shape = inp.shape
196 dim = dim % inp.ndim
197 M = 1
198 N = shape[dim]
199 for i in range(dim):
200 M *= shape[i]
201 inp = inp.contiguous()
202 K = inp.numel() // M // N
204 if dtype is None:
205 dtype = inp.dtype
206 if dtype in (
207 torch.bool,
208 torch.int8,
209 torch.uint8,
210 torch.int16,
211 torch.int32,
212 torch.int64,
213 ):
214 dtype = torch.int64
215 out = torch.empty_like(inp, dtype=dtype)
217 compute_dtype = out.dtype
218 if inp.dtype == torch.float16 or inp.dtype == torch.bfloat16:
219 compute_dtype = torch.float32
221 if M == 1 and K == 1:
222 scan_then_fan_col(inp, out, N, compute_dtype)
223 else:
224 scan_then_fan(inp, out, M, N, K, compute_dtype)
225 return out
228@libentry()
229@triton.jit(do_not_specialize=["K"])
230def normed_cumsum_kernel(inp, out, K, BLOCK: tl.constexpr):
231 row_start = tle.program_id(0) * K
232 row_off = tl.arange(0, BLOCK)
233 x = tl.load(inp + row_start + row_off, mask=row_off < K, other=0)
234 if x.dtype.is_fp16():
235 x = x.to(tl.float32)
236 y_sum = tl.sum(x, 0)
237 y = tl.cumsum(x, 0)
238 y = y / y_sum
239 tl.store(out + row_start + row_off, y, mask=row_off < K)
242@libentry()
243@triton.jit(
244 do_not_specialize=[
245 "r",
246 "t",
247 "R",
248 "K",
249 "r_stride",
250 "out_r_stride",
251 ]
252)
253def block_cumsum_kernel(
254 inp,
255 out,
256 sums,
257 r,
258 t,
259 R,
260 K,
261 r_stride,
262 k_stride,
263 out_r_stride,
264 out_k_stride,
265 OUTPUT_SUMS: tl.constexpr,
266 NORMALIZE: tl.constexpr,
267 HAS_OUT_LAYOUT: tl.constexpr,
268 TILE: tl.constexpr,
269):
270 # One CTA processes a (r, t*tile) chunk
271 # rows = [ grid.y, grid.y + r )
272 # cols = [ grid.x * t * tile, (grid.x + 1) * t * tile )
273 gridx = tle.program_id(0).to(tl.int64)
274 gridy = tle.program_id(1).to(tl.int64)
275 n_chunks = tle.num_programs(0)
277 for row in range(gridy * r, min((gridy + 1) * r, R)):
278 curr_cumsum = tl.zeros((1,), tl.float32)
279 row_offset = row * r_stride
280 cols = gridx * t * TILE + tl.arange(0, TILE)
281 for ti in range(0, t):
282 cols_offset = cols * k_stride
283 x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0)
284 if x.dtype.is_fp16() | x.dtype.is_bf16():
285 x = x.to(tl.float32)
286 tile_sum = tl.sum(x, 0)[None]
287 tile_cumsum = tl.cumsum(x, 0) + curr_cumsum
288 curr_cumsum += tile_sum
289 if HAS_OUT_LAYOUT:
290 cols_offset = cols * out_k_stride
291 row_offset = row * out_r_stride
292 tl.store(out + row_offset + cols_offset, tile_cumsum, mask=cols < K)
293 if OUTPUT_SUMS:
294 tl.store(sums + row * n_chunks + gridx[None], curr_cumsum)
295 cols += TILE
296 if NORMALIZE:
297 cols = gridx * t * TILE + tl.arange(0, TILE)
298 for _ in range(0, t):
299 cols_offset = cols * k_stride
300 if HAS_OUT_LAYOUT:
301 cols_offset = cols * out_k_stride
302 row_offset = row * out_r_stride
303 x = tl.load(out + row_offset + cols_offset, mask=cols < K, other=0)
304 if x.dtype.is_fp16() | x.dtype.is_bf16():
305 x = x.to(tl.float32)
306 x = x / curr_cumsum
307 tl.store(out + row_offset + cols_offset, x, mask=cols < K)
308 cols += TILE
311@libentry()
312@triton.jit(
313 do_not_specialize=[
314 "r",
315 "t",
316 "R",
317 "K",
318 "r_stride",
319 "out_r_stride",
320 ]
321)
322def block_update_kernel(
323 inp,
324 base,
325 rscale_ptr,
326 out,
327 r,
328 t,
329 R,
330 K,
331 r_stride,
332 k_stride,
333 out_r_stride,
334 out_k_stride,
335 rscale_stride,
336 HAS_OUT_LAYOUT: tl.constexpr,
337 TILE: tl.constexpr,
338):
339 # One CTA processes a (r, t*tile) chunk
340 # rows = [ grid.y, grid.y + r )
341 # cols = [ grid.x * t * tile, (grid.x + 1) * t * tile )
342 gridx = tle.program_id(0).to(tl.int64)
343 gridy = tle.program_id(1).to(tl.int64)
344 n_gridx = tle.num_programs(1)
346 base += gridy * n_gridx + gridx
347 rscale_ptr += gridy * rscale_stride
349 for row in range(gridy, min(gridy + r, R)):
350 d = tl.load(base)
351 rscale = tl.load(rscale_ptr)
352 base += gridx
353 rscale_ptr += rscale_stride
354 row_offset = row * r_stride
355 cols = gridx * t * TILE + tl.arange(0, TILE)
356 for _ in range(0, t):
357 cols_offset = cols * k_stride
358 x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0)
359 x += d
360 x /= rscale
361 if HAS_OUT_LAYOUT:
362 cols_offset = cols * out_k_stride
363 row_offset = row * out_r_stride
364 tl.store(out + row_offset + cols_offset, x, mask=cols < K)
365 cols += TILE
368GRID_Y_LIMIT = 65535
371def normed_cumsum(inp, dim=-1):
372 logging.debug("GEMS NORMED_CUMSUM")
373 assert inp.dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64)
374 dim = dim % inp.ndim
375 N = inp.numel()
376 K = inp.size(dim)
377 # inp = inp.contiguous()
378 # First and last dims are easier to handle, but transpose the middle dim to the last
379 ranked_dims = sorted(range(inp.ndim), key=lambda i: inp.stride(i), reverse=True)
380 is_mid_dim = dim not in (ranked_dims[0], ranked_dims[-1])
381 if is_mid_dim:
382 inp = inp.transpose(dim, -1).contiguous()
383 dim = -1
384 out = torch.empty_like(inp)
385 # with torch_device_fn.device(inp.device.index):
386 # Pass one, scan a (batch, n_tiles * TILE) sized block within each cta
387 device_props = torch_device_fn.get_device_properties(device)
388 if isinstance(device_props, dict):
389 num_sms = int(device_props.get("multi_processor_count", 1))
390 else:
391 num_sms = device_props.multi_processor_count
392 TILE = 2048
393 # Each row is split into n_chunks of chunks where each chunk is compised of
394 # n_tiles of tiles. Different chunks are assigned to different ctas.
395 n_rows = N // K
396 n_chunks = min(triton.cdiv(num_sms, n_rows), triton.cdiv(K, TILE))
397 n_tiles = triton.cdiv(triton.cdiv(K, TILE), n_chunks)
398 k_stride = inp.stride(dim)
399 r_stride = inp.size(dim) if k_stride == 1 else 1
400 if n_rows > GRID_Y_LIMIT:
401 batch = triton.cdiv(n_rows, GRID_Y_LIMIT)
402 n_batch = triton.cdiv(n_rows, batch)
403 else:
404 batch = 1
405 n_batch = n_rows
407 grid = (n_chunks, n_batch)
408 if n_chunks == 1:
409 block_cumsum_kernel[grid](
410 inp,
411 out,
412 0,
413 batch,
414 n_tiles,
415 n_rows,
416 K,
417 r_stride,
418 k_stride,
419 r_stride,
420 k_stride,
421 OUTPUT_SUMS=False,
422 NORMALIZE=True,
423 HAS_OUT_LAYOUT=False,
424 TILE=TILE,
425 )
426 return out
428 if inp.dtype != torch.float64:
429 acc_dtype = torch.float32
430 else:
431 acc_dtype = torch.float64
432 sums = torch.empty((n_rows, n_chunks), dtype=acc_dtype, device=inp.device)
433 cumsums = torch.empty_like(sums)
434 block_cumsum_kernel[grid](
435 inp,
436 out,
437 sums,
438 batch,
439 n_tiles,
440 n_rows,
441 K,
442 r_stride,
443 k_stride,
444 r_stride,
445 k_stride,
446 OUTPUT_SUMS=True,
447 NORMALIZE=False,
448 HAS_OUT_LAYOUT=False,
449 TILE=TILE,
450 )
451 # Pass two, scan partial cumsums
452 block_cumsum_kernel[(1, n_batch)](
453 sums,
454 cumsums,
455 0,
456 batch,
457 1,
458 n_rows,
459 n_chunks,
460 n_chunks,
461 1,
462 n_chunks,
463 1,
464 OUTPUT_SUMS=False,
465 NORMALIZE=False,
466 HAS_OUT_LAYOUT=True,
467 TILE=TILE,
468 )
469 rscale = cumsums[..., -1]
470 block_update_kernel[grid](
471 out,
472 cumsums - sums,
473 rscale,
474 out,
475 batch,
476 n_tiles,
477 n_rows,
478 K,
479 r_stride,
480 k_stride,
481 r_stride,
482 k_stride,
483 n_chunks,
484 HAS_OUT_LAYOUT=False,
485 TILE=TILE,
486 )
487 return out