Coverage for src/flag_gems/runtime/backend/_sunrise/ops/scatter.py: 0%
269 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import importlib
2import logging
3import os
4from typing import Any, Callable, List, Mapping, Tuple
6import torch
8from flag_gems.utils.code_cache import code_cache_dir
9from flag_gems.utils.code_utils import IndentedBuffer, write_atomic
10from flag_gems.utils.shape_utils import (
11 MemOverlap,
12 has_internal_overlapping,
13 restride_dim,
14)
16logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
19def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
20 code.writeline("import torch")
21 code.writeline("try:")
22 code.writeline(" import torch_ptpu")
23 code.writeline("except ImportError:")
24 code.writeline(" import torch.cuda as torch_ptpu")
25 code.writeline("import triton")
26 code.writeline("import triton.language as tl")
27 code.newline()
28 code.writeline("from flag_gems.utils import libentry")
29 code.writeline("from flag_gems import runtime")
30 code.writeline("import flag_gems")
31 # code.writeline("from flag_gems.utils import triton_lang_extension as ext")
32 code.newline()
33 code.newline()
34 return code
37def generate_scatter_kernel(
38 rank: int,
39 kernel_name: str,
40 code: IndentedBuffer,
41) -> IndentedBuffer:
42 # make the inlined function visible in the context
43 code.newline()
45 # the autotune function
47 code.writeline("def heur_block(args):")
48 with code.indent():
49 code.writeline("if(flag_gems.vendor_name in ['metax', 'iluvatar']):")
50 with code.indent():
51 code.writeline("return 256")
52 code.writeline("return 128")
53 code.newline()
54 code.newline()
56 code.writeline("def loop_count(args):")
57 with code.indent():
58 code.writeline("return 4")
59 code.newline()
60 code.newline()
62 # the decorators
63 code.writeline("@libentry()")
64 code.writeline("@triton.heuristics(")
65 with code.indent():
66 code.writeline("{")
67 with code.indent():
68 code.writeline('"BLOCK": heur_block,')
69 code.writeline('"LOOP": loop_count,')
70 code.writeline("}")
71 code.writeline(")")
72 inp_stride_vars = ",".join(f"'inp_stride_{i}'" for i in range(rank))
73 index_stride_vars = ",".join(f"'index_stride_{i}'" for i in range(rank))
74 src_stride_vars = ",".join(f"'src_stride_{i}'" for i in range(rank))
75 shape_vars = ",".join(f"'shape_{i}'" for i in range(rank))
76 code.writeline(
77 f"@triton.jit(do_not_specialize=['N','stride_dim','inp_size_dim',"
78 f"{inp_stride_vars},{index_stride_vars},{src_stride_vars},{shape_vars}])"
79 )
81 # signature
82 code.writeline(f"def {kernel_name}(")
83 with code.indent():
84 if rank > 0:
85 code.writeline("src_strided,")
86 code.writeline("index,")
87 code.writeline("inp,")
88 code.writeline("out,")
90 stride_args = ", ".join(f"inp_stride_{i}: int" for i in range(rank))
91 code.writeline(f"{stride_args}, # stride for inp")
93 stride_args = ", ".join(f"index_stride_{i}: int" for i in range(rank))
94 code.writeline(f"{stride_args}, # stride for index")
96 stride_args = ", ".join(f"src_stride_{i}: int" for i in range(rank))
97 code.writeline(f"{stride_args}, # stride for src")
99 shape_args = ", ".join(f"shape_{i}: int" for i in range(rank))
100 code.writeline(f"{shape_args}, # shape")
101 code.writeline("inp_size_dim,")
102 code.writeline("stride_dim,")
103 code.writeline("N,")
104 # reduce options
105 code.writeline("IS_ADD: tl.constexpr,")
106 code.writeline("IS_MUL: tl.constexpr,")
107 code.writeline("BLOCK: tl.constexpr,")
108 code.writeline("LOOP: tl.constexpr,")
109 code.writeline("INT32_OFFSET: tl.constexpr")
111 code.writeline("):")
113 # Kernel Code
114 with code.indent():
115 code.writeline("pid = tl.program_id(0)")
116 code.writeline("if not INT32_OFFSET:")
117 with code.indent():
118 code.writeline("pid = pid.to(tl.int64)")
119 code.writeline("offsets = pid * LOOP * BLOCK + tl.arange(0, BLOCK)")
121 # 1. Calculate inp_offsets and idx_offsets
122 code.writeline("for loop_iter in tl.static_range(LOOP):")
123 with code.indent():
124 code.writeline("mask = offsets < N")
125 code.writeline("cur_idx = offsets")
126 code.writeline("if INT32_OFFSET:")
127 with code.indent():
128 code.writeline("inp_offsets = tl.zeros((BLOCK, ), dtype=tl.int32)")
129 code.writeline("idx_offsets = tl.zeros((BLOCK, ), dtype=tl.int32)")
130 code.writeline("src_offsets = tl.zeros((BLOCK, ), dtype=tl.int32)")
131 code.writeline("else:")
132 with code.indent():
133 code.writeline("inp_offsets = tl.zeros((BLOCK, ), dtype=tl.int64)")
134 code.writeline("idx_offsets = tl.zeros((BLOCK, ), dtype=tl.int64)")
135 code.writeline("src_offsets = tl.zeros((BLOCK, ), dtype=tl.int64)")
136 for i in range(rank)[::-1]:
137 code.writeline("if INT32_OFFSET:")
138 with code.indent():
139 code.writeline(f"shape_{i} = shape_{i}.to(tl.int32)")
140 code.writeline(f"inp_stride_{i} = inp_stride_{i}.to(tl.int32)")
141 code.writeline(f"index_stride_{i} = index_stride_{i}.to(tl.int32)")
142 code.writeline(f"src_stride_{i} = src_stride_{i}.to(tl.int32)")
143 code.writeline(f"mod = cur_idx % shape_{i}")
144 code.writeline(f"inp_offsets += mod * inp_stride_{i}")
145 code.writeline(f"idx_offsets += mod * index_stride_{i}")
146 code.writeline(f"src_offsets += mod * src_stride_{i}")
147 if i != 0:
148 code.writeline(f"cur_idx = cur_idx // shape_{i}")
150 # 2. Use offsets to scatter
151 code.writeline(
152 "cur_src = tl.load(src_strided + src_offsets, mask=mask, other=0)"
153 )
154 code.writeline(
155 "cur_index = tl.load(index + idx_offsets, mask=mask, other=0)"
156 )
157 code.writeline("if INT32_OFFSET:")
158 with code.indent():
159 code.writeline("cur_index = cur_index.to(tl.int32)")
160 code.writeline("stride_dim = stride_dim.to(tl.int32)")
162 code.writeline("dim_offsets = cur_index * stride_dim")
163 code.writeline("inp_offsets += dim_offsets")
164 code.newline()
165 code.writeline("if IS_ADD: ")
166 with code.indent():
167 code.writeline(
168 "tl.atomic_add(out + inp_offsets, cur_src, mask=mask, sem='relaxed')"
169 )
170 code.writeline("elif IS_MUL: ")
171 with code.indent():
172 code.writeline("stop = tl.where(mask, 0, 1).to(tl.int1)")
173 code.writeline("block_stop = False")
174 code.writeline("while not block_stop:")
175 with code.indent():
176 code.writeline
177 code.writeline(
178 "cur_inp = tl.load(out + inp_offsets, mask=mask, other=0)"
179 )
180 code.writeline("res = tl.where(stop, cur_inp, cur_inp * cur_src)")
181 code.writeline(
182 "cas_res = tl.atomic_cas(out + inp_offsets, cur_inp, res, sem='relaxed')"
183 )
184 code.writeline("stop |= cur_inp == cas_res")
185 code.writeline("block_stop = tl.sum(stop.to(tl.int32)) == BLOCK")
187 code.writeline("else: ")
188 with code.indent():
189 code.writeline("tl.store(out + inp_offsets, cur_src, mask=mask)")
191 code.writeline("offsets += BLOCK")
193 code.newline()
194 code.newline()
195 return code
198def parameter_for_wrapper() -> str:
199 # src_strided, index, inp, out, dim, M, N, reduce
200 parameters: List[str] = []
202 parameters.append("src_strided")
203 parameters.append("index")
204 parameters.append("inp")
205 parameters.append("out")
206 parameters.append("dim_size")
207 parameters.append("dim_stride")
208 parameters.append("N")
209 parameters.append("reduce: tl.constexpr=None")
210 parameters.append("int32_offset: tl.constexpr=None")
212 return ", ".join(parameters)
215def generate_destination_passing_wrapper(
216 rank: int,
217 wrapper_name: str,
218 kernel_name: str,
219 code: IndentedBuffer,
220) -> IndentedBuffer:
221 parameters: str = parameter_for_wrapper()
222 wrapper_signature: str = f"def {wrapper_name}({parameters}):"
223 code.writeline(wrapper_signature)
225 with code.indent():
226 code.writeline("inp_strides = list(inp.stride())")
227 code.writeline("index_strides = index.stride()")
228 code.writeline("src_strides = src_strided.stride()")
229 code.writeline("index_shapes = list(index.shape)")
230 code.writeline("inp_size_dim = dim_size")
231 code.writeline("stride_dim = dim_stride")
233 code.writeline('IS_ADD = reduce == "add"')
234 code.writeline('IS_MUL = reduce == "multiply"')
235 code.writeline("int32_offset = int32_offset or True")
237 # kernel launch
238 code.writeline("grid = lambda meta: (")
239 with code.indent():
240 code.writeline('triton.cdiv(N, meta["BLOCK"] * meta["LOOP"]), ')
241 code.writeline(")")
243 kernel_launch: str = f"{kernel_name}[grid]("
244 code.writeline(kernel_launch)
246 with code.indent():
247 code.writeline("src_strided, index, inp, out, ")
248 if rank > 0:
249 s = ", ".join(f"inp_strides[{i}]" for i in range(rank))
250 code.writeline(f"{s},")
252 s = ", ".join(f"index_strides[{i}]" for i in range(rank))
253 code.writeline(f"{s},")
255 s = ", ".join(f"src_strides[{i}]" for i in range(rank))
256 code.writeline(f"{s},")
258 s = ", ".join(f"index_shapes[{i}]" for i in range(rank))
259 code.writeline(f"{s},")
261 code.writeline("inp_size_dim,")
262 code.writeline("stride_dim,")
263 code.writeline("N,")
264 # reduce options
265 code.writeline("IS_ADD,")
266 code.writeline("IS_MUL,")
267 code.writeline("INT32_OFFSET=int32_offset,")
268 code.writeline(")")
269 code.writeline("return out")
271 return code
274def generate_code(
275 inputs: Tuple[Any],
276 wrapper_name: str,
277 kernel_name: str,
278 code: IndentedBuffer,
279) -> IndentedBuffer:
280 # inputs: [src_strided, index, inp, out, dim, M, N, reduce]
281 shape = inputs[1].shape
282 rank = len(shape)
284 code = generate_imports(code)
285 code = generate_scatter_kernel(rank, kernel_name, code)
286 code = generate_destination_passing_wrapper(rank, wrapper_name, kernel_name, code)
287 return code
290class ScatterFunction:
291 def __init__(self):
292 self.pid = os.getpid()
293 self.overloads: Mapping[str, Callable] = {}
295 def __call__(self, *args, **kwargs):
296 key = f"{self.arg_key(*args)}"
297 if key in self.overloads:
298 overload = self.overloads[key]
299 else:
300 code = IndentedBuffer()
301 code = generate_code(
302 args,
303 "_scatter_wrapper",
304 "_scatter_jit_function",
305 code,
306 )
308 file_name = f"scatter_rank_{key}.py"
309 file_path = code_cache_dir() / file_name
310 write_atomic(file_path, code.getvalue())
312 # load
313 spec = importlib.util.spec_from_file_location(
314 f"_gen_module_rank_{key}",
315 file_path,
316 )
318 m = importlib.util.module_from_spec(spec)
319 spec.loader.exec_module(m)
320 overload = getattr(m, "_scatter_wrapper")
321 self.overloads[key] = overload
323 return overload(*args, **kwargs)
325 def arg_key(self, *args):
326 tensors = [item for item in args if torch.is_tensor(item)]
327 max_rank = max(item.ndim for item in tensors)
328 return max_rank
331_scatter_func = ScatterFunction()
334# 由于atomic不支持fp16相关操作,所以需要进行转换之后再运算,恢复成fp16;
335def scatter(inp, dim, index, src, reduce=None):
336 logger.debug("GEMS SCATTER")
337 is_fp16 = inp.dtype == torch.float16 and (reduce is not None)
338 if is_fp16:
339 inp = inp.float()
340 src = src.float()
341 out = inp.clone()
343 if reduce is not None:
344 assert inp.dtype not in (
345 torch.bfloat16,
346 ), "Unsupported operation: reduce scatter bfloat tensors."
348 if has_internal_overlapping(out) == MemOverlap.Yes:
349 out = out.contiguous()
351 src_strided = src.as_strided(index.shape, src.stride())
352 inp_restrided = restride_dim(inp, dim, index.shape)
353 dim_size = inp.size(dim)
354 dim_stride = inp.stride(dim)
355 N = index.numel()
357 int32_size_dim = lambda x: x.stride(dim) * x.size(dim) < 2**32
358 use_int32_offset = all(map(int32_size_dim, (inp, index, src)))
359 _scatter_func(
360 src_strided,
361 index,
362 inp_restrided,
363 out,
364 dim_size,
365 dim_stride,
366 N,
367 reduce,
368 int32_offset=use_int32_offset,
369 )
370 if is_fp16:
371 out = out.half()
372 return out
375def scatter_(inp, dim, index, src, reduce=None):
376 logger.debug("GEMS SCATTER_")
377 base_inp = inp
378 is_fp16 = inp.dtype == torch.float16 and (reduce is not None)
379 if is_fp16:
380 inp = inp.float()
381 src = src.float()
382 out = inp
384 if reduce is not None:
385 assert inp.dtype not in (
386 torch.bfloat16,
387 ), "Unsupported operation: reduce scatter bfloat tensors."
389 assert (
390 has_internal_overlapping(out) != MemOverlap.Yes
391 ), "Unsupported operation: trying to inplace write to an internally overlapping tensor."
393 src_restrided = src.as_strided(index.shape, src.stride())
394 inp_restrided = restride_dim(inp, dim, index.shape)
395 dim_size = inp.size(dim)
396 dim_stride = inp.stride(dim)
397 N = index.numel()
399 int32_size_dim = lambda x: x.stride(dim) * x.size(dim) < 2**32
400 use_int32_offset = all(map(int32_size_dim, (inp, index, src)))
401 _scatter_func(
402 src_restrided,
403 index,
404 inp_restrided,
405 out,
406 dim_size,
407 dim_stride,
408 N,
409 reduce,
410 int32_offset=use_int32_offset,
411 )
412 if is_fp16:
413 base_inp.copy_(out)
414 return base_inp
415 return inp