Coverage for src/flag_gems/runtime/backend/_sunrise/ops/to.py: 0%

58 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2from typing import Optional 

3 

4import torch 

5import triton 

6 

7from flag_gems.utils import pointwise_dynamic 

8 

9logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

10 

11_FALLBACK_KEYSET = torch._C.DispatchKeySet( 

12 torch._C.DispatchKey.CompositeExplicitAutograd 

13) 

14 

15 

16@pointwise_dynamic( 

17 is_tensor=[ 

18 True, 

19 ], 

20 promotion_methods=[(0, "DEFAULT")], 

21) 

22@triton.jit 

23def _to_copy_func(x): 

24 return x 

25 

26 

27def _resolve_dtype(x: torch.Tensor, dtype: Optional[torch.dtype]) -> torch.dtype: 

28 if dtype is None: 

29 return x.dtype 

30 if isinstance(dtype, torch.dtype): 

31 return dtype 

32 raise TypeError(f"Unsupported dtype argument type: {type(dtype)!r}") 

33 

34 

35def _resolve_device(x: torch.Tensor, device: Optional[torch.device]) -> torch.device: 

36 if device is None: 

37 return x.device 

38 return torch.device(device) 

39 

40 

41def _normalize_memory_format( 

42 memory_format: Optional[torch.memory_format], 

43) -> torch.memory_format: 

44 if memory_format is None: 

45 return torch.preserve_format 

46 return memory_format 

47 

48 

49def _allocate_preserve_format(x: torch.Tensor, empty_kwargs: dict) -> torch.Tensor: 

50 if torch.ops.aten.is_non_overlapping_and_dense(x): 

51 return torch.empty_strided(x.size(), x.stride(), **empty_kwargs) 

52 return torch.empty_like(x, memory_format=torch.preserve_format, **empty_kwargs) 

53 

54 

55def _fallback_to_copy( 

56 x: torch.Tensor, 

57 *, 

58 dtype: torch.dtype, 

59 layout, 

60 device: torch.device, 

61 pin_memory, 

62 non_blocking: bool, 

63 memory_format: torch.memory_format, 

64): 

65 return torch.ops.aten._to_copy.default.redispatch( 

66 _FALLBACK_KEYSET, 

67 x, 

68 dtype=dtype, 

69 layout=layout, 

70 device=device, 

71 pin_memory=pin_memory, 

72 non_blocking=non_blocking, 

73 memory_format=memory_format, 

74 ) 

75 

76 

77# func: _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, 

78# bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor 

79def to_copy( 

80 x, 

81 *, 

82 dtype=None, 

83 layout=None, 

84 device=None, 

85 pin_memory=None, 

86 non_blocking=False, 

87 memory_format=None, 

88): 

89 if (layout is not None and layout != torch.strided) or x.layout != torch.strided: 

90 raise NotImplementedError( 

91 "FlagGems to_copy currently supports strided tensors only." 

92 ) 

93 if pin_memory is not None: 

94 raise NotImplementedError( 

95 "FlagGems to_copy does not yet support pin_memory=True." 

96 ) 

97 if x.is_quantized: 

98 raise NotImplementedError( 

99 "Quantized tensors are not supported in FlagGems to_copy yet." 

100 ) 

101 

102 target_dtype = _resolve_dtype(x, dtype) 

103 target_device = _resolve_device(x, device) 

104 target_memory_format = _normalize_memory_format(memory_format) 

105 

106 # PTPU can hold complex tensors, but the same-device ptpu cast path drops into 

107 # a backend copy_ implementation that does not handle complex. Stage through CPU 

108 # to avoid ptpu complex copy_/view_as_real gaps. 

109 if x.dtype.is_complex or target_dtype.is_complex: 

110 logger.debug("GEMS_SUNRISE _TO_COPY COMPLEX VIA CPU") 

111 cpu_x = x 

112 if x.device.type != "cpu": 

113 cpu_x = _fallback_to_copy( 

114 x, 

115 dtype=x.dtype, 

116 layout=layout, 

117 device=torch.device("cpu"), 

118 pin_memory=pin_memory, 

119 non_blocking=non_blocking, 

120 memory_format=target_memory_format, 

121 ) 

122 cpu_res = _fallback_to_copy( 

123 cpu_x, 

124 dtype=target_dtype, 

125 layout=layout, 

126 device=torch.device("cpu"), 

127 pin_memory=pin_memory, 

128 non_blocking=non_blocking, 

129 memory_format=target_memory_format, 

130 ) 

131 if target_device.type == "cpu": 

132 return cpu_res 

133 return _fallback_to_copy( 

134 cpu_res, 

135 dtype=target_dtype, 

136 layout=layout, 

137 device=target_device, 

138 pin_memory=pin_memory, 

139 non_blocking=non_blocking, 

140 memory_format=target_memory_format, 

141 ) 

142 

143 if target_device != x.device or ( 

144 x.device.type == "cpu" and target_device.type == "cpu" 

145 ): 

146 return _fallback_to_copy( 

147 x, 

148 dtype=target_dtype, 

149 layout=layout, 

150 device=target_device, 

151 pin_memory=pin_memory, 

152 non_blocking=non_blocking, 

153 memory_format=target_memory_format, 

154 ) 

155 

156 logger.debug("GEMS_SUNRISE _TO_COPY") 

157 empty_kwargs = {"dtype": target_dtype, "device": target_device} 

158 

159 if target_memory_format is torch.preserve_format: 

160 out = _allocate_preserve_format(x, empty_kwargs) 

161 else: 

162 out = torch.empty_like(x, memory_format=target_memory_format, **empty_kwargs) 

163 

164 return _to_copy_func(x, out0=out)