Coverage for src/flag_gems/ops/t_copy.py: 53%

70 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8import flag_gems 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13@triton.jit 

14def t_copy_2d_kernel( 

15 in_ptr, 

16 out_ptr, 

17 in_stride_0, 

18 in_stride_1, 

19 out_stride_0, 

20 out_stride_1, 

21 M, # input dim0 

22 N, # input dim1 

23 BLOCK_M: tl.constexpr, 

24 BLOCK_N: tl.constexpr, 

25): 

26 pid_m = tl.program_id(0) 

27 pid_n = tl.program_id(1) 

28 

29 i = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # corresponds to out rows [0..N) 

30 j = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # corresponds to out cols [0..M) 

31 

32 i64 = i.to(tl.int64)[None, :] # shape [1, BM] 

33 j64 = j.to(tl.int64)[:, None] # shape [BN, 1] 

34 

35 # out shape = (N, M) 

36 mask = (i64 < N) & (j64 < M) 

37 

38 # in index = (j, i) -> in_offset = j*in_stride_0 + i*in_stride_1 

39 in_offsets = j64 * in_stride_0 + i64 * in_stride_1 

40 # out index = (i, j) -> out_offset = i*out_stride_0 + j*out_stride_1 

41 out_offsets = i64 * out_stride_0 + j64 * out_stride_1 

42 

43 x = tl.load(in_ptr + in_offsets, mask=mask) 

44 tl.store(out_ptr + out_offsets, x, mask=mask) 

45 

46 

47@triton.jit 

48def copy_1d_strided_kernel( 

49 in_ptr, 

50 out_ptr, 

51 in_stride, 

52 out_stride, 

53 N, 

54 BLOCK_SIZE: tl.constexpr, 

55): 

56 pid = tl.program_id(0) 

57 offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

58 mask = offs < N 

59 offs64 = offs.to(tl.int64) 

60 in_idx = offs64 * in_stride 

61 out_idx = offs64 * out_stride 

62 x = tl.load(in_ptr + in_idx, mask=mask) 

63 tl.store(out_ptr + out_idx, x, mask=mask) 

64 

65 

66def _launch_t_copy_kernel(inp: torch.Tensor, out: torch.Tensor): 

67 if inp.device.type != flag_gems.device or out.device.type != flag_gems.device: 

68 raise ValueError(f"t_copy kernels require {flag_gems.device} tensors") 

69 assert inp.dtype == out.dtype, "dtype mismatch between input and output" 

70 

71 dim = inp.dim() 

72 if dim == 0: 

73 # Scalar copy 

74 n = 1 

75 grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),) 

76 copy_1d_strided_kernel[grid]( 

77 inp, 

78 out, 

79 0, 

80 0, 

81 n, 

82 BLOCK_SIZE=1, 

83 ) 

84 elif dim == 1: 

85 n = inp.numel() 

86 in_stride = inp.stride(0) 

87 out_stride = out.stride(0) 

88 assert out.numel() == n, "Output size mismatch for 1D t_copy" 

89 grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),) 

90 copy_1d_strided_kernel[grid]( 

91 inp, 

92 out, 

93 in_stride, 

94 out_stride, 

95 n, 

96 BLOCK_SIZE=1024, 

97 ) 

98 elif dim == 2: 

99 M, N = inp.shape # input dims 

100 # out should be (N, M) 

101 assert ( 

102 out.dim() == 2 and out.shape[0] == N and out.shape[1] == M 

103 ), "Output shape must be (input.size(1), input.size(0)) for t_copy" 

104 in_s0, in_s1 = inp.stride() 

105 out_s0, out_s1 = out.stride() 

106 grid = lambda meta: ( 

107 triton.cdiv(N, meta["BLOCK_M"]), 

108 triton.cdiv(M, meta["BLOCK_N"]), 

109 ) 

110 t_copy_2d_kernel[grid]( 

111 inp, 

112 out, 

113 in_s0, 

114 in_s1, 

115 out_s0, 

116 out_s1, 

117 M, 

118 N, 

119 BLOCK_M=32, 

120 BLOCK_N=32, 

121 ) 

122 else: 

123 raise RuntimeError("t_copy expects a tensor with <= 2 dims") 

124 

125 

126def t_copy_out( 

127 input: torch.Tensor, 

128 out: torch.Tensor, 

129 memory_format: torch.memory_format | None = None, 

130): 

131 logger.debug("GEMS T_COPY_OUT") 

132 _launch_t_copy_kernel(input, out) 

133 return out 

134 

135 

136def t_copy(input: torch.Tensor, memory_format: torch.memory_format | None = None): 

137 logger.debug("GEMS T_COPY") 

138 dim = input.dim() 

139 if dim == 0: 

140 out = torch.empty((), dtype=input.dtype, device=input.device) 

141 elif dim == 1: 

142 out = torch.empty_like(input, memory_format=torch.contiguous_format) 

143 elif dim == 2: 

144 M, N = input.shape 

145 out = torch.empty( 

146 (N, M), 

147 dtype=input.dtype, 

148 device=input.device, 

149 memory_format=torch.contiguous_format, 

150 ) 

151 else: 

152 raise RuntimeError("t_copy expects a tensor with <= 2 dims") 

153 _launch_t_copy_kernel(input, out) 

154 return out