Coverage for src/flag_gems/runtime/backend/_mthreads/ops/arange.py: 0%
89 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
2import math
3from typing import Optional
5import torch
6import triton
7import triton.language as tl
9from flag_gems import runtime
10from flag_gems.ops.arange import arange_start as default_arange_start
11from flag_gems.runtime import torch_device_fn
12from flag_gems.utils import libentry
13from flag_gems.utils import triton_lang_extension as ext
15logger = logging.getLogger(
16 f"flag_gems.runtime.backend._mthreads.ops.{__name__.split('.')[-1]}"
17)
19device_ = runtime.device
20_SUPPORTED_DTYPES = {
21 torch.float16,
22 torch.bfloat16,
23 torch.float32,
24 torch.int32,
25 torch.int64,
26}
27_AUTOTUNE_CONFIGS = [
28 triton.Config({"BLOCK_SIZE": 256}, num_warps=4, num_stages=1),
29 triton.Config({"BLOCK_SIZE": 512}, num_warps=8, num_stages=1),
30 triton.Config({"BLOCK_SIZE": 1024}, num_warps=8, num_stages=1),
31 triton.Config({"BLOCK_SIZE": 2048}, num_warps=8, num_stages=1),
32]
35@libentry()
36@triton.autotune(configs=_AUTOTUNE_CONFIGS, key=["n_elements", "USE_INT64"])
37@triton.jit(do_not_specialize=["start", "step"])
38def arange_kernel(
39 out_ptr,
40 start,
41 step,
42 n_elements,
43 IS_FLOAT: tl.constexpr,
44 USE_INT64: tl.constexpr,
45 BLOCK_SIZE: tl.constexpr,
46):
47 pid = ext.program_id(0)
48 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
49 if USE_INT64:
50 offsets = offsets.to(tl.int64)
51 n_elements = tl.full((1,), n_elements, tl.int64)
52 else:
53 offsets = offsets.to(tl.int32)
54 n_elements = tl.full((1,), n_elements, tl.int32)
55 mask = offsets < n_elements
57 if IS_FLOAT:
58 idx = offsets.to(tl.float32)
59 step_val = tl.full((1,), step, tl.float32)
60 start_val = tl.full((1,), start, tl.float32)
61 values = tl.fma(idx, step_val, start_val)
62 else:
63 value_dtype = tl.int64 if USE_INT64 else tl.int32
64 idx = offsets.to(value_dtype)
65 step_val = tl.full((1,), step, value_dtype)
66 start_val = tl.full((1,), start, value_dtype)
67 values = start_val + idx * step_val
69 tl.store(out_ptr + offsets, values, mask=mask)
72def _normalize_scalar(value):
73 if isinstance(value, torch.Tensor):
74 return value.item()
75 return value
78def _compute_size(start, end, step, is_float_dtype: bool) -> int:
79 if step == 0:
80 raise ValueError("arange(): step must be non-zero.")
81 if is_float_dtype:
82 size = math.ceil((end - start) / step)
83 else:
84 sgn = (step > 0) - (step < 0)
85 size = (end - start + step - sgn) // step
86 return int(size) if size > 0 else 0
89def _use_triton(dtype: torch.dtype, device: torch.device, size: int) -> bool:
90 if device.type != "musa":
91 return False
92 if dtype not in _SUPPORTED_DTYPES:
93 return False
94 return size > 0
97def _launch_triton_kernel(
98 out: torch.Tensor,
99 start,
100 step,
101 size: int,
102 *,
103 is_float_dtype: bool,
104 use_int64: bool,
105):
106 grid = lambda meta: (triton.cdiv(size, meta["BLOCK_SIZE"]),)
107 with torch_device_fn.device(out.device):
108 arange_kernel[grid](
109 out,
110 start,
111 step,
112 size,
113 IS_FLOAT=is_float_dtype,
114 USE_INT64=use_int64,
115 )
116 return out
119def arange_start(
120 start,
121 end,
122 step=1,
123 *,
124 dtype: Optional[torch.dtype] = None,
125 layout=None,
126 device=None,
127 pin_memory: Optional[bool] = None,
128):
129 logger.debug("GEMS_MTHREADS ARANGE")
130 start = _normalize_scalar(start)
131 end = _normalize_scalar(end)
132 step = _normalize_scalar(step)
134 if dtype is None:
135 dtype = torch.int64
136 if pin_memory is None:
137 pin_memory = False
138 if device is None:
139 device = torch.device(device_.name)
140 else:
141 device = torch.device(device)
143 # Handle int64 dtype with float parameters - convert to int
144 if dtype is torch.int64:
145 if (
146 isinstance(start, float)
147 or isinstance(end, float)
148 or isinstance(step, float)
149 ):
150 start = int(start) if isinstance(start, float) else start
151 end = int(end) if isinstance(end, float) else end
152 step = int(step) if isinstance(step, float) else step
153 if step == 0:
154 raise RuntimeError("step must be nonzero")
156 is_float_dtype = torch.is_floating_point(torch.tensor(0, dtype=dtype))
157 use_int64 = dtype == torch.int64
158 size = _compute_size(start, end, step, is_float_dtype)
160 if not _use_triton(dtype, device, size):
161 return default_arange_start(
162 start,
163 end,
164 step,
165 dtype=dtype,
166 layout=layout,
167 device=device,
168 pin_memory=pin_memory,
169 )
171 result = torch.empty((size,), device=device, dtype=dtype, pin_memory=pin_memory)
172 return _launch_triton_kernel(
173 result,
174 start,
175 step,
176 size,
177 is_float_dtype=is_float_dtype,
178 use_int64=use_int64,
179 )
182def arange(
183 end,
184 *,
185 dtype: Optional[torch.dtype] = None,
186 layout=None,
187 device=None,
188 pin_memory: Optional[bool] = None,
189):
190 return arange_start(
191 0,
192 end,
193 1,
194 dtype=dtype,
195 layout=layout,
196 device=device,
197 pin_memory=pin_memory,
198 )