Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/to.py: 0%

69 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2import os 

3from typing import Optional 

4 

5import torch 

6import triton 

7from _kunlunxin.utils.codegen_config_utils import CodeGenConfig 

8 

9from ..utils.pointwise_dynamic import pointwise_dynamic 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12 

13_FALLBACK_KEYSET = torch._C.DispatchKeySet( 

14 torch._C.DispatchKey.CompositeExplicitAutograd 

15) 

16 

17 

18@pointwise_dynamic( 

19 is_tensor=[ 

20 True, 

21 ], 

22 promotion_methods=[(0, "DEFAULT")], 

23) 

24@triton.jit 

25def _to_copy_func(x): 

26 return x 

27 

28 

29close_interleave_config = CodeGenConfig( 

30 512, 

31 (65536, 65536, 65536), 

32 32, 

33 True, 

34 prefer_1d_tile=True, 

35 isCloseInterleave=True, 

36) 

37 

38 

39@pointwise_dynamic( 

40 is_tensor=[ 

41 True, 

42 ], 

43 promotion_methods=[(0, "DEFAULT")], 

44 config=close_interleave_config, 

45) 

46@triton.jit 

47def _to_copy_func_close_interleave(x): 

48 return x 

49 

50 

51def _resolve_dtype(x: torch.Tensor, dtype: Optional[torch.dtype]) -> torch.dtype: 

52 if dtype is None: 

53 return x.dtype 

54 if isinstance(dtype, torch.dtype): 

55 return dtype 

56 raise TypeError(f"Unsupported dtype argument type: {type(dtype)!r}") 

57 

58 

59def _resolve_device(x: torch.Tensor, device: Optional[torch.device]) -> torch.device: 

60 if device is None: 

61 return x.device 

62 return torch.device(device) 

63 

64 

65def _normalize_memory_format( 

66 memory_format: Optional[torch.memory_format], 

67) -> torch.memory_format: 

68 if memory_format is None: 

69 return torch.preserve_format 

70 return memory_format 

71 

72 

73def _allocate_preserve_format(x: torch.Tensor, empty_kwargs: dict) -> torch.Tensor: 

74 """Recreate tensor storage while honoring preserve_format semantics.""" 

75 if torch.ops.aten.is_non_overlapping_and_dense(x): 

76 return torch.empty_strided(x.size(), x.stride(), **empty_kwargs) 

77 # Fall back to PyTorch's best-effort layout suggestion when stride replication is unsafe. 

78 return torch.empty_like(x, memory_format=torch.preserve_format, **empty_kwargs) 

79 

80 

81# func: _to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, 

82# bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor 

83def to_copy( 

84 x, 

85 *, 

86 dtype=None, 

87 layout=None, 

88 device=None, 

89 pin_memory=None, 

90 non_blocking=False, 

91 memory_format=None, 

92): 

93 if x.dtype == torch.bfloat16: 

94 to_dtype_fn = _to_copy_func_close_interleave 

95 else: 

96 to_dtype_fn = _to_copy_func 

97 

98 # We only implement the dense strided kernel today; all other layouts fall back to PyTorch. 

99 if (layout is not None and layout != torch.strided) or x.layout != torch.strided: 

100 raise NotImplementedError( 

101 "FlagGems to_copy currently supports strided tensors only." 

102 ) 

103 if pin_memory is not None: 

104 raise NotImplementedError( 

105 "FlagGems to_copy does not yet support pin_memory=True." 

106 ) 

107 if x.is_quantized: 

108 raise NotImplementedError( 

109 "Quantized tensors are not supported in FlagGems to_copy yet." 

110 ) 

111 

112 target_dtype = _resolve_dtype(x, dtype) 

113 target_device = _resolve_device(x, device) 

114 target_memory_format = _normalize_memory_format(memory_format) 

115 

116 # Triton on kunlunxin does not support complex dtypes; fall back to PyTorch. 

117 if x.dtype.is_complex or target_dtype.is_complex: 

118 return torch.ops.aten._to_copy.default.redispatch( 

119 _FALLBACK_KEYSET, 

120 x, 

121 dtype=target_dtype, 

122 layout=layout, 

123 device=target_device, 

124 pin_memory=pin_memory, 

125 non_blocking=non_blocking, 

126 memory_format=target_memory_format, 

127 ) 

128 

129 if target_device != x.device or ( 

130 x.device.type == "cpu" and target_device.type == "cpu" 

131 ): 

132 # Device transfer (d2h/h2d etc.) relies on PyTorch's implementation. 

133 return torch.ops.aten._to_copy.default.redispatch( 

134 _FALLBACK_KEYSET, 

135 x, 

136 dtype=target_dtype, 

137 layout=layout, 

138 device=target_device, 

139 pin_memory=pin_memory, 

140 non_blocking=non_blocking, 

141 memory_format=target_memory_format, 

142 ) 

143 

144 logger.debug("GEMS_KUNLUNXIN _TO_COPY") 

145 empty_kwargs = {"dtype": target_dtype, "device": target_device} 

146 

147 if target_memory_format is torch.preserve_format: 

148 out = _allocate_preserve_format(x, empty_kwargs) 

149 else: 

150 out = torch.empty_like(x, memory_format=target_memory_format, **empty_kwargs) 

151 

152 out = torch.empty_like(x, dtype=dtype, memory_format=memory_format) 

153 if out.element_size() == 8: 

154 os.environ["TRITONXPU_ELEMBYTES"] = "8" 

155 os.environ["TRITONXPU_BF16_FAST"] = "1" 

156 res = to_dtype_fn(x, out0=out) 

157 del os.environ["TRITONXPU_ELEMBYTES"] 

158 del os.environ["TRITONXPU_BF16_FAST"] 

159 else: 

160 os.environ["TRITONXPU_BF16_FAST"] = "1" 

161 res = to_dtype_fn(x, out0=out) 

162 del os.environ["TRITONXPU_BF16_FAST"] 

163 return res