Coverage for src/flag_gems/ops/index_copy_.py: 99%
160 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import importlib
2import logging
3import os
4from typing import Any, Callable, List, Mapping, Tuple
6import torch
8from flag_gems.utils.code_cache import code_cache_dir
9from flag_gems.utils.code_utils import IndentedBuffer
11logger = logging.getLogger(__name__)
14def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
15 code.writeline("import triton")
16 code.writeline("import triton.language as tl")
17 code.writeline("from flag_gems.utils import libentry")
19 code.newline()
20 code.newline()
22 return code
25def generate_index_copy_kernel(
26 rank: int,
27 kernel_name: str,
28 code: IndentedBuffer,
29) -> IndentedBuffer:
30 # the decorators
31 code.writeline("@libentry()")
32 code.writeline("@triton.jit")
34 # signature
35 code.writeline(f"def {kernel_name}(")
36 with code.indent():
37 if rank > 0:
38 code.writeline("index,")
39 code.writeline("src,")
40 code.writeline("out,")
41 code.writeline("N,")
42 code.writeline("inp_numel,")
43 code.writeline("inp_stride_dim,")
44 code.writeline("inp_shape_dim,")
45 code.writeline("src_shape_dim,")
46 code.writeline("delta,")
48 stride_args = ", ".join(f"src_stride_{i}: int" for i in range(rank))
49 code.writeline(f"{stride_args}, # stride for src")
51 shape_args = ", ".join(f"src_shape_{i}: int" for i in range(rank))
52 code.writeline(f"{shape_args}, # shape for src")
54 code.writeline("BLOCK_SIZE: tl.constexpr,")
56 code.writeline("):")
58 # Kernel Code
59 with code.indent():
60 code.writeline("pid = tl.program_id(axis=0)")
61 code.writeline("offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)")
62 code.writeline("mask = offsets < N")
64 for i in range(rank - 1, -1, -1):
65 code.writeline(f"src_offset{i} = offsets % src_shape_{i}")
66 code.writeline(f"offsets = offsets // src_shape_{i}")
67 code.newline()
68 comp = [f"src_offset{i} * src_stride_{i}" for i in range(rank)]
69 code.writeline(f"src_offset = {' + '.join(comp)}")
71 code.writeline("pre_cal = (inp_stride_dim * src_shape_dim)")
73 # index copy
74 code.writeline("pre_idx = (src_offset // pre_cal).to(tl.int64)")
75 code.writeline(
76 "dim_idx = (src_offset % pre_cal // inp_stride_dim).to(tl.int64)"
77 )
78 code.writeline(
79 "src_dim_idx = (tl.load(index + dim_idx, mask=mask, other=0)).to(tl.int64)"
80 )
81 code.writeline(
82 'assert src_dim_idx >= 0 and src_dim_idx < inp_shape_dim, "0 <= index < self.size(dim)"'
83 )
84 code.writeline(
85 "input_idx = (src_offset + (delta * pre_idx + src_dim_idx - dim_idx) * inp_stride_dim).to(tl.int64)"
86 )
88 code.writeline("input_mask = (input_idx >= 0) & (input_idx < inp_numel)")
89 code.writeline("store_mask = mask & input_mask")
90 code.writeline("src_val = tl.load(src + src_offset, mask=mask, other=0)")
91 code.writeline("tl.store(out + input_idx, src_val, mask=store_mask)")
93 code.newline()
94 code.newline()
95 return code
98def parameter_for_wrapper() -> str:
99 # out, index, src, dim, inp_stride_dim, src_shape_dim, delta, N, inp.numel()
100 parameters: List[str] = []
101 parameters.append("out")
102 parameters.append("index")
103 parameters.append("src")
104 parameters.append("dim")
105 parameters.append("inp_stride_dim")
106 parameters.append("inp_shape_dim")
107 parameters.append("src_shape_dim")
108 parameters.append("delta")
109 parameters.append("N")
110 parameters.append("inp_numel")
112 return ", ".join(parameters)
115def generate_destination_passing_wrapper(
116 rank: int,
117 wrapper_name: str,
118 kernel_name: str,
119 code: IndentedBuffer,
120) -> IndentedBuffer:
121 parameters: str = parameter_for_wrapper()
122 wrapper_signature: str = f"def {wrapper_name} ({parameters}):"
123 code.writeline(wrapper_signature)
125 with code.indent():
126 code.writeline("src_strides = list(src.stride())")
127 code.writeline("src_shapes = list(src.shape)")
129 # kernel launch
130 code.writeline("BLOCK_SIZE = 128") # BLOCK_SIZE setting
131 code.writeline("grid = (triton.cdiv(N, BLOCK_SIZE),)")
132 kernel_launch: str = f"{kernel_name}[grid]("
133 code.writeline(kernel_launch)
134 with code.indent():
135 code.writeline(
136 "index, src, out, N, inp_numel, inp_stride_dim, inp_shape_dim, src_shape_dim, delta, "
137 )
138 if rank > 0:
139 s = ", ".join(f"src_strides[{i}]" for i in range(rank))
140 code.writeline(f"{s},")
142 s = ", ".join(f"src_shapes[{i}]" for i in range(rank))
143 code.writeline(f"{s},")
144 code.writeline("BLOCK_SIZE=BLOCK_SIZE")
145 code.writeline(")")
146 code.writeline("return out")
148 return code
151def generate_code(
152 inputs: Tuple[Any],
153 wrapper_name: str,
154 kernel_name: str,
155 code: IndentedBuffer,
156) -> IndentedBuffer:
157 # inputs: [out, index, src, dim, inp_stride_dim, inp_shape_dim, src_shape_dim, delta, N, inp.numel()]
158 shape = inputs[2].shape
159 rank = len(shape)
161 code = generate_imports(code)
162 code = generate_index_copy_kernel(rank, kernel_name, code)
163 code = generate_destination_passing_wrapper(rank, wrapper_name, kernel_name, code)
164 return code
167class IndexCopyFunction:
168 def __init__(self):
169 self.pid = os.getpid()
170 self.overloads: Mapping[str, Callable] = {}
172 def __call__(self, *args, **kwargs):
173 key = f"{self.arg_key(*args)}"
174 if key in self.overloads:
175 return self.overloads[key](*args, **kwargs)
177 code = IndentedBuffer()
178 code = generate_code(
179 args,
180 "_index_copy_wrapper",
181 "_index_copy_jit_function",
182 code,
183 )
185 file_name = f"index_copy_rank_{key}_pid_{self.pid}.py"
187 try:
188 with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f:
189 f.write(code.getvalue())
191 # load
192 spec = importlib.util.spec_from_file_location(
193 f"_gen_module_rank_{key}_pid_{self.pid}",
194 f.name,
195 )
197 m = importlib.util.module_from_spec(spec)
198 spec.loader.exec_module(m)
199 overload = getattr(m, "_index_copy_wrapper")
200 self.overloads[key] = overload
201 except Exception as e:
202 raise RuntimeError(
203 f"Failed to generate or load index_copy kernel: {e}"
204 ) from e
206 return overload(*args, **kwargs)
208 def arg_key(self, *args):
209 tensors = [item for item in args if torch.is_tensor(item)]
210 max_rank = max(item.ndim for item in tensors)
211 return max_rank
214_index_copy_func = IndexCopyFunction()
217_FALLBACK_KEYSET = torch._C.DispatchKeySet(
218 torch._C.DispatchKey.CompositeExplicitAutograd
219)
222def index_copy(inp, dim, index, src):
223 logger.debug("GEMS INDEX_COPY")
224 assert ((0 <= index) * (index < inp.size(dim))).equal(
225 torch.ones(tuple(index.shape), dtype=torch.bool, device=inp.device)
226 ), "0 <= index < self.size(dim)"
227 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
228 assert index.numel() == src.size(
229 dim
230 ), "The dimth dimension of source must have the same size as the length of index"
231 assert (
232 inp.ndim == src.ndim
233 ), "Self and source should have the same number of dimensions"
234 assert all(
235 (inp.size(i) == src.size(i)) or i == dim for i in range(0, inp.ndim)
236 ), "src.size(d) == self.size(d) for all dimensions d != dim"
238 # Use native clone to avoid potential issues with FlagGems copy_ dispatch
239 out = torch.ops.aten.clone.default.redispatch(_FALLBACK_KEYSET, inp)
241 dim %= inp.ndim
242 inp_stride_dim = inp.stride(dim)
243 src_shape_dim = src.size(dim)
244 inp_shape_dim = inp.size(dim)
245 delta = inp.size(dim) - src_shape_dim
246 N = src.numel()
248 _index_copy_func(
249 out,
250 index,
251 src,
252 dim,
253 inp_stride_dim,
254 inp_shape_dim,
255 src_shape_dim,
256 delta,
257 N,
258 inp.numel(),
259 )
260 return out
263def index_copy_(inp, dim, index, src):
264 logger.debug("GEMS INDEX_COPY_")
265 assert ((0 <= index) * (index < inp.size(dim))).equal(
266 torch.ones(tuple(index.shape), dtype=torch.bool, device=inp.device)
267 ), "0 <= index < self.size(dim)"
268 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
269 assert index.numel() == src.size(
270 dim
271 ), "The dimth dimension of source must have the same size as the length of index"
272 assert (
273 inp.ndim == src.ndim
274 ), "Self and source should have the same number of dimensions"
275 assert all(
276 (inp.size(i) == src.size(i)) or i == dim for i in range(0, inp.ndim)
277 ), "src.size(d) == self.size(d) for all dimensions d != dim"
279 dim %= inp.ndim
280 inp_stride_dim = inp.stride(dim)
281 src_shape_dim = src.size(dim)
282 inp_shape_dim = inp.size(dim)
283 delta = inp.size(dim) - src_shape_dim
284 N = src.numel()
286 _index_copy_func(
287 inp,
288 index,
289 src,
290 dim,
291 inp_stride_dim,
292 inp_shape_dim,
293 src_shape_dim,
294 delta,
295 N,
296 inp.numel(),
297 )
298 return inp