Coverage for src/flag_gems/runtime/backend/_cambricon/ops/arange.py: 0%
51 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.utils import libentry, libtuner
11from ..utils import TOTAL_CORE_NUM
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16@libentry()
17@libtuner(
18 configs=[
19 triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_stages=3, num_warps=1),
20 triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=3, num_warps=1),
21 triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_stages=3, num_warps=1),
22 triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=3, num_warps=1),
23 ],
24 key=["size"],
25 strategy=["log"],
26)
27@triton.jit
28def arange_func(y_ptr, start, end, step, size, BLOCK_SIZE: tl.constexpr):
29 pid = tl.program_id(axis=0)
30 num_jobs = tl.num_programs(axis=0)
31 block_start = pid * BLOCK_SIZE
32 block_step = num_jobs * BLOCK_SIZE
33 block_start = block_start
34 for block_start_offset in range(block_start, size, block_step):
35 offset = tl.arange(0, BLOCK_SIZE) + block_start_offset
36 arange_val = offset * step + start
37 tl.store(y_ptr + offset, arange_val, mask=offset < size)
40def arange_start(
41 start, end, step=1, *, dtype=None, layout=None, device=None, pin_memory=None
42):
43 logger.debug("GEMS_CAMBRICON ARANGE")
44 if dtype is torch.int64:
45 start = int(start)
46 end = int(end)
47 step = int(step)
48 if step == 0:
49 raise RuntimeError("step must be nonzero")
50 sgn = (step > 0) - (step < 0)
51 size = (end - start + step - sgn) // step
52 else:
53 if dtype is torch.int64 and (
54 isinstance(step, float)
55 or isinstance(start, float)
56 or isinstance(end, float)
57 ):
58 int_step = int(step)
59 if int_step == 0:
60 raise RuntimeError("step must be nonzero")
61 size = math.ceil((end - start) / step)
62 size = int(size)
64 assert (
65 size < torch.iinfo(torch.int32).max
66 ), f"Size {size} is not less than the maximum int32 value max_int32"
68 grid = lambda META: (min(triton.cdiv(size, META["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
70 if dtype is None:
71 dtype = torch.int64
73 if pin_memory is None:
74 pin_memory = False
76 if device is None:
77 device = (
78 runtime.device.name
79 ) # Note(Zhengzekang): Torch default value is CPU, but triton is target to GPU.
81 result = torch.empty((size,), device=device, dtype=dtype, pin_memory=pin_memory)
82 arange_func[grid](result, start, end, step, size)
83 return result
86def arange(end, *, dtype=None, layout=None, device=None, pin_memory=None):
87 return arange_start(
88 0, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
89 )