Coverage for src/flag_gems/runtime/backend/_sunrise/ops/to.py: 0%

58 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2from typing import Optional 

3 

4import torch 

5import triton 

6 

7from flag_gems.utils import pointwise_dynamic 

8 

9logger = logging.getLogger(__name__) 

10 

11_FALLBACK_KEYSET = torch._C.DispatchKeySet( 

12 torch._C.DispatchKey.CompositeExplicitAutograd 

13) 

14 

15 

16@pointwise_dynamic( 

17 is_tensor=[ 

18 True, 

19 ], 

20 promotion_methods=[(0, "DEFAULT")], 

21) 

22@triton.jit 

23def _to_copy_func(x): 

24 return x 

25 

26 

27def _resolve_dtype(x: torch.Tensor, dtype: Optional[torch.dtype]) -> torch.dtype: 

28 if dtype is None: 

29 return x.dtype 

30 if isinstance(dtype, torch.dtype): 

31 return dtype 

32 raise TypeError(f"Unsupported dtype argument type: {type(dtype)!r}") 

33 

34 

35def _resolve_device(x: torch.Tensor, device: Optional[torch.device]) -> torch.device: 

36 if device is None: 

37 return x.device 

38 return torch.device(device) 

39 

40 

41def _normalize_memory_format( 

42 memory_format: Optional[torch.memory_format], 

43) -> torch.memory_format: 

44 if memory_format is None: 

45 return torch.preserve_format 

46 return memory_format 

47 

48 

49def _allocate_preserve_format(x: torch.Tensor, empty_kwargs: dict) -> torch.Tensor: 

50 """Recreate tensor storage while honoring preserve_format semantics.""" 

51 if torch.ops.aten.is_non_overlapping_and_dense(x): 

52 return torch.empty_strided(x.size(), x.stride(), **empty_kwargs) 

53 # Fall back to PyTorch's best-effort layout suggestion when stride replication is unsafe. 

54 return torch.empty_like(x, memory_format=torch.preserve_format, **empty_kwargs) 

55 

56 

57def _fallback_to_copy( 

58 x: torch.Tensor, 

59 *, 

60 dtype: torch.dtype, 

61 layout, 

62 device: torch.device, 

63 pin_memory, 

64 non_blocking: bool, 

65 memory_format: torch.memory_format, 

66): 

67 return torch.ops.aten._to_copy.default.redispatch( 

68 _FALLBACK_KEYSET, 

69 x, 

70 dtype=dtype, 

71 layout=layout, 

72 device=device, 

73 pin_memory=pin_memory, 

74 non_blocking=non_blocking, 

75 memory_format=memory_format, 

76 ) 

77 

78 

79# func: _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, 

80# bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor 

81def to_copy( 

82 x, 

83 *, 

84 dtype=None, 

85 layout=None, 

86 device=None, 

87 pin_memory=None, 

88 non_blocking=False, 

89 memory_format=None, 

90): 

91 # We only implement the dense strided kernel today; all other layouts fall back to PyTorch. 

92 if (layout is not None and layout != torch.strided) or x.layout != torch.strided: 

93 raise NotImplementedError( 

94 "FlagGems to_copy currently supports strided tensors only." 

95 ) 

96 if pin_memory is not None: 

97 raise NotImplementedError( 

98 "FlagGems to_copy does not yet support pin_memory=True." 

99 ) 

100 if x.is_quantized: 

101 raise NotImplementedError( 

102 "Quantized tensors are not supported in FlagGems to_copy yet." 

103 ) 

104 

105 target_dtype = _resolve_dtype(x, dtype) 

106 target_device = _resolve_device(x, device) 

107 target_memory_format = _normalize_memory_format(memory_format) 

108 

109 # PTPU can hold complex tensors, but the same-device ptpu cast path drops into 

110 # a backend copy_ implementation that does not handle complex. Stage through CPU 

111 # to avoid ptpu complex copy_/view_as_real gaps. 

112 if x.dtype.is_complex or target_dtype.is_complex: 

113 logger.debug("GEMS_SUNRISE _TO_COPY COMPLEX VIA CPU") 

114 cpu_x = x 

115 if x.device.type != "cpu": 

116 cpu_x = _fallback_to_copy( 

117 x, 

118 dtype=x.dtype, 

119 layout=layout, 

120 device=torch.device("cpu"), 

121 pin_memory=pin_memory, 

122 non_blocking=non_blocking, 

123 memory_format=target_memory_format, 

124 ) 

125 cpu_res = _fallback_to_copy( 

126 cpu_x, 

127 dtype=target_dtype, 

128 layout=layout, 

129 device=torch.device("cpu"), 

130 pin_memory=pin_memory, 

131 non_blocking=non_blocking, 

132 memory_format=target_memory_format, 

133 ) 

134 if target_device.type == "cpu": 

135 return cpu_res 

136 return _fallback_to_copy( 

137 cpu_res, 

138 dtype=target_dtype, 

139 layout=layout, 

140 device=target_device, 

141 pin_memory=pin_memory, 

142 non_blocking=non_blocking, 

143 memory_format=target_memory_format, 

144 ) 

145 

146 if target_device != x.device or ( 

147 x.device.type == "cpu" and target_device.type == "cpu" 

148 ): 

149 # Device transfer (d2h/h2d etc.) relies on PyTorch's implementation. 

150 return _fallback_to_copy( 

151 x, 

152 dtype=target_dtype, 

153 layout=layout, 

154 device=target_device, 

155 pin_memory=pin_memory, 

156 non_blocking=non_blocking, 

157 memory_format=target_memory_format, 

158 ) 

159 

160 logger.debug("GEMS_SUNRISE _TO_COPY") 

161 empty_kwargs = {"dtype": target_dtype, "device": target_device} 

162 

163 if target_memory_format is torch.preserve_format: 

164 out = _allocate_preserve_format(x, empty_kwargs) 

165 else: 

166 out = torch.empty_like(x, memory_format=target_memory_format, **empty_kwargs) 

167 

168 return _to_copy_func(x, out0=out)