Coverage for src/flag_gems/runtime/backend/_sunrise/ops/index_add.py: 0%
158 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import importlib
2import logging
3import os
4from typing import Any, Callable, List, Mapping, Tuple
6import torch
8from flag_gems.utils.code_cache import code_cache_dir
9from flag_gems.utils.code_utils import IndentedBuffer
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
15 code.writeline("import triton")
16 code.writeline("import triton.language as tl")
17 code.writeline("from flag_gems.utils import libentry")
19 code.newline()
20 code.newline()
22 return code
25def generate_index_add_kernel(
26 rank: int,
27 kernel_name: str,
28 code: IndentedBuffer,
29) -> IndentedBuffer:
30 # the decorators
31 code.writeline("@libentry()")
32 code.writeline("@triton.jit")
34 # signature
35 code.writeline(f"def {kernel_name}(")
36 with code.indent():
37 if rank > 0:
38 code.writeline("index,")
39 code.writeline("src,")
40 code.writeline("out,")
41 code.writeline("N,")
42 code.writeline("inp_numel,")
43 code.writeline("inp_stride_dim,")
44 code.writeline("inp_shape_dim,")
45 code.writeline("src_shape_dim,")
46 code.writeline("delta,")
47 code.writeline("alpha,")
49 stride_args = ", ".join(f"src_stride_{i}: int" for i in range(rank))
50 code.writeline(f"{stride_args}, # stride for src")
52 shape_args = ", ".join(f"src_shape_{i}: int" for i in range(rank))
53 code.writeline(f"{shape_args}, # shape for src")
55 code.writeline("BLOCK_SIZE: tl.constexpr,")
57 code.writeline("):")
59 # Kernel Code
60 with code.indent():
61 code.writeline("pid = tl.program_id(axis=0)")
62 code.writeline("offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)")
63 code.writeline("mask = offsets < N")
65 for i in range(rank - 1, -1, -1):
66 code.writeline(f"src_offset{i} = offsets % src_shape_{i}")
67 code.writeline(f"offsets = offsets // src_shape_{i}")
68 code.newline()
69 comp = [f"src_offset{i} * src_stride_{i}" for i in range(rank)]
70 code.writeline(f"src_offset = {' + '.join(comp)}")
72 code.writeline("pre_cal = (inp_stride_dim * src_shape_dim)")
74 # index add
75 code.writeline("pre_idx = (src_offset // pre_cal).to(tl.int64)")
76 code.writeline(
77 "dim_idx = (src_offset % pre_cal // inp_stride_dim).to(tl.int64)"
78 )
79 code.writeline(
80 "src_dim_idx = (tl.load(index + dim_idx, mask=mask, other=0)).to(tl.int64)"
81 )
82 code.writeline(
83 'assert src_dim_idx >= 0 and src_dim_idx < inp_shape_dim, "0 <= index < self.size(dim)"'
84 )
85 code.writeline(
86 "input_idx = (src_offset + (delta * pre_idx + src_dim_idx - dim_idx) * inp_stride_dim).to(tl.int64)"
87 )
89 code.writeline("input_mask = (input_idx < inp_numel) & mask")
90 code.writeline(
91 "add_on = tl.load(src + src_offset, mask=mask, other=0) * alpha"
92 )
93 # code.writeline(
94 # "tl.atomic_add(out + input_idx, add_on, mask=input_mask, sem='relaxed')"
95 # )
96 # TODO: tl.atomic_add doesn't support bfloat16! The following method may be unsafe.
97 code.writeline("cur_out = tl.load(out + input_idx, mask=input_mask)")
98 code.writeline(
99 "tl.store(out + input_idx, cur_out + add_on, mask=input_mask)"
100 )
102 code.newline()
103 code.newline()
104 return code
107def parameter_for_wrapper() -> str:
108 # out, index, src, dim, inp_stride_dim, src_shape_dim, delta, N, inp.numel(), alpha
109 parameters: List[str] = []
110 parameters.append("out")
111 parameters.append("index")
112 parameters.append("src")
113 parameters.append("dim")
114 parameters.append("inp_stride_dim")
115 parameters.append("inp_shape_dim")
116 parameters.append("src_shape_dim")
117 parameters.append("delta")
118 parameters.append("N")
119 parameters.append("inp_numel")
120 parameters.append("alpha")
122 return ", ".join(parameters)
125def generate_destination_passing_wrapper(
126 rank: int,
127 wrapper_name: str,
128 kernel_name: str,
129 code: IndentedBuffer,
130) -> IndentedBuffer:
131 parameters: str = parameter_for_wrapper()
132 wrapper_signature: str = f"def {wrapper_name} ({parameters}):"
133 code.writeline(wrapper_signature)
135 with code.indent():
136 code.writeline("src_strides = list(src.stride())")
137 code.writeline("src_shapes = list(src.shape)")
139 # kernel launch
140 code.writeline("BLOCK_SIZE = 128") # BLOCK_SIZE setting
141 code.writeline("grid = (triton.cdiv(N, BLOCK_SIZE),)")
142 kernel_launch: str = f"{kernel_name}[grid]("
143 code.writeline(kernel_launch)
144 with code.indent():
145 code.writeline(
146 "index, src, out, N, inp_numel, inp_stride_dim, inp_shape_dim, src_shape_dim, delta, alpha, "
147 )
148 if rank > 0:
149 s = ", ".join(f"src_strides[{i}]" for i in range(rank))
150 code.writeline(f"{s},")
152 s = ", ".join(f"src_shapes[{i}]" for i in range(rank))
153 code.writeline(f"{s},")
154 code.writeline("BLOCK_SIZE=BLOCK_SIZE")
155 code.writeline(")")
156 code.writeline("return out")
158 return code
161def generate_code(
162 inputs: Tuple[Any],
163 wrapper_name: str,
164 kernel_name: str,
165 code: IndentedBuffer,
166) -> IndentedBuffer:
167 # inputs: [out, index, src, dim, inp_stride_dim, inp_shape_dim, src_shape_dim, delta, N, inp.numel(), alpha]
168 shape = inputs[2].shape
169 rank = len(shape)
171 code = generate_imports(code)
172 code = generate_index_add_kernel(rank, kernel_name, code)
173 code = generate_destination_passing_wrapper(rank, wrapper_name, kernel_name, code)
174 return code
177class IndexAddFunction:
178 def __init__(self):
179 self.pid = os.getpid()
180 self.overloads: Mapping[str, Callable] = {}
182 def __call__(self, *args, **kwargs):
183 key = f"{self.arg_key(*args)}"
184 if key in self.overloads:
185 overload = self.overloads[key]
186 else:
187 code = IndentedBuffer()
188 code = generate_code(
189 args,
190 "_index_add_wrapper",
191 "_index_add_jit_function",
192 code,
193 )
195 file_name = f"index_add_rank_{key}_pid_{self.pid}.py"
197 with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f:
198 f.write(code.getvalue())
200 # load
201 spec = importlib.util.spec_from_file_location(
202 f"_gen_module_rank_{key}_pid_{self.pid}",
203 f.name,
204 )
206 m = importlib.util.module_from_spec(spec)
207 spec.loader.exec_module(m)
208 overload = getattr(m, "_index_add_wrapper")
209 self.overloads[key] = overload
211 return overload(*args, **kwargs)
213 def arg_key(self, *args):
214 tensors = [item for item in args if torch.is_tensor(item)]
215 max_rank = max(item.ndim for item in tensors)
216 return max_rank
219_index_add_func = IndexAddFunction()
222def index_add(inp, dim, index, src, alpha=1):
223 logger.debug("GEMS INDEX ADD")
224 assert ((0 <= index) * (index < inp.size(dim))).equal(
225 torch.ones(tuple(index.shape), dtype=torch.bool, device=inp.device)
226 ), "0 <= index < self.size(dim)"
227 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
228 assert index.numel() == src.size(
229 dim
230 ), "The dimth dimension of source must have the same size as the length of index"
231 assert (
232 inp.ndim == src.ndim
233 ), "Self and source should have the same number of dimensions"
234 assert (
235 ((inp.size(i) == src.size(i)) or i == dim) for i in range(0, inp.ndim)
236 ), "src.size(d) == self.size(d) for all dimensions d != dim"
238 out = inp.clone()
240 dim %= inp.ndim
241 inp_stride_dim = inp.stride(dim)
242 src_shape_dim = src.size(dim)
243 inp_shape_dim = inp.size(dim)
244 delta = inp.size(dim) - src_shape_dim
245 N = src.numel()
247 _index_add_func(
248 out,
249 index,
250 src,
251 dim,
252 inp_stride_dim,
253 inp_shape_dim,
254 src_shape_dim,
255 delta,
256 N,
257 inp.numel(),
258 alpha,
259 )
260 return out
263def index_add_(inp, dim, index, src, alpha=1):
264 logger.debug("GEMS INDEX ADD_")
265 assert ((0 <= index) * (index < inp.size(dim))).equal(
266 torch.ones(tuple(index.shape), dtype=torch.bool, device=inp.device)
267 ), "0 <= index < self.size(dim)"
268 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
269 assert index.numel() == src.size(
270 dim
271 ), "The dimth dimension of source must have the same size as the length of index"
272 assert (
273 inp.ndim == src.ndim
274 ), "Self and source should have the same number of dimensions"
275 assert (
276 ((inp.size(i) == src.size(i)) or i == dim) for i in range(0, inp.ndim)
277 ), "src.size(d) == self.size(d) for all dimensions d != dim"
279 dim %= inp.ndim
280 inp_stride_dim = inp.stride(dim)
281 src_shape_dim = src.size(dim)
282 inp_shape_dim = inp.size(dim)
283 delta = inp.size(dim) - src_shape_dim
284 N = src.numel()
286 _index_add_func(
287 inp,
288 index,
289 src,
290 dim,
291 inp_stride_dim,
292 inp_shape_dim,
293 src_shape_dim,
294 delta,
295 N,
296 inp.numel(),
297 alpha,
298 )
299 return inp