Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/cumprod.py: 0%
240 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import functools
2import logging
3import math
5import torch
6import triton
7import triton.language as tl
8from torch._prims_common import is_boolean_dtype, is_integer_dtype
10from flag_gems.runtime import device as runtime_device
11from flag_gems.runtime import torch_device_fn
12from flag_gems.utils import get_device_properties, libentry
13from flag_gems.utils import triton_lang_extension as ext
15logger = logging.getLogger(__name__)
17_FALLBACK_KEYSET = torch._C.DispatchKeySet(
18 torch._C.DispatchKey.CompositeExplicitAutograd
19)
20DEFAULT_BLOCK_SIZE = 1024
21CUDA_SMALL_SCAN_LIMIT = 1024 * 4
22ASCEND_SCAN_LIMIT = 1024
23DEFAULT_NUM_SMS = 40
26@functools.lru_cache
27def get_num_sms(idx: int) -> int:
28 return get_device_properties(idx).multi_processor_count or DEFAULT_NUM_SMS
31def _get_device_index(torch_device):
32 if torch_device.index is not None:
33 return torch_device.index
34 return torch_device_fn.current_device()
37@tl.constexpr
38def get_prod_accum_type(out_dtype: tl.dtype) -> tl.dtype:
39 if out_dtype.is_bf16() or out_dtype.is_fp16():
40 return tl.float32
41 if out_dtype.is_int():
42 return tl.int64
43 return out_dtype
46@triton.jit
47def reduce_mul(a, b):
48 return a * b
51@libentry()
52@triton.jit(do_not_specialize=["n_elements", "part_num"])
53def scan_part_product_kernel(
54 inp,
55 out,
56 partial_product,
57 n_elements,
58 part_num,
59 BLOCK_SIZE: tl.constexpr,
60):
61 pid = ext.program_id(0)
62 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
63 mask = offset < n_elements
65 acc_dtype: tl.constexpr = get_prod_accum_type(out.type.element_ty)
66 inp_vals = tl.load(inp + offset, mask=mask, other=1).to(acc_dtype)
67 result = tl.cumprod(inp_vals, axis=0)
68 part_product = tl.reduce(inp_vals, axis=0, combine_fn=reduce_mul)
70 tl.store(out + offset, result, mask=mask)
71 tl.store(partial_product + pid, part_product)
74@libentry()
75@triton.jit(do_not_specialize=["n_elements", "part_num"])
76def multiply_base_product_kernel(
77 out,
78 partial_product,
79 n_elements,
80 part_num,
81 BLOCK_SIZE: tl.constexpr,
82):
83 pid = ext.program_id(0)
84 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
85 mask = offset < n_elements
87 out_vals = tl.load(out + offset, mask=mask)
89 if pid > 0:
90 acc_dtype: tl.constexpr = get_prod_accum_type(out.type.element_ty)
91 base_product = tl.load(partial_product + pid - 1).to(acc_dtype)
92 final_vals = out_vals.to(acc_dtype) * base_product
93 tl.store(out + offset, final_vals, mask=mask)
96@libentry()
97@triton.jit(do_not_specialize=["part_num"])
98def scan_part_product_abc_kernel(
99 inp,
100 out,
101 partial_product,
102 B,
103 C,
104 part_num,
105 BLOCK_SIZE: tl.constexpr,
106):
107 pid_a = ext.program_id(0)
108 pid_b = ext.program_id(1)
109 pid_c = ext.program_id(2)
111 a_idx = pid_a
112 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
113 c_idx = pid_c
115 offset = a_idx * B * C + b_idx * C + c_idx
116 base_part_offset = a_idx * part_num * C + c_idx
117 part_offset = base_part_offset + pid_b * C
118 mask = b_idx < B
120 acc_dtype: tl.constexpr = get_prod_accum_type(out.type.element_ty)
121 inp_vals = tl.load(inp + offset, mask=mask, other=1).to(acc_dtype)
122 result = tl.cumprod(inp_vals, axis=0)
123 part_product = tl.reduce(inp_vals, axis=0, combine_fn=reduce_mul)
125 tl.store(out + offset, result, mask=mask)
126 tl.store(partial_product + part_offset, part_product)
129@libentry()
130@triton.jit(do_not_specialize=["part_num"])
131def multiply_base_product_abc_kernel(
132 out,
133 partial_product,
134 B,
135 C,
136 part_num,
137 BLOCK_SIZE: tl.constexpr,
138):
139 pid_a = ext.program_id(0)
140 pid_b = ext.program_id(1)
141 pid_c = ext.program_id(2)
143 a_idx = pid_a
144 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
145 c_idx = pid_c
147 offset = a_idx * B * C + b_idx * C + c_idx
148 base_part_offset = a_idx * part_num * C + c_idx
149 last_part_offset = base_part_offset + (pid_b - 1) * C
150 mask = b_idx < B
152 out_vals = tl.load(out + offset, mask=mask)
154 if pid_b > 0:
155 acc_dtype: tl.constexpr = get_prod_accum_type(out.type.element_ty)
156 base_product = tl.load(partial_product + last_part_offset).to(acc_dtype)
157 final_vals = out_vals.to(acc_dtype) * base_product
158 tl.store(out + offset, final_vals, mask=mask)
161def scan_then_fan_col(inp, out, n_ele, dtype):
162 BLOCK_SIZE = _scan_block_size(n_ele)
163 part_num = math.ceil(n_ele / BLOCK_SIZE)
164 partial_product = torch.empty(part_num, dtype=dtype, device=inp.device)
166 grid = (part_num,)
167 with torch_device_fn.device(inp.device):
168 scan_part_product_kernel[grid](
169 inp, out, partial_product, n_ele, part_num, BLOCK_SIZE
170 )
172 if part_num >= 2:
173 partial_prefix = torch.empty_like(partial_product)
174 scan_then_fan_col(partial_product, partial_prefix, part_num, dtype)
175 with torch_device_fn.device(inp.device):
176 multiply_base_product_kernel[grid](
177 out, partial_prefix, n_ele, part_num, BLOCK_SIZE
178 )
181def scan_then_fan(inp, out, A, B, C, dtype):
182 BLOCK_SIZE = _scan_block_size(B)
183 part_num = math.ceil(B / BLOCK_SIZE)
184 partial_product = torch.empty(A, part_num, C, dtype=dtype, device=inp.device)
186 grid = (A, part_num, C)
187 with torch_device_fn.device(inp.device):
188 scan_part_product_abc_kernel[grid](
189 inp, out, partial_product, B, C, part_num, BLOCK_SIZE
190 )
192 if part_num >= 2:
193 partial_prefix = torch.empty_like(partial_product)
194 scan_then_fan(partial_product, partial_prefix, A, part_num, C, dtype)
195 with torch_device_fn.device(inp.device):
196 multiply_base_product_abc_kernel[grid](
197 out, partial_prefix, B, C, part_num, BLOCK_SIZE
198 )
201def _get_output_dtype(inp, dtype):
202 if dtype is not None:
203 return dtype
204 if is_integer_dtype(inp.dtype) or is_boolean_dtype(inp.dtype):
205 return torch.int64
206 return inp.dtype
209def _get_compute_dtype(dtype):
210 if dtype in (torch.float16, torch.bfloat16):
211 return torch.float32
212 if is_integer_dtype(dtype) or is_boolean_dtype(dtype):
213 return torch.int64
214 return dtype
217def _should_redispatch_on_ascend(dtype):
218 return runtime_device.vendor_name == "ascend" and (
219 is_integer_dtype(dtype) or is_boolean_dtype(dtype)
220 )
223def _scan_block_size(length):
224 limit = (
225 ASCEND_SCAN_LIMIT
226 if runtime_device.vendor_name == "ascend"
227 else CUDA_SMALL_SCAN_LIMIT
228 )
229 if length <= limit:
230 return triton.next_power_of_2(length)
231 return DEFAULT_BLOCK_SIZE
234def cumprod_wrapper(inp, dim, dtype=None, out=None):
235 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
236 dim = dim % inp.ndim
237 out_dtype = _get_output_dtype(inp, dtype)
239 inp = inp.contiguous()
240 if out is None:
241 out = torch.empty_like(inp, dtype=out_dtype)
243 if inp.numel() == 0:
244 return out
246 shape = inp.shape
247 M = math.prod(shape[:dim])
248 N = shape[dim]
249 K = inp.numel() // M // N
250 compute_dtype = _get_compute_dtype(out.dtype)
252 if K == 1:
253 reduce_then_scan_row(inp, out, M, N, compute_dtype)
254 else:
255 scan_then_fan(inp, out, M, N, K, compute_dtype)
257 return out
260def reduce_then_scan_row(x, out, M, N, compute_dtype):
261 persistent_limit = (
262 ASCEND_SCAN_LIMIT if runtime_device.vendor_name == "ascend" else 16384
263 )
264 if N <= persistent_limit:
265 TILE_SIZE = triton.next_power_of_2(N)
266 num_warps = 8 if TILE_SIZE > 2048 else 4
267 reduce_then_scan_root_scan_kernel_row[(M, 1, 1)](
268 x, out, N, TILE_SIZE, num_warps=num_warps
269 )
270 return out
272 TILE_SIZE = min(_scan_block_size(N), triton.next_power_of_2(N))
273 num_warps = 8 if TILE_SIZE > 2048 else 4
274 num_tiles = triton.cdiv(N, TILE_SIZE)
275 max_ctas = get_num_sms(_get_device_index(x.device)) * 4
276 num_ctas = min(num_tiles, max_ctas)
277 ROOT_SCAN_TILE_SIZE = triton.next_power_of_2(num_ctas)
278 tiles_per_cta = triton.cdiv(num_tiles, num_ctas)
280 block_products = torch.empty((M, num_ctas), dtype=compute_dtype, device=x.device)
281 block_inclusive_prefix = torch.empty_like(block_products)
283 reduce_then_scan_block_product_kernel_row[(M, num_ctas, 1, 1)](
284 x, block_products, N, tiles_per_cta, TILE_SIZE, num_warps=num_warps
285 )
286 reduce_then_scan_root_scan_kernel_row[(M, 1, 1)](
287 block_products,
288 block_inclusive_prefix,
289 num_ctas,
290 ROOT_SCAN_TILE_SIZE,
291 num_warps=num_warps,
292 )
293 reduce_then_scan_block_scan_kernel_row[(M, num_ctas, 1)](
294 x,
295 block_inclusive_prefix,
296 out,
297 N,
298 num_ctas,
299 tiles_per_cta,
300 TILE_SIZE,
301 num_warps=num_warps,
302 )
303 return out
306@triton.jit
307def reduce_then_scan_block_product_kernel_row(
308 in_ptr,
309 block_product_ptr,
310 N,
311 tiles_per_cta,
312 TILE_SIZE: tl.constexpr,
313):
314 pid_n = tl.program_id(1).to(tl.int64)
315 pid_m = tl.program_id(0).to(tl.int64)
316 num_programs_n = tl.num_programs(1)
317 block_offset = pid_n * (tiles_per_cta * TILE_SIZE)
318 block_end = min(block_offset + tiles_per_cta * TILE_SIZE, N)
320 acc_dtype: tl.constexpr = get_prod_accum_type(block_product_ptr.type.element_ty)
321 acc = tl.full((TILE_SIZE,), value=1, dtype=acc_dtype)
322 for start in range(block_offset, block_end, TILE_SIZE):
323 offsets = start + tl.arange(0, TILE_SIZE)
324 x = tl.load(in_ptr + pid_m * N + offsets, mask=offsets < N, other=1).to(
325 acc_dtype
326 )
327 acc *= x
328 block_product = tl.reduce(acc, axis=0, combine_fn=reduce_mul)
329 tl.store(
330 block_product_ptr + pid_m * num_programs_n + pid_n,
331 block_product,
332 cache_modifier=".cg",
333 )
336@triton.jit
337def reduce_then_scan_root_scan_kernel_row(in_ptr, out_ptr, N, TILE_SIZE: tl.constexpr):
338 pid = tl.program_id(0).to(tl.int64)
339 offsets = tl.arange(0, TILE_SIZE)
340 mask = offsets < N
341 acc_dtype: tl.constexpr = get_prod_accum_type(out_ptr.type.element_ty)
342 x = tl.load(in_ptr + pid * N + offsets, mask=mask, other=1).to(acc_dtype)
343 out = tl.cumprod(x, 0)
344 tl.store(out_ptr + pid * N + offsets, out, mask=mask)
347@triton.jit
348def reduce_then_scan_block_scan_kernel_row(
349 in_ptr,
350 previous_product_ptr,
351 out_ptr,
352 N,
353 num_tiles_n,
354 tiles_per_cta,
355 TILE_SIZE: tl.constexpr,
356):
357 pid_m = tl.program_id(0).to(tl.int64)
358 pid_n = tl.program_id(1).to(tl.int64)
359 block_offset = pid_n * (tiles_per_cta * TILE_SIZE)
360 block_end = min(block_offset + tiles_per_cta * TILE_SIZE, N)
361 acc_dtype: tl.constexpr = get_prod_accum_type(out_ptr.type.element_ty)
363 prefix = tl.load(
364 previous_product_ptr + pid_m * num_tiles_n + pid_n - 1,
365 mask=pid_n > 0,
366 other=1,
367 ).to(acc_dtype)
368 for start in range(block_offset, block_end, TILE_SIZE):
369 offsets = start + tl.arange(0, TILE_SIZE)
370 mask = offsets < N
371 x = tl.load(in_ptr + pid_m * N + offsets, mask=mask, other=1).to(acc_dtype)
372 tile_scan = prefix * tl.cumprod(x, 0)
373 prefix *= tl.reduce(x, axis=0, combine_fn=reduce_mul)
374 tl.store(
375 out_ptr + pid_m * N + offsets, tile_scan, mask=mask, cache_modifier=".cg"
376 )
379def cumprod(inp, dim, *, dtype=None):
380 logger.debug("GEMS CUMPROD")
381 out_dtype = _get_output_dtype(inp, dtype)
382 if is_boolean_dtype(inp.dtype):
383 if is_boolean_dtype(out_dtype):
384 return torch.ops.aten.cumprod.default.redispatch(
385 _FALLBACK_KEYSET, inp, dim, dtype=dtype
386 )
387 uint8_inp = inp.to(torch.uint8)
388 if runtime_device.vendor_name == "ascend":
389 return torch.ops.aten.cumprod.default.redispatch(
390 _FALLBACK_KEYSET, uint8_inp, dim, dtype=dtype
391 )
392 return cumprod_wrapper(uint8_inp, dim, out_dtype)
393 if _should_redispatch_on_ascend(out_dtype):
394 return torch.ops.aten.cumprod.default.redispatch(
395 _FALLBACK_KEYSET, inp, dim, dtype=dtype
396 )
397 return cumprod_wrapper(inp, dim, dtype)
400def cumprod_(inp, dim, *, dtype=None):
401 logger.debug("GEMS CUMPROD_")
402 if dtype is not None and dtype != inp.dtype:
403 raise RuntimeError(
404 "Bad in-place call: input tensor dtype and output tensor dtype should match"
405 )
406 if is_boolean_dtype(inp.dtype):
407 raise NotImplementedError(
408 "In-place cumprod is not supported for boolean tensors"
409 )
410 if _should_redispatch_on_ascend(inp.dtype):
411 return torch.ops.aten.cumprod_.default.redispatch(
412 _FALLBACK_KEYSET, inp, dim, dtype=dtype
413 )
414 out = cumprod_wrapper(inp, dim, inp.dtype)
415 inp.copy_(out)
416 return inp