Coverage for src/flag_gems/ops/fft.py: 13%
692 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
2import math
3from typing import Tuple
5import torch
6import triton
7import triton.language as tl
9from flag_gems.runtime import torch_device_fn
10from flag_gems.utils.triton_version_utils import HAS_TLE
12logger = logging.getLogger(__name__)
14if HAS_TLE:
15 import triton.experimental.tle.language as tle
16else:
17 tle = None
19PI = math.pi
20_FFT_REG_THRESHOLD = 256
22_BITREV_CACHE: dict[Tuple[int, torch.device], torch.Tensor] = {}
23_TWIDDLE_CACHE: dict[Tuple[int, torch.device], Tuple[torch.Tensor, torch.Tensor]] = {}
26def _is_power_of_two(n: int) -> bool:
27 return n > 0 and (n & (n - 1)) == 0
30def _log2(n: int) -> int:
31 return n.bit_length() - 1
34def _bitrev_indices(n: int, device: torch.device) -> torch.Tensor:
35 key = (n, device)
36 cached = _BITREV_CACHE.get(key)
37 if cached is not None:
38 return cached
39 log_n = _log2(n)
40 idx = torch.arange(n, device=device, dtype=torch.int32)
41 rev = torch.zeros_like(idx)
42 tmp = idx.clone()
43 for _ in range(log_n):
44 rev = (rev << 1) | (tmp & 1)
45 tmp = tmp >> 1
46 _BITREV_CACHE[key] = rev
47 return rev
50def _twiddle_tables(n: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
51 key = (n, device)
52 cached = _TWIDDLE_CACHE.get(key)
53 if cached is not None:
54 return cached
55 log_n = _log2(n)
56 tw_real = torch.empty((n - 1,), device=device, dtype=torch.float32)
57 tw_imag = torch.empty((n - 1,), device=device, dtype=torch.float32)
58 offset = 0
59 for stage in range(log_n):
60 m = 1 << (stage + 1)
61 half = m >> 1
62 j = torch.arange(half, device=device, dtype=torch.float32)
63 angle = (-2.0 * PI / m) * j
64 tw_real[offset : offset + half] = torch.cos(angle)
65 tw_imag[offset : offset + half] = torch.sin(angle)
66 offset += half
67 _TWIDDLE_CACHE[key] = (tw_real, tw_imag)
68 return tw_real, tw_imag
71def _prepare_input(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
72 if x.is_complex():
73 if x.dtype not in (torch.complex64, torch.complex128):
74 raise ValueError(f"unsupported complex dtype: {x.dtype}")
75 x = x.to(torch.complex64)
76 real = x.real.contiguous()
77 imag = x.imag.contiguous()
78 else:
79 if x.dtype not in (torch.float16, torch.float32, torch.bfloat16):
80 raise ValueError(f"unsupported dtype: {x.dtype}")
81 x = x.to(torch.float32)
82 real = x.contiguous()
83 imag = torch.zeros_like(real)
84 return real, imag
87@triton.jit
88def fft_kernel_triton(
89 in_real,
90 in_imag,
91 bitrev,
92 twiddle_real,
93 twiddle_imag,
94 buf0_real,
95 buf0_imag,
96 buf1_real,
97 buf1_imag,
98 stride_in,
99 stride_buf,
100 n_rows,
101 N: tl.constexpr,
102 LOG_N: tl.constexpr,
103):
104 pid = tl.program_id(0)
105 row = pid
106 offs = tl.arange(0, N)
107 row_valid = row < n_rows
108 mask = row_valid & (offs < N)
110 rev = tl.load(bitrev + offs, mask=offs < N, other=0)
111 in_real_ptrs = in_real + row * stride_in + rev
112 in_imag_ptrs = in_imag + row * stride_in + rev
113 vals_real = tl.load(in_real_ptrs, mask=mask, other=0.0)
114 vals_imag = tl.load(in_imag_ptrs, mask=mask, other=0.0)
116 buf0_real_ptrs = buf0_real + row * stride_buf + offs
117 buf0_imag_ptrs = buf0_imag + row * stride_buf + offs
118 tl.store(buf0_real_ptrs, vals_real, mask=mask)
119 tl.store(buf0_imag_ptrs, vals_imag, mask=mask)
121 buf_a_real = buf0_real
122 buf_a_imag = buf0_imag
123 buf_b_real = buf1_real
124 buf_b_imag = buf1_imag
126 if LOG_N % 2 == 1:
127 m = 2
128 half = 1
129 idx = offs
130 pos = idx & (m - 1)
131 j = pos & (half - 1)
132 base = idx - pos
133 even_idx = base + j
134 odd_idx = even_idx + half
136 even_ptrs_real = buf_a_real + row * stride_buf + even_idx
137 even_ptrs_imag = buf_a_imag + row * stride_buf + even_idx
138 odd_ptrs_real = buf_a_real + row * stride_buf + odd_idx
139 odd_ptrs_imag = buf_a_imag + row * stride_buf + odd_idx
141 u_real = tl.load(even_ptrs_real, mask=mask, other=0.0)
142 u_imag = tl.load(even_ptrs_imag, mask=mask, other=0.0)
143 v_real = tl.load(odd_ptrs_real, mask=mask, other=0.0)
144 v_imag = tl.load(odd_ptrs_imag, mask=mask, other=0.0)
146 base_tw = 0
147 tw_idx = base_tw + j
148 tw_real = tl.load(twiddle_real + tw_idx, mask=mask, other=1.0)
149 tw_imag = tl.load(twiddle_imag + tw_idx, mask=mask, other=0.0)
151 v_tw_real = v_real * tw_real - v_imag * tw_imag
152 v_tw_imag = v_real * tw_imag + v_imag * tw_real
154 add_mask = pos < half
155 out_real = tl.where(add_mask, u_real + v_tw_real, u_real - v_tw_real)
156 out_imag = tl.where(add_mask, u_imag + v_tw_imag, u_imag - v_tw_imag)
158 out_ptrs_real = buf_b_real + row * stride_buf + idx
159 out_ptrs_imag = buf_b_imag + row * stride_buf + idx
160 tl.store(out_ptrs_real, out_real, mask=mask)
161 tl.store(out_ptrs_imag, out_imag, mask=mask)
162 tl.debug_barrier()
164 buf_a_real, buf_b_real = buf_b_real, buf_a_real
165 buf_a_imag, buf_b_imag = buf_b_imag, buf_a_imag
167 if LOG_N % 2 == 1:
168 for r4 in tl.static_range((LOG_N - 1) // 2):
169 stage_s = 2 + r4 * 2
170 m = 1 << (stage_s + 1)
171 quarter = m >> 2
172 half = m >> 1
173 three_quarter = quarter + half
175 idx = offs
176 pos = idx & (m - 1)
177 j = pos & (quarter - 1)
178 base = idx - pos
179 i0 = base + j
180 i1 = i0 + quarter
181 i2 = i1 + quarter
182 i3 = i2 + quarter
184 ptr0_real = buf_a_real + row * stride_buf + i0
185 ptr0_imag = buf_a_imag + row * stride_buf + i0
186 ptr1_real = buf_a_real + row * stride_buf + i1
187 ptr1_imag = buf_a_imag + row * stride_buf + i1
188 ptr2_real = buf_a_real + row * stride_buf + i2
189 ptr2_imag = buf_a_imag + row * stride_buf + i2
190 ptr3_real = buf_a_real + row * stride_buf + i3
191 ptr3_imag = buf_a_imag + row * stride_buf + i3
193 x0_real = tl.load(ptr0_real, mask=mask, other=0.0)
194 x0_imag = tl.load(ptr0_imag, mask=mask, other=0.0)
195 x1_real = tl.load(ptr1_real, mask=mask, other=0.0)
196 x1_imag = tl.load(ptr1_imag, mask=mask, other=0.0)
197 x2_real = tl.load(ptr2_real, mask=mask, other=0.0)
198 x2_imag = tl.load(ptr2_imag, mask=mask, other=0.0)
199 x3_real = tl.load(ptr3_real, mask=mask, other=0.0)
200 x3_imag = tl.load(ptr3_imag, mask=mask, other=0.0)
202 base_tw1 = (1 << (stage_s - 1)) - 1
203 base_tw2 = (1 << stage_s) - 1
204 tw1_idx = base_tw1 + j
205 tw2_idx = base_tw2 + j
206 tw1_real = tl.load(twiddle_real + tw1_idx, mask=mask, other=1.0)
207 tw1_imag = tl.load(twiddle_imag + tw1_idx, mask=mask, other=0.0)
208 tw2_real = tl.load(twiddle_real + tw2_idx, mask=mask, other=1.0)
209 tw2_imag = tl.load(twiddle_imag + tw2_idx, mask=mask, other=0.0)
211 t1_real = x1_real * tw1_real - x1_imag * tw1_imag
212 t1_imag = x1_real * tw1_imag + x1_imag * tw1_real
213 t3_real = x3_real * tw1_real - x3_imag * tw1_imag
214 t3_imag = x3_real * tw1_imag + x3_imag * tw1_real
216 u0_real = x0_real + t1_real
217 u0_imag = x0_imag + t1_imag
218 u1_real = x0_real - t1_real
219 u1_imag = x0_imag - t1_imag
220 v0_real = x2_real + t3_real
221 v0_imag = x2_imag + t3_imag
222 v1_real = x2_real - t3_real
223 v1_imag = x2_imag - t3_imag
225 v0_tw_real = v0_real * tw2_real - v0_imag * tw2_imag
226 v0_tw_imag = v0_real * tw2_imag + v0_imag * tw2_real
227 w3_real = tw2_imag
228 w3_imag = -tw2_real
229 v1_tw_real = v1_real * w3_real - v1_imag * w3_imag
230 v1_tw_imag = v1_real * w3_imag + v1_imag * w3_real
232 o0_real = u0_real + v0_tw_real
233 o0_imag = u0_imag + v0_tw_imag
234 o2_real = u0_real - v0_tw_real
235 o2_imag = u0_imag - v0_tw_imag
236 o1_real = u1_real + v1_tw_real
237 o1_imag = u1_imag + v1_tw_imag
238 o3_real = u1_real - v1_tw_real
239 o3_imag = u1_imag - v1_tw_imag
241 m0 = pos < quarter
242 m1 = (pos >= quarter) & (pos < half)
243 m2 = (pos >= half) & (pos < three_quarter)
244 out_real = tl.where(
245 m0, o0_real, tl.where(m1, o1_real, tl.where(m2, o2_real, o3_real))
246 )
247 out_imag = tl.where(
248 m0, o0_imag, tl.where(m1, o1_imag, tl.where(m2, o2_imag, o3_imag))
249 )
251 out_ptrs_real = buf_b_real + row * stride_buf + idx
252 out_ptrs_imag = buf_b_imag + row * stride_buf + idx
253 tl.store(out_ptrs_real, out_real, mask=mask)
254 tl.store(out_ptrs_imag, out_imag, mask=mask)
255 tl.debug_barrier()
257 buf_a_real, buf_b_real = buf_b_real, buf_a_real
258 buf_a_imag, buf_b_imag = buf_b_imag, buf_a_imag
259 else:
260 for r4 in tl.static_range(LOG_N // 2):
261 stage_s = 1 + r4 * 2
262 m = 1 << (stage_s + 1)
263 quarter = m >> 2
264 half = m >> 1
265 three_quarter = quarter + half
267 idx = offs
268 pos = idx & (m - 1)
269 j = pos & (quarter - 1)
270 base = idx - pos
271 i0 = base + j
272 i1 = i0 + quarter
273 i2 = i1 + quarter
274 i3 = i2 + quarter
276 ptr0_real = buf_a_real + row * stride_buf + i0
277 ptr0_imag = buf_a_imag + row * stride_buf + i0
278 ptr1_real = buf_a_real + row * stride_buf + i1
279 ptr1_imag = buf_a_imag + row * stride_buf + i1
280 ptr2_real = buf_a_real + row * stride_buf + i2
281 ptr2_imag = buf_a_imag + row * stride_buf + i2
282 ptr3_real = buf_a_real + row * stride_buf + i3
283 ptr3_imag = buf_a_imag + row * stride_buf + i3
285 x0_real = tl.load(ptr0_real, mask=mask, other=0.0)
286 x0_imag = tl.load(ptr0_imag, mask=mask, other=0.0)
287 x1_real = tl.load(ptr1_real, mask=mask, other=0.0)
288 x1_imag = tl.load(ptr1_imag, mask=mask, other=0.0)
289 x2_real = tl.load(ptr2_real, mask=mask, other=0.0)
290 x2_imag = tl.load(ptr2_imag, mask=mask, other=0.0)
291 x3_real = tl.load(ptr3_real, mask=mask, other=0.0)
292 x3_imag = tl.load(ptr3_imag, mask=mask, other=0.0)
294 base_tw1 = (1 << (stage_s - 1)) - 1
295 base_tw2 = (1 << stage_s) - 1
296 tw1_idx = base_tw1 + j
297 tw2_idx = base_tw2 + j
298 tw1_real = tl.load(twiddle_real + tw1_idx, mask=mask, other=1.0)
299 tw1_imag = tl.load(twiddle_imag + tw1_idx, mask=mask, other=0.0)
300 tw2_real = tl.load(twiddle_real + tw2_idx, mask=mask, other=1.0)
301 tw2_imag = tl.load(twiddle_imag + tw2_idx, mask=mask, other=0.0)
303 t1_real = x1_real * tw1_real - x1_imag * tw1_imag
304 t1_imag = x1_real * tw1_imag + x1_imag * tw1_real
305 t3_real = x3_real * tw1_real - x3_imag * tw1_imag
306 t3_imag = x3_real * tw1_imag + x3_imag * tw1_real
308 u0_real = x0_real + t1_real
309 u0_imag = x0_imag + t1_imag
310 u1_real = x0_real - t1_real
311 u1_imag = x0_imag - t1_imag
312 v0_real = x2_real + t3_real
313 v0_imag = x2_imag + t3_imag
314 v1_real = x2_real - t3_real
315 v1_imag = x2_imag - t3_imag
317 v0_tw_real = v0_real * tw2_real - v0_imag * tw2_imag
318 v0_tw_imag = v0_real * tw2_imag + v0_imag * tw2_real
319 w3_real = tw2_imag
320 w3_imag = -tw2_real
321 v1_tw_real = v1_real * w3_real - v1_imag * w3_imag
322 v1_tw_imag = v1_real * w3_imag + v1_imag * w3_real
324 o0_real = u0_real + v0_tw_real
325 o0_imag = u0_imag + v0_tw_imag
326 o2_real = u0_real - v0_tw_real
327 o2_imag = u0_imag - v0_tw_imag
328 o1_real = u1_real + v1_tw_real
329 o1_imag = u1_imag + v1_tw_imag
330 o3_real = u1_real - v1_tw_real
331 o3_imag = u1_imag - v1_tw_imag
333 m0 = pos < quarter
334 m1 = (pos >= quarter) & (pos < half)
335 m2 = (pos >= half) & (pos < three_quarter)
336 out_real = tl.where(
337 m0, o0_real, tl.where(m1, o1_real, tl.where(m2, o2_real, o3_real))
338 )
339 out_imag = tl.where(
340 m0, o0_imag, tl.where(m1, o1_imag, tl.where(m2, o2_imag, o3_imag))
341 )
343 out_ptrs_real = buf_b_real + row * stride_buf + idx
344 out_ptrs_imag = buf_b_imag + row * stride_buf + idx
345 tl.store(out_ptrs_real, out_real, mask=mask)
346 tl.store(out_ptrs_imag, out_imag, mask=mask)
347 tl.debug_barrier()
349 buf_a_real, buf_b_real = buf_b_real, buf_a_real
350 buf_a_imag, buf_b_imag = buf_b_imag, buf_a_imag
353if HAS_TLE:
355 @triton.jit
356 def fft_kernel_tle(
357 in_real,
358 in_imag,
359 bitrev,
360 twiddle_real,
361 twiddle_imag,
362 out_real,
363 out_imag,
364 stride_in,
365 stride_out,
366 n_rows,
367 N: tl.constexpr,
368 LOG_N: tl.constexpr,
369 ):
370 pid = tl.program_id(0)
371 row = pid
372 offs = tl.arange(0, N)
373 row_valid = row < n_rows
374 mask = row_valid & (offs < N)
376 smem_a_real = tle.gpu.alloc(
377 [N],
378 dtype=tl.float32,
379 layout=None,
380 scope=tle.gpu.smem,
381 nv_mma_shared_layout=False,
382 )
383 smem_a_imag = tle.gpu.alloc(
384 [N],
385 dtype=tl.float32,
386 layout=None,
387 scope=tle.gpu.smem,
388 nv_mma_shared_layout=False,
389 )
390 smem_b_real = tle.gpu.alloc(
391 [N],
392 dtype=tl.float32,
393 layout=None,
394 scope=tle.gpu.smem,
395 nv_mma_shared_layout=False,
396 )
397 smem_b_imag = tle.gpu.alloc(
398 [N],
399 dtype=tl.float32,
400 layout=None,
401 scope=tle.gpu.smem,
402 nv_mma_shared_layout=False,
403 )
405 rev = tl.load(bitrev + offs, mask=offs < N, other=0)
406 in_real_ptrs = in_real + row * stride_in + rev
407 in_imag_ptrs = in_imag + row * stride_in + rev
408 vals_real = tl.load(in_real_ptrs, mask=mask, other=0.0)
409 vals_imag = tl.load(in_imag_ptrs, mask=mask, other=0.0)
411 smem_a_real_ptrs = tle.gpu.local_ptr(smem_a_real, (offs,))
412 smem_a_imag_ptrs = tle.gpu.local_ptr(smem_a_imag, (offs,))
413 tl.store(smem_a_real_ptrs, vals_real, mask=mask)
414 tl.store(smem_a_imag_ptrs, vals_imag, mask=mask)
415 tl.debug_barrier()
417 smem_in_real = smem_a_real
418 smem_in_imag = smem_a_imag
419 smem_out_real = smem_b_real
420 smem_out_imag = smem_b_imag
422 if LOG_N % 2 == 1:
423 m = 2
424 half = 1
425 idx = offs
426 pos = idx & (m - 1)
427 j = pos & (half - 1)
428 base = idx - pos
429 even_idx = base + j
430 odd_idx = even_idx + half
432 even_ptrs_real = tle.gpu.local_ptr(smem_in_real, (even_idx,))
433 even_ptrs_imag = tle.gpu.local_ptr(smem_in_imag, (even_idx,))
434 odd_ptrs_real = tle.gpu.local_ptr(smem_in_real, (odd_idx,))
435 odd_ptrs_imag = tle.gpu.local_ptr(smem_in_imag, (odd_idx,))
437 u_real = tl.load(even_ptrs_real, mask=mask, other=0.0)
438 u_imag = tl.load(even_ptrs_imag, mask=mask, other=0.0)
439 v_real = tl.load(odd_ptrs_real, mask=mask, other=0.0)
440 v_imag = tl.load(odd_ptrs_imag, mask=mask, other=0.0)
442 base_tw = 0
443 tw_idx = base_tw + j
444 tw_real = tl.load(twiddle_real + tw_idx, mask=mask, other=1.0)
445 tw_imag = tl.load(twiddle_imag + tw_idx, mask=mask, other=0.0)
447 v_tw_real = v_real * tw_real - v_imag * tw_imag
448 v_tw_imag = v_real * tw_imag + v_imag * tw_real
450 add_mask = pos < half
451 out_real_val = tl.where(add_mask, u_real + v_tw_real, u_real - v_tw_real)
452 out_imag_val = tl.where(add_mask, u_imag + v_tw_imag, u_imag - v_tw_imag)
454 out_ptrs_real = tle.gpu.local_ptr(smem_out_real, (idx,))
455 out_ptrs_imag = tle.gpu.local_ptr(smem_out_imag, (idx,))
456 tl.store(out_ptrs_real, out_real_val, mask=mask)
457 tl.store(out_ptrs_imag, out_imag_val, mask=mask)
458 tl.debug_barrier()
460 smem_in_real, smem_out_real = smem_out_real, smem_in_real
461 smem_in_imag, smem_out_imag = smem_out_imag, smem_in_imag
463 if LOG_N % 2 == 1:
464 for r4 in tl.static_range((LOG_N - 1) // 2):
465 stage_s = 2 + r4 * 2
466 m = 1 << (stage_s + 1)
467 quarter = m >> 2
468 half = m >> 1
469 three_quarter = quarter + half
471 idx = offs
472 pos = idx & (m - 1)
473 j = pos & (quarter - 1)
474 base = idx - pos
475 i0 = base + j
476 i1 = i0 + quarter
477 i2 = i1 + quarter
478 i3 = i2 + quarter
480 ptr0_real = tle.gpu.local_ptr(smem_in_real, (i0,))
481 ptr0_imag = tle.gpu.local_ptr(smem_in_imag, (i0,))
482 ptr1_real = tle.gpu.local_ptr(smem_in_real, (i1,))
483 ptr1_imag = tle.gpu.local_ptr(smem_in_imag, (i1,))
484 ptr2_real = tle.gpu.local_ptr(smem_in_real, (i2,))
485 ptr2_imag = tle.gpu.local_ptr(smem_in_imag, (i2,))
486 ptr3_real = tle.gpu.local_ptr(smem_in_real, (i3,))
487 ptr3_imag = tle.gpu.local_ptr(smem_in_imag, (i3,))
489 x0_real = tl.load(ptr0_real, mask=mask, other=0.0)
490 x0_imag = tl.load(ptr0_imag, mask=mask, other=0.0)
491 x1_real = tl.load(ptr1_real, mask=mask, other=0.0)
492 x1_imag = tl.load(ptr1_imag, mask=mask, other=0.0)
493 x2_real = tl.load(ptr2_real, mask=mask, other=0.0)
494 x2_imag = tl.load(ptr2_imag, mask=mask, other=0.0)
495 x3_real = tl.load(ptr3_real, mask=mask, other=0.0)
496 x3_imag = tl.load(ptr3_imag, mask=mask, other=0.0)
498 base_tw1 = (1 << (stage_s - 1)) - 1
499 base_tw2 = (1 << stage_s) - 1
500 tw1_idx = base_tw1 + j
501 tw2_idx = base_tw2 + j
502 tw1_real = tl.load(twiddle_real + tw1_idx, mask=mask, other=1.0)
503 tw1_imag = tl.load(twiddle_imag + tw1_idx, mask=mask, other=0.0)
504 tw2_real = tl.load(twiddle_real + tw2_idx, mask=mask, other=1.0)
505 tw2_imag = tl.load(twiddle_imag + tw2_idx, mask=mask, other=0.0)
507 t1_real = x1_real * tw1_real - x1_imag * tw1_imag
508 t1_imag = x1_real * tw1_imag + x1_imag * tw1_real
509 t3_real = x3_real * tw1_real - x3_imag * tw1_imag
510 t3_imag = x3_real * tw1_imag + x3_imag * tw1_real
512 u0_real = x0_real + t1_real
513 u0_imag = x0_imag + t1_imag
514 u1_real = x0_real - t1_real
515 u1_imag = x0_imag - t1_imag
516 v0_real = x2_real + t3_real
517 v0_imag = x2_imag + t3_imag
518 v1_real = x2_real - t3_real
519 v1_imag = x2_imag - t3_imag
521 v0_tw_real = v0_real * tw2_real - v0_imag * tw2_imag
522 v0_tw_imag = v0_real * tw2_imag + v0_imag * tw2_real
523 w3_real = tw2_imag
524 w3_imag = -tw2_real
525 v1_tw_real = v1_real * w3_real - v1_imag * w3_imag
526 v1_tw_imag = v1_real * w3_imag + v1_imag * w3_real
528 o0_real = u0_real + v0_tw_real
529 o0_imag = u0_imag + v0_tw_imag
530 o2_real = u0_real - v0_tw_real
531 o2_imag = u0_imag - v0_tw_imag
532 o1_real = u1_real + v1_tw_real
533 o1_imag = u1_imag + v1_tw_imag
534 o3_real = u1_real - v1_tw_real
535 o3_imag = u1_imag - v1_tw_imag
537 m0 = pos < quarter
538 m1 = (pos >= quarter) & (pos < half)
539 m2 = (pos >= half) & (pos < three_quarter)
540 out_real_val = tl.where(
541 m0, o0_real, tl.where(m1, o1_real, tl.where(m2, o2_real, o3_real))
542 )
543 out_imag_val = tl.where(
544 m0, o0_imag, tl.where(m1, o1_imag, tl.where(m2, o2_imag, o3_imag))
545 )
547 out_ptrs_real = tle.gpu.local_ptr(smem_out_real, (idx,))
548 out_ptrs_imag = tle.gpu.local_ptr(smem_out_imag, (idx,))
549 tl.store(out_ptrs_real, out_real_val, mask=mask)
550 tl.store(out_ptrs_imag, out_imag_val, mask=mask)
551 tl.debug_barrier()
553 smem_in_real, smem_out_real = smem_out_real, smem_in_real
554 smem_in_imag, smem_out_imag = smem_out_imag, smem_in_imag
555 else:
556 for r4 in tl.static_range(LOG_N // 2):
557 stage_s = 1 + r4 * 2
558 m = 1 << (stage_s + 1)
559 quarter = m >> 2
560 half = m >> 1
561 three_quarter = quarter + half
563 idx = offs
564 pos = idx & (m - 1)
565 j = pos & (quarter - 1)
566 base = idx - pos
567 i0 = base + j
568 i1 = i0 + quarter
569 i2 = i1 + quarter
570 i3 = i2 + quarter
572 ptr0_real = tle.gpu.local_ptr(smem_in_real, (i0,))
573 ptr0_imag = tle.gpu.local_ptr(smem_in_imag, (i0,))
574 ptr1_real = tle.gpu.local_ptr(smem_in_real, (i1,))
575 ptr1_imag = tle.gpu.local_ptr(smem_in_imag, (i1,))
576 ptr2_real = tle.gpu.local_ptr(smem_in_real, (i2,))
577 ptr2_imag = tle.gpu.local_ptr(smem_in_imag, (i2,))
578 ptr3_real = tle.gpu.local_ptr(smem_in_real, (i3,))
579 ptr3_imag = tle.gpu.local_ptr(smem_in_imag, (i3,))
581 x0_real = tl.load(ptr0_real, mask=mask, other=0.0)
582 x0_imag = tl.load(ptr0_imag, mask=mask, other=0.0)
583 x1_real = tl.load(ptr1_real, mask=mask, other=0.0)
584 x1_imag = tl.load(ptr1_imag, mask=mask, other=0.0)
585 x2_real = tl.load(ptr2_real, mask=mask, other=0.0)
586 x2_imag = tl.load(ptr2_imag, mask=mask, other=0.0)
587 x3_real = tl.load(ptr3_real, mask=mask, other=0.0)
588 x3_imag = tl.load(ptr3_imag, mask=mask, other=0.0)
590 base_tw1 = (1 << (stage_s - 1)) - 1
591 base_tw2 = (1 << stage_s) - 1
592 tw1_idx = base_tw1 + j
593 tw2_idx = base_tw2 + j
594 tw1_real = tl.load(twiddle_real + tw1_idx, mask=mask, other=1.0)
595 tw1_imag = tl.load(twiddle_imag + tw1_idx, mask=mask, other=0.0)
596 tw2_real = tl.load(twiddle_real + tw2_idx, mask=mask, other=1.0)
597 tw2_imag = tl.load(twiddle_imag + tw2_idx, mask=mask, other=0.0)
599 t1_real = x1_real * tw1_real - x1_imag * tw1_imag
600 t1_imag = x1_real * tw1_imag + x1_imag * tw1_real
601 t3_real = x3_real * tw1_real - x3_imag * tw1_imag
602 t3_imag = x3_real * tw1_imag + x3_imag * tw1_real
604 u0_real = x0_real + t1_real
605 u0_imag = x0_imag + t1_imag
606 u1_real = x0_real - t1_real
607 u1_imag = x0_imag - t1_imag
608 v0_real = x2_real + t3_real
609 v0_imag = x2_imag + t3_imag
610 v1_real = x2_real - t3_real
611 v1_imag = x2_imag - t3_imag
613 v0_tw_real = v0_real * tw2_real - v0_imag * tw2_imag
614 v0_tw_imag = v0_real * tw2_imag + v0_imag * tw2_real
615 w3_real = tw2_imag
616 w3_imag = -tw2_real
617 v1_tw_real = v1_real * w3_real - v1_imag * w3_imag
618 v1_tw_imag = v1_real * w3_imag + v1_imag * w3_real
620 o0_real = u0_real + v0_tw_real
621 o0_imag = u0_imag + v0_tw_imag
622 o2_real = u0_real - v0_tw_real
623 o2_imag = u0_imag - v0_tw_imag
624 o1_real = u1_real + v1_tw_real
625 o1_imag = u1_imag + v1_tw_imag
626 o3_real = u1_real - v1_tw_real
627 o3_imag = u1_imag - v1_tw_imag
629 m0 = pos < quarter
630 m1 = (pos >= quarter) & (pos < half)
631 m2 = (pos >= half) & (pos < three_quarter)
632 out_real_val = tl.where(
633 m0, o0_real, tl.where(m1, o1_real, tl.where(m2, o2_real, o3_real))
634 )
635 out_imag_val = tl.where(
636 m0, o0_imag, tl.where(m1, o1_imag, tl.where(m2, o2_imag, o3_imag))
637 )
639 out_ptrs_real = tle.gpu.local_ptr(smem_out_real, (idx,))
640 out_ptrs_imag = tle.gpu.local_ptr(smem_out_imag, (idx,))
641 tl.store(out_ptrs_real, out_real_val, mask=mask)
642 tl.store(out_ptrs_imag, out_imag_val, mask=mask)
643 tl.debug_barrier()
645 smem_in_real, smem_out_real = smem_out_real, smem_in_real
646 smem_in_imag, smem_out_imag = smem_out_imag, smem_in_imag
648 out_real_ptrs = out_real + row * stride_out + offs
649 out_imag_ptrs = out_imag + row * stride_out + offs
650 smem_final_real_ptrs = tle.gpu.local_ptr(smem_in_real, (offs,))
651 smem_final_imag_ptrs = tle.gpu.local_ptr(smem_in_imag, (offs,))
652 out_vals_real = tl.load(smem_final_real_ptrs, mask=mask, other=0.0)
653 out_vals_imag = tl.load(smem_final_imag_ptrs, mask=mask, other=0.0)
654 tl.store(out_real_ptrs, out_vals_real, mask=mask)
655 tl.store(out_imag_ptrs, out_vals_imag, mask=mask)
657 @triton.jit
658 def fft_kernel_tle_reg(
659 in_real,
660 in_imag,
661 bitrev,
662 twiddle_real,
663 twiddle_imag,
664 out_real,
665 out_imag,
666 stride_in,
667 stride_out,
668 n_rows,
669 N: tl.constexpr,
670 LOG_N: tl.constexpr,
671 ):
672 pid = tl.program_id(0)
673 row = pid
674 offs = tl.arange(0, N)
675 row_valid = row < n_rows
676 mask = row_valid & (offs < N)
678 rev = tl.load(bitrev + offs, mask=offs < N, other=0)
679 in_real_ptrs = in_real + row * stride_in + rev
680 in_imag_ptrs = in_imag + row * stride_in + rev
681 x_real = tl.load(in_real_ptrs, mask=mask, other=0.0)
682 x_imag = tl.load(in_imag_ptrs, mask=mask, other=0.0)
684 if LOG_N % 2 == 1:
685 m = 2
686 half = 1
687 idx = offs
688 pos = idx & (m - 1)
689 j = pos & (half - 1)
690 base = idx - pos
691 even_idx = base + j
692 odd_idx = even_idx + half
694 u_real = tl.gather(x_real, even_idx, axis=0)
695 u_imag = tl.gather(x_imag, even_idx, axis=0)
696 v_real = tl.gather(x_real, odd_idx, axis=0)
697 v_imag = tl.gather(x_imag, odd_idx, axis=0)
699 tw_real = tl.load(twiddle_real + j, mask=mask, other=1.0)
700 tw_imag = tl.load(twiddle_imag + j, mask=mask, other=0.0)
702 v_tw_real = v_real * tw_real - v_imag * tw_imag
703 v_tw_imag = v_real * tw_imag + v_imag * tw_real
705 add_mask = pos < half
706 out_real_val = tl.where(add_mask, u_real + v_tw_real, u_real - v_tw_real)
707 out_imag_val = tl.where(add_mask, u_imag + v_tw_imag, u_imag - v_tw_imag)
708 x_real = out_real_val
709 x_imag = out_imag_val
711 if LOG_N % 2 == 1:
712 for r4 in tl.static_range((LOG_N - 1) // 2):
713 stage_s = 2 + r4 * 2
714 m = 1 << (stage_s + 1)
715 quarter = m >> 2
716 half = m >> 1
717 three_quarter = quarter + half
719 idx = offs
720 pos = idx & (m - 1)
721 j = pos & (quarter - 1)
722 base = idx - pos
723 i0 = base + j
724 i1 = i0 + quarter
725 i2 = i1 + quarter
726 i3 = i2 + quarter
728 x0_real = tl.gather(x_real, i0, axis=0)
729 x0_imag = tl.gather(x_imag, i0, axis=0)
730 x1_real = tl.gather(x_real, i1, axis=0)
731 x1_imag = tl.gather(x_imag, i1, axis=0)
732 x2_real = tl.gather(x_real, i2, axis=0)
733 x2_imag = tl.gather(x_imag, i2, axis=0)
734 x3_real = tl.gather(x_real, i3, axis=0)
735 x3_imag = tl.gather(x_imag, i3, axis=0)
737 base_tw1 = (1 << (stage_s - 1)) - 1
738 base_tw2 = (1 << stage_s) - 1
739 tw1_idx = base_tw1 + j
740 tw2_idx = base_tw2 + j
741 tw1_real = tl.load(twiddle_real + tw1_idx, mask=mask, other=1.0)
742 tw1_imag = tl.load(twiddle_imag + tw1_idx, mask=mask, other=0.0)
743 tw2_real = tl.load(twiddle_real + tw2_idx, mask=mask, other=1.0)
744 tw2_imag = tl.load(twiddle_imag + tw2_idx, mask=mask, other=0.0)
746 t1_real = x1_real * tw1_real - x1_imag * tw1_imag
747 t1_imag = x1_real * tw1_imag + x1_imag * tw1_real
748 t3_real = x3_real * tw1_real - x3_imag * tw1_imag
749 t3_imag = x3_real * tw1_imag + x3_imag * tw1_real
751 u0_real = x0_real + t1_real
752 u0_imag = x0_imag + t1_imag
753 u1_real = x0_real - t1_real
754 u1_imag = x0_imag - t1_imag
755 v0_real = x2_real + t3_real
756 v0_imag = x2_imag + t3_imag
757 v1_real = x2_real - t3_real
758 v1_imag = x2_imag - t3_imag
760 v0_tw_real = v0_real * tw2_real - v0_imag * tw2_imag
761 v0_tw_imag = v0_real * tw2_imag + v0_imag * tw2_real
762 w3_real = tw2_imag
763 w3_imag = -tw2_real
764 v1_tw_real = v1_real * w3_real - v1_imag * w3_imag
765 v1_tw_imag = v1_real * w3_imag + v1_imag * w3_real
767 o0_real = u0_real + v0_tw_real
768 o0_imag = u0_imag + v0_tw_imag
769 o2_real = u0_real - v0_tw_real
770 o2_imag = u0_imag - v0_tw_imag
771 o1_real = u1_real + v1_tw_real
772 o1_imag = u1_imag + v1_tw_imag
773 o3_real = u1_real - v1_tw_real
774 o3_imag = u1_imag - v1_tw_imag
776 m0 = pos < quarter
777 m1 = (pos >= quarter) & (pos < half)
778 m2 = (pos >= half) & (pos < three_quarter)
779 out_real_val = tl.where(
780 m0, o0_real, tl.where(m1, o1_real, tl.where(m2, o2_real, o3_real))
781 )
782 out_imag_val = tl.where(
783 m0, o0_imag, tl.where(m1, o1_imag, tl.where(m2, o2_imag, o3_imag))
784 )
785 x_real = out_real_val
786 x_imag = out_imag_val
787 else:
788 for r4 in tl.static_range(LOG_N // 2):
789 stage_s = 1 + r4 * 2
790 m = 1 << (stage_s + 1)
791 quarter = m >> 2
792 half = m >> 1
793 three_quarter = quarter + half
795 idx = offs
796 pos = idx & (m - 1)
797 j = pos & (quarter - 1)
798 base = idx - pos
799 i0 = base + j
800 i1 = i0 + quarter
801 i2 = i1 + quarter
802 i3 = i2 + quarter
804 x0_real = tl.gather(x_real, i0, axis=0)
805 x0_imag = tl.gather(x_imag, i0, axis=0)
806 x1_real = tl.gather(x_real, i1, axis=0)
807 x1_imag = tl.gather(x_imag, i1, axis=0)
808 x2_real = tl.gather(x_real, i2, axis=0)
809 x2_imag = tl.gather(x_imag, i2, axis=0)
810 x3_real = tl.gather(x_real, i3, axis=0)
811 x3_imag = tl.gather(x_imag, i3, axis=0)
813 base_tw1 = (1 << (stage_s - 1)) - 1
814 base_tw2 = (1 << stage_s) - 1
815 tw1_idx = base_tw1 + j
816 tw2_idx = base_tw2 + j
817 tw1_real = tl.load(twiddle_real + tw1_idx, mask=mask, other=1.0)
818 tw1_imag = tl.load(twiddle_imag + tw1_idx, mask=mask, other=0.0)
819 tw2_real = tl.load(twiddle_real + tw2_idx, mask=mask, other=1.0)
820 tw2_imag = tl.load(twiddle_imag + tw2_idx, mask=mask, other=0.0)
822 t1_real = x1_real * tw1_real - x1_imag * tw1_imag
823 t1_imag = x1_real * tw1_imag + x1_imag * tw1_real
824 t3_real = x3_real * tw1_real - x3_imag * tw1_imag
825 t3_imag = x3_real * tw1_imag + x3_imag * tw1_real
827 u0_real = x0_real + t1_real
828 u0_imag = x0_imag + t1_imag
829 u1_real = x0_real - t1_real
830 u1_imag = x0_imag - t1_imag
831 v0_real = x2_real + t3_real
832 v0_imag = x2_imag + t3_imag
833 v1_real = x2_real - t3_real
834 v1_imag = x2_imag - t3_imag
836 v0_tw_real = v0_real * tw2_real - v0_imag * tw2_imag
837 v0_tw_imag = v0_real * tw2_imag + v0_imag * tw2_real
838 w3_real = tw2_imag
839 w3_imag = -tw2_real
840 v1_tw_real = v1_real * w3_real - v1_imag * w3_imag
841 v1_tw_imag = v1_real * w3_imag + v1_imag * w3_real
843 o0_real = u0_real + v0_tw_real
844 o0_imag = u0_imag + v0_tw_imag
845 o2_real = u0_real - v0_tw_real
846 o2_imag = u0_imag - v0_tw_imag
847 o1_real = u1_real + v1_tw_real
848 o1_imag = u1_imag + v1_tw_imag
849 o3_real = u1_real - v1_tw_real
850 o3_imag = u1_imag - v1_tw_imag
852 m0 = pos < quarter
853 m1 = (pos >= quarter) & (pos < half)
854 m2 = (pos >= half) & (pos < three_quarter)
855 out_real_val = tl.where(
856 m0, o0_real, tl.where(m1, o1_real, tl.where(m2, o2_real, o3_real))
857 )
858 out_imag_val = tl.where(
859 m0, o0_imag, tl.where(m1, o1_imag, tl.where(m2, o2_imag, o3_imag))
860 )
861 x_real = out_real_val
862 x_imag = out_imag_val
864 out_real_ptrs = out_real + row * stride_out + offs
865 out_imag_ptrs = out_imag + row * stride_out + offs
866 tl.store(out_real_ptrs, x_real, mask=mask)
867 tl.store(out_imag_ptrs, x_imag, mask=mask)
870def fft(x: torch.Tensor) -> torch.Tensor:
871 """
872 1D FFT with Triton and TLE (TLE Tutorial)
873 =======================================
875 This tutorial implements a simple 1D complex FFT over the last dimension of an
876 (M, N) tensor and compares Triton vs TLE kernels against torch.fft.fft. If
877 `cuda.tile` is available, it also runs a cuTile FFT kernel adapted from NVIDIA's
878 cutile-python tests.
880 Notes
881 -----
882 - N must be a power-of-two (<= 1024) for this tutorial implementation.
883 - Complex values are represented as two float32 arrays (real/imag).
884 - The kernels implement iterative Cooley-Tukey DIT with a bit-reversal copy.
885 - Twiddle factors are precomputed on the host and read from global memory.
886 - TLE uses a register-only path for small N to reduce shared-memory traffic.
887 - cuTile path is optional and requires `cuda.tile` + `cupy`; it uses a 3-factor
888 decomposition with precomputed DFT/twiddle tables.
889 """
890 logger.debug("GEMS FFT")
891 assert x.is_cuda, "input must be on CUDA"
892 assert x.ndim == 2, "input must be 2D (M, N)"
893 m, n = x.shape
894 if not _is_power_of_two(n):
895 raise ValueError(f"N={n} must be a power-of-two")
896 if n > 1024:
897 raise ValueError(f"N={n} too large for this kernel (max 1024)")
899 in_real, in_imag = _prepare_input(x)
900 bitrev = _bitrev_indices(n, x.device)
901 tw_real, tw_imag = _twiddle_tables(n, x.device)
902 log_n = _log2(n)
904 with torch_device_fn.device(x.device):
905 if HAS_TLE:
906 out_real = torch.empty((m, n), device=x.device, dtype=torch.float32)
907 out_imag = torch.empty((m, n), device=x.device, dtype=torch.float32)
909 grid = (m,)
910 if n == _FFT_REG_THRESHOLD:
911 fft_kernel_tle_reg[grid](
912 in_real,
913 in_imag,
914 bitrev,
915 tw_real,
916 tw_imag,
917 out_real,
918 out_imag,
919 in_real.stride(0),
920 out_real.stride(0),
921 m,
922 N=n,
923 LOG_N=log_n,
924 num_warps=4,
925 num_stages=1,
926 )
927 else:
928 fft_kernel_tle[grid](
929 in_real,
930 in_imag,
931 bitrev,
932 tw_real,
933 tw_imag,
934 out_real,
935 out_imag,
936 in_real.stride(0),
937 out_real.stride(0),
938 m,
939 N=n,
940 LOG_N=log_n,
941 num_warps=4,
942 num_stages=1,
943 )
944 return torch.complex(out_real, out_imag)
945 else:
946 buf0_real = torch.empty((m, n), device=x.device, dtype=torch.float32)
947 buf0_imag = torch.empty((m, n), device=x.device, dtype=torch.float32)
948 buf1_real = torch.empty((m, n), device=x.device, dtype=torch.float32)
949 buf1_imag = torch.empty((m, n), device=x.device, dtype=torch.float32)
951 grid = (m,)
952 fft_kernel_triton[grid](
953 in_real,
954 in_imag,
955 bitrev,
956 tw_real,
957 tw_imag,
958 buf0_real,
959 buf0_imag,
960 buf1_real,
961 buf1_imag,
962 in_real.stride(0),
963 buf0_real.stride(0),
964 m,
965 N=n,
966 LOG_N=log_n,
967 num_warps=4,
968 num_stages=1,
969 )
971 # Kernel swaps buf_a/buf_b after each stage write.
972 # Total swaps = (log_n + 1) // 2 (1 radix-2 if odd, then radix-4 pairs).
973 # Result lands in buf0 when total_swaps is even, buf1 when odd.
974 total_swaps = (log_n + 1) // 2
975 if total_swaps % 2 == 0:
976 out_real = buf0_real
977 out_imag = buf0_imag
978 else:
979 out_real = buf1_real
980 out_imag = buf1_imag
982 return torch.complex(out_real, out_imag)