Coverage for src/flag_gems/runtime/backend/_arm/ops/scatter.py: 0%
253 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import importlib
2import logging
3import os
4from typing import Any, Callable, List, Mapping, Tuple
6import torch
8from flag_gems.utils.code_cache import code_cache_dir
9from flag_gems.utils.code_utils import IndentedBuffer, write_atomic
10from flag_gems.utils.shape_utils import (
11 MemOverlap,
12 has_internal_overlapping,
13 restride_dim,
14)
16logger = logging.getLogger(__name__)
19def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
20 code.writeline("import torch")
21 code.writeline("import triton")
22 code.writeline("import triton.language as tl")
23 code.newline()
24 code.writeline("from flag_gems.utils import libentry")
25 code.writeline("from flag_gems import runtime")
26 code.writeline("import flag_gems")
27 # code.writeline("from flag_gems.utils import triton_lang_extension as tle")
28 code.newline()
29 code.newline()
30 return code
33def generate_scatter_kernel(
34 rank: int,
35 kernel_name: str,
36 code: IndentedBuffer,
37) -> IndentedBuffer:
38 # make the inlined function visible in the context
39 code.newline()
41 # the autotune function
43 code.writeline("def heur_block(args):")
44 with code.indent():
45 code.writeline("if(flag_gems.vendor_name in ['metax', 'iluvatar']):")
46 with code.indent():
47 code.writeline("return 256")
48 code.writeline("return 128")
49 code.newline()
50 code.newline()
52 code.writeline("def loop_count(args):")
53 with code.indent():
54 code.writeline("return 4")
55 code.newline()
56 code.newline()
58 # the decorators
59 # code.writeline("@libentry()")
60 code.writeline("@triton.heuristics(")
61 with code.indent():
62 code.writeline("{")
63 with code.indent():
64 code.writeline('"BLOCK": heur_block,')
65 code.writeline('"LOOP": loop_count,')
66 code.writeline("}")
67 code.writeline(")")
68 inp_stride_vars = ",".join(f"'inp_stride_{i}'" for i in range(rank))
69 index_stride_vars = ",".join(f"'index_stride_{i}'" for i in range(rank))
70 src_stride_vars = ",".join(f"'src_stride_{i}'" for i in range(rank))
71 shape_vars = ",".join(f"'shape_{i}'" for i in range(rank))
72 code.writeline(
73 f"@triton.jit(do_not_specialize=['N','stride_dim','inp_size_dim',"
74 f"{inp_stride_vars},{index_stride_vars},{src_stride_vars},{shape_vars}])"
75 )
77 # signature
78 code.writeline(f"def {kernel_name}(")
79 with code.indent():
80 if rank > 0:
81 code.writeline("src_strided,")
82 code.writeline("index,")
83 code.writeline("inp,")
84 code.writeline("out,")
86 stride_args = ", ".join(f"inp_stride_{i}: int" for i in range(rank))
87 code.writeline(f"{stride_args}, # stride for inp")
89 stride_args = ", ".join(f"index_stride_{i}: int" for i in range(rank))
90 code.writeline(f"{stride_args}, # stride for index")
92 stride_args = ", ".join(f"src_stride_{i}: int" for i in range(rank))
93 code.writeline(f"{stride_args}, # stride for src")
95 shape_args = ", ".join(f"shape_{i}: int" for i in range(rank))
96 code.writeline(f"{shape_args}, # shape")
97 code.writeline("inp_size_dim,")
98 code.writeline("stride_dim,")
99 code.writeline("N,")
100 # reduce options
101 code.writeline("IS_ADD: tl.constexpr,")
102 code.writeline("IS_MUL: tl.constexpr,")
103 code.writeline("BLOCK: tl.constexpr,")
104 code.writeline("LOOP: tl.constexpr,")
105 code.writeline("INT32_OFFSET: tl.constexpr")
107 code.writeline("):")
109 # Kernel Code
110 with code.indent():
111 code.writeline("pid = tl.program_id(0)")
112 code.writeline("if not INT32_OFFSET:")
113 with code.indent():
114 code.writeline("pid = pid.to(tl.int64)")
115 code.writeline("offsets = pid * LOOP * BLOCK + tl.arange(0, BLOCK)")
117 # 1. Calculate inp_offsets and idx_offsets
118 code.writeline("for loop_iter in tl.static_range(LOOP):")
119 with code.indent():
120 code.writeline("mask = offsets < N")
121 code.writeline("cur_idx = offsets")
122 code.writeline("if INT32_OFFSET:")
123 with code.indent():
124 code.writeline("inp_offsets = tl.zeros((BLOCK, ), dtype=tl.int32)")
125 code.writeline("idx_offsets = tl.zeros((BLOCK, ), dtype=tl.int32)")
126 code.writeline("src_offsets = tl.zeros((BLOCK, ), dtype=tl.int32)")
127 code.writeline("else:")
128 with code.indent():
129 code.writeline("inp_offsets = tl.zeros((BLOCK, ), dtype=tl.int64)")
130 code.writeline("idx_offsets = tl.zeros((BLOCK, ), dtype=tl.int64)")
131 code.writeline("src_offsets = tl.zeros((BLOCK, ), dtype=tl.int64)")
132 for i in range(rank)[::-1]:
133 code.writeline("if INT32_OFFSET:")
134 with code.indent():
135 code.writeline(f"shape_{i} = shape_{i}.to(tl.int32)")
136 code.writeline(f"inp_stride_{i} = inp_stride_{i}.to(tl.int32)")
137 code.writeline(f"index_stride_{i} = index_stride_{i}.to(tl.int32)")
138 code.writeline(f"src_stride_{i} = src_stride_{i}.to(tl.int32)")
139 code.writeline(f"mod = cur_idx % shape_{i}")
140 code.writeline(f"inp_offsets += mod * inp_stride_{i}")
141 code.writeline(f"idx_offsets += mod * index_stride_{i}")
142 code.writeline(f"src_offsets += mod * src_stride_{i}")
143 if i != 0:
144 code.writeline(f"cur_idx = cur_idx // shape_{i}")
146 # 2. Use offsets to scatter
147 code.writeline(
148 "cur_src = tl.load(src_strided + src_offsets, mask=mask, other=0)"
149 )
150 code.writeline(
151 "cur_index = tl.load(index + idx_offsets, mask=mask, other=0)"
152 )
153 code.writeline("if INT32_OFFSET:")
154 with code.indent():
155 code.writeline("cur_index = cur_index.to(tl.int32)")
156 code.writeline("stride_dim = stride_dim.to(tl.int32)")
158 code.writeline("dim_offsets = cur_index * stride_dim")
159 code.writeline("inp_offsets += dim_offsets")
160 code.newline()
161 code.writeline("if IS_ADD: ")
162 with code.indent():
163 code.writeline(
164 "tl.atomic_add(out + inp_offsets, cur_src, mask=mask, sem='relaxed')"
165 )
166 code.writeline("elif IS_MUL: ")
167 with code.indent():
168 code.writeline("stop = tl.where(mask, 0, 1).to(tl.int1)")
169 code.writeline("block_stop = False")
170 code.writeline("while not block_stop:")
171 with code.indent():
172 code.writeline(
173 "cur_inp = tl.load(out + inp_offsets, mask=mask, other=0)"
174 )
175 code.writeline("res = tl.where(stop, cur_inp, cur_inp * cur_src)")
176 code.writeline(
177 "cas_res = tl.atomic_cas(out + inp_offsets, cur_inp, res, sem='relaxed')"
178 )
179 code.writeline("stop |= cur_inp == cas_res")
180 code.writeline("block_stop = tl.sum(stop.to(tl.int32)) == BLOCK")
182 code.writeline("else: ")
183 with code.indent():
184 code.writeline("tl.store(out + inp_offsets, cur_src, mask=mask)")
186 code.writeline("offsets += BLOCK")
188 code.newline()
189 code.newline()
190 return code
193def parameter_for_wrapper() -> str:
194 # src_strided, index, inp, out, dim, M, N, reduce
195 parameters: List[str] = []
197 parameters.append("src_strided")
198 parameters.append("index")
199 parameters.append("inp")
200 parameters.append("out")
201 parameters.append("dim_size")
202 parameters.append("dim_stride")
203 parameters.append("N")
204 parameters.append("reduce: tl.constexpr=None")
205 parameters.append("int32_offset: tl.constexpr=None")
207 return ", ".join(parameters)
210def generate_destination_passing_wrapper(
211 rank: int,
212 wrapper_name: str,
213 kernel_name: str,
214 code: IndentedBuffer,
215) -> IndentedBuffer:
216 parameters: str = parameter_for_wrapper()
217 wrapper_signature: str = f"def {wrapper_name}({parameters}):"
218 code.writeline(wrapper_signature)
220 with code.indent():
221 code.writeline("inp_strides = list(inp.stride())")
222 code.writeline("index_strides = index.stride()")
223 code.writeline("src_strides = src_strided.stride()")
224 code.writeline("index_shapes = list(index.shape)")
225 code.writeline("inp_size_dim = dim_size")
226 code.writeline("stride_dim = dim_stride")
228 code.writeline('IS_ADD = reduce == "add"')
229 code.writeline('IS_MUL = reduce == "multiply"')
230 code.writeline("int32_offset = int32_offset or True")
232 # kernel launch
233 code.writeline("grid = lambda meta: (")
234 with code.indent():
235 code.writeline('triton.cdiv(N, meta["BLOCK"] * meta["LOOP"]), ')
236 code.writeline(")")
238 kernel_launch: str = f"{kernel_name}[grid]("
239 code.writeline(kernel_launch)
241 with code.indent():
242 code.writeline("src_strided, index, inp, out, ")
243 if rank > 0:
244 s = ", ".join(f"inp_strides[{i}]" for i in range(rank))
245 code.writeline(f"{s},")
247 s = ", ".join(f"index_strides[{i}]" for i in range(rank))
248 code.writeline(f"{s},")
250 s = ", ".join(f"src_strides[{i}]" for i in range(rank))
251 code.writeline(f"{s},")
253 s = ", ".join(f"index_shapes[{i}]" for i in range(rank))
254 code.writeline(f"{s},")
256 code.writeline("inp_size_dim,")
257 code.writeline("stride_dim,")
258 code.writeline("N,")
259 # reduce options
260 code.writeline("IS_ADD,")
261 code.writeline("IS_MUL,")
262 code.writeline("INT32_OFFSET=int32_offset,")
263 code.writeline(")")
264 code.writeline("return out")
266 return code
269def generate_code(
270 inputs: Tuple[Any],
271 wrapper_name: str,
272 kernel_name: str,
273 code: IndentedBuffer,
274) -> IndentedBuffer:
275 # inputs: [src_strided, index, inp, out, dim, M, N, reduce]
276 shape = inputs[1].shape
277 rank = len(shape)
279 code = generate_imports(code)
280 code = generate_scatter_kernel(rank, kernel_name, code)
281 code = generate_destination_passing_wrapper(rank, wrapper_name, kernel_name, code)
282 return code
285class ScatterFunction:
286 def __init__(self):
287 self.pid = os.getpid()
288 self.overloads: Mapping[str, Callable] = {}
290 def __call__(self, *args, **kwargs):
291 key = f"{self.arg_key(*args)}"
292 if key in self.overloads:
293 overload = self.overloads[key]
294 else:
295 code = IndentedBuffer()
296 code = generate_code(
297 args,
298 "_scatter_wrapper",
299 "_scatter_jit_function",
300 code,
301 )
303 file_name = f"scatter_rank_{key}.py"
304 file_path = code_cache_dir() / file_name
305 write_atomic(file_path, code.getvalue())
307 # load
308 spec = importlib.util.spec_from_file_location(
309 f"_gen_module_rank_{key}",
310 file_path,
311 )
313 m = importlib.util.module_from_spec(spec)
314 spec.loader.exec_module(m)
315 overload = getattr(m, "_scatter_wrapper")
316 self.overloads[key] = overload
318 return overload(*args, **kwargs)
320 def arg_key(self, *args):
321 tensors = [item for item in args if torch.is_tensor(item)]
322 max_rank = max(item.ndim for item in tensors)
323 return max_rank
326_scatter_func = ScatterFunction()
329def scatter(inp, dim, index, src, reduce=None):
330 logger.debug("GEMS SCATTER")
331 if reduce == "multiply":
332 raise RuntimeError(
333 "scatter(reduce='multiply') is not supported on ARM Triton CPU backend yet."
334 )
336 out = inp.clone()
338 if reduce is not None:
339 assert inp.dtype not in (
340 torch.bfloat16,
341 ), "Unsupported operation: reduce scatter bfloat tensors."
343 if has_internal_overlapping(out) == MemOverlap.Yes:
344 out = out.contiguous()
346 src_strided = src.as_strided(index.shape, src.stride())
347 inp_restrided = restride_dim(inp, dim, index.shape)
348 dim_size = inp.size(dim)
349 dim_stride = inp.stride(dim)
350 N = index.numel()
352 int32_size_dim = lambda x: x.stride(dim) * x.size(dim) < 2**32
353 use_int32_offset = all(map(int32_size_dim, (inp, index, src)))
354 _scatter_func(
355 src_strided,
356 index,
357 inp_restrided,
358 out,
359 dim_size,
360 dim_stride,
361 N,
362 reduce,
363 int32_offset=use_int32_offset,
364 )
366 return out
369def scatter_(inp, dim, index, src, reduce=None):
370 logger.debug("GEMS SCATTER_")
371 if reduce == "multiply":
372 raise RuntimeError(
373 "scatter_(reduce='multiply') is not supported on ARM Triton CPU backend yet."
374 )
376 out = inp
378 if reduce is not None:
379 assert inp.dtype not in (
380 torch.bfloat16,
381 ), "Unsupported operation: reduce scatter bfloat tensors."
383 assert (
384 has_internal_overlapping(out) != MemOverlap.Yes
385 ), "Unsupported operation: trying to inplace write to an internally overlapping tensor."
387 src_restrided = src.as_strided(index.shape, src.stride())
388 inp_restrided = restride_dim(inp, dim, index.shape)
389 dim_size = inp.size(dim)
390 dim_stride = inp.stride(dim)
391 N = index.numel()
393 int32_size_dim = lambda x: x.stride(dim) * x.size(dim) < 2**32
394 use_int32_offset = all(map(int32_size_dim, (inp, index, src)))
395 _scatter_func(
396 src_restrided,
397 index,
398 inp_restrided,
399 out,
400 dim_size,
401 dim_stride,
402 N,
403 reduce,
404 int32_offset=use_int32_offset,
405 )
407 return inp