Coverage for src/flag_gems/ops/copy.py: 71%

65 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2from typing import Optional 

3 

4import torch 

5import triton 

6 

7from flag_gems.utils import pointwise_dynamic 

8 

9logger = logging.getLogger(__name__) 

10 

11_FALLBACK_KEYSET = torch._C.DispatchKeySet( 

12 torch._C.DispatchKey.CompositeExplicitAutograd 

13) 

14 

15_FLOAT8_E8M0FNU = getattr(torch, "float8_e8m0fnu", None) 

16 

17 

18@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "DEFAULT")]) 

19@triton.jit 

20def _copy_kernel(src): 

21 return src 

22 

23 

24def _can_use_triton(dst: torch.Tensor, src: torch.Tensor) -> bool: 

25 if dst.layout != torch.strided or src.layout != torch.strided: 

26 return False 

27 if dst.device != src.device: 

28 return False 

29 if dst.is_quantized or src.is_quantized: 

30 return False 

31 if src.is_complex() or dst.is_complex(): 

32 # Preserve PyTorch's behaviour of warning when casting complex to real 

33 # by forcing the redispatch path, which issues the warning internally. 

34 return False 

35 if _FLOAT8_E8M0FNU is not None and ( 

36 src.dtype == _FLOAT8_E8M0FNU or dst.dtype == _FLOAT8_E8M0FNU 

37 ): 

38 # Triton does not support float8 yet, so defer to PyTorch which has a reference implementation. 

39 return False 

40 return True 

41 

42 

43def _expand_like(src: torch.Tensor, target_shape: torch.Size) -> torch.Tensor: 

44 if src.shape == target_shape: 

45 return src 

46 return src.expand(target_shape) 

47 

48 

49def copy( 

50 template: torch.Tensor, src: torch.Tensor, *, non_blocking: Optional[bool] = False 

51): 

52 logger.debug("GEMS COPY (functional)") 

53 out = torch.empty_strided( 

54 template.size(), template.stride(), dtype=template.dtype, device=template.device 

55 ) 

56 copy_(out, src, non_blocking=bool(non_blocking)) 

57 return out 

58 

59 

60def copy_(dst: torch.Tensor, src: torch.Tensor, non_blocking: bool = False): 

61 if isinstance(src, (int, float, bool)): 

62 src = torch.tensor(src, device=dst.device) 

63 elif not isinstance(src, torch.Tensor): 

64 raise TypeError("unsupport src type for copy_: ", type(src)) 

65 

66 # this is the same as PyTorch's check 

67 if dst._is_zerotensor(): 

68 raise RuntimeError("ZeroTensors are immutable. Call clone() before copy_.") 

69 if src._is_zerotensor(): 

70 return dst.zero_() 

71 

72 if torch._C._is_alias_of(dst, src): 

73 # Align with PyTorch: if metadata fully matches, this is a no-op. 

74 if ( 

75 dst.storage_offset() == src.storage_offset() 

76 and dst.stride() == src.stride() 

77 and dst.size() == src.size() 

78 and dst.dtype == src.dtype 

79 and dst.device == src.device 

80 and dst.is_conj() == src.is_conj() 

81 and dst.is_neg() == src.is_neg() 

82 ): 

83 return dst 

84 # Otherwise defer to PyTorch for well-defined semantics on overlapping writes. 

85 return torch.ops.aten.copy_.default.redispatch( 

86 _FALLBACK_KEYSET, dst, src, non_blocking 

87 ) 

88 

89 if _FLOAT8_E8M0FNU is not None and ( 

90 src.dtype == _FLOAT8_E8M0FNU or dst.dtype == _FLOAT8_E8M0FNU 

91 ): 

92 return torch.ops.aten.copy_.default.redispatch( 

93 _FALLBACK_KEYSET, dst, src, non_blocking 

94 ) 

95 

96 if src.numel() > 2**31 - 1 or dst.numel() > 2**31 - 1: 

97 return torch.ops.aten.copy_.default.redispatch( 

98 _FALLBACK_KEYSET, dst, src, non_blocking 

99 ) 

100 

101 if not _can_use_triton(dst, src): 

102 return torch.ops.aten.copy_.default.redispatch( 

103 _FALLBACK_KEYSET, dst, src, non_blocking 

104 ) 

105 

106 if dst.numel() == 0: 

107 # Respect PyTorch behaviour: empty tensors should still validate broadcast. 

108 return torch.ops.aten.copy_.default.redispatch( 

109 _FALLBACK_KEYSET, dst, src, non_blocking 

110 ) 

111 

112 logger.debug("GEMS COPY_") 

113 

114 try: 

115 broadcast_shape = torch.broadcast_shapes(dst.shape, src.shape) 

116 except RuntimeError as exc: 

117 raise RuntimeError(str(exc)) from exc 

118 

119 if torch.Size(broadcast_shape) != dst.shape: 

120 raise RuntimeError( 

121 f"The broadcast shape {broadcast_shape} does not match destination shape {tuple(dst.shape)}" 

122 ) 

123 

124 expanded_src = _expand_like(src, dst.shape) 

125 

126 overload = _copy_kernel.instantiate(expanded_src.ndim) 

127 overload(expanded_src, out0=dst) 

128 return dst