Coverage for src/flag_gems/ops/diff.py: 63%

94 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2from functools import reduce 

3from typing import Optional 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12 

13logger = logging.getLogger(__name__) 

14 

15Tensor = torch.Tensor 

16 

17 

18@libentry() 

19@triton.jit 

20def diff_kernel_inner( 

21 output_ptr, 

22 input_ptr, 

23 M, 

24 N, 

25 BLOCK_M: tl.constexpr, 

26 BLOCK_N: tl.constexpr, 

27): 

28 """Compute diff along the inner (last) dimension. 

29 

30 For each row m and output position n, computes: 

31 output[m, n] = input[m, n + 1] - input[m, n] 

32 

33 Input shape: (M, N), Output shape: (M, N-1) 

34 """ 

35 pid_m = tle.program_id(0) 

36 

37 # Row indices this block handles 

38 row_start = pid_m * BLOCK_M 

39 row_offsets = row_start + tl.arange(0, BLOCK_M) 

40 row_mask = row_offsets < M 

41 

42 # Output has N-1 elements per row 

43 output_N = N - 1 

44 

45 # Process output elements in tiles 

46 for n_start in range(0, output_N, BLOCK_N): 

47 col_offsets = n_start + tl.arange(0, BLOCK_N) 

48 col_mask = col_offsets < output_N 

49 

50 # Combined mask 

51 mask = row_mask[:, None] & col_mask[None, :] 

52 

53 # Load input[m, n+1] and input[m, n] 

54 input_offsets_next = row_offsets[:, None] * N + (col_offsets[None, :] + 1) 

55 input_offsets_curr = row_offsets[:, None] * N + col_offsets[None, :] 

56 

57 inp_next = tl.load(input_ptr + input_offsets_next, mask=mask, other=0.0) 

58 inp_curr = tl.load(input_ptr + input_offsets_curr, mask=mask, other=0.0) 

59 

60 # Compute diff 

61 diff_val = inp_next - inp_curr 

62 

63 # Store output 

64 output_offsets = row_offsets[:, None] * output_N + col_offsets[None, :] 

65 tl.store(output_ptr + output_offsets, diff_val, mask=mask) 

66 

67 

68@libentry() 

69@triton.jit 

70def diff_kernel_non_inner( 

71 output_ptr, 

72 input_ptr, 

73 M, 

74 N, 

75 K, 

76 BLOCK_M: tl.constexpr, 

77 BLOCK_K: tl.constexpr, 

78): 

79 """Compute diff along a non-inner dimension. 

80 

81 Input is viewed as (M, N, K) where we compute diff along dim 1 (size N). 

82 For each position (m, n, k), computes: 

83 output[m, n, k] = input[m, n + 1, k] - input[m, n, k] 

84 

85 Input shape: (M, N, K), Output shape: (M, N-1, K) 

86 """ 

87 pid_m = tle.program_id(0) 

88 pid_k = tle.program_id(1) 

89 

90 # K indices this block handles 

91 k_start = pid_k * BLOCK_K 

92 k_offsets = k_start + tl.arange(0, BLOCK_K) 

93 k_mask = k_offsets < K 

94 

95 # Output has N-1 elements along dim 1 

96 output_N = N - 1 

97 

98 # Process all n positions for this (m, k) block 

99 for n in range(output_N): 

100 # Load input[m, n+1, k] and input[m, n, k] 

101 input_offset_next = pid_m * N * K + (n + 1) * K + k_offsets 

102 input_offset_curr = pid_m * N * K + n * K + k_offsets 

103 

104 inp_next = tl.load(input_ptr + input_offset_next, mask=k_mask, other=0.0) 

105 inp_curr = tl.load(input_ptr + input_offset_curr, mask=k_mask, other=0.0) 

106 

107 # Compute diff 

108 diff_val = inp_next - inp_curr 

109 

110 # Store output 

111 output_offset = pid_m * output_N * K + n * K + k_offsets 

112 tl.store(output_ptr + output_offset, diff_val, mask=k_mask) 

113 

114 

115def _diff_once(inp: Tensor, dim: int) -> Tensor: 

116 """Compute single forward difference along specified dimension. 

117 

118 Args: 

119 inp: Input tensor (must be contiguous) 

120 dim: Dimension to compute difference along 

121 

122 Returns: 

123 Tensor with shape reduced by 1 along dim 

124 """ 

125 shape = list(inp.shape) 

126 ndim = inp.ndim 

127 dim = dim % ndim 

128 

129 N = shape[dim] # Size along diff dimension 

130 if N < 2: 

131 raise RuntimeError( 

132 f"diff requires at least 2 elements along dim {dim}, got {N}" 

133 ) 

134 

135 # Compute M (product of dims before dim) and K (product of dims after dim) 

136 M = reduce(lambda x, y: x * y, shape[:dim], 1) 

137 K = reduce(lambda x, y: x * y, shape[dim + 1 :], 1) 

138 

139 # Output shape has dim reduced by 1 

140 out_shape = list(shape) 

141 out_shape[dim] = N - 1 

142 out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device) 

143 

144 with torch_device_fn.device(inp.device): 

145 if K == 1: 

146 # Inner dimension case 

147 # Block sizes must be powers of 2 for triton 

148 BLOCK_M = triton.next_power_of_2(min(32, M)) 

149 BLOCK_N = triton.next_power_of_2(min(256, N - 1)) 

150 grid = (triton.cdiv(M, BLOCK_M),) 

151 diff_kernel_inner[grid]( 

152 out, 

153 inp, 

154 M, 

155 N, 

156 BLOCK_M=BLOCK_M, 

157 BLOCK_N=BLOCK_N, 

158 ) 

159 else: 

160 # Non-inner dimension case 

161 BLOCK_K = triton.next_power_of_2(min(256, K)) 

162 grid = (M, triton.cdiv(K, BLOCK_K)) 

163 diff_kernel_non_inner[grid]( 

164 out, 

165 inp, 

166 M, 

167 N, 

168 K, 

169 BLOCK_M=1, 

170 BLOCK_K=BLOCK_K, 

171 ) 

172 

173 return out 

174 

175 

176def diff( 

177 inp: Tensor, 

178 n: int = 1, 

179 dim: int = -1, 

180 prepend: Optional[Tensor] = None, 

181 append: Optional[Tensor] = None, 

182) -> Tensor: 

183 """Compute the n-th forward difference along the given dimension. 

184 

185 The first-order differences are given by out[i] = input[i + 1] - input[i]. 

186 Higher-order differences are calculated by using diff recursively. 

187 

188 Args: 

189 inp: Input tensor 

190 n: Number of times to recursively compute the difference 

191 dim: Dimension to compute the difference along (default: -1) 

192 prepend: Values to prepend to input along dim before computing diff 

193 append: Values to append to input along dim before computing diff 

194 

195 Returns: 

196 Tensor containing the n-th order differences 

197 """ 

198 logger.debug("GEMS DIFF") 

199 

200 if n == 0: 

201 return inp.clone() 

202 

203 if n < 0: 

204 raise RuntimeError(f"diff expects n >= 0, got {n}") 

205 

206 ndim = inp.ndim 

207 if ndim == 0: 

208 raise RuntimeError("diff requires input to be at least one-dimensional") 

209 

210 dim = dim % ndim 

211 

212 # Handle prepend and append by concatenating 

213 tensors_to_cat = [] 

214 if prepend is not None: 

215 tensors_to_cat.append(prepend) 

216 tensors_to_cat.append(inp) 

217 if append is not None: 

218 tensors_to_cat.append(append) 

219 

220 if len(tensors_to_cat) > 1: 

221 inp = torch.cat(tensors_to_cat, dim=dim) 

222 

223 inp = inp.contiguous() 

224 

225 # Apply diff n times 

226 result = inp 

227 for _ in range(n): 

228 if result.shape[dim] < 2: 

229 raise RuntimeError( 

230 f"diff requires at least 2 elements along dim {dim} for each iteration" 

231 ) 

232 result = _diff_once(result, dim) 

233 

234 return result