Coverage for src/flag_gems/ops/to.py: 87%

52 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2from typing import Optional 

3 

4import torch 

5import triton 

6 

7from flag_gems.utils import pointwise_dynamic 

8 

9logger = logging.getLogger(__name__) 

10 

11_FALLBACK_KEYSET = torch._C.DispatchKeySet( 

12 torch._C.DispatchKey.CompositeExplicitAutograd 

13) 

14 

15# Check if float8_e8m0fnu dtype is available in current PyTorch version 

16_FLOAT8_E8M0FNU = getattr(torch, "float8_e8m0fnu", None) 

17 

18 

19@pointwise_dynamic( 

20 is_tensor=[ 

21 True, 

22 ], 

23 promotion_methods=[(0, "DEFAULT")], 

24) 

25@triton.jit 

26def _to_copy_func(x): 

27 return x 

28 

29 

30def _resolve_dtype(x: torch.Tensor, dtype: Optional[torch.dtype]) -> torch.dtype: 

31 if dtype is None: 

32 return x.dtype 

33 if isinstance(dtype, torch.dtype): 

34 return dtype 

35 raise TypeError(f"Unsupported dtype argument type: {type(dtype)!r}") 

36 

37 

38def _resolve_device(x: torch.Tensor, device: Optional[torch.device]) -> torch.device: 

39 if device is None: 

40 return x.device 

41 return torch.device(device) 

42 

43 

44def _normalize_memory_format( 

45 memory_format: Optional[torch.memory_format], 

46) -> torch.memory_format: 

47 if memory_format is None: 

48 return torch.preserve_format 

49 return memory_format 

50 

51 

52def _allocate_preserve_format(x: torch.Tensor, empty_kwargs: dict) -> torch.Tensor: 

53 """Recreate tensor storage while honoring preserve_format semantics.""" 

54 if torch.ops.aten.is_non_overlapping_and_dense(x): 

55 return torch.empty_strided(x.size(), x.stride(), **empty_kwargs) 

56 # Fall back to PyTorch's best-effort layout suggestion when stride replication is unsafe. 

57 return torch.empty_like(x, memory_format=torch.preserve_format, **empty_kwargs) 

58 

59 

60# func: _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, 

61# bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor 

62def to_copy( 

63 x, 

64 *, 

65 dtype=None, 

66 layout=None, 

67 device=None, 

68 pin_memory=None, 

69 non_blocking=False, 

70 memory_format=None, 

71): 

72 # We only implement the dense strided kernel today; all other layouts fall back to PyTorch. 

73 if (layout is not None and layout != torch.strided) or x.layout != torch.strided: 

74 raise NotImplementedError( 

75 "FlagGems to_copy currently supports strided tensors only." 

76 ) 

77 if pin_memory is not None: 

78 raise NotImplementedError( 

79 "FlagGems to_copy does not yet support pin_memory=True." 

80 ) 

81 if x.is_quantized: 

82 raise NotImplementedError( 

83 "Quantized tensors are not supported in FlagGems to_copy yet." 

84 ) 

85 

86 target_dtype = _resolve_dtype(x, dtype) 

87 target_device = _resolve_device(x, device) 

88 target_memory_format = _normalize_memory_format(memory_format) 

89 

90 # Triton does not support complex dtypes; fall back to PyTorch. 

91 if x.dtype.is_complex or target_dtype.is_complex: 

92 return torch.ops.aten._to_copy.default.redispatch( 

93 _FALLBACK_KEYSET, 

94 x, 

95 dtype=target_dtype, 

96 layout=layout, 

97 device=target_device, 

98 pin_memory=pin_memory, 

99 non_blocking=non_blocking, 

100 memory_format=target_memory_format, 

101 ) 

102 

103 # Triton does not support float8_e8m0fnu dtypes; fall back to PyTorch. 

104 if _FLOAT8_E8M0FNU is not None and ( 

105 x.dtype == torch.float8_e8m0fnu or target_dtype == torch.float8_e8m0fnu 

106 ): 

107 return torch.ops.aten._to_copy.default.redispatch( 

108 _FALLBACK_KEYSET, 

109 x, 

110 dtype=target_dtype, 

111 layout=layout, 

112 device=target_device, 

113 pin_memory=pin_memory, 

114 non_blocking=non_blocking, 

115 memory_format=target_memory_format, 

116 ) 

117 

118 if target_device != x.device or ( 

119 x.device.type == "cpu" and target_device.type == "cpu" 

120 ): 

121 # Device transfer (d2h/h2d etc.) relies on PyTorch's implementation. 

122 return torch.ops.aten._to_copy.default.redispatch( 

123 _FALLBACK_KEYSET, 

124 x, 

125 dtype=target_dtype, 

126 layout=layout, 

127 device=target_device, 

128 pin_memory=pin_memory, 

129 non_blocking=non_blocking, 

130 memory_format=target_memory_format, 

131 ) 

132 

133 logger.debug("GEMS TO_COPY") 

134 empty_kwargs = {"dtype": target_dtype, "device": target_device} 

135 

136 if target_memory_format is torch.preserve_format: 

137 out = _allocate_preserve_format(x, empty_kwargs) 

138 else: 

139 out = torch.empty_like(x, memory_format=target_memory_format, **empty_kwargs) 

140 

141 return _to_copy_func(x, out0=out)