Coverage for src/flag_gems/ops/as_strided_copy.py: 65%

161 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import pointwise_dynamic 

8from flag_gems.utils.shape_utils import MemOverlap, has_internal_overlapping 

9 

10logger = logging.getLogger(__name__) 

11 

12_FALLBACK_KEYSET = torch._C.DispatchKeySet( 

13 torch._C.DispatchKey.CompositeExplicitAutograd 

14) 

15_MAX_TRITON_ELEMENTS = torch.iinfo(torch.int32).max 

16_BLOCK_SIZE = 512 

17_BLOCK_M = 16 

18_BLOCK_N = 16 

19 

20 

21@pointwise_dynamic(is_tensor=[True], promotion_methods=[(0, "DEFAULT")]) 

22@triton.jit 

23def _as_strided_copy_kernel(x): 

24 return x 

25 

26 

27@triton.jit 

28def _as_strided_copy_1d_kernel( 

29 input, 

30 out, 

31 input_stride_0, 

32 out_stride_0, 

33 n_elements, 

34 BLOCK_SIZE: tl.constexpr, 

35): 

36 offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

37 mask = offsets < n_elements 

38 offsets = offsets.to(tl.int64) 

39 values = tl.load(input + offsets * input_stride_0, mask=mask) 

40 tl.store(out + offsets * out_stride_0, values, mask=mask) 

41 

42 

43@triton.jit 

44def _as_strided_copy_2d_kernel( 

45 input, 

46 out, 

47 input_stride_0, 

48 input_stride_1, 

49 out_stride_0, 

50 out_stride_1, 

51 dim_0, 

52 dim_1, 

53 BLOCK_M: tl.constexpr, 

54 BLOCK_N: tl.constexpr, 

55): 

56 offsets_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) 

57 offsets_n = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) 

58 offsets_m = offsets_m.to(tl.int64)[:, None] 

59 offsets_n = offsets_n.to(tl.int64)[None, :] 

60 mask = (offsets_m < dim_0) & (offsets_n < dim_1) 

61 input_offsets = offsets_m * input_stride_0 + offsets_n * input_stride_1 

62 out_offsets = offsets_m * out_stride_0 + offsets_n * out_stride_1 

63 values = tl.load(input + input_offsets, mask=mask) 

64 tl.store(out + out_offsets, values, mask=mask) 

65 

66 

67@triton.jit 

68def _as_strided_copy_3d_kernel( 

69 input, 

70 out, 

71 input_stride_0, 

72 input_stride_1, 

73 input_stride_2, 

74 out_stride_0, 

75 out_stride_1, 

76 out_stride_2, 

77 dim_1, 

78 dim_2, 

79 n_elements, 

80 BLOCK_SIZE: tl.constexpr, 

81): 

82 offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

83 mask = offsets < n_elements 

84 offsets = offsets.to(tl.int64) 

85 index_2 = offsets % dim_2 

86 tmp = offsets // dim_2 

87 index_1 = tmp % dim_1 

88 index_0 = tmp // dim_1 

89 input_offsets = ( 

90 index_0 * input_stride_0 + index_1 * input_stride_1 + index_2 * input_stride_2 

91 ) 

92 out_offsets = ( 

93 index_0 * out_stride_0 + index_1 * out_stride_1 + index_2 * out_stride_2 

94 ) 

95 values = tl.load(input + input_offsets, mask=mask) 

96 tl.store(out + out_offsets, values, mask=mask) 

97 

98 

99def _is_float8(dtype: torch.dtype) -> bool: 

100 return str(dtype).startswith("torch.float8_") 

101 

102 

103def _has_lazy_metadata(tensor: torch.Tensor) -> bool: 

104 is_neg = getattr(tensor, "is_neg", lambda: False) 

105 return tensor.is_conj() or is_neg() 

106 

107 

108def _make_as_strided_view( 

109 input: torch.Tensor, 

110 size, 

111 stride, 

112 storage_offset, 

113) -> torch.Tensor: 

114 # Reuse PyTorch's view construction to match its validation and None-offset semantics. 

115 if storage_offset is None: 

116 return torch.as_strided(input, size, stride) 

117 return torch.as_strided(input, size, stride, storage_offset) 

118 

119 

120def _native_copy_(out: torch.Tensor, src: torch.Tensor): 

121 return torch.ops.aten.copy_.default.redispatch(_FALLBACK_KEYSET, out, src, False) 

122 

123 

124def _fallback_as_strided_copy(input, size, stride, storage_offset=None): 

125 view = _make_as_strided_view(input, size, stride, storage_offset) 

126 out = torch.empty(tuple(size), dtype=input.dtype, device=input.device) 

127 if out.numel() != 0: 

128 # Call native copy_ directly so unsupported CUDA dtypes do not re-enter 

129 # FlagGems copy kernels through the composite as_strided_copy fallback. 

130 _native_copy_(out, view) 

131 return out 

132 

133 

134def _fallback_as_strided_copy_out(input, size, stride, storage_offset=None, *, out): 

135 view = _make_as_strided_view(input, size, stride, storage_offset) 

136 if ( 

137 torch._C._is_alias_of(input, out) 

138 or has_internal_overlapping(out) != MemOverlap.No 

139 ): 

140 temp = torch.empty(tuple(size), dtype=input.dtype, device=input.device) 

141 if temp.numel() != 0: 

142 _native_copy_(temp, view) 

143 view = temp 

144 _native_copy_(out, view) 

145 return out 

146 

147 

148def _can_use_triton(input: torch.Tensor, out: torch.Tensor) -> bool: 

149 if input.layout != torch.strided or out.layout != torch.strided: 

150 return False 

151 if input.device != out.device or input.dtype != out.dtype: 

152 return False 

153 if input.is_quantized or out.is_quantized: 

154 return False 

155 if input.is_complex() or _is_float8(input.dtype): 

156 return False 

157 if out.numel() > _MAX_TRITON_ELEMENTS: 

158 return False 

159 return True 

160 

161 

162def _can_use_byte_triton(input: torch.Tensor, out: torch.Tensor) -> bool: 

163 if input.layout != torch.strided or out.layout != torch.strided: 

164 return False 

165 if input.device != out.device or input.dtype != out.dtype: 

166 return False 

167 if not _is_float8(input.dtype): 

168 return False 

169 if input.element_size() != 1 or out.element_size() != 1: 

170 return False 

171 if _has_lazy_metadata(input) or _has_lazy_metadata(out): 

172 return False 

173 if out.numel() > _MAX_TRITON_ELEMENTS: 

174 return False 

175 return True 

176 

177 

178def _launch_as_strided_copy(view: torch.Tensor, out: torch.Tensor): 

179 dim = view.dim() 

180 if dim == 0: 

181 _as_strided_copy_1d_kernel[(1,)]( 

182 view, 

183 out, 

184 0, 

185 0, 

186 1, 

187 BLOCK_SIZE=1, 

188 ) 

189 elif dim == 1: 

190 n_elements = view.numel() 

191 grid = (triton.cdiv(n_elements, _BLOCK_SIZE),) 

192 _as_strided_copy_1d_kernel[grid]( 

193 view, 

194 out, 

195 view.stride(0), 

196 out.stride(0), 

197 n_elements, 

198 BLOCK_SIZE=_BLOCK_SIZE, 

199 ) 

200 elif dim == 2: 

201 dim_0, dim_1 = view.shape 

202 grid = (triton.cdiv(dim_0, _BLOCK_M), triton.cdiv(dim_1, _BLOCK_N)) 

203 _as_strided_copy_2d_kernel[grid]( 

204 view, 

205 out, 

206 view.stride(0), 

207 view.stride(1), 

208 out.stride(0), 

209 out.stride(1), 

210 dim_0, 

211 dim_1, 

212 BLOCK_M=_BLOCK_M, 

213 BLOCK_N=_BLOCK_N, 

214 ) 

215 elif dim == 3: 

216 n_elements = view.numel() 

217 grid = (triton.cdiv(n_elements, _BLOCK_SIZE),) 

218 _as_strided_copy_3d_kernel[grid]( 

219 view, 

220 out, 

221 view.stride(0), 

222 view.stride(1), 

223 view.stride(2), 

224 out.stride(0), 

225 out.stride(1), 

226 out.stride(2), 

227 view.shape[1], 

228 view.shape[2], 

229 n_elements, 

230 BLOCK_SIZE=_BLOCK_SIZE, 

231 ) 

232 else: 

233 return _as_strided_copy_kernel(view, out0=out) 

234 return out 

235 

236 

237def _launch_byte_as_strided_copy(view: torch.Tensor, out: torch.Tensor): 

238 # Copy one-byte dtypes through uint8 views to avoid Triton fp8 scalar codegen. 

239 # The dtype-view API requires at least one logical dimension on some builds. 

240 byte_view = ( 

241 view.reshape(1).view(torch.uint8) if view.dim() == 0 else view.view(torch.uint8) 

242 ) 

243 byte_out = ( 

244 out.reshape(1).view(torch.uint8) if out.dim() == 0 else out.view(torch.uint8) 

245 ) 

246 _launch_as_strided_copy(byte_view, byte_out) 

247 return out 

248 

249 

250def as_strided_copy(input, size, stride, storage_offset=None): 

251 logger.debug("GEMS AS_STRIDED_COPY") 

252 if input.device.type != "cuda": 

253 view = _make_as_strided_view(input, size, stride, storage_offset) 

254 return view.clone(memory_format=torch.contiguous_format) 

255 

256 out = torch.empty(size, dtype=input.dtype, device=input.device) 

257 if out.numel() == 0: 

258 _make_as_strided_view(input, size, stride, storage_offset) 

259 return out 

260 

261 view = _make_as_strided_view(input, size, stride, storage_offset) 

262 if _can_use_triton(view, out): 

263 return _launch_as_strided_copy(view, out) 

264 if _can_use_byte_triton(view, out): 

265 return _launch_byte_as_strided_copy(view, out) 

266 return _fallback_as_strided_copy(input, size, stride, storage_offset) 

267 

268 

269def as_strided_copy_out(input, size, stride, storage_offset=None, *, out): 

270 logger.debug("GEMS AS_STRIDED_COPY_OUT") 

271 if out.dtype != input.dtype: 

272 # Match PyTorch's strict out-dtype contract without measuring native fallback. 

273 raise RuntimeError( 

274 f"Expected out tensor to have dtype {input.dtype}, but got {out.dtype} instead" 

275 ) 

276 

277 target_size = tuple(size) 

278 if tuple(out.shape) != target_size: 

279 out.resize_(target_size) 

280 

281 if out.numel() == 0: 

282 _make_as_strided_view(input, size, stride, storage_offset) 

283 return out 

284 

285 if input.device.type != "cuda": 

286 view = _make_as_strided_view(input, size, stride, storage_offset) 

287 if ( 

288 torch._C._is_alias_of(input, out) 

289 or has_internal_overlapping(out) != MemOverlap.No 

290 ): 

291 view = view.clone(memory_format=torch.contiguous_format) 

292 out.copy_(view) 

293 return out 

294 

295 if ( 

296 torch._C._is_alias_of(input, out) 

297 or has_internal_overlapping(out) != MemOverlap.No 

298 ): 

299 return _fallback_as_strided_copy_out( 

300 input, size, stride, storage_offset, out=out 

301 ) 

302 

303 view = _make_as_strided_view(input, size, stride, storage_offset) 

304 if _can_use_triton(view, out): 

305 return _launch_as_strided_copy(view, out) 

306 if _can_use_byte_triton(view, out): 

307 return _launch_byte_as_strided_copy(view, out) 

308 return _fallback_as_strided_copy_out(input, size, stride, storage_offset, out=out)