Coverage for src/flag_gems/runtime/backend/_cambricon/ops/pad.py: 0%

349 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import importlib 

2import logging 

3import os 

4from typing import Any, Callable, List, Mapping, Tuple 

5 

6import torch 

7import triton 

8import triton.language as tl 

9 

10from flag_gems.utils import libentry 

11from flag_gems.utils.code_cache import code_cache_dir 

12from flag_gems.utils.code_utils import IndentedBuffer 

13 

14from ..utils import TOTAL_CORE_NUM 

15 

16logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

17 

18 

19# --------------------------- padding wrapper genration ----------------------------------- 

20def parameter_for_wrapper() -> str: 

21 """Generate parameter declaration with type annotation for wrapper function. 

22 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor 

23 """ 

24 parameters: List[str] = [] 

25 

26 parameters.append("in0") 

27 parameters.append("pad") 

28 parameters.append("mode") 

29 parameters.append("value=0") 

30 return ", ".join(parameters) 

31 

32 

33def parameter_for_wrapper_out() -> str: 

34 """Generate parameter declaration with type annotation for wrapper function. 

35 Example: in0: torch.Tensor, val0: float, out0: torch.Tensor 

36 """ 

37 parameters: List[str] = [] 

38 

39 parameters.append("in0") 

40 parameters.append("out0") 

41 parameters.append("dst_shape") 

42 parameters.append("pad_before") 

43 parameters.append("pad_after") 

44 parameters.append("mode") 

45 parameters.append("value=0") 

46 

47 return ", ".join(parameters) 

48 

49 

50def parameter_ref_for_wrapper() -> str: 

51 """Generate parameter reference for wrapper function. 

52 Example: in0, val0, out0, out0_offset 

53 """ 

54 parameters: List[str] = [] 

55 

56 parameters.append("in0") 

57 parameters.append("out0") 

58 parameters.append("dst_shape") 

59 parameters.append("pad_before") 

60 parameters.append("pad_after") 

61 parameters.append("mode") 

62 parameters.append("value") 

63 

64 return ", ".join(parameters) 

65 

66 

67def output_ref_for_wrapper() -> str: 

68 return "out0" 

69 

70 

71def generate_imports(code: IndentedBuffer) -> IndentedBuffer: 

72 code.writeline("import math") 

73 code.writeline("import torch") 

74 code.writeline("import triton") 

75 code.writeline("from triton import language as tl") 

76 code.newline() 

77 code.writeline("from flag_gems.utils.libentry import libentry") 

78 code.writeline("from flag_gems.runtime import torch_device_fn") 

79 code.writeline("from flag_gems.utils import triton_lang_extension as ext") 

80 code.writeline("from flag_gems.utils.type_utils import type_promotion") 

81 code.newline() 

82 code.newline() 

83 return code 

84 

85 

86def generate_functional_padding_wrapper( 

87 wrapper_name: str, 

88 destination_passing_func_name: str, 

89 code: IndentedBuffer, 

90) -> IndentedBuffer: 

91 # wrapper signature 

92 parameters: str = parameter_for_wrapper() 

93 wrapper_signature: str = f"def {wrapper_name}({parameters}):" 

94 code.writeline(wrapper_signature) 

95 

96 with code.indent(): 

97 code.writeline("ndim = in0.ndim") 

98 code.writeline("pad_size = len(pad)") 

99 code.writeline("assert pad_size % 2 == 0") 

100 code.newline() 

101 code.writeline("pad_before = [0 for _ in range(ndim)]") 

102 code.writeline("pad_after = [0 for _ in range(ndim)]") 

103 code.newline() 

104 code.writeline("pad_pair = pad_size // 2 ") 

105 code.writeline("for i in range(pad_pair): ") 

106 with code.indent(): 

107 code.writeline("pad_before[ndim - i - 1] = pad[2 * i]") 

108 code.writeline("pad_after[ndim - i - 1] = pad[2 * i + 1]") 

109 code.writeline("dst_shape = list(in0.shape)") 

110 code.writeline("for i in range(ndim): ") 

111 with code.indent(): 

112 code.writeline("dst_shape[i] += pad_before[i] + pad_after[i]") 

113 

114 code.writeline( 

115 ("out0 = torch.empty(dst_shape, device=in0.device, dtype=in0.dtype)") 

116 ) 

117 

118 # call destination_passing_func 

119 output_names: str = output_ref_for_wrapper() 

120 call_str = ( 

121 f"{output_names} = {destination_passing_func_name}" 

122 f"({parameter_ref_for_wrapper()})" 

123 ) 

124 code.writeline(call_str) 

125 

126 return_str = "return out0" 

127 code.writeline(return_str) 

128 code.newline() 

129 code.newline() 

130 

131 return code 

132 

133 

134def generate_destination_passing_padding_wrapper( 

135 rank: int, 

136 wrapper_name: str, 

137 kernel_name: str, 

138 code: IndentedBuffer, 

139) -> IndentedBuffer: 

140 # wrapper signature 

141 parameters: str = parameter_for_wrapper_out() 

142 

143 wrapper_signature: str = f"def {wrapper_name}({parameters}):" 

144 code.writeline(wrapper_signature) 

145 

146 with code.indent(): 

147 # docstring 

148 code.writeline("BLOCK_SIZE = 2048") 

149 code.writeline("grid = (triton.cdiv(out0.numel(), BLOCK_SIZE), 1, 1)") 

150 code.newline() 

151 

152 code.writeline("x_shape = in0.shape") 

153 code.writeline("in_strides0 = in0.stride()") 

154 code.writeline("out_strides = out0.stride()") 

155 

156 # input strides for each input tensor w.r.t. the task index space 

157 if rank > 0: 

158 code.writeline("# strides of each tensor argument w.r.t the task space") 

159 for i in range(rank): 

160 code.writeline(f"valid_dim{i}_start = pad_before[{i}]") 

161 

162 code.writeline(f"valid_dim{i}_end = dst_shape[{i}] - pad_after[{i}]") 

163 

164 code.newline() 

165 

166 code.writeline("# Check which dimensions have padding") 

167 for i in range(rank): 

168 code.writeline( 

169 f"dim{i}_has_pad = pad_before[{i}] > 0 or pad_after[{i}] > 0" 

170 ) 

171 code.writeline("IS_CONSTANT = mode == 'constant'") 

172 code.writeline("IS_REFLECT = mode == 'reflect'") 

173 code.writeline("IS_REPLICATE = mode == 'replicate'") 

174 code.writeline("IS_CIRCULAR = mode == 'circular'") 

175 

176 code.newline() 

177 

178 # grid 

179 code.writeline("# kernel launch") 

180 

181 # launch kernel 

182 code.writeline("with torch_device_fn.device(in0.device):") 

183 with code.indent(): 

184 kernel_launch: str = f"{kernel_name}[grid](" 

185 code.writeline(kernel_launch) 

186 

187 with code.indent(): 

188 code.writeline("in0, out0, ") 

189 

190 if rank > 0: 

191 s = ", ".join(f"x_shape[{j}]" for j in range(rank)) 

192 code.writeline(f"{s}, # shape for x") 

193 

194 s = ", ".join(f"in_strides0[{j}]" for j in range(rank)) 

195 code.writeline(f"{s}, # stride for x") 

196 

197 s = ", ".join(f"out_strides[{j}]" for j in range(rank)) 

198 code.writeline(f"{s}, # stride for out") 

199 

200 s = ", ".join(f"valid_dim{j}_start" for j in range(rank)) 

201 code.writeline(f"{s}, # valid dim start") 

202 

203 s = ", ".join(f"valid_dim{j}_end" for j in range(rank)) 

204 code.writeline(f"{s}, # valid dim end") 

205 

206 s = ", ".join(f"bool(dim{i}_has_pad)" for i in range(rank)) 

207 code.writeline(f"{s}, # dim has padding flags") 

208 

209 code.writeline("in0.numel(), ") 

210 code.writeline("out0.numel(), ") 

211 code.writeline("value, ") 

212 code.writeline("IS_CONSTANT, ") 

213 code.writeline("IS_REFLECT, ") 

214 code.writeline("IS_REPLICATE, ") 

215 code.writeline("IS_CIRCULAR, ") 

216 code.writeline("BLOCK_SIZE, ") 

217 code.writeline(")") 

218 

219 code.writeline("return out0") 

220 code.newline() 

221 code.newline() 

222 return code 

223 

224 

225def generate_pad_kernel( 

226 rank: int, 

227 kernel_name: str, 

228 code: IndentedBuffer, 

229) -> IndentedBuffer: 

230 # make the inlined function visible in the context 

231 code.newline() 

232 

233 # the decorators 

234 code.writeline("@libentry()") 

235 non_specialize_arg_names = ["value"] 

236 code.writeline(f"@triton.jit(do_not_specialize={non_specialize_arg_names})") 

237 

238 # signature 

239 code.writeline(f"def {kernel_name}(") 

240 with code.indent(): 

241 code.writeline("in0_ptr: tl.tensor, # of tl.pointer_type") 

242 

243 code.writeline("out0_ptr: tl.tensor, # of tl.pointer_type") 

244 

245 if rank > 0: 

246 # shape for inputs 

247 shape_args = ", ".join(f"x_shape{j}: int" for j in range(rank)) 

248 code.writeline(f"{shape_args}, # shape for x") 

249 

250 # shape for inputs 

251 stride_args = ", ".join(f"in_strides{j}: int" for j in range(rank)) 

252 code.writeline(f"{stride_args}, # stride for x") 

253 

254 # shape for inputs 

255 stride_args = ", ".join(f"out_strides{j}: int" for j in range(rank)) 

256 code.writeline(f"{stride_args}, # stride for out") 

257 

258 # shape for inputs 

259 stride_args = ", ".join(f"valid_dim{j}_start: int" for j in range(rank)) 

260 code.writeline(f"{stride_args}, # valid dim start") 

261 

262 # shape for inputs 

263 stride_args = ", ".join(f"valid_dim{j}_end: int" for j in range(rank)) 

264 code.writeline(f"{stride_args}, # valid dim end") 

265 

266 for i in range(rank): 

267 code.writeline(f"dim{i}_has_pad: tl.constexpr, ") 

268 

269 code.writeline("in_elem_cnt: tl.constexpr, ") 

270 code.writeline("out_elem_cnt: tl.constexpr, ") 

271 code.writeline("value, # padding value") 

272 code.writeline("IS_CONSTANT: tl.constexpr, ") 

273 code.writeline("IS_REFLECT: tl.constexpr, ") 

274 code.writeline("IS_REPLICATE: tl.constexpr, ") 

275 code.writeline("IS_CIRCULAR: tl.constexpr, ") 

276 code.writeline("BLOCK_SIZE: tl.constexpr, ") 

277 

278 code.writeline("):") 

279 

280 with code.indent(): 

281 code.writeline("pid = tl.program_id(0)") 

282 code.writeline("block_offset = pid * BLOCK_SIZE") 

283 code.writeline("offset = block_offset + tl.arange(0, BLOCK_SIZE)") 

284 code.newline() 

285 

286 code.writeline("remaining = offset ") 

287 for i in range(rank): 

288 code.writeline(f"idx = remaining // out_strides{i}") 

289 code.writeline(f"dst_index_{i} = idx") 

290 code.writeline(f"remaining = remaining - idx * out_strides{i}") 

291 code.newline() 

292 

293 code.writeline("if_pad_false_mask = tl.zeros((BLOCK_SIZE, ), dtype=tl.int32)") 

294 code.writeline("if_pad_true_mask = tl.full((BLOCK_SIZE, ), 1, dtype=tl.int32)") 

295 

296 code.writeline( 

297 "cond = ((dst_index_0 >= valid_dim0_start) & (dst_index_0 < valid_dim0_end))" 

298 ) 

299 

300 for i in range(1, rank): 

301 code.writeline( 

302 f"cond &= ((dst_index_{i} >= valid_dim{i}_start) & (dst_index_{i} < valid_dim{i}_end))" 

303 ) 

304 

305 code.writeline( 

306 "if_pad = tl.where(cond, if_pad_false_mask, if_pad_true_mask).to(tl.int1)" 

307 ) 

308 

309 for i in range(rank): 

310 code.writeline(f"src_index_{i} = dst_index_{i} - valid_dim{i}_start ") 

311 

312 for i in range(rank): 

313 code.writeline( 

314 f"src_index_{i} = tl.where(src_index_{i} < 0, 0, src_index_{i})" 

315 ) 

316 

317 code.newline() 

318 code.writeline("if IS_REFLECT: ") 

319 with code.indent(): 

320 for i in range(rank): 

321 code.writeline( 

322 f"""src_index_{i} = tl.where(dim{i}_has_pad & (dst_index_{i} < valid_dim{i}_start), 

323 valid_dim{i}_start - dst_index_{i}, src_index_{i})""" 

324 ) 

325 for i in range(rank): 

326 code.writeline( 

327 f"""src_index_{i} = tl.where(dim{i}_has_pad & (dst_index_{i} >= valid_dim{i}_end), 

328 (x_shape{i} + valid_dim{i}_start - 1) * 2 - dst_index_{i} - valid_dim{i}_start, src_index_{i})""" 

329 ) 

330 

331 code.newline() 

332 code.writeline("if IS_REPLICATE: ") 

333 with code.indent(): 

334 for i in range(rank): 

335 code.writeline( 

336 f"src_index_{i} = tl.where(dim{i}_has_pad & (dst_index_{i} < valid_dim{i}_start), 0, src_index_{i})" 

337 ) 

338 for i in range(rank): 

339 end_cond = f"dst_index_{i} >= valid_dim{i}_end" 

340 code.writeline( 

341 f"src_index_{i} = tl.where(dim{i}_has_pad & ({end_cond}), " 

342 f"x_shape{i} - 1, src_index_{i})" 

343 ) 

344 

345 code.newline() 

346 code.writeline("if IS_CIRCULAR: ") 

347 with code.indent(): 

348 for i in range(rank): 

349 code.writeline( 

350 f"""src_index_{i} = tl.where(dim{i}_has_pad & (dst_index_{i} < valid_dim{i}_start), 

351 dst_index_{i} + x_shape{i} - valid_dim{i}_start, src_index_{i})""" 

352 ) 

353 for i in range(rank): 

354 code.writeline( 

355 f"""src_index_{i} = tl.where(dim{i}_has_pad & (dst_index_{i} >= valid_dim{i}_end), 

356 dst_index_{i} - valid_dim{i}_end, src_index_{i})""" 

357 ) 

358 

359 code.newline() 

360 

361 code.writeline("src_offset = src_index_0 * in_strides0") 

362 for i in range(1, rank): 

363 code.writeline(f"src_offset += src_index_{i} * in_strides{i}") 

364 

365 code.writeline("load_cond = src_index_0 < x_shape0") 

366 for i in range(1, rank): 

367 code.writeline(f"load_cond &= src_index_{i} < x_shape{i}") 

368 

369 code.writeline("if IS_CONSTANT: ") 

370 with code.indent(): 

371 code.writeline( 

372 "x_val = tl.load(in0_ptr + src_offset, mask=((if_pad == 0) & load_cond), other=value)" 

373 ) 

374 code.writeline("else: ") 

375 with code.indent(): 

376 code.writeline( 

377 "x_val = tl.load(in0_ptr + src_offset, mask=load_cond, other=0)" 

378 ) 

379 code.writeline("tl.store(out0_ptr + offset, x_val, mask=offset < out_elem_cnt)") 

380 

381 return code 

382 

383 

384def generate_code( 

385 inputs: Tuple[Any], 

386 wrapper_name: str, 

387 destination_passing_func_name: str, 

388 kernel_name: str, 

389 code: IndentedBuffer, 

390) -> IndentedBuffer: 

391 shape = inputs[0].shape 

392 rank = len(shape) 

393 

394 # the only runtime determined factor is the rank of the task space 

395 code = generate_imports(code) 

396 code = generate_functional_padding_wrapper( 

397 wrapper_name, destination_passing_func_name, code 

398 ) 

399 code = generate_destination_passing_padding_wrapper( 

400 rank, destination_passing_func_name, kernel_name, code 

401 ) 

402 code = generate_pad_kernel(rank, kernel_name, code) 

403 return code 

404 

405 

406class PadFunction: 

407 """Utility to generate function for general pointwise operation. It generate wrapper & JITFunction 

408 which are specialized according to the rank of the task space(the broadcasted shape of all input tensors). 

409 The generated code are written out to the cache directory (defaults to ~/.flaggems). 

410 """ 

411 

412 def __init__(self): 

413 self.pid = os.getpid() 

414 self.overloads: Mapping[str, Callable] = {} 

415 

416 def __call__(self, *args, **kwargs): 

417 # note: kwargs should not be used in JITFunction directly 

418 key = f"{self.arg_key(*args)}" 

419 if key in self.overloads: 

420 overload = self.overloads[key] 

421 else: 

422 # generate file & import it 

423 code = IndentedBuffer() 

424 code = generate_code( 

425 args, 

426 "_pad_wrapper", 

427 "_pad_wrapper_out", 

428 "_pad_jit_function", 

429 code, 

430 ) 

431 

432 file_name = f"constant_pad_rank_{key}_pid_{self.pid}.py" 

433 

434 with open(code_cache_dir() / file_name, "wt", encoding="utf-8") as f: 

435 f.write(code.getvalue()) 

436 

437 # load 

438 spec = importlib.util.spec_from_file_location( 

439 f"_gen_module_rank_{key}_pid_{self.pid}", 

440 f.name, 

441 ) 

442 

443 m = importlib.util.module_from_spec(spec) 

444 # do not expose it to sys.modules 

445 # sys.modules["_add_module"] = m 

446 spec.loader.exec_module(m) 

447 overload = getattr(m, "_pad_wrapper") 

448 self.overloads[key] = overload 

449 return overload(*args, **kwargs) 

450 

451 def arg_key(self, *args): 

452 tensors = [item for item in args if torch.is_tensor(item)] 

453 max_rank = max(item.ndim for item in tensors) 

454 return max_rank 

455 

456 

457_pad_func = PadFunction() 

458 

459 

460@libentry() 

461@triton.autotune( 

462 configs=[ 

463 triton.Config({"BLOCK_SIZE": 2**n}, num_stages=s) 

464 for n in range(10, 16, 2) 

465 for s in [1, 3] 

466 ], 

467 key=["inp_elements"], 

468) 

469@triton.jit 

470def pad_1d_constant_kernel( 

471 inp_ptr, 

472 out_ptr, 

473 inp_elements, 

474 pad_value, 

475 pad_left, 

476 pad_right, 

477 BLOCK_SIZE: tl.constexpr, 

478): 

479 pid = tl.program_id(0) 

480 num_jobs = tl.num_programs(0) 

481 start = pid * BLOCK_SIZE 

482 step = num_jobs * BLOCK_SIZE 

483 out_elements = pad_left + inp_elements + pad_right 

484 for off in range(start, out_elements, step): 

485 inp_offset = off + tl.arange(0, BLOCK_SIZE) - pad_left 

486 inp_mask = inp_offset >= 0 and inp_offset < inp_elements 

487 inp = tl.load(inp_ptr + inp_offset, mask=inp_mask, other=pad_value) 

488 out_offset = off + tl.arange(0, BLOCK_SIZE) 

489 out_mask = out_offset < out_elements 

490 tl.store(out_ptr + out_offset, inp, mask=out_mask) 

491 

492 

493@libentry() 

494@triton.autotune( 

495 configs=[ 

496 triton.Config({"BLOCK_H": n}, num_stages=s) 

497 for n in [1, 4, 8, 12, 16, 24] 

498 for s in [1, 3] 

499 ], 

500 key=["H", "W"], 

501) 

502@triton.jit 

503def pad_2d_constant_kernel( 

504 inp_ptr, 

505 out_ptr, 

506 H, 

507 W: tl.constexpr, 

508 pad_value, 

509 pad_left: tl.constexpr, 

510 pad_right: tl.constexpr, 

511 pad_top, 

512 pad_bottom, 

513 BLOCK_H: tl.constexpr, 

514): 

515 pid = tl.program_id(0) 

516 num_jobs = tl.num_programs(0) 

517 block_start = pid * BLOCK_H 

518 step = num_jobs * BLOCK_H 

519 out_W: tl.constexpr = pad_left + W + pad_right 

520 out_H = pad_top + H + pad_bottom 

521 for batch_idx in range(block_start, out_H, step): 

522 offset_h = tl.arange(0, BLOCK_H) + batch_idx - pad_top 

523 offset_w = tl.arange(0, out_W) - pad_left 

524 offsets = offset_h[:, None] * W + offset_w[None, :] 

525 mask = (offset_h[:, None] >= 0 and offset_h[:, None] < H) and ( 

526 offset_w[None, :] >= 0 and offset_w[None, :] < W 

527 ) 

528 inp = tl.load(inp_ptr + offsets, mask=mask, other=pad_value) 

529 

530 out_offset_c = tl.arange(0, out_W) 

531 out_offset_n = tl.arange(0, BLOCK_H) + batch_idx 

532 out_offsets = out_offset_n[:, None] * out_W + out_offset_c[None, :] 

533 out_mask = out_offset_n[:, None] < out_H and out_offset_c[None, :] < out_W 

534 tl.store(out_ptr + out_offsets, inp, mask=out_mask) 

535 

536 

537def pad(self, pad, mode="constant", value=None): 

538 logger.debug("GEMS_CAMBRICON CONSTANT PAD ND") 

539 

540 ndim = self.ndim 

541 pad_size = len(pad) 

542 assert pad_size % 2 == 0 

543 

544 if value is None: 

545 value = 0.0 

546 

547 if mode == "constant": 

548 pad_before = [0 for _ in range(ndim)] 

549 pad_after = [0 for _ in range(ndim)] 

550 pad_pair = pad_size // 2 

551 for i in range(pad_pair): 

552 pad_before[ndim - i - 1] = pad[2 * i] 

553 pad_after[ndim - i - 1] = pad[2 * i + 1] 

554 

555 inp_shape = list(self.shape) 

556 out_shape = list(self.shape) 

557 for i in range(ndim): 

558 out_shape[i] += pad_before[i] + pad_after[i] 

559 out = torch.empty(out_shape, dtype=self.dtype, device=self.device) 

560 

561 if ndim == 1: 

562 grid = lambda meta: ( 

563 min(triton.cdiv(out_shape[0], meta["BLOCK_SIZE"]), TOTAL_CORE_NUM), 

564 ) 

565 pad_1d_constant_kernel[grid]( 

566 self.contiguous(), 

567 out, 

568 inp_shape[0], 

569 value, 

570 pad_before[-1], 

571 pad_after[-1], 

572 ) 

573 return out 

574 

575 if ndim == 2: 

576 grid = lambda meta: ( 

577 min(triton.cdiv(out_shape[0], meta["BLOCK_H"]), TOTAL_CORE_NUM), 

578 ) 

579 pad_2d_constant_kernel[grid]( 

580 self.contiguous(), 

581 out, 

582 inp_shape[0], 

583 inp_shape[1], 

584 value, 

585 pad_before[-1], 

586 pad_after[-1], 

587 pad_before[-2], 

588 pad_after[-2], 

589 ) 

590 return out 

591 

592 if ndim == 3: 

593 out[: pad_before[0]] = torch.full( 

594 out[0 : pad_before[0]].shape, 

595 value, 

596 dtype=self.dtype, 

597 device=self.device, 

598 ) 

599 out[pad_before[0] + inp_shape[0] :] = torch.full( 

600 out[pad_before[0] + inp_shape[0] :].shape, 

601 value, 

602 dtype=self.dtype, 

603 device=self.device, 

604 ) 

605 

606 for i in range(pad_before[0], pad_before[0] + inp_shape[0]): 

607 grid = lambda meta: ( 

608 min(triton.cdiv(out_shape[1], meta["BLOCK_H"]), TOTAL_CORE_NUM), 

609 ) 

610 pad_2d_constant_kernel[grid]( 

611 self[i - pad_before[0]].contiguous(), 

612 out[i], 

613 inp_shape[1], 

614 inp_shape[2], 

615 value, 

616 pad_before[-1], 

617 pad_after[-1], 

618 pad_before[-2], 

619 pad_after[-2], 

620 ) 

621 return out 

622 

623 pad_pairs = len(pad) // 2 

624 

625 if mode == "reflect": 

626 for i in range(pad_pairs): 

627 pad_l, pad_r = pad[2 * i], pad[2 * i + 1] 

628 input_size = self.shape[ndim - 1 - i] 

629 assert ( 

630 pad_l < input_size and pad_r < input_size 

631 ), \ 

632 f"padding size should be less than the corresponding input dimension, \ 

633 but got padding size: {pad_l}, {pad_r}, input size: {self.shape}" 

634 

635 if mode == "circular": 

636 for i in range(pad_pairs): 

637 pad_l, pad_r = pad[2 * i], pad[2 * i + 1] 

638 input_size = self.shape[ndim - 1 - i] 

639 assert ( 

640 pad_l <= input_size and pad_r <= input_size 

641 ), "Padding value causes wrapping around more than once." 

642 

643 out = _pad_func(self, pad, mode, float(value)) 

644 return out 

645 

646 

647def constant_pad_nd(self, pad_list, value=0): 

648 return pad(self, pad_list, mode="constant", value=value)