Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/arange.py: 0%
58 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
2import math
4import torch
5import triton
6import triton.language as tl
8from flag_gems import runtime
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as ext
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15@libentry()
16@triton.jit
17def arange_func(
18 y_ptr,
19 start,
20 end,
21 step,
22 size,
23 BLOCK_SIZE: tl.constexpr,
24):
25 pid = ext.program_id(0)
26 offset = pid * BLOCK_SIZE
27 step_offset = offset * step
29 cols = tl.arange(0, BLOCK_SIZE)
30 arange_val = cols * step + step_offset + start
31 mask = cols + offset < size
32 tl.store(y_ptr + offset + cols, arange_val, mask=mask)
35def arange_start(
36 start, end, step=1, *, dtype=None, layout=None, device=None, pin_memory=None
37):
38 logger.debug("GEMS_KUNLUNXIN ARANGE")
39 if dtype is torch.int64:
40 start = int(start)
41 end = int(end)
42 step = int(step)
43 if step == 0:
44 raise RuntimeError("step must be nonzero")
45 sgn = (step > 0) - (step < 0)
46 size = (end - start + step - sgn) // step
47 else:
48 if dtype is torch.int64 and (
49 isinstance(step, float)
50 or isinstance(start, float)
51 or isinstance(end, float)
52 ):
53 int_step = int(step)
54 if int_step == 0:
55 raise RuntimeError("step must be nonzero")
56 size = math.ceil((end - start) / step)
57 size = int(size)
59 if dtype is None:
60 dtype = torch.int64
62 if pin_memory is None:
63 pin_memory = False
65 if device is None:
66 device = runtime.device.name
68 # Size-based heuristic for BLOCK_SIZE and num_warps
69 if size <= 1024:
70 BLOCK_SIZE = 256
71 num_warps = 2
72 elif size <= 8192:
73 BLOCK_SIZE = 1024
74 num_warps = 4
75 elif size <= 65536:
76 BLOCK_SIZE = 4096
77 num_warps = 8
78 else:
79 BLOCK_SIZE = 8192
80 num_warps = 8
82 grid = triton.cdiv(size, BLOCK_SIZE)
84 result = torch.empty((size,), device=device, dtype=dtype, pin_memory=pin_memory)
85 arange_func[grid,](result, start, end, step, size, BLOCK_SIZE, num_warps=num_warps)
86 return result
89def arange(end, *, dtype=None, layout=None, device=None, pin_memory=None):
90 return arange_start(
91 0, end, 1, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
92 )