Coverage for src/flag_gems/ops/diff.py: 63%
94 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
2from functools import reduce
3from typing import Optional
5import torch
6import triton
7import triton.language as tl
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils import libentry
11from flag_gems.utils import triton_lang_extension as tle
13logger = logging.getLogger(__name__)
15Tensor = torch.Tensor
18@libentry()
19@triton.jit
20def diff_kernel_inner(
21 output_ptr,
22 input_ptr,
23 M,
24 N,
25 BLOCK_M: tl.constexpr,
26 BLOCK_N: tl.constexpr,
27):
28 """Compute diff along the inner (last) dimension.
30 For each row m and output position n, computes:
31 output[m, n] = input[m, n + 1] - input[m, n]
33 Input shape: (M, N), Output shape: (M, N-1)
34 """
35 pid_m = tle.program_id(0)
37 # Row indices this block handles
38 row_start = pid_m * BLOCK_M
39 row_offsets = row_start + tl.arange(0, BLOCK_M)
40 row_mask = row_offsets < M
42 # Output has N-1 elements per row
43 output_N = N - 1
45 # Process output elements in tiles
46 for n_start in range(0, output_N, BLOCK_N):
47 col_offsets = n_start + tl.arange(0, BLOCK_N)
48 col_mask = col_offsets < output_N
50 # Combined mask
51 mask = row_mask[:, None] & col_mask[None, :]
53 # Load input[m, n+1] and input[m, n]
54 input_offsets_next = row_offsets[:, None] * N + (col_offsets[None, :] + 1)
55 input_offsets_curr = row_offsets[:, None] * N + col_offsets[None, :]
57 inp_next = tl.load(input_ptr + input_offsets_next, mask=mask, other=0.0)
58 inp_curr = tl.load(input_ptr + input_offsets_curr, mask=mask, other=0.0)
60 # Compute diff
61 diff_val = inp_next - inp_curr
63 # Store output
64 output_offsets = row_offsets[:, None] * output_N + col_offsets[None, :]
65 tl.store(output_ptr + output_offsets, diff_val, mask=mask)
68@libentry()
69@triton.jit
70def diff_kernel_non_inner(
71 output_ptr,
72 input_ptr,
73 M,
74 N,
75 K,
76 BLOCK_M: tl.constexpr,
77 BLOCK_K: tl.constexpr,
78):
79 """Compute diff along a non-inner dimension.
81 Input is viewed as (M, N, K) where we compute diff along dim 1 (size N).
82 For each position (m, n, k), computes:
83 output[m, n, k] = input[m, n + 1, k] - input[m, n, k]
85 Input shape: (M, N, K), Output shape: (M, N-1, K)
86 """
87 pid_m = tle.program_id(0)
88 pid_k = tle.program_id(1)
90 # K indices this block handles
91 k_start = pid_k * BLOCK_K
92 k_offsets = k_start + tl.arange(0, BLOCK_K)
93 k_mask = k_offsets < K
95 # Output has N-1 elements along dim 1
96 output_N = N - 1
98 # Process all n positions for this (m, k) block
99 for n in range(output_N):
100 # Load input[m, n+1, k] and input[m, n, k]
101 input_offset_next = pid_m * N * K + (n + 1) * K + k_offsets
102 input_offset_curr = pid_m * N * K + n * K + k_offsets
104 inp_next = tl.load(input_ptr + input_offset_next, mask=k_mask, other=0.0)
105 inp_curr = tl.load(input_ptr + input_offset_curr, mask=k_mask, other=0.0)
107 # Compute diff
108 diff_val = inp_next - inp_curr
110 # Store output
111 output_offset = pid_m * output_N * K + n * K + k_offsets
112 tl.store(output_ptr + output_offset, diff_val, mask=k_mask)
115def _diff_once(inp: Tensor, dim: int) -> Tensor:
116 """Compute single forward difference along specified dimension.
118 Args:
119 inp: Input tensor (must be contiguous)
120 dim: Dimension to compute difference along
122 Returns:
123 Tensor with shape reduced by 1 along dim
124 """
125 shape = list(inp.shape)
126 ndim = inp.ndim
127 dim = dim % ndim
129 N = shape[dim] # Size along diff dimension
130 if N < 2:
131 raise RuntimeError(
132 f"diff requires at least 2 elements along dim {dim}, got {N}"
133 )
135 # Compute M (product of dims before dim) and K (product of dims after dim)
136 M = reduce(lambda x, y: x * y, shape[:dim], 1)
137 K = reduce(lambda x, y: x * y, shape[dim + 1 :], 1)
139 # Output shape has dim reduced by 1
140 out_shape = list(shape)
141 out_shape[dim] = N - 1
142 out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device)
144 with torch_device_fn.device(inp.device):
145 if K == 1:
146 # Inner dimension case
147 # Block sizes must be powers of 2 for triton
148 BLOCK_M = triton.next_power_of_2(min(32, M))
149 BLOCK_N = triton.next_power_of_2(min(256, N - 1))
150 grid = (triton.cdiv(M, BLOCK_M),)
151 diff_kernel_inner[grid](
152 out,
153 inp,
154 M,
155 N,
156 BLOCK_M=BLOCK_M,
157 BLOCK_N=BLOCK_N,
158 )
159 else:
160 # Non-inner dimension case
161 BLOCK_K = triton.next_power_of_2(min(256, K))
162 grid = (M, triton.cdiv(K, BLOCK_K))
163 diff_kernel_non_inner[grid](
164 out,
165 inp,
166 M,
167 N,
168 K,
169 BLOCK_M=1,
170 BLOCK_K=BLOCK_K,
171 )
173 return out
176def diff(
177 inp: Tensor,
178 n: int = 1,
179 dim: int = -1,
180 prepend: Optional[Tensor] = None,
181 append: Optional[Tensor] = None,
182) -> Tensor:
183 """Compute the n-th forward difference along the given dimension.
185 The first-order differences are given by out[i] = input[i + 1] - input[i].
186 Higher-order differences are calculated by using diff recursively.
188 Args:
189 inp: Input tensor
190 n: Number of times to recursively compute the difference
191 dim: Dimension to compute the difference along (default: -1)
192 prepend: Values to prepend to input along dim before computing diff
193 append: Values to append to input along dim before computing diff
195 Returns:
196 Tensor containing the n-th order differences
197 """
198 logger.debug("GEMS DIFF")
200 if n == 0:
201 return inp.clone()
203 if n < 0:
204 raise RuntimeError(f"diff expects n >= 0, got {n}")
206 ndim = inp.ndim
207 if ndim == 0:
208 raise RuntimeError("diff requires input to be at least one-dimensional")
210 dim = dim % ndim
212 # Handle prepend and append by concatenating
213 tensors_to_cat = []
214 if prepend is not None:
215 tensors_to_cat.append(prepend)
216 tensors_to_cat.append(inp)
217 if append is not None:
218 tensors_to_cat.append(append)
220 if len(tensors_to_cat) > 1:
221 inp = torch.cat(tensors_to_cat, dim=dim)
223 inp = inp.contiguous()
225 # Apply diff n times
226 result = inp
227 for _ in range(n):
228 if result.shape[dim] < 2:
229 raise RuntimeError(
230 f"diff requires at least 2 elements along dim {dim} for each iteration"
231 )
232 result = _diff_once(result, dim)
234 return result