Coverage for src/flag_gems/runtime/backend/_sunrise/monkey_patch.py: 0%

1104 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import functools 

2import inspect 

3import json 

4import logging 

5import math 

6import numbers 

7import os 

8import time 

9 

10import torch 

11import torch.nn.functional as F 

12 

13_PTPU_DEVICE = "ptpu" 

14_LOGGER = logging.getLogger(__name__) 

15 

16 

17def _is_ptpu_tensor(value): 

18 return isinstance(value, torch.Tensor) and value.device.type == _PTPU_DEVICE 

19 

20 

21def _is_ptpu_device(device): 

22 if device is None: 

23 return False 

24 if isinstance(device, torch.device): 

25 return device.type == _PTPU_DEVICE 

26 if isinstance(device, str): 

27 return device.split(":", 1)[0] == _PTPU_DEVICE 

28 return False 

29 

30 

31def _is_cpu_device(device): 

32 if device is None: 

33 return False 

34 if isinstance(device, torch.device): 

35 return device.type == "cpu" 

36 if isinstance(device, str): 

37 return device.split(":", 1)[0] == "cpu" 

38 return False 

39 

40 

41def _has_tensor_base_view(tensor): 

42 return ( 

43 isinstance(tensor, torch.Tensor) and getattr(tensor, "_base", None) is not None 

44 ) 

45 

46 

47def _to_cpu_if_ptpu(value): 

48 if _is_ptpu_tensor(value): 

49 return value.cpu() 

50 return value 

51 

52 

53def _to_device_if_tensor(value, device): 

54 if isinstance(value, torch.Tensor): 

55 return value.to(device=device) 

56 if isinstance(value, tuple): 

57 return tuple(_to_device_if_tensor(item, device) for item in value) 

58 return value 

59 

60 

61def _should_fallback_to_cpu(exc, tensor, aten_op): 

62 if not _is_ptpu_tensor(tensor): 

63 return False 

64 message = str(exc).lower() 

65 return aten_op.lower() in message and _PTPU_DEVICE in message 

66 

67 

68def _copy_cpu_result_to_out(result, out): 

69 if isinstance(out, torch.Tensor): 

70 out.copy_(_to_device_if_tensor(result, out.device)) 

71 return out 

72 if isinstance(out, tuple): 

73 for result_item, out_item in zip(result, out): 

74 _copy_cpu_result_to_out(result_item, out_item) 

75 return out 

76 return None 

77 

78 

79def _finalize_cpu_result(result, out, device): 

80 copied_out = _copy_cpu_result_to_out(result, out) 

81 if copied_out is not None: 

82 return copied_out 

83 return _to_device_if_tensor(result, device) 

84 

85 

86def _copy_result_to_tensor(result, tensor): 

87 tensor.copy_(_to_device_if_tensor(result, tensor.device)) 

88 return tensor 

89 

90 

91def _cpu_fallback(tensor, args, kwargs, original_fn): 

92 cpu_tensor = tensor.cpu() 

93 cpu_args = tuple(_to_cpu_if_ptpu(arg) for arg in args) 

94 cpu_kwargs = {key: _to_cpu_if_ptpu(value) for key, value in kwargs.items()} 

95 result = original_fn(cpu_tensor, *cpu_args, **cpu_kwargs) 

96 return _finalize_cpu_result(result, kwargs.get("out"), tensor.device) 

97 

98 

99def _inplace_cpu_fallback(tensor, args, kwargs, original_fn): 

100 cpu_tensor = tensor.cpu() 

101 cpu_args = tuple(_to_cpu_if_ptpu(arg) for arg in args) 

102 cpu_kwargs = {key: _to_cpu_if_ptpu(value) for key, value in kwargs.items()} 

103 result = original_fn(cpu_tensor, *cpu_args, **cpu_kwargs) 

104 return _copy_result_to_tensor(result, tensor) 

105 

106 

107def _torch_function_cpu_fallback(tensor, args, kwargs, original_fn): 

108 cpu_args = tuple(_to_cpu_if_ptpu(arg) for arg in args) 

109 cpu_kwargs = {key: _to_cpu_if_ptpu(value) for key, value in kwargs.items()} 

110 result = original_fn(*cpu_args, **cpu_kwargs) 

111 return _finalize_cpu_result(result, kwargs.get("out"), tensor.device) 

112 

113 

114def _torch_function_inplace_cpu_fallback(tensor, args, kwargs, original_fn): 

115 cpu_args = tuple(_to_cpu_if_ptpu(arg) for arg in args) 

116 cpu_kwargs = {key: _to_cpu_if_ptpu(value) for key, value in kwargs.items()} 

117 result = original_fn(*cpu_args, **cpu_kwargs) 

118 return _copy_result_to_tensor(result, tensor) 

119 

120 

121def _patch_tensor_copy_scalar_fill_fallback(): 

122 patched_attr = "_flag_gems_sunrise_copy_scalar_fill_patched" 

123 if getattr(torch.Tensor, patched_attr, False): 

124 return 

125 

126 original_fn = torch.Tensor.copy_ 

127 

128 def _scalar_fill_value(src): 

129 if isinstance(src, torch.Tensor): 

130 if src.ndim != 0: 

131 return None 

132 src = _to_cpu_if_ptpu(src) 

133 return src.item() 

134 if isinstance(src, numbers.Number): 

135 return src 

136 return None 

137 

138 @functools.wraps(original_fn) 

139 def copy_with_scalar_fill_fallback(self, src, *args, **kwargs): 

140 try: 

141 return original_fn(self, src, *args, **kwargs) 

142 except RuntimeError as exc: 

143 if _flag_gems_use_gems_active() or not _is_ptpu_tensor(self): 

144 raise 

145 if "cannot copy src shape: []" not in str(exc): 

146 raise 

147 fill_value = _scalar_fill_value(src) 

148 if fill_value is None: 

149 raise 

150 return self.fill_(fill_value) 

151 

152 torch.Tensor.copy_ = copy_with_scalar_fill_fallback 

153 setattr(torch.Tensor, patched_attr, True) 

154 

155 

156def _patch_tensor_method(name, aten_op, inplace=False): 

157 patched_attr = f"_flag_gems_sunrise_{name}_patched" 

158 if getattr(torch.Tensor, patched_attr, False): 

159 return 

160 

161 original_fn = getattr(torch.Tensor, name) 

162 

163 @functools.wraps(original_fn) 

164 def tensor_method_with_ptpu_cpu_fallback(self, *args, **kwargs): 

165 try: 

166 return original_fn(self, *args, **kwargs) 

167 except NotImplementedError as exc: 

168 if _flag_gems_use_gems_active(): 

169 raise 

170 if not _should_fallback_to_cpu(exc, self, aten_op): 

171 raise 

172 if inplace: 

173 return _inplace_cpu_fallback(self, args, kwargs, original_fn) 

174 return _cpu_fallback(self, args, kwargs, original_fn) 

175 

176 setattr(torch.Tensor, name, tensor_method_with_ptpu_cpu_fallback) 

177 setattr(torch.Tensor, patched_attr, True) 

178 

179 

180def _patch_tensor_property(name, aten_op): 

181 """Patch a `getset_descriptor` property on `torch.Tensor` (e.g. `real`, `imag`). 

182 

183 Wrap only the getter. Re-raise on non-PTPU dispatches or unrelated aten ops. 

184 Keep the original setter intact so alias-write semantics (`t.real = ...`) 

185 still go through the C-side descriptor. 

186 """ 

187 patched_attr = f"_flag_gems_sunrise_{name}_patched" 

188 if getattr(torch.Tensor, patched_attr, False): 

189 return 

190 

191 original_descriptor = getattr(torch.Tensor, name) 

192 original_get = original_descriptor.__get__ 

193 original_set = getattr(original_descriptor, "__set__", None) 

194 

195 def getter(self): 

196 try: 

197 return original_get(self) 

198 except NotImplementedError as exc: 

199 if _flag_gems_use_gems_active(): 

200 raise 

201 if not _should_fallback_to_cpu(exc, self, aten_op): 

202 raise 

203 cpu_result = original_get(self.cpu()) 

204 device_result = _to_device_if_tensor(cpu_result, self.device) 

205 if isinstance(cpu_result, torch.Tensor) and cpu_result.is_neg(): 

206 return torch._neg_view(device_result) 

207 return device_result 

208 

209 if original_set is None: 

210 new_descriptor = property(getter) 

211 else: 

212 

213 def setter(self, value): 

214 return original_set(self, value) 

215 

216 new_descriptor = property(getter, setter) 

217 

218 setattr(torch.Tensor, name, new_descriptor) 

219 setattr(torch.Tensor, patched_attr, True) 

220 

221 

222def _patch_torch_function(name, aten_op, inplace=False): 

223 patched_attr = f"_flag_gems_sunrise_{name}_patched" 

224 if getattr(torch, patched_attr, False): 

225 return 

226 

227 original_fn = getattr(torch, name) 

228 

229 @functools.wraps(original_fn) 

230 def function_with_ptpu_cpu_fallback(*args, **kwargs): 

231 tensor = args[0] if args else kwargs.get("input") 

232 try: 

233 return original_fn(*args, **kwargs) 

234 except NotImplementedError as exc: 

235 if _flag_gems_use_gems_active(): 

236 raise 

237 if not _should_fallback_to_cpu(exc, tensor, aten_op): 

238 raise 

239 if inplace: 

240 return _torch_function_inplace_cpu_fallback( 

241 tensor, args, kwargs, original_fn 

242 ) 

243 return _torch_function_cpu_fallback(tensor, args, kwargs, original_fn) 

244 

245 setattr(torch, name, function_with_ptpu_cpu_fallback) 

246 setattr(torch, patched_attr, True) 

247 

248 

249def _patch_torch_nn_functional(name, aten_op): 

250 """Patch `torch.nn.functional.<name>(...)` for PTPU CPU fallback. 

251 

252 Use when the failing call site is inside a `torch.nn` module's `forward` 

253 that routes through `torch.nn.functional.<name>(...)` (e.g. `F.pad`, 

254 `F.interpolate`) and the C++ dispatcher does not surface in the Python 

255 `torch.ops.aten.<op>(...)` packet path. 

256 """ 

257 patched_attr = f"_flag_gems_sunrise_nn_functional_{name}_patched" 

258 if getattr(F, patched_attr, False): 

259 return 

260 

261 original_fn = getattr(F, name) 

262 

263 @functools.wraps(original_fn) 

264 def functional_with_ptpu_cpu_fallback(*args, **kwargs): 

265 tensor = args[0] if args else kwargs.get("input") 

266 try: 

267 return original_fn(*args, **kwargs) 

268 except NotImplementedError as exc: 

269 if _flag_gems_use_gems_active(): 

270 raise 

271 if not _should_fallback_to_cpu(exc, tensor, aten_op): 

272 raise 

273 return _torch_function_cpu_fallback(tensor, args, kwargs, original_fn) 

274 

275 setattr(F, name, functional_with_ptpu_cpu_fallback) 

276 setattr(F, patched_attr, True) 

277 

278 

279def _vector_norm_arg(args, kwargs, index, name, default=None): 

280 return args[index] if len(args) > index else kwargs.get(name, default) 

281 

282 

283def _normalize_vector_norm_dims(tensor, dim): 

284 if dim is None: 

285 return tuple(range(tensor.ndim)) 

286 if isinstance(dim, int): 

287 return (dim % tensor.ndim,) 

288 return tuple(d % tensor.ndim for d in dim) 

289 

290 

291def _maybe_stable_cpu_vector_norm_reference(args, kwargs): 

292 """Use an explicit high-precision CPU reference for long finite norms. 

293 

294 PyTorch CPU `torch.linalg.vector_norm` can undercount long float32 

295 reductions on this environment, especially for multi-dim reductions over 

296 non-unit-stride slices. The Sunrise/PTPU Triton kernel is much closer to a 

297 double-precision reference, so keep the device path native and only correct 

298 the CPU reference helper path outside `flag_gems.use_gems()`. 

299 """ 

300 tensor = args[0] if args else kwargs.get("input") or kwargs.get("x") 

301 if ( 

302 _flag_gems_use_gems_active() 

303 or not isinstance(tensor, torch.Tensor) 

304 or tensor.device.type != "cpu" 

305 or not tensor.is_floating_point() 

306 or tensor.dtype not in (torch.float16, torch.float32, torch.bfloat16) 

307 ): 

308 return None 

309 

310 ord_value = _vector_norm_arg(args, kwargs, 1, "ord", 2) 

311 if ord_value not in (1, 2): 

312 return None 

313 

314 dim = _vector_norm_arg(args, kwargs, 2, "dim", None) 

315 dims = _normalize_vector_norm_dims(tensor, dim) 

316 if not dims: 

317 return None 

318 

319 reduction_numel = math.prod(tensor.shape[d] for d in dims) 

320 if reduction_numel < 2048: 

321 return None 

322 

323 keepdim = _vector_norm_arg(args, kwargs, 3, "keepdim", False) 

324 dtype = kwargs.get("dtype", None) or tensor.dtype 

325 if isinstance(dtype, str): 

326 dtype = getattr(torch, dtype) 

327 out = kwargs.get("out", None) 

328 

329 work = tensor.to(torch.float64) 

330 if ord_value == 1: 

331 result = work.abs().sum(dim=dims, keepdim=keepdim) 

332 else: 

333 result = torch.sqrt((work * work).sum(dim=dims, keepdim=keepdim)) 

334 result = result.to(dtype=dtype) 

335 

336 if out is not None: 

337 out.copy_(result) 

338 return out 

339 return result 

340 

341 

342def _patch_torch_linalg_function(name, aten_op): 

343 """Patch `torch.linalg.<name>(...)` for Sunrise reference/fallback quirks.""" 

344 patched_attr = f"_flag_gems_sunrise_linalg_{name}_patched" 

345 if getattr(torch.linalg, patched_attr, False): 

346 return 

347 

348 original_fn = getattr(torch.linalg, name) 

349 

350 @functools.wraps(original_fn) 

351 def linalg_with_ptpu_cpu_fallback(*args, **kwargs): 

352 tensor = args[0] if args else kwargs.get("input") 

353 if name == "vector_norm": 

354 stable_result = _maybe_stable_cpu_vector_norm_reference(args, kwargs) 

355 if stable_result is not None: 

356 return stable_result 

357 try: 

358 return original_fn(*args, **kwargs) 

359 except NotImplementedError as exc: 

360 if _flag_gems_use_gems_active(): 

361 raise 

362 if not _should_fallback_to_cpu(exc, tensor, aten_op): 

363 raise 

364 return _torch_function_cpu_fallback(tensor, args, kwargs, original_fn) 

365 

366 setattr(torch.linalg, name, linalg_with_ptpu_cpu_fallback) 

367 setattr(torch.linalg, patched_attr, True) 

368 

369 

370def _patch_torch_tensor_out(packet_name, aten_op): 

371 packet = getattr(torch.ops.aten, packet_name) 

372 patched_attr = "_flag_gems_sunrise_tensor_out_patched" 

373 if getattr(packet, patched_attr, False): 

374 return 

375 

376 original_fn = packet.Tensor_out 

377 

378 @functools.wraps(original_fn) 

379 def tensor_out_with_ptpu_cpu_fallback(*args, **kwargs): 

380 tensor = args[0] if args else kwargs.get("self") 

381 try: 

382 return original_fn(*args, **kwargs) 

383 except NotImplementedError as exc: 

384 if _flag_gems_use_gems_active(): 

385 raise 

386 if not _should_fallback_to_cpu(exc, tensor, aten_op): 

387 raise 

388 return _torch_function_cpu_fallback(tensor, args, kwargs, original_fn) 

389 

390 packet.Tensor_out = tensor_out_with_ptpu_cpu_fallback 

391 setattr(packet, patched_attr, True) 

392 

393 

394def _patch_torch_out(packet_name, aten_op): 

395 packet = getattr(torch.ops.aten, packet_name) 

396 patched_attr = "_flag_gems_sunrise_out_patched" 

397 if getattr(packet, patched_attr, False): 

398 return 

399 

400 original_fn = packet.out 

401 

402 @functools.wraps(original_fn) 

403 def out_with_ptpu_cpu_fallback(*args, **kwargs): 

404 tensor = args[0] if args else kwargs.get("self") or kwargs.get("input") 

405 try: 

406 return original_fn(*args, **kwargs) 

407 except NotImplementedError as exc: 

408 if _flag_gems_use_gems_active(): 

409 raise 

410 if not _should_fallback_to_cpu(exc, tensor, aten_op): 

411 raise 

412 return _torch_function_cpu_fallback(tensor, args, kwargs, original_fn) 

413 

414 packet.out = out_with_ptpu_cpu_fallback 

415 setattr(packet, patched_attr, True) 

416 

417 

418def _patch_torch_creation_function(name, aten_op): 

419 """Patch a `torch.<name>(...)` creation op (no dispatch-driving tensor input). 

420 

421 Detect a PTPU target via the `device=` kwarg, fall back by calling the 

422 original function on CPU, then move the result to the requested device. 

423 """ 

424 patched_attr = f"_flag_gems_sunrise_{name}_patched" 

425 if getattr(torch, patched_attr, False): 

426 return 

427 

428 original_fn = getattr(torch, name) 

429 

430 @functools.wraps(original_fn) 

431 def creation_with_ptpu_cpu_fallback(*args, **kwargs): 

432 device = kwargs.get("device") 

433 try: 

434 return original_fn(*args, **kwargs) 

435 except NotImplementedError as exc: 

436 if _flag_gems_use_gems_active(): 

437 raise 

438 if not _is_ptpu_device(device): 

439 raise 

440 message = str(exc).lower() 

441 if aten_op.lower() not in message or _PTPU_DEVICE not in message: 

442 raise 

443 cpu_kwargs = dict(kwargs) 

444 cpu_kwargs["device"] = "cpu" 

445 out = kwargs.get("out") 

446 if isinstance(out, torch.Tensor) and _is_ptpu_tensor(out): 

447 cpu_kwargs["out"] = None 

448 result = original_fn(*args, **cpu_kwargs) 

449 return _finalize_cpu_result( 

450 result, 

451 kwargs.get("out"), 

452 torch.device(device) 

453 if not isinstance(device, torch.device) 

454 else device, 

455 ) 

456 

457 setattr(torch, name, creation_with_ptpu_cpu_fallback) 

458 setattr(torch, patched_attr, True) 

459 

460 

461def _patch_torch_randn_complex_dtype(): 

462 """Generate complex-dtype `torch.randn(...)` on CPU when targeting PTPU. 

463 

464 PTPU's `randn` implementation calls `normal_` internally, which raises 

465 `RuntimeError: normal_ does not support complex tensors on PTPU, but got 

466 c10::complex<...>` for any complex dtype. This is a quirk: the failure 

467 text is a plain `RuntimeError`, not `NotImplementedError`, and it does 

468 not name an `aten::...` symbol, so `_should_fallback_to_cpu(...)` and 

469 `_patch_torch_creation_function(...)` do not fit. 

470 

471 Narrow guard: 

472 

473 - Wrap only `torch.randn` 

474 - Only divert when `dtype` is a complex dtype AND `device` is PTPU 

475 - Only divert when the raised `RuntimeError` matches the known quirk text 

476 - Real-dtype `torch.randn(..., device='ptpu')` is untouched 

477 """ 

478 patched_attr = "_flag_gems_sunrise_randn_complex_dtype_patched" 

479 if getattr(torch, patched_attr, False): 

480 return 

481 

482 original_fn = torch.randn 

483 complex_quirk_marker = "normal_ does not support complex tensors" 

484 float64_quirk_marker = "supports only float16, bfloat16 and float32 tensors" 

485 

486 @functools.wraps(original_fn) 

487 def randn_with_ptpu_complex_cpu_fallback(*args, **kwargs): 

488 dtype = kwargs.get("dtype") 

489 device = kwargs.get("device") 

490 if ( 

491 isinstance(dtype, torch.dtype) 

492 and dtype == torch.float64 

493 and _is_ptpu_device(device) 

494 and not _flag_gems_use_gems_active() 

495 ): 

496 cpu_kwargs = dict(kwargs) 

497 cpu_kwargs["device"] = "cpu" 

498 result = original_fn(*args, **cpu_kwargs) 

499 target_device = ( 

500 device if isinstance(device, torch.device) else torch.device(device) 

501 ) 

502 return _to_device_if_tensor(result, target_device) 

503 if ( 

504 isinstance(dtype, torch.dtype) 

505 and dtype.is_complex 

506 and _is_ptpu_device(device) 

507 ): 

508 try: 

509 return original_fn(*args, **kwargs) 

510 except RuntimeError as exc: 

511 if _flag_gems_use_gems_active(): 

512 raise 

513 if complex_quirk_marker not in str( 

514 exc 

515 ) and float64_quirk_marker not in str(exc): 

516 raise 

517 cpu_kwargs = dict(kwargs) 

518 cpu_kwargs["device"] = "cpu" 

519 result = original_fn(*args, **cpu_kwargs) 

520 target_device = ( 

521 device if isinstance(device, torch.device) else torch.device(device) 

522 ) 

523 return _to_device_if_tensor(result, target_device) 

524 return original_fn(*args, **kwargs) 

525 

526 torch.randn = randn_with_ptpu_complex_cpu_fallback 

527 setattr(torch, patched_attr, True) 

528 

529 

530def _patch_torch_cudnn_convolution(): 

531 """Run `torch.cudnn_convolution(...)` on CPU via `F.conv{1,2,3}d` for PTPU. 

532 

533 `aten::cudnn_convolution` is a CUDA/cuDNN-only op — it is unimplemented on 

534 PTPU AND on CPU, so the usual "bounce the same call to CPU" trick fails. 

535 The math is plain (bias-free) convolution, which CPU *does* support through 

536 `torch.nn.functional.conv{1,2,3}d`. So the fallback both moves to CPU and 

537 re-expresses the op as the corresponding functional conv, then moves the 

538 result back to the PTPU device. 

539 

540 Signature mapping (note `cudnn_convolution` has no bias arg, and its 

541 `benchmark` / `deterministic` / `allow_tf32` tuning flags have no CPU 

542 analogue and are dropped): 

543 

544 cudnn_convolution(input, weight, *, padding, stride, dilation, groups, 

545 benchmark, deterministic, allow_tf32) 

546 -> F.conv{1,2,3}d(input, weight, bias=None, 

547 stride=stride, padding=padding, 

548 dilation=dilation, groups=groups) 

549 

550 The conv rank is selected by `input.dim()` (3->1d, 4->2d, 5->3d). 

551 """ 

552 patched_attr = "_flag_gems_sunrise_cudnn_convolution_patched" 

553 if getattr(torch, patched_attr, False): 

554 return 

555 

556 original_fn = torch.cudnn_convolution 

557 conv_by_rank = { 

558 3: F.conv1d, 

559 4: F.conv2d, 

560 5: F.conv3d, 

561 } 

562 

563 @functools.wraps(original_fn) 

564 def cudnn_convolution_with_ptpu_cpu_fallback(*args, **kwargs): 

565 tensor = args[0] if args else kwargs.get("input") or kwargs.get("self") 

566 try: 

567 return original_fn(*args, **kwargs) 

568 except NotImplementedError as exc: 

569 if _flag_gems_use_gems_active(): 

570 raise 

571 if not _should_fallback_to_cpu(exc, tensor, "aten::cudnn_convolution"): 

572 raise 

573 

574 call_args = list(args) 

575 call_kwargs = dict(kwargs) 

576 

577 def _take(name, position): 

578 if len(call_args) > position: 

579 return call_args[position] 

580 return call_kwargs.get(name) 

581 

582 inp = _take("input", 0) 

583 weight = _take("weight", 1) 

584 padding = _take("padding", 2) 

585 stride = _take("stride", 3) 

586 dilation = _take("dilation", 4) 

587 groups = _take("groups", 5) 

588 

589 conv_fn = conv_by_rank.get(inp.dim()) 

590 if conv_fn is None: 

591 raise 

592 cpu_out = conv_fn( 

593 _to_cpu_if_ptpu(inp), 

594 _to_cpu_if_ptpu(weight), 

595 bias=None, 

596 stride=stride, 

597 padding=padding, 

598 dilation=dilation, 

599 groups=groups, 

600 ) 

601 return _to_device_if_tensor(cpu_out, tensor.device) 

602 

603 torch.cudnn_convolution = cudnn_convolution_with_ptpu_cpu_fallback 

604 setattr(torch, patched_attr, True) 

605 

606 

607def _patch_torch_div_floor_trunc_integer_dtype(): 

608 """Force `torch.div(int_tensor, ..., rounding_mode='floor'|'trunc')` to 

609 return an integer dtype on PTPU. 

610 

611 PTPU's `aten::div.Tensor` returns float for integer-typed inputs even when 

612 `rounding_mode` requests integer-style rounding (CPU returns int). This is 

613 a wrong-dtype quirk, not a NotImplementedError, so it does not fit any of 

614 the `_should_fallback_to_cpu` helpers above. 

615 

616 Narrow guard: 

617 

618 - Wrap only `torch.div` 

619 - Only divert when `rounding_mode` is `'floor'` / `'trunc'` 

620 - Only divert when at least one participating operand is a PTPU integer 

621 (non-floating, non-complex) tensor and every participating operand keeps 

622 integer floor/trunc semantics 

623 - True division (`rounding_mode=None`) is left untouched even for int inputs 

624 (returning float there is the correct PyTorch semantics) 

625 """ 

626 patched_attr = "_flag_gems_sunrise_div_floor_trunc_dtype_patched" 

627 if getattr(torch, patched_attr, False): 

628 return 

629 

630 original_fn = torch.div 

631 

632 def _is_integer_like_div_operand(value): 

633 if isinstance(value, torch.Tensor): 

634 return not value.is_floating_point() and not value.is_complex() 

635 return isinstance(value, (bool, int)) 

636 

637 def _find_ptpu_integer_tensor(args, kwargs): 

638 candidates = [] 

639 if len(args) > 0: 

640 candidates.append(args[0]) 

641 if len(args) > 1: 

642 candidates.append(args[1]) 

643 candidates.extend( 

644 [ 

645 kwargs.get("input"), 

646 kwargs.get("other"), 

647 kwargs.get("tensor"), 

648 kwargs.get("value"), 

649 ] 

650 ) 

651 for value in candidates: 

652 if ( 

653 isinstance(value, torch.Tensor) 

654 and value.device.type == _PTPU_DEVICE 

655 and _is_integer_like_div_operand(value) 

656 ): 

657 return value 

658 return None 

659 

660 @functools.wraps(original_fn) 

661 def div_with_ptpu_integer_dtype_fix(*args, **kwargs): 

662 rounding_mode = kwargs.get("rounding_mode") 

663 if rounding_mode in ("floor", "trunc"): 

664 if _flag_gems_use_gems_active(): 

665 return original_fn(*args, **kwargs) 

666 tensor = _find_ptpu_integer_tensor(args, kwargs) 

667 operands = ( 

668 args[:2] 

669 if len(args) >= 2 

670 else ( 

671 tuple(args) 

672 + tuple( 

673 value 

674 for value in (kwargs.get("input"), kwargs.get("other")) 

675 if value is not None 

676 ) 

677 ) 

678 ) 

679 if ( 

680 tensor is not None 

681 and operands 

682 and all(_is_integer_like_div_operand(value) for value in operands) 

683 ): 

684 return _torch_function_cpu_fallback(tensor, args, kwargs, original_fn) 

685 return original_fn(*args, **kwargs) 

686 

687 torch.div = div_with_ptpu_integer_dtype_fix 

688 setattr(torch, patched_attr, True) 

689 

690 

691def _patch_tensor_to_cpu_for_complex_views(): 

692 """Route complex PTPU view copies to CPU through the base tensor safely. 

693 

694 Sunrise/PTPU has two related host-copy gaps for complex tensors: 

695 

696 - conjugate views can segfault on `.cpu()` / `.to('cpu')` 

697 - sliced / non-contiguous complex views can fail with 

698 `direct_copy_kernel_ptpu ... failed to dispatch data type ComplexFloat` 

699 

700 For these cases, copy the root base tensor to CPU first, rebuild the 

701 original view metadata on CPU with `as_strided`, then reapply lazy conj/neg 

702 bits on the CPU tensor. 

703 """ 

704 to_attr = "_flag_gems_sunrise_tensor_to_complex_view_cpu_patched" 

705 cpu_attr = "_flag_gems_sunrise_tensor_cpu_complex_view_patched" 

706 if getattr(torch.Tensor, to_attr, False) and getattr(torch.Tensor, cpu_attr, False): 

707 return 

708 

709 original_to = torch.Tensor.to 

710 original_cpu = torch.Tensor.cpu 

711 

712 def _should_route_through_base(self): 

713 return ( 

714 isinstance(self, torch.Tensor) 

715 and self.device.type == _PTPU_DEVICE 

716 and self.is_complex() 

717 and (self.is_conj() or self.is_neg() or _has_tensor_base_view(self)) 

718 ) 

719 

720 def _to_targets_cpu(args, kwargs): 

721 if _is_cpu_device(kwargs.get("device")): 

722 return True 

723 if not args: 

724 return False 

725 first = args[0] 

726 if _is_cpu_device(first): 

727 return True 

728 if isinstance(first, torch.Tensor): 

729 return first.device.type == "cpu" 

730 return False 

731 

732 def _to_target_dtype(args, kwargs): 

733 dtype = kwargs.get("dtype") 

734 if isinstance(dtype, torch.dtype): 

735 return dtype 

736 if not args: 

737 return None 

738 first = args[0] 

739 if isinstance(first, torch.dtype): 

740 return first 

741 if isinstance(first, torch.Tensor): 

742 return first.dtype 

743 return None 

744 

745 def _rebuild_complex_view_on_cpu(self): 

746 if self.is_conj(): 

747 cpu_view = _rebuild_complex_view_on_cpu(self.conj()).conj() 

748 if self.is_neg(): 

749 cpu_view = torch._neg_view(cpu_view) 

750 return cpu_view 

751 

752 root = self 

753 while _has_tensor_base_view(root): 

754 root = root._base 

755 

756 cpu_root = original_cpu(root) 

757 cpu_view = cpu_root 

758 if root is not self: 

759 cpu_view = torch.as_strided( 

760 cpu_root, 

761 self.size(), 

762 self.stride(), 

763 self.storage_offset(), 

764 ) 

765 if self.is_neg(): 

766 cpu_view = torch._neg_view(cpu_view) 

767 return cpu_view 

768 

769 @functools.wraps(original_to) 

770 def to_with_complex_conj_cpu_route(self, *args, **kwargs): 

771 if _flag_gems_use_gems_active(): 

772 return original_to(self, *args, **kwargs) 

773 if _should_route_through_base(self) and _to_targets_cpu(args, kwargs): 

774 cpu_view = _rebuild_complex_view_on_cpu(self) 

775 return original_to(cpu_view, *args, **kwargs) 

776 try: 

777 return original_to(self, *args, **kwargs) 

778 except RuntimeError as exc: 

779 target_dtype = _to_target_dtype(args, kwargs) 

780 if ( 

781 not _is_ptpu_tensor(self) 

782 or self.is_complex() 

783 or not isinstance(target_dtype, torch.dtype) 

784 or not target_dtype.is_complex 

785 or "failed to dispatch data type complex" not in str(exc).lower() 

786 ): 

787 raise 

788 cpu_cast = original_to(original_cpu(self), *args, **kwargs) 

789 return original_to(cpu_cast, device=self.device) 

790 

791 @functools.wraps(original_cpu) 

792 def cpu_with_complex_conj_cpu_route(self, *args, **kwargs): 

793 if _flag_gems_use_gems_active(): 

794 return original_cpu(self, *args, **kwargs) 

795 if _should_route_through_base(self): 

796 return _rebuild_complex_view_on_cpu(self) 

797 return original_cpu(self, *args, **kwargs) 

798 

799 torch.Tensor.to = to_with_complex_conj_cpu_route 

800 torch.Tensor.cpu = cpu_with_complex_conj_cpu_route 

801 setattr(torch.Tensor, to_attr, True) 

802 setattr(torch.Tensor, cpu_attr, True) 

803 

804 

805def _patch_complex_tensor_scalar_mul_runtime_error(): 

806 """Fallback complex-tensor scalar mul to CPU on the PTPU runtime quirk. 

807 

808 Sunrise/PTPU currently fails outside `flag_gems.use_gems()` for: 

809 

810 - `x * 2.0` 

811 - `x.mul(2.0)` 

812 - `torch.mul(x, 2.0)` 

813 

814 when `x` is a PTPU complex tensor. The failure is a plain `RuntimeError` 

815 whose text looks like: 

816 

817 `...BINARY_MUL... failed to dispatch data type ComplexFloat` 

818 

819 This is not a `NotImplementedError` and does not name an `aten::...` 

820 symbol, so the generic `_should_fallback_to_cpu(...)` helpers do not fit. 

821 

822 Narrow guard: 

823 

824 - only `torch.mul`, `Tensor.mul`, and `Tensor.__mul__` 

825 - only when the left-hand side is a PTPU complex tensor 

826 - only when the right-hand side is a non-tensor scalar 

827 - only on the known runtime error substring 

828 """ 

829 tensor_mul_attr = "_flag_gems_sunrise_tensor_mul_complex_scalar_patched" 

830 function_mul_attr = "_flag_gems_sunrise_function_mul_complex_scalar_patched" 

831 if getattr(torch.Tensor, tensor_mul_attr, False) and getattr( 

832 torch, function_mul_attr, False 

833 ): 

834 return 

835 

836 quirk_marker = "failed to dispatch data type complex" 

837 

838 def _should_fallback_complex_scalar_mul(tensor, other): 

839 return ( 

840 isinstance(tensor, torch.Tensor) 

841 and tensor.device.type == _PTPU_DEVICE 

842 and not isinstance(other, torch.Tensor) 

843 and (tensor.is_complex() or isinstance(other, complex)) 

844 ) 

845 

846 def _ptpu_mul_reference_tensor(*values): 

847 scalar_complex = any(isinstance(value, complex) for value in values) 

848 for value in values: 

849 if not isinstance(value, torch.Tensor) or value.device.type != _PTPU_DEVICE: 

850 continue 

851 if value.is_complex() or value.dtype == torch.float64: 

852 return value 

853 if scalar_complex: 

854 return value 

855 return None 

856 

857 original_tensor_mul = torch.Tensor.mul 

858 original_tensor_dunder_mul = torch.Tensor.__mul__ 

859 original_tensor_dunder_rmul = torch.Tensor.__rmul__ 

860 original_function_mul = torch.mul 

861 

862 @functools.wraps(original_tensor_mul) 

863 def tensor_mul_with_complex_scalar_cpu_fallback(self, other): 

864 reference_tensor = _ptpu_mul_reference_tensor(self, other) 

865 if reference_tensor is not None and not _flag_gems_use_gems_active(): 

866 return original_tensor_mul( 

867 _to_cpu_if_ptpu(self), _to_cpu_if_ptpu(other) 

868 ).to(reference_tensor.device) 

869 try: 

870 return original_tensor_mul(self, other) 

871 except RuntimeError as exc: 

872 if _flag_gems_use_gems_active(): 

873 raise 

874 if not _should_fallback_complex_scalar_mul(self, other): 

875 raise 

876 if quirk_marker not in str(exc).lower(): 

877 raise 

878 return original_tensor_mul(self.cpu(), other).to(self.device) 

879 

880 @functools.wraps(original_tensor_dunder_mul) 

881 def tensor_dunder_mul_with_complex_scalar_cpu_fallback(self, other): 

882 reference_tensor = _ptpu_mul_reference_tensor(self, other) 

883 if reference_tensor is not None and not _flag_gems_use_gems_active(): 

884 return original_tensor_dunder_mul( 

885 _to_cpu_if_ptpu(self), _to_cpu_if_ptpu(other) 

886 ).to(reference_tensor.device) 

887 try: 

888 return original_tensor_dunder_mul(self, other) 

889 except RuntimeError as exc: 

890 if _flag_gems_use_gems_active(): 

891 raise 

892 if not _should_fallback_complex_scalar_mul(self, other): 

893 raise 

894 if quirk_marker not in str(exc).lower(): 

895 raise 

896 return original_tensor_dunder_mul(self.cpu(), other).to(self.device) 

897 

898 @functools.wraps(original_tensor_dunder_rmul) 

899 def tensor_dunder_rmul_with_complex_scalar_cpu_fallback(self, other): 

900 reference_tensor = _ptpu_mul_reference_tensor(self, other) 

901 if reference_tensor is not None and not _flag_gems_use_gems_active(): 

902 return original_tensor_dunder_rmul( 

903 _to_cpu_if_ptpu(self), _to_cpu_if_ptpu(other) 

904 ).to(reference_tensor.device) 

905 try: 

906 return original_tensor_dunder_rmul(self, other) 

907 except RuntimeError as exc: 

908 if _flag_gems_use_gems_active(): 

909 raise 

910 if not _should_fallback_complex_scalar_mul(self, other): 

911 raise 

912 if quirk_marker not in str(exc).lower(): 

913 raise 

914 return original_tensor_dunder_rmul(self.cpu(), other).to(self.device) 

915 

916 @functools.wraps(original_function_mul) 

917 def function_mul_with_complex_scalar_cpu_fallback(*args, **kwargs): 

918 tensor = next((arg for arg in args[:2] if isinstance(arg, torch.Tensor)), None) 

919 if tensor is None: 

920 tensor = kwargs.get("input") 

921 if tensor is None: 

922 tensor = kwargs.get("other") 

923 other = None 

924 if len(args) > 1: 

925 other = args[1] if tensor is args[0] else args[0] 

926 else: 

927 other = ( 

928 kwargs.get("other") 

929 if tensor is kwargs.get("input") 

930 else kwargs.get("input") 

931 ) 

932 reference_tensor = _ptpu_mul_reference_tensor(tensor, other) 

933 if reference_tensor is not None and not _flag_gems_use_gems_active(): 

934 cpu_args = tuple(_to_cpu_if_ptpu(arg) for arg in args) 

935 cpu_kwargs = {key: _to_cpu_if_ptpu(value) for key, value in kwargs.items()} 

936 result = original_function_mul(*cpu_args, **cpu_kwargs) 

937 return _to_device_if_tensor(result, reference_tensor.device) 

938 try: 

939 return original_function_mul(*args, **kwargs) 

940 except RuntimeError as exc: 

941 if _flag_gems_use_gems_active(): 

942 raise 

943 if not _should_fallback_complex_scalar_mul(tensor, other): 

944 raise 

945 if quirk_marker not in str(exc).lower(): 

946 raise 

947 cpu_args = tuple(_to_cpu_if_ptpu(arg) for arg in args) 

948 cpu_kwargs = {key: _to_cpu_if_ptpu(value) for key, value in kwargs.items()} 

949 result = original_function_mul(*cpu_args, **cpu_kwargs) 

950 return _to_device_if_tensor(result, tensor.device) 

951 

952 torch.Tensor.mul = tensor_mul_with_complex_scalar_cpu_fallback 

953 torch.Tensor.__mul__ = tensor_dunder_mul_with_complex_scalar_cpu_fallback 

954 torch.Tensor.__rmul__ = tensor_dunder_rmul_with_complex_scalar_cpu_fallback 

955 torch.mul = function_mul_with_complex_scalar_cpu_fallback 

956 setattr(torch.Tensor, tensor_mul_attr, True) 

957 setattr(torch, function_mul_attr, True) 

958 

959 

960def _patch_complex_tensor_add_runtime_error(): 

961 """Fallback complex add to CPU on the Sunrise/PTPU runtime quirk. 

962 

963 Outside `flag_gems.use_gems()`, raw complex add can fail with a plain 

964 runtime error like: 

965 

966 `...BINARY_ADD... failed to dispatch data type ComplexFloat` 

967 

968 This typically shows up in reference expressions such as `a + b * alpha` 

969 inside tests. Keep the guard narrow so the real device add path under 

970 `use_gems()` remains visible. 

971 """ 

972 tensor_add_attr = "_flag_gems_sunrise_tensor_add_complex_patched" 

973 function_add_attr = "_flag_gems_sunrise_function_add_complex_patched" 

974 if getattr(torch.Tensor, tensor_add_attr, False) and getattr( 

975 torch, function_add_attr, False 

976 ): 

977 return 

978 

979 quirk_marker = "failed to dispatch data type complex" 

980 

981 def _first_ptpu_complex_tensor(*values): 

982 for value in values: 

983 if ( 

984 isinstance(value, torch.Tensor) 

985 and value.device.type == _PTPU_DEVICE 

986 and value.is_complex() 

987 ): 

988 return value 

989 return None 

990 

991 def _should_route_complex_scalar_add(tensor, other): 

992 return ( 

993 isinstance(tensor, torch.Tensor) 

994 and tensor.device.type == _PTPU_DEVICE 

995 and tensor.is_complex() 

996 and isinstance(other, complex) 

997 ) 

998 

999 def _ptpu_add_reference_tensor(*values): 

1000 for value in values: 

1001 if ( 

1002 isinstance(value, torch.Tensor) 

1003 and value.device.type == _PTPU_DEVICE 

1004 and (value.is_complex() or value.dtype == torch.float64) 

1005 ): 

1006 return value 

1007 return None 

1008 

1009 original_tensor_add = torch.Tensor.add 

1010 original_tensor_dunder_add = torch.Tensor.__add__ 

1011 original_function_add = torch.add 

1012 

1013 @functools.wraps(original_tensor_add) 

1014 def tensor_add_with_complex_cpu_fallback(self, other, *args, **kwargs): 

1015 reference_tensor = _ptpu_add_reference_tensor(self, other) 

1016 if reference_tensor is not None and not _flag_gems_use_gems_active(): 

1017 return original_tensor_add( 

1018 _to_cpu_if_ptpu(self), _to_cpu_if_ptpu(other), *args, **kwargs 

1019 ).to(reference_tensor.device) 

1020 if not _flag_gems_use_gems_active() and _should_route_complex_scalar_add( 

1021 self, other 

1022 ): 

1023 return original_tensor_add(self.cpu(), other, *args, **kwargs).to( 

1024 self.device 

1025 ) 

1026 try: 

1027 return original_tensor_add(self, other, *args, **kwargs) 

1028 except RuntimeError as exc: 

1029 if _flag_gems_use_gems_active(): 

1030 raise 

1031 tensor = _first_ptpu_complex_tensor(self, other) 

1032 if tensor is None or quirk_marker not in str(exc).lower(): 

1033 raise 

1034 cpu_self = _to_cpu_if_ptpu(self) 

1035 cpu_other = _to_cpu_if_ptpu(other) 

1036 result = original_tensor_add(cpu_self, cpu_other, *args, **kwargs) 

1037 return _to_device_if_tensor(result, tensor.device) 

1038 

1039 @functools.wraps(original_tensor_dunder_add) 

1040 def tensor_dunder_add_with_complex_cpu_fallback(self, other): 

1041 reference_tensor = _ptpu_add_reference_tensor(self, other) 

1042 if reference_tensor is not None and not _flag_gems_use_gems_active(): 

1043 return original_tensor_dunder_add( 

1044 _to_cpu_if_ptpu(self), _to_cpu_if_ptpu(other) 

1045 ).to(reference_tensor.device) 

1046 if not _flag_gems_use_gems_active() and _should_route_complex_scalar_add( 

1047 self, other 

1048 ): 

1049 return original_tensor_dunder_add(self.cpu(), other).to(self.device) 

1050 try: 

1051 return original_tensor_dunder_add(self, other) 

1052 except RuntimeError as exc: 

1053 if _flag_gems_use_gems_active(): 

1054 raise 

1055 tensor = _first_ptpu_complex_tensor(self, other) 

1056 if tensor is None or quirk_marker not in str(exc).lower(): 

1057 raise 

1058 cpu_self = _to_cpu_if_ptpu(self) 

1059 cpu_other = _to_cpu_if_ptpu(other) 

1060 result = original_tensor_dunder_add(cpu_self, cpu_other) 

1061 return _to_device_if_tensor(result, tensor.device) 

1062 

1063 @functools.wraps(original_function_add) 

1064 def function_add_with_complex_cpu_fallback(*args, **kwargs): 

1065 tensor = _first_ptpu_complex_tensor( 

1066 *( 

1067 args[:2] 

1068 if len(args) >= 2 

1069 else (kwargs.get("input"), kwargs.get("other")) 

1070 ) 

1071 ) 

1072 other = args[1] if len(args) > 1 else kwargs.get("other") 

1073 reference_tensor = _ptpu_add_reference_tensor( 

1074 *( 

1075 args[:2] 

1076 if len(args) >= 2 

1077 else (kwargs.get("input"), kwargs.get("other")) 

1078 ) 

1079 ) 

1080 if reference_tensor is not None and not _flag_gems_use_gems_active(): 

1081 cpu_args = tuple(_to_cpu_if_ptpu(arg) for arg in args) 

1082 cpu_kwargs = {key: _to_cpu_if_ptpu(value) for key, value in kwargs.items()} 

1083 result = original_function_add(*cpu_args, **cpu_kwargs) 

1084 return _to_device_if_tensor(result, reference_tensor.device) 

1085 if not _flag_gems_use_gems_active() and _should_route_complex_scalar_add( 

1086 tensor, other 

1087 ): 

1088 cpu_args = tuple(_to_cpu_if_ptpu(arg) for arg in args) 

1089 cpu_kwargs = {key: _to_cpu_if_ptpu(value) for key, value in kwargs.items()} 

1090 result = original_function_add(*cpu_args, **cpu_kwargs) 

1091 return _to_device_if_tensor(result, tensor.device) 

1092 try: 

1093 return original_function_add(*args, **kwargs) 

1094 except RuntimeError as exc: 

1095 if _flag_gems_use_gems_active(): 

1096 raise 

1097 if tensor is None or quirk_marker not in str(exc).lower(): 

1098 raise 

1099 cpu_args = tuple(_to_cpu_if_ptpu(arg) for arg in args) 

1100 cpu_kwargs = {key: _to_cpu_if_ptpu(value) for key, value in kwargs.items()} 

1101 result = original_function_add(*cpu_args, **cpu_kwargs) 

1102 return _to_device_if_tensor(result, tensor.device) 

1103 

1104 torch.Tensor.add = tensor_add_with_complex_cpu_fallback 

1105 torch.Tensor.__add__ = tensor_dunder_add_with_complex_cpu_fallback 

1106 torch.add = function_add_with_complex_cpu_fallback 

1107 setattr(torch.Tensor, tensor_add_attr, True) 

1108 setattr(torch, function_add_attr, True) 

1109 

1110 

1111def _patch_torch_isclose_allclose_complex_dtype(): 

1112 """Fallback `torch.isclose` / `torch.allclose` for PTPU complex/fp64 tensors. 

1113 

1114 `torch.testing.assert_close(...)` on Sunrise/PTPU complex tensors reaches 

1115 `torch.isclose(...)`, which can raise: 

1116 

1117 `RuntimeError: unsupported scalar type: ComplexFloat` 

1118 

1119 This is a plain runtime quirk outside `flag_gems.use_gems()`, not an 

1120 `aten::...`-tagged `NotImplementedError`, so the normal helper path does 

1121 not catch it. 

1122 

1123 Narrow guard: 

1124 

1125 - only `torch.isclose` and `torch.allclose` 

1126 - only when the first argument is a PTPU complex/fp64 tensor 

1127 - only on the known runtime error substring for the complex case 

1128 """ 

1129 patched_attr = "_flag_gems_sunrise_isclose_allclose_complex_dtype_patched" 

1130 if getattr(torch, patched_attr, False): 

1131 return 

1132 

1133 quirk_marker = "unsupported scalar type: complex" 

1134 original_isclose = torch.isclose 

1135 original_allclose = torch.allclose 

1136 

1137 def _should_fallback_compare(tensor): 

1138 return ( 

1139 isinstance(tensor, torch.Tensor) 

1140 and tensor.device.type == _PTPU_DEVICE 

1141 and (tensor.is_complex() or tensor.dtype == torch.float64) 

1142 ) 

1143 

1144 @functools.wraps(original_isclose) 

1145 def isclose_with_complex_cpu_fallback(*args, **kwargs): 

1146 tensor = args[0] if args else kwargs.get("input") 

1147 if not _flag_gems_use_gems_active() and _should_fallback_compare(tensor): 

1148 cpu_args = tuple(_to_cpu_if_ptpu(arg) for arg in args) 

1149 cpu_kwargs = {key: _to_cpu_if_ptpu(value) for key, value in kwargs.items()} 

1150 result = original_isclose(*cpu_args, **cpu_kwargs) 

1151 return _to_device_if_tensor(result, tensor.device) 

1152 try: 

1153 return original_isclose(*args, **kwargs) 

1154 except RuntimeError as exc: 

1155 if _flag_gems_use_gems_active(): 

1156 raise 

1157 if not _should_fallback_compare(tensor): 

1158 raise 

1159 if tensor.is_complex() and quirk_marker not in str(exc).lower(): 

1160 raise 

1161 cpu_args = tuple(_to_cpu_if_ptpu(arg) for arg in args) 

1162 cpu_kwargs = {key: _to_cpu_if_ptpu(value) for key, value in kwargs.items()} 

1163 result = original_isclose(*cpu_args, **cpu_kwargs) 

1164 return _to_device_if_tensor(result, tensor.device) 

1165 

1166 @functools.wraps(original_allclose) 

1167 def allclose_with_complex_cpu_fallback(*args, **kwargs): 

1168 tensor = args[0] if args else kwargs.get("input") 

1169 if not _flag_gems_use_gems_active() and _should_fallback_compare(tensor): 

1170 cpu_args = tuple(_to_cpu_if_ptpu(arg) for arg in args) 

1171 cpu_kwargs = {key: _to_cpu_if_ptpu(value) for key, value in kwargs.items()} 

1172 return original_allclose(*cpu_args, **cpu_kwargs) 

1173 try: 

1174 return original_allclose(*args, **kwargs) 

1175 except RuntimeError as exc: 

1176 if _flag_gems_use_gems_active(): 

1177 raise 

1178 if not _should_fallback_compare(tensor): 

1179 raise 

1180 if tensor.is_complex() and quirk_marker not in str(exc).lower(): 

1181 raise 

1182 cpu_args = tuple(_to_cpu_if_ptpu(arg) for arg in args) 

1183 cpu_kwargs = {key: _to_cpu_if_ptpu(value) for key, value in kwargs.items()} 

1184 return original_allclose(*cpu_args, **cpu_kwargs) 

1185 

1186 torch.isclose = isclose_with_complex_cpu_fallback 

1187 torch.allclose = allclose_with_complex_cpu_fallback 

1188 setattr(torch, patched_attr, True) 

1189 

1190 

1191def _patch_complex_matmul_runtime_error(): 

1192 """Fallback reference matmul-family ops to CPU on Sunrise/PTPU. 

1193 

1194 Complex reconstruction paths such as `u @ diag(s) @ v.mH` can fail outside 

1195 `flag_gems.use_gems()` with runtime errors from lowerings like: 

1196 

1197 - `addbmm_out not implemented for ComplexFloat` 

1198 - `baddbmm_out only supports float/half/bfloat16, got ComplexFloat` 

1199 

1200 Separately, some real-valued *degenerate batched matmuls* such as 

1201 `(..., 17, 1) @ (..., 1, 1)` can silently produce garbage on PTPU in the 

1202 same reference-style reconstruction path, even though the upstream SVD 

1203 factors themselves are correct. Route those narrow cases to CPU too. 

1204 

1205 This is a reference-path/runtime gap rather than a FlagGems kernel bug. 

1206 Keep the guard tight: 

1207 

1208 - only outside `flag_gems.use_gems()` 

1209 - always for PTPU complex/fp64 tensors 

1210 - additionally for PTPU real floating batched matmuls where at least one 

1211 tensor has a singleton matrix dimension (`min(shape[-2:]) == 1`) 

1212 - wrap matmul-family entry points that the Python surface can hit during 

1213 reconstruction: `Tensor.__matmul__`, `Tensor.matmul`, `torch.matmul`, 

1214 `torch.bmm`, `torch.addbmm`, `torch.baddbmm` 

1215 """ 

1216 tensor_attr = "_flag_gems_sunrise_tensor_matmul_complex_patched" 

1217 function_attr = "_flag_gems_sunrise_function_matmul_complex_patched" 

1218 if getattr(torch.Tensor, tensor_attr, False) and getattr( 

1219 torch, function_attr, False 

1220 ): 

1221 return 

1222 

1223 quirk_markers = ( 

1224 "addbmm_out not implemented for complex", 

1225 "baddbmm_out only supports float/half/bfloat16, got complex", 

1226 "unsupported scalar type: complex", 

1227 ) 

1228 

1229 def _ptpu_matmul_reference_tensor(*values): 

1230 first_ptpu_tensor = None 

1231 for value in values: 

1232 if not isinstance(value, torch.Tensor) or value.device.type != _PTPU_DEVICE: 

1233 continue 

1234 if first_ptpu_tensor is None: 

1235 first_ptpu_tensor = value 

1236 if value.is_complex() or value.dtype == torch.float64: 

1237 return value 

1238 return first_ptpu_tensor 

1239 

1240 def _ptpu_tensor_args(*values): 

1241 return [ 

1242 value 

1243 for value in values 

1244 if isinstance(value, torch.Tensor) and value.device.type == _PTPU_DEVICE 

1245 ] 

1246 

1247 def _should_route_reference_matmul(*values): 

1248 tensors = _ptpu_tensor_args(*values) 

1249 if not tensors: 

1250 return False 

1251 if any(t.is_complex() or t.dtype == torch.float64 for t in tensors): 

1252 return True 

1253 return any( 

1254 t.ndim >= 3 

1255 and t.is_floating_point() 

1256 and not t.is_complex() 

1257 and min(t.shape[-2:]) == 1 

1258 for t in tensors 

1259 ) 

1260 

1261 def _cpu_dispatch_to_reference_device(reference_tensor, original_fn, args, kwargs): 

1262 cpu_args = tuple(_to_cpu_if_ptpu(arg) for arg in args) 

1263 cpu_kwargs = {key: _to_cpu_if_ptpu(value) for key, value in kwargs.items()} 

1264 result = original_fn(*cpu_args, **cpu_kwargs) 

1265 out = kwargs.get("out") 

1266 return _finalize_cpu_result(result, out, reference_tensor.device) 

1267 

1268 original_tensor_matmul = torch.Tensor.matmul 

1269 original_tensor_dunder_matmul = torch.Tensor.__matmul__ 

1270 original_function_matmul = torch.matmul 

1271 original_function_bmm = torch.bmm 

1272 original_function_addbmm = torch.addbmm 

1273 original_function_baddbmm = torch.baddbmm 

1274 

1275 @functools.wraps(original_tensor_matmul) 

1276 def tensor_matmul_with_complex_cpu_fallback(self, other): 

1277 reference_tensor = _ptpu_matmul_reference_tensor(self, other) 

1278 if not _flag_gems_use_gems_active() and _should_route_reference_matmul( 

1279 self, other 

1280 ): 

1281 return _cpu_dispatch_to_reference_device( 

1282 reference_tensor, original_tensor_matmul, (self, other), {} 

1283 ) 

1284 try: 

1285 return original_tensor_matmul(self, other) 

1286 except RuntimeError as exc: 

1287 if _flag_gems_use_gems_active(): 

1288 raise 

1289 if reference_tensor is None or not any( 

1290 marker in str(exc).lower() for marker in quirk_markers 

1291 ): 

1292 raise 

1293 return _cpu_dispatch_to_reference_device( 

1294 reference_tensor, original_tensor_matmul, (self, other), {} 

1295 ) 

1296 

1297 @functools.wraps(original_tensor_dunder_matmul) 

1298 def tensor_dunder_matmul_with_complex_cpu_fallback(self, other): 

1299 reference_tensor = _ptpu_matmul_reference_tensor(self, other) 

1300 if not _flag_gems_use_gems_active() and _should_route_reference_matmul( 

1301 self, other 

1302 ): 

1303 return _cpu_dispatch_to_reference_device( 

1304 reference_tensor, original_tensor_dunder_matmul, (self, other), {} 

1305 ) 

1306 try: 

1307 return original_tensor_dunder_matmul(self, other) 

1308 except RuntimeError as exc: 

1309 if _flag_gems_use_gems_active(): 

1310 raise 

1311 if reference_tensor is None or not any( 

1312 marker in str(exc).lower() for marker in quirk_markers 

1313 ): 

1314 raise 

1315 return _cpu_dispatch_to_reference_device( 

1316 reference_tensor, original_tensor_dunder_matmul, (self, other), {} 

1317 ) 

1318 

1319 def _patch_torch_matmul_like(name, original_fn): 

1320 @functools.wraps(original_fn) 

1321 def fn_with_complex_cpu_fallback(*args, **kwargs): 

1322 reference_tensor = _ptpu_matmul_reference_tensor(*args, *kwargs.values()) 

1323 if not _flag_gems_use_gems_active() and _should_route_reference_matmul( 

1324 *args, *kwargs.values() 

1325 ): 

1326 return _cpu_dispatch_to_reference_device( 

1327 reference_tensor, original_fn, args, kwargs 

1328 ) 

1329 try: 

1330 return original_fn(*args, **kwargs) 

1331 except RuntimeError as exc: 

1332 if _flag_gems_use_gems_active(): 

1333 raise 

1334 if reference_tensor is None or not any( 

1335 marker in str(exc).lower() for marker in quirk_markers 

1336 ): 

1337 raise 

1338 return _cpu_dispatch_to_reference_device( 

1339 reference_tensor, original_fn, args, kwargs 

1340 ) 

1341 

1342 setattr(torch, name, fn_with_complex_cpu_fallback) 

1343 

1344 torch.Tensor.matmul = tensor_matmul_with_complex_cpu_fallback 

1345 torch.Tensor.__matmul__ = tensor_dunder_matmul_with_complex_cpu_fallback 

1346 _patch_torch_matmul_like("matmul", original_function_matmul) 

1347 _patch_torch_matmul_like("bmm", original_function_bmm) 

1348 _patch_torch_matmul_like("addbmm", original_function_addbmm) 

1349 _patch_torch_matmul_like("baddbmm", original_function_baddbmm) 

1350 setattr(torch.Tensor, tensor_attr, True) 

1351 setattr(torch, function_attr, True) 

1352 

1353 

1354def _flag_gems_use_gems_active(): 

1355 """Return True while a `flag_gems.use_gems()` context is active. 

1356 

1357 `use_gems()` sets the module-level `current_work_registrar` on enter and 

1358 `del`s it on exit, so `getattr(flag_gems, "current_work_registrar", None)` 

1359 is a reliable, side-effect-free signal for "are we currently dispatching 

1360 aten ops through FlagGems device kernels?". 

1361 """ 

1362 import flag_gems 

1363 

1364 return getattr(flag_gems, "current_work_registrar", None) is not None 

1365 

1366 

1367def _patch_torch_einsum_low_precision_reference(): 

1368 """Compute low-precision `torch.einsum(...)` reference matmuls on CPU. 

1369 

1370 This is a precision quirk, not a `NotImplementedError`. `torch.einsum` 

1371 lowers its contraction to a matmul/bmm. On Sunrise/PTPU the fp16 / bf16 

1372 matmul accumulates in low precision, while CPU (and CUDA) accumulate fp16 

1373 matmuls internally in fp32. Tests such as `test_flash_attn_varlen_func.py` 

1374 build their CPU "golden" reference with raw `torch.einsum("hqk,khd->qhd", 

1375 attn, v)` on tensors that happen to live on PTPU (the test wraps setup in 

1376 `with torch.device("ptpu")` and never routes the reference through 

1377 `accuracy_utils.to_reference()`), so the *reference itself* drifts by up to 

1378 ~0.5 versus the true CPU result and the assertion fails even though the 

1379 Sunrise flash-attention kernel under test is correct (~1e-3). 

1380 

1381 The fix mirrors the "wrong ref operator → CPU" rule: redirect only the 

1382 reference-path einsum to CPU. The guard is intentionally tight so that the 

1383 real device-under-test einsum (`test_einsum.py`, `test_fp8_einsum.py`, ...) 

1384 is never diverted: 

1385 

1386 - Skip entirely while `flag_gems.use_gems()` is active. The device path in 

1387 `test_einsum.py` runs einsum under `use_gems()`; the reference paths do 

1388 not. (No FlagGems op implementation calls `torch.einsum`, so this never 

1389 touches kernel internals.) 

1390 - Only divert when at least one operand is a PTPU tensor. 

1391 - Only divert when the contraction dtype is fp16 / bf16. fp32 / fp64 

1392 references (e.g. `to_reference(.., upcast=True)`, `q.float()`) already 

1393 match CPU and are left on device. 

1394 - Equivalent to upcasting the einsum to fp32 on device, but computing on 

1395 CPU keeps the reference identical to a `--ref cpu` golden value. 

1396 """ 

1397 patched_attr = "_flag_gems_sunrise_einsum_low_precision_patched" 

1398 if getattr(torch, patched_attr, False): 

1399 return 

1400 

1401 original_fn = torch.einsum 

1402 low_precision_dtypes = (torch.float16, torch.bfloat16) 

1403 

1404 def _operand_tensors(operands): 

1405 # torch.einsum accepts either (equation, *tensors) or 

1406 # (equation, [tensors]); flatten the sublist form too. 

1407 for operand in operands: 

1408 if isinstance(operand, torch.Tensor): 

1409 yield operand 

1410 elif isinstance(operand, (list, tuple)): 

1411 for item in operand: 

1412 if isinstance(item, torch.Tensor): 

1413 yield item 

1414 

1415 @functools.wraps(original_fn) 

1416 def einsum_with_ptpu_low_precision_cpu_reference(equation, *operands): 

1417 if not _flag_gems_use_gems_active(): 

1418 tensors = list(_operand_tensors(operands)) 

1419 if any(_is_ptpu_tensor(t) for t in tensors) and any( 

1420 t.dtype in low_precision_dtypes for t in tensors 

1421 ): 

1422 cpu_operands = tuple( 

1423 _to_cpu_if_ptpu(operand) 

1424 if isinstance(operand, torch.Tensor) 

1425 else ( 

1426 [_to_cpu_if_ptpu(item) for item in operand] 

1427 if isinstance(operand, (list, tuple)) 

1428 else operand 

1429 ) 

1430 for operand in operands 

1431 ) 

1432 device = next((t.device for t in tensors if _is_ptpu_tensor(t)), None) 

1433 result = original_fn(equation, *cpu_operands) 

1434 return _to_device_if_tensor(result, device) 

1435 return original_fn(equation, *operands) 

1436 

1437 torch.einsum = einsum_with_ptpu_low_precision_cpu_reference 

1438 setattr(torch, patched_attr, True) 

1439 

1440 

1441def _patch_bool_sum_cpu_reference(): 

1442 """Compute PTPU bool-tensor `sum` reductions on CPU outside `use_gems()`. 

1443 

1444 Sunrise/PTPU occasionally returns the wrong population count for boolean 

1445 masks in test-setup code such as `numel = mask.sum().item()`. This is a 

1446 silent semantic quirk rather than a `NotImplementedError`, so we cannot 

1447 rely on the normal exception-driven CPU fallback helpers. 

1448 

1449 Keep the guard intentionally tight: 

1450 

1451 - only `torch.Tensor.sum` / `torch.sum` 

1452 - only when the input tensor is a PTPU bool tensor 

1453 - only outside `flag_gems.use_gems()`, so the real reduction kernels under 

1454 test are still exercised inside the device path 

1455 """ 

1456 tensor_attr = "_flag_gems_sunrise_tensor_bool_sum_cpu_patched" 

1457 function_attr = "_flag_gems_sunrise_function_bool_sum_cpu_patched" 

1458 if getattr(torch.Tensor, tensor_attr, False) and getattr( 

1459 torch, function_attr, False 

1460 ): 

1461 return 

1462 

1463 original_tensor_sum = torch.Tensor.sum 

1464 original_function_sum = torch.sum 

1465 

1466 def _should_route_bool_sum(tensor): 

1467 return ( 

1468 isinstance(tensor, torch.Tensor) 

1469 and tensor.device.type == _PTPU_DEVICE 

1470 and tensor.dtype == torch.bool 

1471 ) 

1472 

1473 @functools.wraps(original_tensor_sum) 

1474 def tensor_sum_with_bool_cpu_fallback(self, *args, **kwargs): 

1475 if not _flag_gems_use_gems_active() and _should_route_bool_sum(self): 

1476 return _cpu_fallback(self, args, kwargs, original_tensor_sum) 

1477 return original_tensor_sum(self, *args, **kwargs) 

1478 

1479 @functools.wraps(original_function_sum) 

1480 def function_sum_with_bool_cpu_fallback(*args, **kwargs): 

1481 tensor = args[0] if args else kwargs.get("input") 

1482 if not _flag_gems_use_gems_active() and _should_route_bool_sum(tensor): 

1483 return _torch_function_cpu_fallback( 

1484 tensor, args, kwargs, original_function_sum 

1485 ) 

1486 return original_function_sum(*args, **kwargs) 

1487 

1488 torch.Tensor.sum = tensor_sum_with_bool_cpu_fallback 

1489 torch.sum = function_sum_with_bool_cpu_fallback 

1490 setattr(torch.Tensor, tensor_attr, True) 

1491 setattr(torch, function_attr, True) 

1492 

1493 

1494def _patch_torch_nn_functional_one_hot_cpu_reference(): 

1495 """Compute `torch.nn.functional.one_hot(...)` on CPU for PTPU inputs. 

1496 

1497 Tests such as `test_multinomial.py` build reference counts with 

1498 `torch.nn.functional.one_hot(...)` directly on tensors that may live on 

1499 PTPU. Route only that reference-style path to CPU: 

1500 

1501 - only `torch.nn.functional.one_hot` 

1502 - only when the input tensor is on PTPU 

1503 - only outside `flag_gems.use_gems()`, so the real backend one_hot path 

1504 remains available inside the device-under-test region 

1505 """ 

1506 patched_attr = "_flag_gems_sunrise_nn_functional_one_hot_cpu_patched" 

1507 if getattr(F, patched_attr, False): 

1508 return 

1509 

1510 original_fn = F.one_hot 

1511 

1512 @functools.wraps(original_fn) 

1513 def one_hot_with_ptpu_cpu_reference(*args, **kwargs): 

1514 tensor = args[0] if args else kwargs.get("tensor") or kwargs.get("input") 

1515 if not _flag_gems_use_gems_active() and _is_ptpu_tensor(tensor): 

1516 return _torch_function_cpu_fallback(tensor, args, kwargs, original_fn) 

1517 return original_fn(*args, **kwargs) 

1518 

1519 F.one_hot = one_hot_with_ptpu_cpu_reference 

1520 setattr(F, patched_attr, True) 

1521 

1522 

1523def _patch_torch_packet(packet_name, aten_op): 

1524 packet = getattr(torch.ops.aten, packet_name) 

1525 patched_attr = "_flag_gems_sunrise_packet_patched" 

1526 if getattr(packet, patched_attr, False): 

1527 return 

1528 

1529 original_fn = packet._op 

1530 

1531 @functools.wraps(original_fn) 

1532 def packet_with_ptpu_cpu_fallback(*args, **kwargs): 

1533 tensor = args[0] if args else kwargs.get("self") or kwargs.get("input") 

1534 try: 

1535 return original_fn(*args, **kwargs) 

1536 except NotImplementedError as exc: 

1537 if _flag_gems_use_gems_active(): 

1538 raise 

1539 if not _should_fallback_to_cpu(exc, tensor, aten_op): 

1540 raise 

1541 return _torch_function_cpu_fallback(tensor, args, kwargs, original_fn) 

1542 

1543 packet._op = packet_with_ptpu_cpu_fallback 

1544 setattr(packet, patched_attr, True) 

1545 

1546 

1547def _patch_torch_ptpu_get_device_index(): 

1548 """Work around torch_ptpu's `_get_device_index()` choking on an index-less 

1549 `torch.device('ptpu')`: `device.index` is None, so the trailing 

1550 `device >= 0` raises `TypeError: '>=' not supported between NoneType and 

1551 int`. flag_gems constructor/RNG ops pass exactly such a device into 

1552 `torch_device_fn.device(device)` under `use_gems()`. Coerce a None index to 

1553 `current_device()`. Every torch_ptpu device helper and the device guard 

1554 resolve `_get_device_index` from the `torch_ptpu.ptpu` module globals, so 

1555 rebinding it there fixes them all. 

1556 """ 

1557 try: 

1558 import torch_ptpu.ptpu as _ptpu 

1559 except Exception: 

1560 return 

1561 

1562 if getattr(_ptpu, "_flag_gems_sunrise_gdi_patched", False): 

1563 return 

1564 

1565 original_fn = getattr(_ptpu, "_get_device_index", None) 

1566 if original_fn is None: 

1567 return 

1568 

1569 @functools.wraps(original_fn) 

1570 def get_device_index_with_index_fallback(device): 

1571 if isinstance(device, torch.device) and device.index is None: 

1572 device = _ptpu.current_device() 

1573 return original_fn(device) 

1574 

1575 _ptpu._get_device_index = get_device_index_with_index_fallback 

1576 _ptpu._flag_gems_sunrise_gdi_patched = True 

1577 

1578 

1579def _pytest_terminal_summary_frame(): 

1580 for frame_info in inspect.stack(context=0): 

1581 frame_path = os.path.normpath(frame_info.filename) 

1582 if frame_info.function == "pytest_terminal_summary" and frame_path.endswith( 

1583 os.path.join("tests", "conftest.py") 

1584 ): 

1585 return frame_info 

1586 return None 

1587 

1588 

1589def _backup_corrupt_accuracy_report(frame_info, payload): 

1590 if not payload: 

1591 return None 

1592 frame = frame_info.frame 

1593 json_file = frame.f_locals.get("json_file") 

1594 report_path = getattr(json_file, "name", None) or frame.f_globals.get("REPORT_FILE") 

1595 if not report_path: 

1596 return None 

1597 report_path = os.path.abspath(os.fspath(report_path)) 

1598 backup_path = ( 

1599 f"{report_path}.corrupt." f"{os.getpid()}." f"{int(time.time() * 1000)}" 

1600 ) 

1601 with open(backup_path, "w", encoding="utf-8") as backup_file: 

1602 backup_file.write(payload) 

1603 return backup_path 

1604 

1605 

1606def _sanitize_accuracy_report_json(value): 

1607 if isinstance(value, torch.Tensor): 

1608 return ( 

1609 { 

1610 "__tensor__": True, 

1611 "dtype": str(value.dtype), 

1612 "shape": list(value.shape), 

1613 "device": str(value.device), 

1614 "requires_grad": bool(value.requires_grad), 

1615 }, 

1616 1, 

1617 ) 

1618 if isinstance(value, dict): 

1619 sanitized = {} 

1620 replaced = 0 

1621 for key, item in value.items(): 

1622 if isinstance(key, (str, int, float, bool)) or key is None: 

1623 safe_key = key 

1624 else: 

1625 safe_key = str(key) 

1626 safe_item, item_replaced = _sanitize_accuracy_report_json(item) 

1627 sanitized[safe_key] = safe_item 

1628 replaced += item_replaced 

1629 return sanitized, replaced 

1630 if isinstance(value, (list, tuple)): 

1631 sanitized = [] 

1632 replaced = 0 

1633 for item in value: 

1634 safe_item, item_replaced = _sanitize_accuracy_report_json(item) 

1635 sanitized.append(safe_item) 

1636 replaced += item_replaced 

1637 return sanitized, replaced 

1638 if isinstance(value, (set, frozenset)): 

1639 sanitized = [] 

1640 replaced = 0 

1641 for item in value: 

1642 safe_item, item_replaced = _sanitize_accuracy_report_json(item) 

1643 sanitized.append(safe_item) 

1644 replaced += item_replaced 

1645 return sanitized, replaced 

1646 return value, 0 

1647 

1648 

1649def _patch_json_loads_for_accuracy_result(): 

1650 """Ignore a truncated `accuracy_result.json` in test summary on Sunrise. 

1651 

1652 Some CI jobs finish all pytest cases successfully, then fail in 

1653 `tests/conftest.py::pytest_terminal_summary` while merging the accumulated 

1654 `accuracy_result.json`. The failure is a plain `json.JSONDecodeError` on a 

1655 previously truncated file, so keep the fix narrow and Sunrise-local: 

1656 

1657 - patch only `json.loads` 

1658 - only intercept `json.JSONDecodeError` 

1659 - only when the caller is `tests/conftest.py::pytest_terminal_summary` 

1660 - backup the corrupt payload before falling back to `{}` 

1661 """ 

1662 patched_attr = "_flag_gems_sunrise_accuracy_json_loads_patched" 

1663 if getattr(json, patched_attr, False): 

1664 return 

1665 

1666 original_fn = json.loads 

1667 

1668 @functools.wraps(original_fn) 

1669 def loads_with_accuracy_result_fallback(*args, **kwargs): 

1670 try: 

1671 return original_fn(*args, **kwargs) 

1672 except json.JSONDecodeError: 

1673 frame_info = _pytest_terminal_summary_frame() 

1674 if frame_info is None: 

1675 raise 

1676 payload = args[0] if args else kwargs.get("s") 

1677 backup_path = None 

1678 try: 

1679 backup_path = _backup_corrupt_accuracy_report(frame_info, payload) 

1680 except OSError as backup_exc: 

1681 _LOGGER.warning( 

1682 "Sunrise skipped corrupt accuracy_result backup: %s", backup_exc 

1683 ) 

1684 if backup_path is not None: 

1685 _LOGGER.warning( 

1686 "Sunrise ignored corrupt accuracy_result JSON and backed it up to %s", 

1687 backup_path, 

1688 ) 

1689 else: 

1690 _LOGGER.warning("Sunrise ignored corrupt accuracy_result JSON") 

1691 return {} 

1692 

1693 json.loads = loads_with_accuracy_result_fallback 

1694 setattr(json, patched_attr, True) 

1695 

1696 

1697def _patch_json_dump_for_accuracy_result(): 

1698 """Sanitize tensor payloads before pytest summary writes JSON on Sunrise.""" 

1699 patched_attr = "_flag_gems_sunrise_accuracy_json_dump_patched" 

1700 if getattr(json, patched_attr, False): 

1701 return 

1702 

1703 original_fn = json.dump 

1704 

1705 @functools.wraps(original_fn) 

1706 def dump_with_accuracy_result_sanitize(*args, **kwargs): 

1707 frame_info = _pytest_terminal_summary_frame() 

1708 if frame_info is None or not args: 

1709 return original_fn(*args, **kwargs) 

1710 payload = args[0] 

1711 safe_payload, replaced = _sanitize_accuracy_report_json(payload) 

1712 if replaced: 

1713 _LOGGER.warning( 

1714 "Sunrise sanitized %d tensor value(s) before writing accuracy_result JSON", 

1715 replaced, 

1716 ) 

1717 args = (safe_payload, *args[1:]) 

1718 return original_fn(*args, **kwargs) 

1719 

1720 json.dump = dump_with_accuracy_result_sanitize 

1721 setattr(json, patched_attr, True) 

1722 

1723 

1724def apply_sunrise_monkey_patches(): 

1725 _patch_torch_ptpu_get_device_index() 

1726 _patch_json_loads_for_accuracy_result() 

1727 _patch_json_dump_for_accuracy_result() 

1728 _patch_tensor_copy_scalar_fill_fallback() 

1729 # triu 

1730 _patch_tensor_method("triu", "aten::triu.out") 

1731 _patch_tensor_method("triu_", "aten::triu.out", inplace=True) 

1732 _patch_torch_function("triu", "aten::triu.out") 

1733 

1734 # tanh 

1735 _patch_tensor_method("tanh", "aten::tanh.out") 

1736 _patch_tensor_method("tanh_", "aten::tanh.out", inplace=True) 

1737 _patch_torch_function("tanh", "aten::tanh.out") 

1738 

1739 # relu 

1740 _patch_tensor_method("relu", "aten::relu") 

1741 _patch_tensor_method("relu_", "aten::relu", inplace=True) 

1742 _patch_torch_function("relu", "aten::relu") 

1743 

1744 # clamp_min 

1745 _patch_tensor_method("clamp_min", "aten::clamp_min") 

1746 _patch_tensor_method("clamp_min_", "aten::clamp_min", inplace=True) 

1747 _patch_torch_function("clamp_min", "aten::clamp_min") 

1748 _patch_torch_function("clamp_min_", "aten::clamp_min", inplace=True) 

1749 _patch_torch_tensor_out("clamp_min", "aten::clamp_min.Tensor_out") 

1750 

1751 # remainder / mod 

1752 _patch_tensor_method("__mod__", "aten::remainder") 

1753 _patch_tensor_method("remainder", "aten::remainder") 

1754 _patch_tensor_method("remainder_", "aten::remainder", inplace=True) 

1755 _patch_torch_function("remainder", "aten::remainder") 

1756 _patch_torch_tensor_out("remainder", "aten::remainder.Tensor_out") 

1757 

1758 # floor_divide 

1759 _patch_tensor_method("__floordiv__", "aten::floor_divide") 

1760 _patch_tensor_method("floor_divide", "aten::floor_divide") 

1761 _patch_tensor_method("floor_divide_", "aten::floor_divide", inplace=True) 

1762 _patch_torch_function("floor_divide", "aten::floor_divide") 

1763 _patch_bool_sum_cpu_reference() 

1764 

1765 # reductions used in tests 

1766 _patch_torch_function("amin", "aten::amin") 

1767 _patch_torch_function("amax", "aten::amax") 

1768 _patch_tensor_method("min", "aten::min") 

1769 _patch_torch_function("min", "aten::min") 

1770 _patch_tensor_method("median", "aten::median") 

1771 _patch_torch_function("median", "aten::median") 

1772 _patch_tensor_method("amax", "aten::amax.out") 

1773 _patch_tensor_method("logsumexp", "aten::amax.out") 

1774 _patch_torch_function("logsumexp", "aten::amax.out") 

1775 _patch_tensor_method("mean", "aten::mean") 

1776 _patch_torch_function("mean", "aten::mean") 

1777 _patch_torch_function("norm", "aten::linalg_vector_norm.out") 

1778 _patch_torch_linalg_function("vector_norm", "aten::linalg_vector_norm.out") 

1779 _patch_torch_linalg_function("qr", "aten::linalg_qr.out") 

1780 _patch_torch_function("unique_consecutive", "aten::unique_consecutive") 

1781 

1782 # misc test helpers 

1783 _patch_tensor_method("__invert__", "aten::bitwise_not.out") 

1784 _patch_tensor_method("bitwise_not", "aten::bitwise_not.out") 

1785 _patch_tensor_method("bitwise_not_", "aten::bitwise_not.out", inplace=True) 

1786 _patch_torch_function("bitwise_not", "aten::bitwise_not.out") 

1787 _patch_tensor_method("__and__", "aten::bitwise_and") 

1788 _patch_tensor_method("bitwise_and", "aten::bitwise_and") 

1789 _patch_tensor_method("bitwise_and_", "aten::bitwise_and", inplace=True) 

1790 _patch_torch_function("bitwise_and", "aten::bitwise_and") 

1791 _patch_torch_tensor_out("bitwise_and", "aten::bitwise_and.Tensor_out") 

1792 _patch_tensor_method("masked_select", "aten::masked_select") 

1793 _patch_torch_function("masked_select", "aten::masked_select") 

1794 _patch_tensor_method("__or__", "aten::bitwise_or") 

1795 _patch_torch_function("bitwise_or", "aten::bitwise_or") 

1796 _patch_torch_tensor_out("bitwise_or", "aten::bitwise_or.Tensor_out") 

1797 _patch_torch_function("isclose", "aten::bitwise_and.Tensor_out") 

1798 _patch_torch_function("allclose", "aten::bitwise_and.Tensor_out") 

1799 _patch_torch_function("complex", "aten::complex.out") 

1800 _patch_torch_creation_function("eye", "aten::eye.m_out") 

1801 _patch_torch_creation_function("linspace", "aten::linspace.out") 

1802 _patch_torch_out("hypot", "aten::hypot.out") 

1803 _patch_torch_creation_function("randperm", "aten::randperm.generator_out") 

1804 _patch_tensor_property("real", "aten::view_as_real") 

1805 _patch_tensor_property("imag", "aten::view_as_real") 

1806 _patch_torch_nn_functional("pad", "aten::replication_pad3d.out") 

1807 _patch_torch_nn_functional("logsigmoid", "aten::log_sigmoid_forward") 

1808 _patch_torch_nn_functional_one_hot_cpu_reference() 

1809 _patch_torch_randn_complex_dtype() 

1810 _patch_torch_cudnn_convolution() 

1811 _patch_torch_div_floor_trunc_integer_dtype() 

1812 _patch_tensor_to_cpu_for_complex_views() 

1813 _patch_complex_tensor_scalar_mul_runtime_error() 

1814 _patch_complex_tensor_add_runtime_error() 

1815 _patch_complex_tensor_add_runtime_error() 

1816 _patch_complex_matmul_runtime_error() 

1817 _patch_torch_isclose_allclose_complex_dtype() 

1818 _patch_torch_einsum_low_precision_reference() 

1819 _patch_torch_packet("elu", "aten::elu.out")