Coverage for src/flag_gems/runtime/backend/_sunrise/ops/index_add.py: 0%

158 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import importlib 

2import logging 

3import os 

4from typing import Any, Callable, List, Mapping, Tuple 

5 

6import torch 

7 

8from flag_gems.utils.code_cache import code_cache_dir 

9from flag_gems.utils.code_utils import IndentedBuffer 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12 

13 

14def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

15 code.writeline("import triton") 

16 code.writeline("import triton.language as tl") 

17 code.writeline("from flag_gems.utils import libentry") 

18 

19 code.newline() 

20 code.newline() 

21 

22 return code 

23 

24 

25def generate_index_add_kernel( 

26 rank: int, 

27 kernel_name: str, 

28 code: IndentedBuffer, 

29) -> IndentedBuffer: 

30 # the decorators 

31 code.writeline("@libentry()") 

32 code.writeline("@triton.jit") 

33 

34 # signature 

35 code.writeline(f"def {kernel_name}(") 

36 with code.indent(): 

37 if rank > 0: 

38 code.writeline("index,") 

39 code.writeline("src,") 

40 code.writeline("out,") 

41 code.writeline("N,") 

42 code.writeline("inp_numel,") 

43 code.writeline("inp_stride_dim,") 

44 code.writeline("inp_shape_dim,") 

45 code.writeline("src_shape_dim,") 

46 code.writeline("delta,") 

47 code.writeline("alpha,") 

48 

49 stride_args = ", ".join(f"src_stride_{i}: int" for i in range(rank)) 

50 code.writeline(f"{stride_args}, # stride for src") 

51 

52 shape_args = ", ".join(f"src_shape_{i}: int" for i in range(rank)) 

53 code.writeline(f"{shape_args}, # shape for src") 

54 

55 code.writeline("BLOCK_SIZE: tl.constexpr,") 

56 

57 code.writeline("):") 

58 

59 # Kernel Code 

60 with code.indent(): 

61 code.writeline("pid = tl.program_id(axis=0)") 

62 code.writeline("offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)") 

63 code.writeline("mask = offsets < N") 

64 

65 for i in range(rank - 1, -1, -1): 

66 code.writeline(f"src_offset{i} = offsets % src_shape_{i}") 

67 code.writeline(f"offsets = offsets // src_shape_{i}") 

68 code.newline() 

69 comp = [f"src_offset{i} * src_stride_{i}" for i in range(rank)] 

70 code.writeline(f"src_offset = {' + '.join(comp)}") 

71 

72 code.writeline("pre_cal = (inp_stride_dim * src_shape_dim)") 

73 

74 # index add 

75 code.writeline("pre_idx = (src_offset // pre_cal).to(tl.int64)") 

76 code.writeline( 

77 "dim_idx = (src_offset % pre_cal // inp_stride_dim).to(tl.int64)" 

78 ) 

79 code.writeline( 

80 "src_dim_idx = (tl.load(index + dim_idx, mask=mask, other=0)).to(tl.int64)" 

81 ) 

82 code.writeline( 

83 'assert src_dim_idx >= 0 and src_dim_idx < inp_shape_dim, "0 <= index < self.size(dim)"' 

84 ) 

85 code.writeline( 

86 "input_idx = (src_offset + (delta * pre_idx + src_dim_idx - dim_idx) * inp_stride_dim).to(tl.int64)" 

87 ) 

88 

89 code.writeline("input_mask = (input_idx < inp_numel) & mask") 

90 code.writeline( 

91 "add_on = tl.load(src + src_offset, mask=mask, other=0) * alpha" 

92 ) 

93 # code.writeline( 

94 # "tl.atomic_add(out + input_idx, add_on, mask=input_mask, sem='relaxed')" 

95 # ) 

96 # TODO: tl.atomic_add doesn't support bfloat16! The following method may be unsafe. 

97 code.writeline("cur_out = tl.load(out + input_idx, mask=input_mask)") 

98 code.writeline( 

99 "tl.store(out + input_idx, cur_out + add_on, mask=input_mask)" 

100 ) 

101 

102 code.newline() 

103 code.newline() 

104 return code 

105 

106 

107def parameter_for_wrapper() -> str: 

108 # out, index, src, dim, inp_stride_dim, src_shape_dim, delta, N, inp.numel(), alpha 

109 parameters: List[str] = [] 

110 parameters.append("out") 

111 parameters.append("index") 

112 parameters.append("src") 

113 parameters.append("dim") 

114 parameters.append("inp_stride_dim") 

115 parameters.append("inp_shape_dim") 

116 parameters.append("src_shape_dim") 

117 parameters.append("delta") 

118 parameters.append("N") 

119 parameters.append("inp_numel") 

120 parameters.append("alpha") 

121 

122 return ", ".join(parameters) 

123 

124 

125def generate_destination_passing_wrapper( 

126 rank: int, 

127 wrapper_name: str, 

128 kernel_name: str, 

129 code: IndentedBuffer, 

130) -> IndentedBuffer: 

131 parameters: str = parameter_for_wrapper() 

132 wrapper_signature: str = f"def {wrapper_name} ({parameters}):" 

133 code.writeline(wrapper_signature) 

134 

135 with code.indent(): 

136 code.writeline("src_strides = list(src.stride())") 

137 code.writeline("src_shapes = list(src.shape)") 

138 

139 # kernel launch 

140 code.writeline("BLOCK_SIZE = 128") # BLOCK_SIZE setting 

141 code.writeline("grid = (triton.cdiv(N, BLOCK_SIZE),)") 

142 kernel_launch: str = f"{kernel_name}[grid](" 

143 code.writeline(kernel_launch) 

144 with code.indent(): 

145 code.writeline( 

146 "index, src, out, N, inp_numel, inp_stride_dim, inp_shape_dim, src_shape_dim, delta, alpha, " 

147 ) 

148 if rank > 0: 

149 s = ", ".join(f"src_strides[{i}]" for i in range(rank)) 

150 code.writeline(f"{s},") 

151 

152 s = ", ".join(f"src_shapes[{i}]" for i in range(rank)) 

153 code.writeline(f"{s},") 

154 code.writeline("BLOCK_SIZE=BLOCK_SIZE") 

155 code.writeline(")") 

156 code.writeline("return out") 

157 

158 return code 

159 

160 

161def generate_code( 

162 inputs: Tuple[Any], 

163 wrapper_name: str, 

164 kernel_name: str, 

165 code: IndentedBuffer, 

166) -> IndentedBuffer: 

167 # inputs: [out, index, src, dim, inp_stride_dim, inp_shape_dim, src_shape_dim, delta, N, inp.numel(), alpha] 

168 shape = inputs[2].shape 

169 rank = len(shape) 

170 

171 code = generate_imports(code) 

172 code = generate_index_add_kernel(rank, kernel_name, code) 

173 code = generate_destination_passing_wrapper(rank, wrapper_name, kernel_name, code) 

174 return code 

175 

176 

177class IndexAddFunction: 

178 def __init__(self): 

179 self.pid = os.getpid() 

180 self.overloads: Mapping[str, Callable] = {} 

181 

182 def __call__(self, *args, **kwargs): 

183 key = f"{self.arg_key(*args)}" 

184 if key in self.overloads: 

185 overload = self.overloads[key] 

186 else: 

187 code = IndentedBuffer() 

188 code = generate_code( 

189 args, 

190 "_index_add_wrapper", 

191 "_index_add_jit_function", 

192 code, 

193 ) 

194 

195 file_name = f"index_add_rank_{key}_pid_{self.pid}.py" 

196 

197 with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f: 

198 f.write(code.getvalue()) 

199 

200 # load 

201 spec = importlib.util.spec_from_file_location( 

202 f"_gen_module_rank_{key}_pid_{self.pid}", 

203 f.name, 

204 ) 

205 

206 m = importlib.util.module_from_spec(spec) 

207 spec.loader.exec_module(m) 

208 overload = getattr(m, "_index_add_wrapper") 

209 self.overloads[key] = overload 

210 

211 return overload(*args, **kwargs) 

212 

213 def arg_key(self, *args): 

214 tensors = [item for item in args if torch.is_tensor(item)] 

215 max_rank = max(item.ndim for item in tensors) 

216 return max_rank 

217 

218 

219_index_add_func = IndexAddFunction() 

220 

221 

222def index_add(inp, dim, index, src, alpha=1): 

223 logger.debug("GEMS INDEX ADD") 

224 assert ((0 <= index) * (index < inp.size(dim))).equal( 

225 torch.ones(tuple(index.shape), dtype=torch.bool, device=inp.device) 

226 ), "0 <= index < self.size(dim)" 

227 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

228 assert index.numel() == src.size( 

229 dim 

230 ), "The dimth dimension of source must have the same size as the length of index" 

231 assert ( 

232 inp.ndim == src.ndim 

233 ), "Self and source should have the same number of dimensions" 

234 assert ( 

235 ((inp.size(i) == src.size(i)) or i == dim) for i in range(0, inp.ndim) 

236 ), "src.size(d) == self.size(d) for all dimensions d != dim" 

237 

238 out = inp.clone() 

239 

240 dim %= inp.ndim 

241 inp_stride_dim = inp.stride(dim) 

242 src_shape_dim = src.size(dim) 

243 inp_shape_dim = inp.size(dim) 

244 delta = inp.size(dim) - src_shape_dim 

245 N = src.numel() 

246 

247 _index_add_func( 

248 out, 

249 index, 

250 src, 

251 dim, 

252 inp_stride_dim, 

253 inp_shape_dim, 

254 src_shape_dim, 

255 delta, 

256 N, 

257 inp.numel(), 

258 alpha, 

259 ) 

260 return out 

261 

262 

263def index_add_(inp, dim, index, src, alpha=1): 

264 logger.debug("GEMS INDEX ADD_") 

265 assert ((0 <= index) * (index < inp.size(dim))).equal( 

266 torch.ones(tuple(index.shape), dtype=torch.bool, device=inp.device) 

267 ), "0 <= index < self.size(dim)" 

268 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

269 assert index.numel() == src.size( 

270 dim 

271 ), "The dimth dimension of source must have the same size as the length of index" 

272 assert ( 

273 inp.ndim == src.ndim 

274 ), "Self and source should have the same number of dimensions" 

275 assert ( 

276 ((inp.size(i) == src.size(i)) or i == dim) for i in range(0, inp.ndim) 

277 ), "src.size(d) == self.size(d) for all dimensions d != dim" 

278 

279 dim %= inp.ndim 

280 inp_stride_dim = inp.stride(dim) 

281 src_shape_dim = src.size(dim) 

282 inp_shape_dim = inp.size(dim) 

283 delta = inp.size(dim) - src_shape_dim 

284 N = src.numel() 

285 

286 _index_add_func( 

287 inp, 

288 index, 

289 src, 

290 dim, 

291 inp_stride_dim, 

292 inp_shape_dim, 

293 src_shape_dim, 

294 delta, 

295 N, 

296 inp.numel(), 

297 alpha, 

298 ) 

299 return inp