Coverage for src/flag_gems/runtime/backend/_sunrise/ops/gather.py: 0%
134 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import importlib
2import logging
3import os
4from typing import Any, Callable, Mapping, Tuple
6import torch
8from flag_gems.ops.scatter import scatter_
9from flag_gems.utils.code_cache import code_cache_dir
10from flag_gems.utils.code_utils import IndentedBuffer, write_atomic
11from flag_gems.utils.shape_utils import restride_dim
13logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
16def generate_imports(code: IndentedBuffer) -> IndentedBuffer:
17 code.writeline("import torch")
18 code.writeline("try:")
19 code.writeline(" import torch_ptpu")
20 code.writeline("except ImportError:")
21 code.writeline(" import torch.cuda as torch_ptpu")
22 code.writeline("import triton")
23 code.writeline("import triton.language as tl")
24 code.newline()
25 code.writeline("from flag_gems.utils import libentry")
26 code.writeline("from flag_gems import runtime")
27 code.writeline("from flag_gems.utils import triton_lang_extension as ext")
29 code.newline()
30 code.newline()
31 return code
34def generate_gather_kernel(
35 rank: int,
36 kernel_name: str,
37 code: IndentedBuffer,
38) -> IndentedBuffer:
39 # make the inlined function visible in the context
40 code.newline()
42 code.writeline("@libentry()")
43 code.writeline("@triton.heuristics({'BLOCK_SIZE_N': lambda args: 128})")
44 code.writeline("@triton.jit")
45 code.writeline(f"def {kernel_name}(")
46 with code.indent():
47 args = [
48 "inp, ",
49 "index, ",
50 "out, ",
51 ]
52 args += [f"inp_shape{i}," for i in range(rank)]
53 args += [f"index_shape{i}, " for i in range(rank)]
54 args += [f"out_shape{i}, " for i in range(rank)]
55 args += [f"inp_stride{i}, " for i in range(rank)]
56 args += [f"index_stride{i}, " for i in range(rank)]
57 args += [f"out_stride{i}, " for i in range(rank)]
58 args += ["dim, ", "dim_stride, ", "N, ", "BLOCK_SIZE_N: tl.constexpr, "]
59 code.writelines(args)
60 code.writeline("):")
62 with code.indent():
63 code.writeline("pid = ext.program_id(0)")
64 code.writeline(
65 "offset = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)"
66 )
67 code.newline()
68 code.writeline("cur_offset = offset")
69 for i in range(rank - 1, -1, -1):
70 code.writeline(f"index_idx{i} = cur_offset % index_shape{i}")
71 code.writeline(f"cur_offset = cur_offset // index_shape{i}")
72 code.newline()
73 comp = [f"index_idx{i} * index_stride{i}" for i in range(rank)]
74 code.writeline(f"index_offset = {' + '.join(comp)}")
75 code.writeline("mask = offset < N")
76 code.writeline("cur_index = tl.load(index + index_offset, mask=mask, other=0)")
77 code.newline()
78 comp = [f"index_idx{i} * inp_stride{i}" for i in range(rank)]
79 code.writeline(f"inp_offset = {' + '.join(comp)}")
80 code.writeline("inp_offset += cur_index * dim_stride")
81 code.writeline("cur_inp = tl.load(inp + inp_offset, mask=mask, other=0)")
82 code.newline()
83 comp = [f"index_idx{i} * out_stride{i}" for i in range(rank)]
84 code.writeline(f"out_offset = {' + '.join(comp)}")
85 code.writeline("tl.store(out + out_offset, value=cur_inp, mask=mask)")
87 code.newline()
88 code.newline()
89 return code
92def generate_gather_wrapper(
93 rank: int,
94 wrapper_name: str,
95 kernel_name: str,
96 code: IndentedBuffer,
97) -> IndentedBuffer:
98 code.writeline(f"def {wrapper_name}(inp, dim, index, out, dim_stride, N):")
99 with code.indent():
100 code.writeline("inp_shape = inp.shape")
101 code.writeline("inp_stride = inp.stride()")
102 code.writeline("index_shape = index.shape")
103 code.writeline("index_stride = index.stride()")
104 code.writeline("out_shape = out.shape")
105 code.writeline("out_stride = out.stride()")
106 code.writeline("grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE_N']), )")
107 code.writeline(f"{kernel_name}[grid](")
108 with code.indent():
109 args = [
110 "inp, ",
111 "index, ",
112 "out, ",
113 ]
114 args += [f"inp_shape[{i}], " for i in range(rank)]
115 args += [f"index_shape[{i}], " for i in range(rank)]
116 args += [f"out_shape[{i}], " for i in range(rank)]
117 args += [f"inp_stride[{i}], " for i in range(rank)]
118 args += [f"index_stride[{i}], " for i in range(rank)]
119 args += [f"out_stride[{i}], " for i in range(rank)]
120 args += [
121 "dim, ",
122 "dim_stride, ",
123 "N, ",
124 ]
125 code.writelines(args)
126 code.writeline(")")
127 code.writeline("return out")
128 code.newline()
129 code.newline()
130 return code
133def generate_code(
134 inputs: Tuple[Any],
135 wrapper_name: str,
136 kernel_name: str,
137 code: IndentedBuffer,
138) -> IndentedBuffer:
139 rank = inputs[0].ndim
141 code = generate_imports(code)
142 code = generate_gather_kernel(rank, kernel_name, code)
143 code = generate_gather_wrapper(rank, wrapper_name, kernel_name, code)
144 return code
147class GatherFunction:
148 def __init__(self):
149 self.pid = os.getpid()
150 self.overloads: Mapping[str, Callable] = {}
152 def __call__(self, *args, **kwargs):
153 key = f"{self.arg_key(*args)}"
154 if key in self.overloads:
155 overload = self.overloads[key]
156 else:
157 code = IndentedBuffer()
158 code = generate_code(
159 args,
160 "_gather_wrapper",
161 "_gather_flaggems_jit_function",
162 code,
163 )
165 file_name = f"gather_rank_{key}.py"
166 file_path = code_cache_dir() / file_name
167 write_atomic(file_path, code.getvalue())
169 # load
170 spec = importlib.util.spec_from_file_location(
171 f"_gen_module_rank_{key}",
172 file_path,
173 )
175 m = importlib.util.module_from_spec(spec)
176 spec.loader.exec_module(m)
177 overload = getattr(m, "_gather_wrapper")
178 self.overloads[key] = overload
180 return overload(*args, **kwargs)
182 def arg_key(self, *args):
183 return args[0].ndim
186_gather_func = GatherFunction()
189def gather(inp, dim, index, out=None, sparse_grad=False):
190 logger.debug("GEMS GATHER")
191 if out is None:
192 out = torch.empty_like(index, dtype=inp.dtype, device=inp.device)
193 dim_stride = inp.stride(dim)
194 inp_strided = restride_dim(inp, dim, index.shape)
195 N = index.numel()
196 _gather_func(inp_strided, dim, index, out, dim_stride, N)
197 return out
200def gather_backward(grad, self, dim, index, sparse_grad):
201 logger.debug("GEMS GATHER BACKWARD")
202 result = grad.new_zeros(self.shape)
203 return scatter_(result, dim, index, grad, reduce="add")