Coverage for src/flag_gems/runtime/backend/_sunrise/__init__.py: 0%

400 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import os 

2 

3import torch_ptpu # noqa: F401 

4from backend_utils import VendorInfoBase 

5 

6from .monkey_patch import apply_sunrise_monkey_patches 

7 

8vendor_info = VendorInfoBase( 

9 vendor_name="sunrise", 

10 device_name="ptpu", 

11 device_query_cmd="pt_smi", 

12 triton_extra_name="tang", 

13 dispatch_key="PrivateUse1", 

14) 

15 

16CUSTOMIZED_UNUSED_OPS = () 

17 

18 

19def _sunrise_rebuild_ptpu_tensor_from_cpu( 

20 tensor_cls, cpu_tensor, device, requires_grad 

21): 

22 import torch 

23 from torch.nn.parameter import Parameter 

24 

25 tensor = cpu_tensor.to(device=device) 

26 if tensor_cls == Parameter: 

27 return Parameter(tensor, requires_grad=requires_grad) 

28 if tensor_cls not in (torch.Tensor, type(tensor)): 

29 try: 

30 tensor = tensor.as_subclass(tensor_cls) 

31 except Exception: 

32 pass 

33 tensor.requires_grad = requires_grad 

34 return tensor 

35 

36 

37def _should_stage_ptpu_tensor_for_multiprocessing(tensor): 

38 import torch 

39 

40 return ( 

41 isinstance(tensor, torch.Tensor) 

42 and tensor.device.type == "ptpu" 

43 and tensor.layout == torch.strided 

44 and not tensor.is_nested 

45 ) 

46 

47 

48def _sunrise_monkey_patch_enabled(): 

49 value = os.getenv("FLAG_GEMS_SUNRISE_ENABLE_MONKEY_PATCH", "1").strip().lower() 

50 return value not in {"0", "false", "no", "off"} 

51 

52 

53if _sunrise_monkey_patch_enabled(): 

54 apply_sunrise_monkey_patches() 

55 

56# [Sunrise fix] Aten lower needed. 

57CUSTOMIZED_AUTOGRAD_OPS = ( 

58 "absolute", 

59 "arcsinh_", 

60 "arcsinh", 

61 "arcsinh.out", 

62 "arctanh_", 

63 "clip", 

64 "clip_", 

65 "concatenate", 

66 "conj_physical", 

67 "diag", 

68 "diff", 

69 "__ior__.Tensor", 

70 "__ior__.Scalar", 

71 "__or__.Tensor", 

72 "__or__.Scalar", 

73 "embedding_backward", 

74 "feature_dropout", 

75 "feature_dropout_", 

76 "gather_backward", 

77 "greater.Tensor", 

78 "greater.Scalar", 

79 "greater.Scalar_out", 

80 "hstack", 

81 "isclose", 

82 "isfinite", 

83 "kron", 

84 "log_sigmoid", 

85 "margin_ranking_loss", 

86 "nonzero_numpy", 

87 "pad", 

88 "prelu", 

89 "quantile", 

90 "relu6", 

91 "repeat_interleave.self_int", 

92 "repeat_interleave.self_Tensor", 

93 "resolve_conj", 

94 "resolve_neg", 

95 "selu", 

96 "selu_", 

97 "square", 

98 "square_", 

99 "square.out", 

100 "svd", 

101 "tile", 

102 "vstack", 

103) 

104 

105 

106def _sunrise_extra_config_entries(): # 有些公共库也没有注册的op,只能先放在这里了。使得tests能过 

107 from .ops import ( 

108 amax_out, 

109 amin, 

110 amin_out, 

111 aminmax_out, 

112 clamp_min, 

113 clamp_min_, 

114 clamp_min_out, 

115 hypot_out, 

116 ) 

117 

118 return ( 

119 ("amax.out", amax_out), 

120 ("amin", amin), 

121 ("amin.out", amin_out), 

122 ("aminmax.out", aminmax_out), 

123 ("clamp_min.Tensor", clamp_min), 

124 ("clamp_min.Tensor_out", clamp_min_out), 

125 ("clamp_min_.Tensor", clamp_min_), 

126 ("hypot.out", hypot_out), 

127 ) 

128 

129 

130def _install_autograd_dispatch_patch(): 

131 import torch 

132 

133 from flag_gems.runtime.op_registrar import GeneralOpRegistrar 

134 

135 register_cls = GeneralOpRegistrar 

136 

137 if getattr(register_cls, "_sunrise_autograd_dispatch_patched", False): 

138 return 

139 

140 original_register_impl = register_cls.register_impl 

141 autograd_key = torch._C.DispatchKey.Autograd.name 

142 autograd_ops = frozenset(CUSTOMIZED_AUTOGRAD_OPS) 

143 

144 def register_impl(self, key, fn, extra_dispatch_keys=()): 

145 if self.device.vendor_name == vendor_info.vendor_name and key in autograd_ops: 

146 all_dispatch_keys = list(extra_dispatch_keys) 

147 if autograd_key not in all_dispatch_keys: 

148 all_dispatch_keys.append(autograd_key) 

149 extra_dispatch_keys = tuple(all_dispatch_keys) 

150 

151 return original_register_impl(self, key, fn, extra_dispatch_keys) 

152 

153 register_cls.register_impl = register_impl 

154 register_cls._sunrise_autograd_dispatch_patched = True 

155 register_cls._sunrise_original_register_impl = original_register_impl 

156 

157 

158def _install_register_config_patch(): 

159 from flag_gems.runtime.op_registrar import GeneralOpRegistrar 

160 

161 register_cls = GeneralOpRegistrar 

162 

163 if getattr(register_cls, "_sunrise_config_patched", False): 

164 return 

165 

166 original_init = register_cls.__init__ 

167 

168 def _extend_config(config, full_config_by_func): 

169 extra_entries = _sunrise_extra_config_entries() 

170 existing_keys = {item[0] for item in config} 

171 merged_config = tuple(config) + tuple( 

172 item for item in extra_entries if item[0] not in existing_keys 

173 ) 

174 

175 if full_config_by_func is None: 

176 return merged_config, None 

177 

178 merged_map = {key: list(value) for key, value in full_config_by_func.items()} 

179 for item in extra_entries: 

180 fn = item[1] 

181 func_name = fn.__name__ if hasattr(fn, "__name__") else str(fn) 

182 merged_map.setdefault(func_name, []) 

183 if item not in merged_map[func_name]: 

184 merged_map[func_name].append(item) 

185 return merged_config, merged_map 

186 

187 def __init__( 

188 self, 

189 config, 

190 user_include_ops=None, 

191 user_exclude_ops=None, 

192 cpp_patched_ops=None, 

193 lib=None, 

194 full_config_by_func=None, 

195 ): 

196 config, full_config_by_func = _extend_config(config, full_config_by_func) 

197 return original_init( 

198 self, 

199 config, 

200 user_include_ops=user_include_ops, 

201 user_exclude_ops=user_exclude_ops, 

202 cpp_patched_ops=cpp_patched_ops, 

203 lib=lib, 

204 full_config_by_func=full_config_by_func, 

205 ) 

206 

207 register_cls.__init__ = __init__ 

208 register_cls._sunrise_config_patched = True 

209 register_cls._sunrise_original_init = original_init 

210 

211 

212def _install_typed_ptr_device_patch(): 

213 from flag_gems.utils.tensor_wrapper import TypedPtr 

214 

215 if getattr(TypedPtr, "_sunrise_device_patched", False): 

216 return 

217 

218 def __init__(self, ptr, dtype, device=None): 

219 self.ptr = ptr 

220 self.dtype = dtype 

221 self.device = device 

222 

223 @classmethod 

224 def from_tensor(cls, tensor, offset=0): 

225 return cls( 

226 tensor.data_ptr() + tensor.element_size() * offset, 

227 tensor.dtype, 

228 tensor.device, 

229 ) 

230 

231 @classmethod 

232 def reinterpret_tensor(cls, tensor, dtype, offset=0): 

233 return cls(tensor.data_ptr() + dtype.itemsize * offset, dtype, tensor.device) 

234 

235 TypedPtr.__init__ = __init__ 

236 TypedPtr.from_tensor = from_tensor 

237 TypedPtr.reinterpret_tensor = reinterpret_tensor 

238 TypedPtr._sunrise_device_patched = True 

239 

240 

241def _install_pointwise_dynamic_complex_patch(): 

242 import torch 

243 

244 from flag_gems.utils.pointwise_dynamic import ComplexMode, PointwiseDynamicFunction 

245 from flag_gems.utils.shape_utils import all_the_same_shape, all_the_same_stride 

246 from flag_gems.utils.tensor_wrapper import StridedBuffer 

247 

248 if getattr(PointwiseDynamicFunction, "_sunrise_complex_patched", False): 

249 return 

250 

251 def _tensor_is_contiguous(tensor): 

252 if isinstance(tensor, torch.Tensor): 

253 return tensor.is_contiguous() 

254 expected_stride = 1 

255 for size, stride in zip(reversed(tensor.shape), reversed(tensor.stride())): 

256 if size == 1: 

257 continue 

258 if stride != expected_stride: 

259 return False 

260 expected_stride *= size 

261 return True 

262 

263 if not hasattr(StridedBuffer, "is_contiguous"): 

264 StridedBuffer.is_contiguous = _tensor_is_contiguous 

265 

266 def _call_real_impl(self, *args, _skip_tensor_check=False, **kwargs): 

267 from flag_gems import runtime 

268 

269 if not runtime.device.support_fp64: 

270 ptpu_tensor = next( 

271 ( 

272 arg 

273 for arg in args 

274 if isinstance(arg, torch.Tensor) 

275 and arg.device.type == "ptpu" 

276 and arg.dtype == torch.float64 

277 ), 

278 None, 

279 ) 

280 if ptpu_tensor is None: 

281 ptpu_tensor = next( 

282 ( 

283 value 

284 for value in kwargs.values() 

285 if isinstance(value, torch.Tensor) 

286 and value.device.type == "ptpu" 

287 and value.dtype == torch.float64 

288 ), 

289 None, 

290 ) 

291 if ptpu_tensor is not None: 

292 cpu_args = tuple( 

293 arg.cpu() if isinstance(arg, torch.Tensor) else arg for arg in args 

294 ) 

295 cpu_kwargs = { 

296 key: (value.cpu() if isinstance(value, torch.Tensor) else value) 

297 for key, value in kwargs.items() 

298 if not key.startswith("out") 

299 } 

300 py_fn = getattr(self._scalar_fn, "fn", self._scalar_fn) 

301 result = py_fn(*cpu_args, **cpu_kwargs) 

302 out = kwargs.get("out0") 

303 if out is not None: 

304 out.copy_(result.to(out.device)) 

305 return out 

306 if isinstance(result, tuple): 

307 return tuple( 

308 item.to(ptpu_tensor.device) 

309 if isinstance(item, torch.Tensor) 

310 else item 

311 for item in result 

312 ) 

313 if isinstance(result, torch.Tensor): 

314 return result.to(ptpu_tensor.device) 

315 return result 

316 

317 ndim, args, kwargs = self.prepare_args( 

318 *args, _skip_tensor_check=_skip_tensor_check, **kwargs 

319 ) 

320 overload = self.instantiate(ndim) 

321 out = overload(*args, **kwargs) 

322 return self._unwrap(out) 

323 

324 def _is_missing_backend_view_op(self, exc, aten_op): 

325 message = str(exc) 

326 return ( 

327 isinstance(exc, NotImplementedError) 

328 and aten_op in message 

329 and "ptpu" in message 

330 ) 

331 

332 def _complex_real_view_buffer(self, tensor): 

333 real_dtype = tensor.dtype.to_real() 

334 shape = tuple(tensor.shape) + (2,) 

335 strides = tuple(stride * 2 for stride in tensor.stride()) + (1,) 

336 return StridedBuffer(tensor, shape=shape, strides=strides, dtype=real_dtype) 

337 

338 def _complex_component_buffers(self, tensor): 

339 real_dtype = tensor.dtype.to_real() 

340 strides = tuple(stride * 2 for stride in tensor.stride()) 

341 real = StridedBuffer( 

342 tensor, shape=tensor.shape, strides=strides, dtype=real_dtype 

343 ) 

344 imag = StridedBuffer( 

345 tensor, 

346 shape=tensor.shape, 

347 strides=strides, 

348 dtype=real_dtype, 

349 offset=1, 

350 ) 

351 return real, imag 

352 

353 def _view_as_real_for_kernel(self, tensor): 

354 try: 

355 return torch.view_as_real(tensor) 

356 except Exception as exc: 

357 if self._is_missing_backend_view_op(exc, "aten::view_as_real"): 

358 return self._complex_real_view_buffer(tensor) 

359 raise 

360 

361 def _split_complex_components(self, tensor): 

362 try: 

363 real_view = torch.view_as_real(tensor) 

364 return real_view[..., 0], real_view[..., 1] 

365 except Exception as exc: 

366 if self._is_missing_backend_view_op(exc, "aten::view_as_real"): 

367 return self._complex_component_buffers(tensor) 

368 raise 

369 

370 def _view_as_complex_result(self, tensor): 

371 try: 

372 return torch.view_as_complex(tensor.contiguous()) 

373 except Exception as exc: 

374 if self._is_missing_backend_view_op(exc, "aten::view_as_complex"): 

375 return torch.view_as_complex(tensor.cpu().contiguous()).to( 

376 tensor.device 

377 ) 

378 raise 

379 

380 def _cross_components_for_kernel(self, tensor): 

381 try: 

382 real_view = torch.view_as_real(tensor) 

383 except Exception as exc: 

384 if not self._is_missing_backend_view_op(exc, "aten::view_as_real"): 

385 raise 

386 real_view = torch.view_as_real(tensor.cpu()).to(tensor.device) 

387 return real_view[..., 0], real_view[..., 1] 

388 

389 def _cpu_fallback_value(self, value): 

390 if isinstance(value, torch.Tensor): 

391 return value.cpu() 

392 return value 

393 

394 def _should_cpu_fallback_complex(self, result_dtype, device): 

395 if device is None or device.type == "cpu": 

396 return False 

397 if result_dtype != torch.complex128: 

398 return False 

399 from flag_gems import runtime 

400 

401 return not runtime.device.support_fp64 

402 

403 def _cpu_fallback_complex_dispatch(self, args, kwargs, device): 

404 cpu_args = tuple(self._cpu_fallback_value(arg) for arg in args) 

405 out = kwargs.get("out0") 

406 py_fn = getattr(self._scalar_fn, "fn", self._scalar_fn) 

407 result = py_fn(*cpu_args) 

408 if out is not None: 

409 out.copy_(result.to(out.device)) 

410 return out 

411 if isinstance(result, torch.Tensor): 

412 return result.to(device) 

413 if isinstance(result, tuple): 

414 return tuple( 

415 item.to(device) if isinstance(item, torch.Tensor) else item 

416 for item in result 

417 ) 

418 return result 

419 

420 def _call_complex_dispatch(self, *args, **kwargs): 

421 strategy = self.complex_strategy 

422 operands, others = self._split_args(args) 

423 

424 device = self._infer_device(operands) 

425 result_dtype = self._infer_complex_dtype(operands) 

426 

427 if self._should_cpu_fallback_complex(result_dtype, device): 

428 return self._cpu_fallback_complex_dispatch(args, kwargs, device) 

429 

430 if strategy.tensorize_scalars and strategy.fallback_target is not None: 

431 operands = self._tensorize_scalar_operands(operands, result_dtype, device) 

432 new_args = self._merge_args(operands, others) 

433 return strategy.fallback_target(*new_args, **kwargs) 

434 

435 for i in list(operands.keys()): 

436 operands[i] = self._to_complex_tensor(operands[i], result_dtype, device) 

437 

438 complex_tensors = [operands[i] for i in sorted(operands.keys())] 

439 complex_tensors = torch.broadcast_tensors(*complex_tensors) 

440 for idx, key in enumerate(sorted(operands.keys())): 

441 operands[key] = complex_tensors[idx] 

442 

443 classification = self._classify_complex_inputs(operands) 

444 

445 if strategy.mode == ComplexMode.CROSS and classification == "all_complex": 

446 return self._call_complex_cross(operands, result_dtype) 

447 if classification in ("all_complex", "mixed"): 

448 return self._call_complex_elementwise( 

449 operands, others, result_dtype, kwargs 

450 ) 

451 new_args = self._merge_args(operands, others) 

452 return self._call_real_impl(*new_args, **kwargs) 

453 

454 def _call_complex_elementwise(self, operands, others, result_dtype, kwargs): 

455 real_tensors = { 

456 i: self._view_as_real_for_kernel(t) for i, t in operands.items() 

457 } 

458 out_kwargs = dict(kwargs) 

459 out_complex = out_kwargs.get("out0") 

460 if out_complex is None: 

461 first_operand = operands[sorted(operands.keys())[0]] 

462 out_complex = torch.empty( 

463 first_operand.shape, 

464 dtype=result_dtype, 

465 device=first_operand.device, 

466 ) 

467 out_kwargs["out0"] = out_complex 

468 

469 out_kwargs["out0"] = self._view_as_real_for_kernel(out_complex) 

470 new_args = self._merge_args(real_tensors, others) 

471 self._call_real_impl(*new_args, _skip_tensor_check=True, **out_kwargs) 

472 return out_complex 

473 

474 def _call_complex_cross(self, operands, result_dtype): 

475 sorted_keys = sorted(operands.keys()) 

476 a_tensor, b_tensor = operands[sorted_keys[0]], operands[sorted_keys[1]] 

477 ar, ai = self._cross_components_for_kernel(a_tensor) 

478 br, bi = self._cross_components_for_kernel(b_tensor) 

479 

480 common_dtype = torch.promote_types(ar.dtype, br.dtype) 

481 if ar.dtype != common_dtype: 

482 ar, ai = ar.to(common_dtype), ai.to(common_dtype) 

483 if br.dtype != common_dtype: 

484 br, bi = br.to(common_dtype), bi.to(common_dtype) 

485 

486 cross_kernel = self.complex_strategy.cross_kernel 

487 real, imag = cross_kernel._call_real_impl( 

488 ar, ai, br, bi, _skip_tensor_check=True 

489 ) 

490 out = torch.stack((real, imag), dim=-1) 

491 return self._view_as_complex_result(out).to(result_dtype) 

492 

493 def use_fast_path(tensors): 

494 if not all_the_same_shape(tensors): 

495 return False 

496 if all(_tensor_is_contiguous(tensor) for tensor in tensors): 

497 return True 

498 return ( 

499 all(isinstance(tensor, torch.Tensor) for tensor in tensors) 

500 and all_the_same_stride(tensors) 

501 and torch.ops.aten.is_non_overlapping_and_dense(tensors[0]) 

502 ) 

503 

504 PointwiseDynamicFunction._call_real_impl = _call_real_impl 

505 PointwiseDynamicFunction._is_missing_backend_view_op = _is_missing_backend_view_op 

506 PointwiseDynamicFunction._complex_real_view_buffer = _complex_real_view_buffer 

507 PointwiseDynamicFunction._complex_component_buffers = _complex_component_buffers 

508 PointwiseDynamicFunction._view_as_real_for_kernel = _view_as_real_for_kernel 

509 PointwiseDynamicFunction._split_complex_components = _split_complex_components 

510 PointwiseDynamicFunction._view_as_complex_result = _view_as_complex_result 

511 PointwiseDynamicFunction._cross_components_for_kernel = _cross_components_for_kernel 

512 PointwiseDynamicFunction._cpu_fallback_value = _cpu_fallback_value 

513 PointwiseDynamicFunction._should_cpu_fallback_complex = _should_cpu_fallback_complex 

514 PointwiseDynamicFunction._cpu_fallback_complex_dispatch = ( 

515 _cpu_fallback_complex_dispatch 

516 ) 

517 PointwiseDynamicFunction._call_complex_dispatch = _call_complex_dispatch 

518 PointwiseDynamicFunction._call_complex_elementwise = _call_complex_elementwise 

519 PointwiseDynamicFunction._call_complex_cross = _call_complex_cross 

520 PointwiseDynamicFunction.use_fast_path = staticmethod(use_fast_path) 

521 PointwiseDynamicFunction._sunrise_complex_patched = True 

522 

523 

524def _install_pointwise_dynamic_post_import_hook(): 

525 import builtins 

526 import sys 

527 

528 if getattr(builtins, "_sunrise_pointwise_import_hook_installed", False): 

529 return 

530 

531 original_import = builtins.__import__ 

532 

533 def maybe_patch(): 

534 module = sys.modules.get("flag_gems.utils.pointwise_dynamic") 

535 if module is None or getattr(module, "_sunrise_complex_patch_attempted", False): 

536 return 

537 module._sunrise_complex_patch_attempted = True 

538 builtins.__import__ = original_import 

539 builtins._sunrise_pointwise_import_hook_installed = False 

540 _install_pointwise_dynamic_complex_patch() 

541 

542 def import_with_sunrise_pointwise_patch( 

543 name, globals=None, locals=None, fromlist=(), level=0 

544 ): 

545 module = original_import(name, globals, locals, fromlist, level) 

546 if name == "flag_gems.utils.pointwise_dynamic" or ( 

547 name == "flag_gems.utils" and "pointwise_dynamic" in fromlist 

548 ): 

549 maybe_patch() 

550 return module 

551 

552 builtins.__import__ = import_with_sunrise_pointwise_patch 

553 builtins._sunrise_pointwise_import_hook_installed = True 

554 

555 

556def _install_ptpu_manual_seed_patch(): 

557 import torch 

558 

559 ptpu_mod = getattr(torch, "ptpu", None) 

560 if ptpu_mod is None or getattr(ptpu_mod, "_sunrise_manual_seed_patched", False): 

561 return 

562 

563 def _is_in_bad_fork(): 

564 return False 

565 

566 def manual_seed_all(seed): 

567 seed = int(seed) & 0xFFFFFFFFFFFFFFFF 

568 

569 # PTPU exposes RNG state get/set but not a Python seed API. The runtime 

570 # state is a 16-byte blob where the low 8 bytes act as the seed and the 

571 # high 8 bytes can be reset to zero for a fresh sequence start. 

572 # `torch.manual_seed()` can be called under `with torch.device("ptpu")`, 

573 # so build the state explicitly on CPU for `torch.ptpu.set_rng_state()`. 

574 state = torch.zeros(16, dtype=torch.uint8, device="cpu") 

575 for i in range(8): 

576 state[i] = (seed >> (8 * i)) & 0xFF 

577 

578 for device_idx in range(ptpu_mod.device_count()): 

579 ptpu_mod.set_rng_state(state, device_idx) 

580 

581 ptpu_mod._is_in_bad_fork = _is_in_bad_fork 

582 ptpu_mod.manual_seed_all = manual_seed_all 

583 ptpu_mod._sunrise_manual_seed_patched = True 

584 

585 

586def _install_ptpu_default_generators_patch(): 

587 import torch 

588 

589 ptpu_mod = getattr(torch, "ptpu", None) 

590 if ( 

591 ptpu_mod is None 

592 or hasattr(ptpu_mod, "default_generators") 

593 or getattr(ptpu_mod, "_sunrise_default_generators_patched", False) 

594 ): 

595 return 

596 

597 class _SunrisePtpuGenerator: 

598 def __init__(self, device_idx): 

599 self.device_idx = int(device_idx) 

600 self.device = torch.device("ptpu", self.device_idx) 

601 

602 def get_state(self): 

603 return ptpu_mod.get_rng_state(self.device_idx).detach().cpu().clone() 

604 

605 def set_state(self, state): 

606 if not isinstance(state, torch.Tensor): 

607 raise TypeError("PTPU RNG state must be a torch.Tensor") 

608 if state.dtype != torch.uint8: 

609 raise TypeError("PTPU RNG state must be a torch.uint8 tensor") 

610 ptpu_mod.set_rng_state(state.detach().cpu().contiguous(), self.device_idx) 

611 

612 def manual_seed(self, seed): 

613 seed = int(seed) & 0xFFFFFFFFFFFFFFFF 

614 state = torch.zeros(16, dtype=torch.uint8, device="cpu") 

615 for i in range(8): 

616 state[i] = (seed >> (8 * i)) & 0xFF 

617 self.set_state(state) 

618 return self 

619 

620 class _SunrisePtpuDefaultGenerators: 

621 def __init__(self): 

622 self._generators = {} 

623 

624 def _normalize_device(self, device): 

625 if device is None: 

626 return int(ptpu_mod.current_device()) 

627 if isinstance(device, torch.device): 

628 if device.type != "ptpu": 

629 raise RuntimeError(f"Expected a ptpu device, got {device}") 

630 return ( 

631 int(ptpu_mod.current_device()) 

632 if device.index is None 

633 else int(device.index) 

634 ) 

635 if isinstance(device, str): 

636 return self._normalize_device(torch.device(device)) 

637 return int(device) 

638 

639 def __getitem__(self, device): 

640 device_idx = self._normalize_device(device) 

641 device_count = int(ptpu_mod.device_count()) 

642 if device_idx < 0: 

643 device_idx += device_count 

644 if device_idx < 0 or device_idx >= device_count: 

645 raise IndexError( 

646 f"PTPU device index {device_idx} is out of range " 

647 f"for {device_count} devices" 

648 ) 

649 if device_idx not in self._generators: 

650 self._generators[device_idx] = _SunrisePtpuGenerator(device_idx) 

651 return self._generators[device_idx] 

652 

653 def __iter__(self): 

654 for device_idx in range(len(self)): 

655 yield self[device_idx] 

656 

657 def __len__(self): 

658 return int(ptpu_mod.device_count()) 

659 

660 ptpu_mod.default_generators = _SunrisePtpuDefaultGenerators() 

661 ptpu_mod._sunrise_default_generators_patched = True 

662 

663 

664def _install_ptpu_multiprocessing_reduction_patch(): 

665 import multiprocessing.reduction as mp_reduction 

666 

667 import torch 

668 import torch.multiprocessing.reductions as reductions 

669 from torch.nn.parameter import Parameter 

670 

671 if getattr(reductions, "_sunrise_ptpu_reduce_tensor_patched", False): 

672 return 

673 

674 original_reduce_tensor = reductions.reduce_tensor 

675 

676 # Keep this in `_sunrise/__init__.py` instead of `monkey_patch.py` because 

677 # multiprocessing reducers are registered eagerly in Python's global 

678 # pickling table. This is import-time runtime wiring, not a call-site-level 

679 # torch API fallback that can be caught and retried after a NotImplemented. 

680 def reduce_tensor_with_ptpu_cpu_staging(tensor): 

681 if not _should_stage_ptpu_tensor_for_multiprocessing(tensor): 

682 return original_reduce_tensor(tensor) 

683 

684 if tensor.requires_grad and not tensor.is_leaf: 

685 raise RuntimeError( 

686 "Cowardly refusing to serialize non-leaf tensor which requires_grad, " 

687 "since autograd does not support crossing process boundaries. " 

688 "If you just want to transfer the data, call detach() on the tensor " 

689 "before serializing (e.g., putting it on the queue)." 

690 ) 

691 

692 reductions.check_serializing_named_tensor(tensor) 

693 torch.utils.hooks.warn_if_has_hooks(tensor) 

694 

695 return ( 

696 reductions._sunrise_rebuild_ptpu_tensor_from_cpu, 

697 ( 

698 type(tensor), 

699 tensor.detach().cpu(), 

700 tensor.device, 

701 tensor.requires_grad, 

702 ), 

703 ) 

704 

705 reductions._sunrise_rebuild_ptpu_tensor_from_cpu = ( 

706 _sunrise_rebuild_ptpu_tensor_from_cpu 

707 ) 

708 reductions._sunrise_original_reduce_tensor = original_reduce_tensor 

709 reductions.reduce_tensor = reduce_tensor_with_ptpu_cpu_staging 

710 for tensor_cls in torch._tensor_classes: 

711 mp_reduction.register(tensor_cls, reduce_tensor_with_ptpu_cpu_staging) 

712 mp_reduction.register(torch.Tensor, reduce_tensor_with_ptpu_cpu_staging) 

713 mp_reduction.register(Parameter, reduce_tensor_with_ptpu_cpu_staging) 

714 reductions._sunrise_ptpu_reduce_tensor_patched = True 

715 

716 

717_install_ptpu_default_generators_patch() 

718_install_ptpu_manual_seed_patch() 

719_install_autograd_dispatch_patch() 

720_install_register_config_patch() # 有些公共库也没有注册的op,只能先放在这里了。使得tests能过 

721_install_pointwise_dynamic_post_import_hook() 

722 

723 

724__all__ = ["*"]