Coverage for src/flag_gems/runtime/backend/_sunrise/ops/pad.py: 0%

286 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import importlib 

2import logging 

3import os 

4from typing import Any, Callable, List, Mapping, Tuple 

5 

6import torch 

7 

8from flag_gems.utils.code_cache import code_cache_dir 

9from flag_gems.utils.code_utils import IndentedBuffer, write_atomic 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14# --------------------------- padding wrapper genration ----------------------------------- 

15def parameter_for_wrapper() -> str: 

16 """Generate parameter declaration with type annotation for wrapper function. 

17 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor 

18 """ 

19 parameters: List[str] = [] 

20 

21 parameters.append("in0") 

22 parameters.append("pad") 

23 parameters.append("mode") 

24 parameters.append("value=0") 

25 return ", ".join(parameters) 

26 

27 

28def parameter_for_wrapper_out() -> str: 

29 """Generate parameter declaration with type annotation for wrapper function. 

30 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor 

31 """ 

32 parameters: List[str] = [] 

33 

34 parameters.append("in0") 

35 parameters.append("out0") 

36 parameters.append("dst_shape") 

37 parameters.append("pad_before") 

38 parameters.append("pad_after") 

39 parameters.append("mode") 

40 parameters.append("value=0") 

41 

42 return ", ".join(parameters) 

43 

44 

45def parameter_ref_for_wrapper() -> str: 

46 """Generate parameter reference for wrapper function. 

47 Example: in0, val0, out0, out0_offset 

48 """ 

49 parameters: List[str] = [] 

50 

51 parameters.append("in0") 

52 parameters.append("out0") 

53 parameters.append("dst_shape") 

54 parameters.append("pad_before") 

55 parameters.append("pad_after") 

56 parameters.append("mode") 

57 parameters.append("value") 

58 

59 return ", ".join(parameters) 

60 

61 

62def output_ref_for_wrapper() -> str: 

63 return "out0" 

64 

65 

66def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

67 code.writeline("import math") 

68 code.writeline("import torch") 

69 code.writeline("import triton") 

70 code.writeline("from triton import language as tl") 

71 code.newline() 

72 code.writeline("from flag_gems.utils.libentry import libentry") 

73 code.writeline("from flag_gems.runtime import torch_device_fn") 

74 code.writeline("from flag_gems.utils import triton_lang_extension as ext") 

75 code.writeline("from flag_gems.utils.type_utils import type_promotion") 

76 code.newline() 

77 code.newline() 

78 return code 

79 

80 

81def generate_functional_padding_wrapper( 

82 wrapper_name: str, 

83 destination_passing_func_name: str, 

84 code: IndentedBuffer, 

85) -> IndentedBuffer: 

86 # wrapper signature 

87 parameters: str = parameter_for_wrapper() 

88 wrapper_signature: str = f"def {wrapper_name}({parameters}):" 

89 code.writeline(wrapper_signature) 

90 

91 with code.indent(): 

92 code.writeline("ndim = in0.ndim") 

93 code.writeline("pad_size = len(pad)") 

94 code.writeline("assert pad_size % 2 == 0") 

95 code.newline() 

96 code.writeline("pad_before = [0 for _ in range(ndim)]") 

97 code.writeline("pad_after = [0 for _ in range(ndim)]") 

98 code.newline() 

99 code.writeline("pad_pair = pad_size // 2 ") 

100 code.writeline("for i in range(pad_pair): ") 

101 with code.indent(): 

102 code.writeline("pad_before[ndim - i - 1] = pad[2 * i]") 

103 code.writeline("pad_after[ndim - i - 1] = pad[2 * i + 1]") 

104 code.writeline("dst_shape = list(in0.shape)") 

105 code.writeline("for i in range(ndim): ") 

106 with code.indent(): 

107 code.writeline("dst_shape[i] += pad_before[i] + pad_after[i]") 

108 

109 code.writeline( 

110 ("out0 = torch.empty(dst_shape, device=in0.device, dtype=in0.dtype)") 

111 ) 

112 

113 # call destination_passing_func 

114 output_names: str = output_ref_for_wrapper() 

115 call_str = ( 

116 f"{output_names} = {destination_passing_func_name}" 

117 f"({parameter_ref_for_wrapper()})" 

118 ) 

119 code.writeline(call_str) 

120 

121 return_str = "return out0" 

122 code.writeline(return_str) 

123 code.newline() 

124 code.newline() 

125 

126 return code 

127 

128 

129def generate_destination_passing_padding_wrapper( 

130 rank: int, 

131 wrapper_name: str, 

132 kernel_name: str, 

133 code: IndentedBuffer, 

134) -> IndentedBuffer: 

135 # wrapper signature 

136 parameters: str = parameter_for_wrapper_out() 

137 

138 wrapper_signature: str = f"def {wrapper_name}({parameters}):" 

139 code.writeline(wrapper_signature) 

140 

141 with code.indent(): 

142 # docstring 

143 code.writeline("BLOCK_SIZE = 256") 

144 code.writeline("grid = (triton.cdiv(out0.numel(), BLOCK_SIZE), 1, 1)") 

145 code.newline() 

146 

147 code.writeline("x_shape = in0.shape") 

148 code.writeline("in_strides0 = in0.stride()") 

149 code.writeline("out_strides = out0.stride()") 

150 

151 # input strides for each input tensor w.r.t. the task index space 

152 if rank > 0: 

153 code.writeline("# strides of each tensor argument w.r.t the task space") 

154 for i in range(rank): 

155 code.writeline(f"valid_dim{i}_start = pad_before[{i}]") 

156 

157 code.writeline(f"valid_dim{i}_end = dst_shape[{i}] - pad_after[{i}]") 

158 

159 code.newline() 

160 

161 code.writeline("# Check which dimensions have padding") 

162 for i in range(rank): 

163 code.writeline( 

164 f"dim{i}_has_pad = pad_before[{i}] > 0 or pad_after[{i}] > 0" 

165 ) 

166 code.writeline("IS_CONSTANT = mode == 'constant'") 

167 code.writeline("IS_REFLECT = mode == 'reflect'") 

168 code.writeline("IS_REPLICATE = mode == 'replicate'") 

169 code.writeline("IS_CIRCULAR = mode == 'circular'") 

170 

171 code.newline() 

172 

173 # grid 

174 code.writeline("# kernel launch") 

175 

176 # launch kernel 

177 code.writeline("with torch_device_fn.device(in0.device):") 

178 with code.indent(): 

179 kernel_launch: str = f"{kernel_name}[grid](" 

180 code.writeline(kernel_launch) 

181 

182 with code.indent(): 

183 code.writeline("in0, out0, ") 

184 

185 if rank > 0: 

186 s = ", ".join(f"x_shape[{j}]" for j in range(rank)) 

187 code.writeline(f"{s}, # shape for x") 

188 

189 s = ", ".join(f"in_strides0[{j}]" for j in range(rank)) 

190 code.writeline(f"{s}, # stride for x") 

191 

192 s = ", ".join(f"out_strides[{j}]" for j in range(rank)) 

193 code.writeline(f"{s}, # stride for out") 

194 

195 s = ", ".join(f"valid_dim{j}_start" for j in range(rank)) 

196 code.writeline(f"{s}, # valid dim start") 

197 

198 s = ", ".join(f"valid_dim{j}_end" for j in range(rank)) 

199 code.writeline(f"{s}, # valid dim end") 

200 

201 s = ", ".join(f"bool(dim{i}_has_pad)" for i in range(rank)) 

202 code.writeline(f"{s}, # dim has padding flags") 

203 

204 code.writeline("in0.numel(), ") 

205 code.writeline("out0.numel(), ") 

206 code.writeline("value, ") 

207 code.writeline("IS_CONSTANT, ") 

208 code.writeline("IS_REFLECT, ") 

209 code.writeline("IS_REPLICATE, ") 

210 code.writeline("IS_CIRCULAR, ") 

211 code.writeline("BLOCK_SIZE, ") 

212 code.writeline(")") 

213 

214 code.writeline("return out0") 

215 code.newline() 

216 code.newline() 

217 return code 

218 

219 

220def generate_pad_kernel( 

221 rank: int, 

222 kernel_name: str, 

223 code: IndentedBuffer, 

224) -> IndentedBuffer: 

225 # make the inlined function visible in the context 

226 code.newline() 

227 

228 # the decorators 

229 code.writeline("@libentry()") 

230 non_specialize_arg_names = ["value"] 

231 code.writeline(f"@triton.jit(do_not_specialize={non_specialize_arg_names})") 

232 

233 # signature 

234 code.writeline(f"def {kernel_name}(") 

235 with code.indent(): 

236 code.writeline("in0_ptr: tl.tensor, # of tl.pointer_type") 

237 

238 code.writeline("out0_ptr: tl.tensor, # of tl.pointer_type") 

239 

240 if rank > 0: 

241 # shape for inputs 

242 shape_args = ", ".join(f"x_shape{j}: int" for j in range(rank)) 

243 code.writeline(f"{shape_args}, # shape for x") 

244 

245 # shape for inputs 

246 stride_args = ", ".join(f"in_strides{j}: int" for j in range(rank)) 

247 code.writeline(f"{stride_args}, # stride for x") 

248 

249 # shape for inputs 

250 stride_args = ", ".join(f"out_strides{j}: int" for j in range(rank)) 

251 code.writeline(f"{stride_args}, # stride for out") 

252 

253 # shape for inputs 

254 stride_args = ", ".join(f"valid_dim{j}_start: int" for j in range(rank)) 

255 code.writeline(f"{stride_args}, # valid dim start") 

256 

257 # shape for inputs 

258 stride_args = ", ".join(f"valid_dim{j}_end: int" for j in range(rank)) 

259 code.writeline(f"{stride_args}, # valid dim end") 

260 

261 for i in range(rank): 

262 code.writeline(f"dim{i}_has_pad: tl.constexpr, ") 

263 

264 code.writeline("in_elem_cnt: tl.constexpr, ") 

265 code.writeline("out_elem_cnt: tl.constexpr, ") 

266 code.writeline("value, # padding value") 

267 code.writeline("IS_CONSTANT: tl.constexpr, ") 

268 code.writeline("IS_REFLECT: tl.constexpr, ") 

269 code.writeline("IS_REPLICATE: tl.constexpr, ") 

270 code.writeline("IS_CIRCULAR: tl.constexpr, ") 

271 code.writeline("BLOCK_SIZE: tl.constexpr, ") 

272 

273 code.writeline("):") 

274 

275 with code.indent(): 

276 code.writeline("pid = ext.program_id(0)") 

277 code.writeline("block_offset = pid * BLOCK_SIZE") 

278 code.writeline("offset = block_offset + tl.arange(0, BLOCK_SIZE)") 

279 code.newline() 

280 

281 code.writeline("remaining = offset ") 

282 for i in range(rank): 

283 code.writeline(f"idx = remaining // out_strides{i}") 

284 code.writeline(f"dst_index_{i} = idx") 

285 code.writeline(f"remaining = remaining - idx * out_strides{i}") 

286 code.newline() 

287 

288 code.writeline("if_pad_false_mask = tl.zeros((BLOCK_SIZE, ), dtype=tl.int32)") 

289 code.writeline("if_pad_true_mask = tl.full((BLOCK_SIZE, ), 1, dtype=tl.int32)") 

290 

291 code.writeline( 

292 "cond = ((dst_index_0 >= valid_dim0_start) & (dst_index_0 < valid_dim0_end))" 

293 ) 

294 

295 for i in range(1, rank): 

296 code.writeline( 

297 f"cond &= ((dst_index_{i} >= valid_dim{i}_start) & (dst_index_{i} < valid_dim{i}_end))" 

298 ) 

299 

300 code.writeline( 

301 "if_pad = tl.where(cond, if_pad_false_mask, if_pad_true_mask).to(tl.int1)" 

302 ) 

303 

304 for i in range(rank): 

305 code.writeline(f"src_index_{i} = dst_index_{i} - valid_dim{i}_start ") 

306 

307 for i in range(rank): 

308 code.writeline( 

309 f"src_index_{i} = tl.where(src_index_{i} < 0, 0, src_index_{i})" 

310 ) 

311 

312 code.newline() 

313 code.writeline("if IS_REFLECT: ") 

314 with code.indent(): 

315 for i in range(rank): 

316 code.writeline( 

317 f"""src_index_{i} = tl.where(dim{i}_has_pad & (dst_index_{i} < valid_dim{i}_start), 

318 valid_dim{i}_start - dst_index_{i}, src_index_{i})""" 

319 ) 

320 for i in range(rank): 

321 code.writeline( 

322 f"""src_index_{i} = tl.where(dim{i}_has_pad & (dst_index_{i} >= valid_dim{i}_end), 

323 (x_shape{i} + valid_dim{i}_start - 1) * 2 - dst_index_{i} - valid_dim{i}_start, src_index_{i})""" 

324 ) 

325 

326 code.newline() 

327 code.writeline("if IS_REPLICATE: ") 

328 with code.indent(): 

329 for i in range(rank): 

330 code.writeline( 

331 f"src_index_{i} = tl.where(dim{i}_has_pad & (dst_index_{i} < valid_dim{i}_start), 0, src_index_{i})" 

332 ) 

333 for i in range(rank): 

334 end_cond = f"dst_index_{i} >= valid_dim{i}_end" 

335 code.writeline( 

336 f"src_index_{i} = tl.where(dim{i}_has_pad & ({end_cond}), " 

337 f"x_shape{i} - 1, src_index_{i})" 

338 ) 

339 

340 code.newline() 

341 code.writeline("if IS_CIRCULAR: ") 

342 with code.indent(): 

343 for i in range(rank): 

344 code.writeline( 

345 f"""src_index_{i} = tl.where(dim{i}_has_pad & (dst_index_{i} < valid_dim{i}_start), 

346 dst_index_{i} + x_shape{i} - valid_dim{i}_start, src_index_{i})""" 

347 ) 

348 for i in range(rank): 

349 code.writeline( 

350 f"""src_index_{i} = tl.where(dim{i}_has_pad & (dst_index_{i} >= valid_dim{i}_end), 

351 dst_index_{i} - valid_dim{i}_end, src_index_{i})""" 

352 ) 

353 

354 code.newline() 

355 

356 for i in range(rank): 

357 code.writeline( 

358 f"safe_src_index_{i} = tl.where(src_index_{i} < x_shape{i}, src_index_{i}, x_shape{i} - 1)" 

359 ) 

360 

361 code.newline() 

362 

363 code.writeline("src_offset = src_index_0 * in_strides0") 

364 for i in range(1, rank): 

365 code.writeline(f"src_offset += src_index_{i} * in_strides{i}") 

366 

367 code.writeline("safe_src_offset = safe_src_index_0 * in_strides0") 

368 for i in range(1, rank): 

369 code.writeline(f"safe_src_offset += safe_src_index_{i} * in_strides{i}") 

370 

371 code.writeline("load_cond = src_index_0 < x_shape0") 

372 for i in range(1, rank): 

373 code.writeline(f"load_cond &= src_index_{i} < x_shape{i}") 

374 

375 code.writeline("if IS_CONSTANT: ") 

376 with code.indent(): 

377 code.writeline( 

378 "x_loaded = tl.load(in0_ptr + safe_src_offset, mask=offset < out_elem_cnt, other=0)" 

379 ) 

380 code.writeline("x_val = tl.where(cond, x_loaded, value)") 

381 code.writeline("else: ") 

382 with code.indent(): 

383 code.writeline( 

384 "x_val = tl.load(in0_ptr + src_offset, mask=load_cond, other=0)" 

385 ) 

386 code.writeline("tl.store(out0_ptr + offset, x_val, mask=offset < out_elem_cnt)") 

387 

388 return code 

389 

390 

391def generate_code( 

392 inputs: Tuple[Any], 

393 wrapper_name: str, 

394 destination_passing_func_name: str, 

395 kernel_name: str, 

396 code: IndentedBuffer, 

397) -> IndentedBuffer: 

398 shape = inputs[0].shape 

399 rank = len(shape) 

400 

401 # the only runtime determined factor is the rank of the task space 

402 code = generate_imports(code) 

403 code = generate_functional_padding_wrapper( 

404 wrapper_name, destination_passing_func_name, code 

405 ) 

406 code = generate_destination_passing_padding_wrapper( 

407 rank, destination_passing_func_name, kernel_name, code 

408 ) 

409 code = generate_pad_kernel(rank, kernel_name, code) 

410 return code 

411 

412 

413class PadFunction: 

414 """Utility to generate function for general pointwise operation. It generate wrapper & JITFunction 

415 which are specialized according to the rank of the task space(the broadcasted shape of all input tensors). 

416 The generated code are written out to the cache directory (defaults to ~/.flaggems). 

417 """ 

418 

419 def __init__(self): 

420 self.pid = os.getpid() 

421 self.overloads: Mapping[str, Callable] = {} 

422 

423 def __call__(self, *args, **kwargs): 

424 # note: kwargs should not be used in JITFunction directly 

425 key = f"{self.arg_key(*args)}" 

426 if key in self.overloads: 

427 overload = self.overloads[key] 

428 else: 

429 # generate file & import it 

430 code = IndentedBuffer() 

431 code = generate_code( 

432 args, 

433 "_pad_wrapper", 

434 "_pad_wrapper_out", 

435 "_pad_jit_function", 

436 code, 

437 ) 

438 

439 file_name = f"constant_pad_rank_{key}.py" 

440 file_path = code_cache_dir() / file_name 

441 write_atomic(file_path, code.getvalue()) 

442 

443 # load 

444 spec = importlib.util.spec_from_file_location( 

445 f"_gen_module_rank_{key}", 

446 file_path, 

447 ) 

448 

449 m = importlib.util.module_from_spec(spec) 

450 # do not expose it to sys.modules 

451 # sys.modules["_add_module"] = m 

452 spec.loader.exec_module(m) 

453 overload = getattr(m, "_pad_wrapper") 

454 self.overloads[key] = overload 

455 return overload(*args, **kwargs) 

456 

457 def arg_key(self, *args): 

458 tensors = [item for item in args if torch.is_tensor(item)] 

459 max_rank = max(item.ndim for item in tensors) 

460 return max_rank 

461 

462 

463_pad_func = PadFunction() 

464 

465 

466def pad(self, pad, mode="constant", value=None): 

467 logger.debug("GEMS CONSTANT PAD ND") 

468 

469 ndim = self.ndim 

470 

471 if value is None: 

472 value = 0.0 

473 

474 pad_pairs = len(pad) // 2 

475 

476 if mode == "reflect": 

477 for i in range(pad_pairs): 

478 pad_l, pad_r = pad[2 * i], pad[2 * i + 1] 

479 input_size = self.shape[ndim - 1 - i] 

480 assert ( 

481 pad_l < input_size and pad_r < input_size 

482 ), \ 

483 f"padding size should be less than the corresponding input dimension, \ 

484 but got padding size: {pad_l}, {pad_r}, input size: {self.shape}" 

485 

486 if mode == "circular": 

487 for i in range(pad_pairs): 

488 pad_l, pad_r = pad[2 * i], pad[2 * i + 1] 

489 input_size = self.shape[ndim - 1 - i] 

490 assert ( 

491 pad_l <= input_size and pad_r <= input_size 

492 ), "Padding value causes wrapping around more than once." 

493 

494 out = _pad_func(self, pad, mode, float(value)) 

495 return out 

496 

497 

498def constant_pad_nd(self, pad_list, value=0): 

499 return pad(self, pad_list, mode="constant", value=value)