Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/gather.py: 0%

197 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import importlib 

2import logging 

3import os 

4from typing import Any, Callable, List, Mapping, Tuple 

5 

6import torch 

7 

8from flag_gems.utils.code_cache import cache_dir 

9from flag_gems.utils.code_utils import IndentedBuffer 

10from flag_gems.utils.shape_utils import restride_dim 

11 

12from .scatter import scatter_ 

13 

14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

15 

16 

17def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

18 code.writeline("import torch") 

19 code.writeline("import triton") 

20 code.writeline("import triton.language as tl") 

21 code.writeline("import builtins") 

22 code.newline() 

23 code.writeline("from flag_gems.utils import libentry") 

24 code.writeline("from flag_gems import runtime") 

25 code.writeline("from flag_gems.utils import triton_lang_extension as ext") 

26 

27 code.newline() 

28 code.newline() 

29 return code 

30 

31 

32def generate_gather_kernel( 

33 rank: int, 

34 kernel_name: str, 

35 code: IndentedBuffer, 

36) -> IndentedBuffer: 

37 # make the inlined function visible in the context 

38 code.newline() 

39 

40 # the autotune function 

41 code.writeline("def cfggen():") 

42 with code.indent(): 

43 code.writeline("block_m = [1, 2, 4, 8]") 

44 code.writeline("block_n = [256, 512, 1024, 2048]") 

45 code.writeline("configs = [") 

46 with code.indent(): 

47 code.writeline('triton.Config({"BLOCK_M": m, "BLOCK_N": n}, num_warps=4)') 

48 code.writeline("for m in block_m") 

49 code.writeline("for n in block_n") 

50 code.writeline("]") 

51 code.writeline("return configs") 

52 

53 code.newline() 

54 code.newline() 

55 

56 code.writeline("def heur_block_m(args):") 

57 with code.indent(): 

58 code.writeline('return triton.next_power_of_2(triton.cdiv(args["M"], 12))') 

59 

60 code.newline() 

61 

62 code.writeline("def heur_block_n(args):") 

63 with code.indent(): 

64 code.writeline('return builtins.min(triton.next_power_of_2(args["N"]), 4096)') 

65 

66 code.newline() 

67 code.newline() 

68 

69 # the decorators 

70 code.writeline("@libentry()") 

71 # code.writeline('@triton.autotune(configs=cfggen(), key=["M", "N"])') 

72 code.writeline("@triton.heuristics(") 

73 with code.indent(): 

74 code.writeline("values={") 

75 with code.indent(): 

76 code.writeline('"BLOCK_M": heur_block_m,') 

77 code.writeline('"BLOCK_N": heur_block_n,') 

78 code.writeline("},") 

79 code.writeline(")") 

80 code.writeline("@triton.jit") 

81 

82 # signature 

83 code.writeline(f"def {kernel_name}(") 

84 with code.indent(): 

85 if rank > 0: 

86 code.writeline("inp,") 

87 code.writeline("out,") 

88 code.writeline("index,") 

89 

90 stride_args = ", ".join( 

91 f"inp_stride_{i}: tl.constexpr" for i in range(rank) 

92 ) 

93 code.writeline(f"{stride_args}, # stride for inp") 

94 

95 stride_args = ", ".join( 

96 f"index_stride_{i}: tl.constexpr" for i in range(rank) 

97 ) 

98 code.writeline(f"{stride_args}, # stride for index") 

99 

100 shape_args = ", ".join( 

101 f"index_shape_{i}: tl.constexpr" for i in range(rank) 

102 ) 

103 code.writeline(f"{shape_args}, # shape for index") 

104 

105 code.writeline("dim: tl.constexpr,") 

106 code.writeline("stride_dim: tl.constexpr,") 

107 code.writeline("inp_dim_size: tl.constexpr,") 

108 code.writeline("M: tl.constexpr,") 

109 code.writeline("N: tl.constexpr,") 

110 code.writeline("BLOCK_M: tl.constexpr,") 

111 code.writeline("BLOCK_N: tl.constexpr,") 

112 code.writeline("):") 

113 

114 # Kernel Code 

115 with code.indent(): 

116 code.writeline("pid_x = ext.program_id(0)") 

117 code.writeline("pid_y = ext.program_id(1)") 

118 code.writeline( 

119 "rows_offsets = pid_x * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]" 

120 ) 

121 code.writeline( 

122 "cols_offsets = pid_y * BLOCK_N + tl.arange(0, BLOCK_N)[None, :]" 

123 ) 

124 code.writeline("rows_mask = rows_offsets < M") 

125 code.writeline("cols_mask = cols_offsets < N") 

126 

127 code.writeline("offsets = (rows_offsets * N + cols_offsets).to(tl.int64)") 

128 code.writeline("mask = rows_mask & cols_mask") 

129 

130 # 1. Calculate inp_offsets and idx_offsets 

131 code.writeline("inp_offsets = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int64)") 

132 code.writeline("idx_offsets = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int64)") 

133 code.writeline("cur_idx = rows_offsets * N + cols_offsets") 

134 

135 # 2. snippets 

136 for i in range(rank): 

137 code.writeline(f"mod = cur_idx % index_shape_{i}") 

138 code.writeline(f"inp_offsets += mod * inp_stride_{i}") 

139 code.writeline(f"idx_offsets += mod * index_stride_{i}") 

140 if i != (rank - 1): 

141 code.writeline(f"cur_idx //= index_shape_{i}") 

142 

143 # Use offsets to gather 

144 code.writeline("cur_index = tl.load(index + idx_offsets, mask=mask, other=0)") 

145 code.writeline("inp_offsets += cur_index * stride_dim") 

146 code.writeline("cur_inp = tl.load(inp + inp_offsets, mask=mask, other=0)") 

147 code.writeline("tl.store(out + idx_offsets, cur_inp, mask=mask)") 

148 

149 code.newline() 

150 code.newline() 

151 return code 

152 

153 

154def parameter_for_wrapper() -> str: 

155 # inp_strided, out, index, dim, stride_dim, inp_dim_size, M, N 

156 parameters: List[str] = [] 

157 

158 parameters.append("inp_strided") 

159 parameters.append("out") 

160 parameters.append("index") 

161 parameters.append("dim") 

162 parameters.append("stride_dim") 

163 parameters.append("inp_dim_size") 

164 parameters.append("M") 

165 parameters.append("N") 

166 

167 return ", ".join(parameters) 

168 

169 

170def generate_gather_wrapper( 

171 rank: int, 

172 wrapper_name: str, 

173 kernel_name: str, 

174 code: IndentedBuffer, 

175) -> IndentedBuffer: 

176 parameters: str = parameter_for_wrapper() 

177 wrapper_signature: str = f"def {wrapper_name}({parameters}):" 

178 code.writeline(wrapper_signature) 

179 

180 with code.indent(): 

181 code.writeline("inp_strides = inp_strided.stride()") 

182 code.writeline("index_strides = index.stride()") 

183 code.writeline("index_shapes = list(index.shape)") 

184 

185 # kernel launch 

186 code.writeline("grid = lambda meta: (") 

187 with code.indent(): 

188 code.writeline('triton.cdiv(M, meta["BLOCK_M"]),') 

189 code.writeline('triton.cdiv(N, meta["BLOCK_N"])') 

190 code.writeline(")") 

191 

192 kernel_launch: str = f"{kernel_name}[grid](" 

193 code.writeline(kernel_launch) 

194 

195 with code.indent(): 

196 code.writeline("inp_strided, out, index, ") 

197 if rank > 0: 

198 s = ", ".join(f"inp_strides[{i}]" for i in range(rank)) 

199 code.writeline(f"{s},") 

200 

201 s = ", ".join(f"index_strides[{i}]" for i in range(rank)) 

202 code.writeline(f"{s},") 

203 

204 s = ", ".join(f"index_shapes[{i}]" for i in range(rank)) 

205 code.writeline(f"{s},") 

206 

207 code.writeline("dim,") 

208 code.writeline("stride_dim,") 

209 code.writeline("inp_dim_size,") 

210 code.writeline("M,") 

211 code.writeline("N,") 

212 code.writeline(")") 

213 code.writeline("return out") 

214 

215 return code 

216 

217 

218def generate_code( 

219 inputs: Tuple[Any], 

220 wrapper_name: str, 

221 kernel_name: str, 

222 code: IndentedBuffer, 

223) -> IndentedBuffer: 

224 # inputs: inp_strided, out, index, dim, stride_dim, inp_dim_size, M, N 

225 shape = inputs[2].shape 

226 rank = len(shape) 

227 

228 code = generate_imports(code) 

229 code = generate_gather_kernel(rank, kernel_name, code) 

230 code = generate_gather_wrapper(rank, wrapper_name, kernel_name, code) 

231 return code 

232 

233 

234class GatherFunction: 

235 def __init__(self): 

236 self.pid = os.getpid() 

237 self.overloads: Mapping[str, Callable] = {} 

238 

239 def __call__(self, *args, **kwargs): 

240 key = f"{self.arg_key(*args)}" 

241 if key in self.overloads: 

242 overload = self.overloads[key] 

243 else: 

244 code = IndentedBuffer() 

245 code = generate_code( 

246 args, 

247 "_gather_wrapper", 

248 "_gather_jit_function", 

249 code, 

250 ) 

251 

252 file_name = f"gather_rank_{key}_pid_{self.pid}.py" 

253 

254 with open(cache_dir() / file_name, "wt", encoding="utf-8") as f: 

255 f.write(code.getvalue()) 

256 

257 # load 

258 spec = importlib.util.spec_from_file_location( 

259 f"_gen_module_rank_{key}_pid_{self.pid}", 

260 f.name, 

261 ) 

262 

263 m = importlib.util.module_from_spec(spec) 

264 spec.loader.exec_module(m) 

265 overload = getattr(m, "_gather_wrapper") 

266 self.overloads[key] = overload 

267 

268 return overload(*args, **kwargs) 

269 

270 def arg_key(self, *args): 

271 tensors = [item for item in args if torch.is_tensor(item)] 

272 max_rank = max(item.ndim for item in tensors) 

273 return max_rank 

274 

275 

276_gather_func = GatherFunction() 

277 

278 

279def gather(inp, dim, index, out=None, sparse_grad=False): 

280 logger.debug("GEMS_KUNLUNXIN GATHER") 

281 if dim < 0: 

282 dim += inp.ndim 

283 if inp.ndim != index.ndim: 

284 raise IndexError( 

285 f"Index tensor must have the same number of dimensions as input tensor. " 

286 f"Got {index.ndim} and {inp.ndim}." 

287 ) 

288 inp = inp.contiguous() 

289 index = index.contiguous() 

290 if out is None: 

291 out = torch.empty_like(index, dtype=inp.dtype, device=inp.device) 

292 out = out.contiguous() 

293 stride_dim = inp.stride(dim) 

294 

295 inp_strided = restride_dim(inp, dim, index.shape) 

296 # plain_idx = torch.arange(0, index.numel(), device=inp.device).reshape(index.shape) 

297 N = list(index.shape)[index.ndim - 1] 

298 M = index.numel() // N 

299 inp_dim_size = inp.size(dim) 

300 

301 _gather_func(inp_strided, out, index, dim, stride_dim, inp_dim_size, M, N) 

302 return out 

303 

304 

305def gather_backward(grad, self, dim, index, sparse_grad): 

306 logger.debug("GEMS_KUNLUNXIN GATHER_BACKWARD") 

307 result = grad.new_zeros(self.shape) 

308 return scatter_(result, dim, index, grad, reduce="add")