Coverage for src/flag_gems/ops/fft.py: 13%

692 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2import math 

3from typing import Tuple 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils.triton_version_utils import HAS_TLE 

11 

12logger = logging.getLogger(__name__) 

13 

14if HAS_TLE: 

15 import triton.experimental.tle.language as tle 

16else: 

17 tle = None 

18 

19PI = math.pi 

20_FFT_REG_THRESHOLD = 256 

21 

22_BITREV_CACHE: dict[Tuple[int, torch.device], torch.Tensor] = {} 

23_TWIDDLE_CACHE: dict[Tuple[int, torch.device], Tuple[torch.Tensor, torch.Tensor]] = {} 

24 

25 

26def _is_power_of_two(n: int) -> bool: 

27 return n > 0 and (n & (n - 1)) == 0 

28 

29 

30def _log2(n: int) -> int: 

31 return n.bit_length() - 1 

32 

33 

34def _bitrev_indices(n: int, device: torch.device) -> torch.Tensor: 

35 key = (n, device) 

36 cached = _BITREV_CACHE.get(key) 

37 if cached is not None: 

38 return cached 

39 log_n = _log2(n) 

40 idx = torch.arange(n, device=device, dtype=torch.int32) 

41 rev = torch.zeros_like(idx) 

42 tmp = idx.clone() 

43 for _ in range(log_n): 

44 rev = (rev << 1) | (tmp & 1) 

45 tmp = tmp >> 1 

46 _BITREV_CACHE[key] = rev 

47 return rev 

48 

49 

50def _twiddle_tables(n: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: 

51 key = (n, device) 

52 cached = _TWIDDLE_CACHE.get(key) 

53 if cached is not None: 

54 return cached 

55 log_n = _log2(n) 

56 tw_real = torch.empty((n - 1,), device=device, dtype=torch.float32) 

57 tw_imag = torch.empty((n - 1,), device=device, dtype=torch.float32) 

58 offset = 0 

59 for stage in range(log_n): 

60 m = 1 << (stage + 1) 

61 half = m >> 1 

62 j = torch.arange(half, device=device, dtype=torch.float32) 

63 angle = (-2.0 * PI / m) * j 

64 tw_real[offset : offset + half] = torch.cos(angle) 

65 tw_imag[offset : offset + half] = torch.sin(angle) 

66 offset += half 

67 _TWIDDLE_CACHE[key] = (tw_real, tw_imag) 

68 return tw_real, tw_imag 

69 

70 

71def _prepare_input(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 

72 if x.is_complex(): 

73 if x.dtype not in (torch.complex64, torch.complex128): 

74 raise ValueError(f"unsupported complex dtype: {x.dtype}") 

75 x = x.to(torch.complex64) 

76 real = x.real.contiguous() 

77 imag = x.imag.contiguous() 

78 else: 

79 if x.dtype not in (torch.float16, torch.float32, torch.bfloat16): 

80 raise ValueError(f"unsupported dtype: {x.dtype}") 

81 x = x.to(torch.float32) 

82 real = x.contiguous() 

83 imag = torch.zeros_like(real) 

84 return real, imag 

85 

86 

87@triton.jit 

88def fft_kernel_triton( 

89 in_real, 

90 in_imag, 

91 bitrev, 

92 twiddle_real, 

93 twiddle_imag, 

94 buf0_real, 

95 buf0_imag, 

96 buf1_real, 

97 buf1_imag, 

98 stride_in, 

99 stride_buf, 

100 n_rows, 

101 N: tl.constexpr, 

102 LOG_N: tl.constexpr, 

103): 

104 pid = tl.program_id(0) 

105 row = pid 

106 offs = tl.arange(0, N) 

107 row_valid = row < n_rows 

108 mask = row_valid & (offs < N) 

109 

110 rev = tl.load(bitrev + offs, mask=offs < N, other=0) 

111 in_real_ptrs = in_real + row * stride_in + rev 

112 in_imag_ptrs = in_imag + row * stride_in + rev 

113 vals_real = tl.load(in_real_ptrs, mask=mask, other=0.0) 

114 vals_imag = tl.load(in_imag_ptrs, mask=mask, other=0.0) 

115 

116 buf0_real_ptrs = buf0_real + row * stride_buf + offs 

117 buf0_imag_ptrs = buf0_imag + row * stride_buf + offs 

118 tl.store(buf0_real_ptrs, vals_real, mask=mask) 

119 tl.store(buf0_imag_ptrs, vals_imag, mask=mask) 

120 

121 buf_a_real = buf0_real 

122 buf_a_imag = buf0_imag 

123 buf_b_real = buf1_real 

124 buf_b_imag = buf1_imag 

125 

126 if LOG_N % 2 == 1: 

127 m = 2 

128 half = 1 

129 idx = offs 

130 pos = idx & (m - 1) 

131 j = pos & (half - 1) 

132 base = idx - pos 

133 even_idx = base + j 

134 odd_idx = even_idx + half 

135 

136 even_ptrs_real = buf_a_real + row * stride_buf + even_idx 

137 even_ptrs_imag = buf_a_imag + row * stride_buf + even_idx 

138 odd_ptrs_real = buf_a_real + row * stride_buf + odd_idx 

139 odd_ptrs_imag = buf_a_imag + row * stride_buf + odd_idx 

140 

141 u_real = tl.load(even_ptrs_real, mask=mask, other=0.0) 

142 u_imag = tl.load(even_ptrs_imag, mask=mask, other=0.0) 

143 v_real = tl.load(odd_ptrs_real, mask=mask, other=0.0) 

144 v_imag = tl.load(odd_ptrs_imag, mask=mask, other=0.0) 

145 

146 base_tw = 0 

147 tw_idx = base_tw + j 

148 tw_real = tl.load(twiddle_real + tw_idx, mask=mask, other=1.0) 

149 tw_imag = tl.load(twiddle_imag + tw_idx, mask=mask, other=0.0) 

150 

151 v_tw_real = v_real * tw_real - v_imag * tw_imag 

152 v_tw_imag = v_real * tw_imag + v_imag * tw_real 

153 

154 add_mask = pos < half 

155 out_real = tl.where(add_mask, u_real + v_tw_real, u_real - v_tw_real) 

156 out_imag = tl.where(add_mask, u_imag + v_tw_imag, u_imag - v_tw_imag) 

157 

158 out_ptrs_real = buf_b_real + row * stride_buf + idx 

159 out_ptrs_imag = buf_b_imag + row * stride_buf + idx 

160 tl.store(out_ptrs_real, out_real, mask=mask) 

161 tl.store(out_ptrs_imag, out_imag, mask=mask) 

162 tl.debug_barrier() 

163 

164 buf_a_real, buf_b_real = buf_b_real, buf_a_real 

165 buf_a_imag, buf_b_imag = buf_b_imag, buf_a_imag 

166 

167 if LOG_N % 2 == 1: 

168 for r4 in tl.static_range((LOG_N - 1) // 2): 

169 stage_s = 2 + r4 * 2 

170 m = 1 << (stage_s + 1) 

171 quarter = m >> 2 

172 half = m >> 1 

173 three_quarter = quarter + half 

174 

175 idx = offs 

176 pos = idx & (m - 1) 

177 j = pos & (quarter - 1) 

178 base = idx - pos 

179 i0 = base + j 

180 i1 = i0 + quarter 

181 i2 = i1 + quarter 

182 i3 = i2 + quarter 

183 

184 ptr0_real = buf_a_real + row * stride_buf + i0 

185 ptr0_imag = buf_a_imag + row * stride_buf + i0 

186 ptr1_real = buf_a_real + row * stride_buf + i1 

187 ptr1_imag = buf_a_imag + row * stride_buf + i1 

188 ptr2_real = buf_a_real + row * stride_buf + i2 

189 ptr2_imag = buf_a_imag + row * stride_buf + i2 

190 ptr3_real = buf_a_real + row * stride_buf + i3 

191 ptr3_imag = buf_a_imag + row * stride_buf + i3 

192 

193 x0_real = tl.load(ptr0_real, mask=mask, other=0.0) 

194 x0_imag = tl.load(ptr0_imag, mask=mask, other=0.0) 

195 x1_real = tl.load(ptr1_real, mask=mask, other=0.0) 

196 x1_imag = tl.load(ptr1_imag, mask=mask, other=0.0) 

197 x2_real = tl.load(ptr2_real, mask=mask, other=0.0) 

198 x2_imag = tl.load(ptr2_imag, mask=mask, other=0.0) 

199 x3_real = tl.load(ptr3_real, mask=mask, other=0.0) 

200 x3_imag = tl.load(ptr3_imag, mask=mask, other=0.0) 

201 

202 base_tw1 = (1 << (stage_s - 1)) - 1 

203 base_tw2 = (1 << stage_s) - 1 

204 tw1_idx = base_tw1 + j 

205 tw2_idx = base_tw2 + j 

206 tw1_real = tl.load(twiddle_real + tw1_idx, mask=mask, other=1.0) 

207 tw1_imag = tl.load(twiddle_imag + tw1_idx, mask=mask, other=0.0) 

208 tw2_real = tl.load(twiddle_real + tw2_idx, mask=mask, other=1.0) 

209 tw2_imag = tl.load(twiddle_imag + tw2_idx, mask=mask, other=0.0) 

210 

211 t1_real = x1_real * tw1_real - x1_imag * tw1_imag 

212 t1_imag = x1_real * tw1_imag + x1_imag * tw1_real 

213 t3_real = x3_real * tw1_real - x3_imag * tw1_imag 

214 t3_imag = x3_real * tw1_imag + x3_imag * tw1_real 

215 

216 u0_real = x0_real + t1_real 

217 u0_imag = x0_imag + t1_imag 

218 u1_real = x0_real - t1_real 

219 u1_imag = x0_imag - t1_imag 

220 v0_real = x2_real + t3_real 

221 v0_imag = x2_imag + t3_imag 

222 v1_real = x2_real - t3_real 

223 v1_imag = x2_imag - t3_imag 

224 

225 v0_tw_real = v0_real * tw2_real - v0_imag * tw2_imag 

226 v0_tw_imag = v0_real * tw2_imag + v0_imag * tw2_real 

227 w3_real = tw2_imag 

228 w3_imag = -tw2_real 

229 v1_tw_real = v1_real * w3_real - v1_imag * w3_imag 

230 v1_tw_imag = v1_real * w3_imag + v1_imag * w3_real 

231 

232 o0_real = u0_real + v0_tw_real 

233 o0_imag = u0_imag + v0_tw_imag 

234 o2_real = u0_real - v0_tw_real 

235 o2_imag = u0_imag - v0_tw_imag 

236 o1_real = u1_real + v1_tw_real 

237 o1_imag = u1_imag + v1_tw_imag 

238 o3_real = u1_real - v1_tw_real 

239 o3_imag = u1_imag - v1_tw_imag 

240 

241 m0 = pos < quarter 

242 m1 = (pos >= quarter) & (pos < half) 

243 m2 = (pos >= half) & (pos < three_quarter) 

244 out_real = tl.where( 

245 m0, o0_real, tl.where(m1, o1_real, tl.where(m2, o2_real, o3_real)) 

246 ) 

247 out_imag = tl.where( 

248 m0, o0_imag, tl.where(m1, o1_imag, tl.where(m2, o2_imag, o3_imag)) 

249 ) 

250 

251 out_ptrs_real = buf_b_real + row * stride_buf + idx 

252 out_ptrs_imag = buf_b_imag + row * stride_buf + idx 

253 tl.store(out_ptrs_real, out_real, mask=mask) 

254 tl.store(out_ptrs_imag, out_imag, mask=mask) 

255 tl.debug_barrier() 

256 

257 buf_a_real, buf_b_real = buf_b_real, buf_a_real 

258 buf_a_imag, buf_b_imag = buf_b_imag, buf_a_imag 

259 else: 

260 for r4 in tl.static_range(LOG_N // 2): 

261 stage_s = 1 + r4 * 2 

262 m = 1 << (stage_s + 1) 

263 quarter = m >> 2 

264 half = m >> 1 

265 three_quarter = quarter + half 

266 

267 idx = offs 

268 pos = idx & (m - 1) 

269 j = pos & (quarter - 1) 

270 base = idx - pos 

271 i0 = base + j 

272 i1 = i0 + quarter 

273 i2 = i1 + quarter 

274 i3 = i2 + quarter 

275 

276 ptr0_real = buf_a_real + row * stride_buf + i0 

277 ptr0_imag = buf_a_imag + row * stride_buf + i0 

278 ptr1_real = buf_a_real + row * stride_buf + i1 

279 ptr1_imag = buf_a_imag + row * stride_buf + i1 

280 ptr2_real = buf_a_real + row * stride_buf + i2 

281 ptr2_imag = buf_a_imag + row * stride_buf + i2 

282 ptr3_real = buf_a_real + row * stride_buf + i3 

283 ptr3_imag = buf_a_imag + row * stride_buf + i3 

284 

285 x0_real = tl.load(ptr0_real, mask=mask, other=0.0) 

286 x0_imag = tl.load(ptr0_imag, mask=mask, other=0.0) 

287 x1_real = tl.load(ptr1_real, mask=mask, other=0.0) 

288 x1_imag = tl.load(ptr1_imag, mask=mask, other=0.0) 

289 x2_real = tl.load(ptr2_real, mask=mask, other=0.0) 

290 x2_imag = tl.load(ptr2_imag, mask=mask, other=0.0) 

291 x3_real = tl.load(ptr3_real, mask=mask, other=0.0) 

292 x3_imag = tl.load(ptr3_imag, mask=mask, other=0.0) 

293 

294 base_tw1 = (1 << (stage_s - 1)) - 1 

295 base_tw2 = (1 << stage_s) - 1 

296 tw1_idx = base_tw1 + j 

297 tw2_idx = base_tw2 + j 

298 tw1_real = tl.load(twiddle_real + tw1_idx, mask=mask, other=1.0) 

299 tw1_imag = tl.load(twiddle_imag + tw1_idx, mask=mask, other=0.0) 

300 tw2_real = tl.load(twiddle_real + tw2_idx, mask=mask, other=1.0) 

301 tw2_imag = tl.load(twiddle_imag + tw2_idx, mask=mask, other=0.0) 

302 

303 t1_real = x1_real * tw1_real - x1_imag * tw1_imag 

304 t1_imag = x1_real * tw1_imag + x1_imag * tw1_real 

305 t3_real = x3_real * tw1_real - x3_imag * tw1_imag 

306 t3_imag = x3_real * tw1_imag + x3_imag * tw1_real 

307 

308 u0_real = x0_real + t1_real 

309 u0_imag = x0_imag + t1_imag 

310 u1_real = x0_real - t1_real 

311 u1_imag = x0_imag - t1_imag 

312 v0_real = x2_real + t3_real 

313 v0_imag = x2_imag + t3_imag 

314 v1_real = x2_real - t3_real 

315 v1_imag = x2_imag - t3_imag 

316 

317 v0_tw_real = v0_real * tw2_real - v0_imag * tw2_imag 

318 v0_tw_imag = v0_real * tw2_imag + v0_imag * tw2_real 

319 w3_real = tw2_imag 

320 w3_imag = -tw2_real 

321 v1_tw_real = v1_real * w3_real - v1_imag * w3_imag 

322 v1_tw_imag = v1_real * w3_imag + v1_imag * w3_real 

323 

324 o0_real = u0_real + v0_tw_real 

325 o0_imag = u0_imag + v0_tw_imag 

326 o2_real = u0_real - v0_tw_real 

327 o2_imag = u0_imag - v0_tw_imag 

328 o1_real = u1_real + v1_tw_real 

329 o1_imag = u1_imag + v1_tw_imag 

330 o3_real = u1_real - v1_tw_real 

331 o3_imag = u1_imag - v1_tw_imag 

332 

333 m0 = pos < quarter 

334 m1 = (pos >= quarter) & (pos < half) 

335 m2 = (pos >= half) & (pos < three_quarter) 

336 out_real = tl.where( 

337 m0, o0_real, tl.where(m1, o1_real, tl.where(m2, o2_real, o3_real)) 

338 ) 

339 out_imag = tl.where( 

340 m0, o0_imag, tl.where(m1, o1_imag, tl.where(m2, o2_imag, o3_imag)) 

341 ) 

342 

343 out_ptrs_real = buf_b_real + row * stride_buf + idx 

344 out_ptrs_imag = buf_b_imag + row * stride_buf + idx 

345 tl.store(out_ptrs_real, out_real, mask=mask) 

346 tl.store(out_ptrs_imag, out_imag, mask=mask) 

347 tl.debug_barrier() 

348 

349 buf_a_real, buf_b_real = buf_b_real, buf_a_real 

350 buf_a_imag, buf_b_imag = buf_b_imag, buf_a_imag 

351 

352 

353if HAS_TLE: 

354 

355 @triton.jit 

356 def fft_kernel_tle( 

357 in_real, 

358 in_imag, 

359 bitrev, 

360 twiddle_real, 

361 twiddle_imag, 

362 out_real, 

363 out_imag, 

364 stride_in, 

365 stride_out, 

366 n_rows, 

367 N: tl.constexpr, 

368 LOG_N: tl.constexpr, 

369 ): 

370 pid = tl.program_id(0) 

371 row = pid 

372 offs = tl.arange(0, N) 

373 row_valid = row < n_rows 

374 mask = row_valid & (offs < N) 

375 

376 smem_a_real = tle.gpu.alloc( 

377 [N], 

378 dtype=tl.float32, 

379 layout=None, 

380 scope=tle.gpu.smem, 

381 nv_mma_shared_layout=False, 

382 ) 

383 smem_a_imag = tle.gpu.alloc( 

384 [N], 

385 dtype=tl.float32, 

386 layout=None, 

387 scope=tle.gpu.smem, 

388 nv_mma_shared_layout=False, 

389 ) 

390 smem_b_real = tle.gpu.alloc( 

391 [N], 

392 dtype=tl.float32, 

393 layout=None, 

394 scope=tle.gpu.smem, 

395 nv_mma_shared_layout=False, 

396 ) 

397 smem_b_imag = tle.gpu.alloc( 

398 [N], 

399 dtype=tl.float32, 

400 layout=None, 

401 scope=tle.gpu.smem, 

402 nv_mma_shared_layout=False, 

403 ) 

404 

405 rev = tl.load(bitrev + offs, mask=offs < N, other=0) 

406 in_real_ptrs = in_real + row * stride_in + rev 

407 in_imag_ptrs = in_imag + row * stride_in + rev 

408 vals_real = tl.load(in_real_ptrs, mask=mask, other=0.0) 

409 vals_imag = tl.load(in_imag_ptrs, mask=mask, other=0.0) 

410 

411 smem_a_real_ptrs = tle.gpu.local_ptr(smem_a_real, (offs,)) 

412 smem_a_imag_ptrs = tle.gpu.local_ptr(smem_a_imag, (offs,)) 

413 tl.store(smem_a_real_ptrs, vals_real, mask=mask) 

414 tl.store(smem_a_imag_ptrs, vals_imag, mask=mask) 

415 tl.debug_barrier() 

416 

417 smem_in_real = smem_a_real 

418 smem_in_imag = smem_a_imag 

419 smem_out_real = smem_b_real 

420 smem_out_imag = smem_b_imag 

421 

422 if LOG_N % 2 == 1: 

423 m = 2 

424 half = 1 

425 idx = offs 

426 pos = idx & (m - 1) 

427 j = pos & (half - 1) 

428 base = idx - pos 

429 even_idx = base + j 

430 odd_idx = even_idx + half 

431 

432 even_ptrs_real = tle.gpu.local_ptr(smem_in_real, (even_idx,)) 

433 even_ptrs_imag = tle.gpu.local_ptr(smem_in_imag, (even_idx,)) 

434 odd_ptrs_real = tle.gpu.local_ptr(smem_in_real, (odd_idx,)) 

435 odd_ptrs_imag = tle.gpu.local_ptr(smem_in_imag, (odd_idx,)) 

436 

437 u_real = tl.load(even_ptrs_real, mask=mask, other=0.0) 

438 u_imag = tl.load(even_ptrs_imag, mask=mask, other=0.0) 

439 v_real = tl.load(odd_ptrs_real, mask=mask, other=0.0) 

440 v_imag = tl.load(odd_ptrs_imag, mask=mask, other=0.0) 

441 

442 base_tw = 0 

443 tw_idx = base_tw + j 

444 tw_real = tl.load(twiddle_real + tw_idx, mask=mask, other=1.0) 

445 tw_imag = tl.load(twiddle_imag + tw_idx, mask=mask, other=0.0) 

446 

447 v_tw_real = v_real * tw_real - v_imag * tw_imag 

448 v_tw_imag = v_real * tw_imag + v_imag * tw_real 

449 

450 add_mask = pos < half 

451 out_real_val = tl.where(add_mask, u_real + v_tw_real, u_real - v_tw_real) 

452 out_imag_val = tl.where(add_mask, u_imag + v_tw_imag, u_imag - v_tw_imag) 

453 

454 out_ptrs_real = tle.gpu.local_ptr(smem_out_real, (idx,)) 

455 out_ptrs_imag = tle.gpu.local_ptr(smem_out_imag, (idx,)) 

456 tl.store(out_ptrs_real, out_real_val, mask=mask) 

457 tl.store(out_ptrs_imag, out_imag_val, mask=mask) 

458 tl.debug_barrier() 

459 

460 smem_in_real, smem_out_real = smem_out_real, smem_in_real 

461 smem_in_imag, smem_out_imag = smem_out_imag, smem_in_imag 

462 

463 if LOG_N % 2 == 1: 

464 for r4 in tl.static_range((LOG_N - 1) // 2): 

465 stage_s = 2 + r4 * 2 

466 m = 1 << (stage_s + 1) 

467 quarter = m >> 2 

468 half = m >> 1 

469 three_quarter = quarter + half 

470 

471 idx = offs 

472 pos = idx & (m - 1) 

473 j = pos & (quarter - 1) 

474 base = idx - pos 

475 i0 = base + j 

476 i1 = i0 + quarter 

477 i2 = i1 + quarter 

478 i3 = i2 + quarter 

479 

480 ptr0_real = tle.gpu.local_ptr(smem_in_real, (i0,)) 

481 ptr0_imag = tle.gpu.local_ptr(smem_in_imag, (i0,)) 

482 ptr1_real = tle.gpu.local_ptr(smem_in_real, (i1,)) 

483 ptr1_imag = tle.gpu.local_ptr(smem_in_imag, (i1,)) 

484 ptr2_real = tle.gpu.local_ptr(smem_in_real, (i2,)) 

485 ptr2_imag = tle.gpu.local_ptr(smem_in_imag, (i2,)) 

486 ptr3_real = tle.gpu.local_ptr(smem_in_real, (i3,)) 

487 ptr3_imag = tle.gpu.local_ptr(smem_in_imag, (i3,)) 

488 

489 x0_real = tl.load(ptr0_real, mask=mask, other=0.0) 

490 x0_imag = tl.load(ptr0_imag, mask=mask, other=0.0) 

491 x1_real = tl.load(ptr1_real, mask=mask, other=0.0) 

492 x1_imag = tl.load(ptr1_imag, mask=mask, other=0.0) 

493 x2_real = tl.load(ptr2_real, mask=mask, other=0.0) 

494 x2_imag = tl.load(ptr2_imag, mask=mask, other=0.0) 

495 x3_real = tl.load(ptr3_real, mask=mask, other=0.0) 

496 x3_imag = tl.load(ptr3_imag, mask=mask, other=0.0) 

497 

498 base_tw1 = (1 << (stage_s - 1)) - 1 

499 base_tw2 = (1 << stage_s) - 1 

500 tw1_idx = base_tw1 + j 

501 tw2_idx = base_tw2 + j 

502 tw1_real = tl.load(twiddle_real + tw1_idx, mask=mask, other=1.0) 

503 tw1_imag = tl.load(twiddle_imag + tw1_idx, mask=mask, other=0.0) 

504 tw2_real = tl.load(twiddle_real + tw2_idx, mask=mask, other=1.0) 

505 tw2_imag = tl.load(twiddle_imag + tw2_idx, mask=mask, other=0.0) 

506 

507 t1_real = x1_real * tw1_real - x1_imag * tw1_imag 

508 t1_imag = x1_real * tw1_imag + x1_imag * tw1_real 

509 t3_real = x3_real * tw1_real - x3_imag * tw1_imag 

510 t3_imag = x3_real * tw1_imag + x3_imag * tw1_real 

511 

512 u0_real = x0_real + t1_real 

513 u0_imag = x0_imag + t1_imag 

514 u1_real = x0_real - t1_real 

515 u1_imag = x0_imag - t1_imag 

516 v0_real = x2_real + t3_real 

517 v0_imag = x2_imag + t3_imag 

518 v1_real = x2_real - t3_real 

519 v1_imag = x2_imag - t3_imag 

520 

521 v0_tw_real = v0_real * tw2_real - v0_imag * tw2_imag 

522 v0_tw_imag = v0_real * tw2_imag + v0_imag * tw2_real 

523 w3_real = tw2_imag 

524 w3_imag = -tw2_real 

525 v1_tw_real = v1_real * w3_real - v1_imag * w3_imag 

526 v1_tw_imag = v1_real * w3_imag + v1_imag * w3_real 

527 

528 o0_real = u0_real + v0_tw_real 

529 o0_imag = u0_imag + v0_tw_imag 

530 o2_real = u0_real - v0_tw_real 

531 o2_imag = u0_imag - v0_tw_imag 

532 o1_real = u1_real + v1_tw_real 

533 o1_imag = u1_imag + v1_tw_imag 

534 o3_real = u1_real - v1_tw_real 

535 o3_imag = u1_imag - v1_tw_imag 

536 

537 m0 = pos < quarter 

538 m1 = (pos >= quarter) & (pos < half) 

539 m2 = (pos >= half) & (pos < three_quarter) 

540 out_real_val = tl.where( 

541 m0, o0_real, tl.where(m1, o1_real, tl.where(m2, o2_real, o3_real)) 

542 ) 

543 out_imag_val = tl.where( 

544 m0, o0_imag, tl.where(m1, o1_imag, tl.where(m2, o2_imag, o3_imag)) 

545 ) 

546 

547 out_ptrs_real = tle.gpu.local_ptr(smem_out_real, (idx,)) 

548 out_ptrs_imag = tle.gpu.local_ptr(smem_out_imag, (idx,)) 

549 tl.store(out_ptrs_real, out_real_val, mask=mask) 

550 tl.store(out_ptrs_imag, out_imag_val, mask=mask) 

551 tl.debug_barrier() 

552 

553 smem_in_real, smem_out_real = smem_out_real, smem_in_real 

554 smem_in_imag, smem_out_imag = smem_out_imag, smem_in_imag 

555 else: 

556 for r4 in tl.static_range(LOG_N // 2): 

557 stage_s = 1 + r4 * 2 

558 m = 1 << (stage_s + 1) 

559 quarter = m >> 2 

560 half = m >> 1 

561 three_quarter = quarter + half 

562 

563 idx = offs 

564 pos = idx & (m - 1) 

565 j = pos & (quarter - 1) 

566 base = idx - pos 

567 i0 = base + j 

568 i1 = i0 + quarter 

569 i2 = i1 + quarter 

570 i3 = i2 + quarter 

571 

572 ptr0_real = tle.gpu.local_ptr(smem_in_real, (i0,)) 

573 ptr0_imag = tle.gpu.local_ptr(smem_in_imag, (i0,)) 

574 ptr1_real = tle.gpu.local_ptr(smem_in_real, (i1,)) 

575 ptr1_imag = tle.gpu.local_ptr(smem_in_imag, (i1,)) 

576 ptr2_real = tle.gpu.local_ptr(smem_in_real, (i2,)) 

577 ptr2_imag = tle.gpu.local_ptr(smem_in_imag, (i2,)) 

578 ptr3_real = tle.gpu.local_ptr(smem_in_real, (i3,)) 

579 ptr3_imag = tle.gpu.local_ptr(smem_in_imag, (i3,)) 

580 

581 x0_real = tl.load(ptr0_real, mask=mask, other=0.0) 

582 x0_imag = tl.load(ptr0_imag, mask=mask, other=0.0) 

583 x1_real = tl.load(ptr1_real, mask=mask, other=0.0) 

584 x1_imag = tl.load(ptr1_imag, mask=mask, other=0.0) 

585 x2_real = tl.load(ptr2_real, mask=mask, other=0.0) 

586 x2_imag = tl.load(ptr2_imag, mask=mask, other=0.0) 

587 x3_real = tl.load(ptr3_real, mask=mask, other=0.0) 

588 x3_imag = tl.load(ptr3_imag, mask=mask, other=0.0) 

589 

590 base_tw1 = (1 << (stage_s - 1)) - 1 

591 base_tw2 = (1 << stage_s) - 1 

592 tw1_idx = base_tw1 + j 

593 tw2_idx = base_tw2 + j 

594 tw1_real = tl.load(twiddle_real + tw1_idx, mask=mask, other=1.0) 

595 tw1_imag = tl.load(twiddle_imag + tw1_idx, mask=mask, other=0.0) 

596 tw2_real = tl.load(twiddle_real + tw2_idx, mask=mask, other=1.0) 

597 tw2_imag = tl.load(twiddle_imag + tw2_idx, mask=mask, other=0.0) 

598 

599 t1_real = x1_real * tw1_real - x1_imag * tw1_imag 

600 t1_imag = x1_real * tw1_imag + x1_imag * tw1_real 

601 t3_real = x3_real * tw1_real - x3_imag * tw1_imag 

602 t3_imag = x3_real * tw1_imag + x3_imag * tw1_real 

603 

604 u0_real = x0_real + t1_real 

605 u0_imag = x0_imag + t1_imag 

606 u1_real = x0_real - t1_real 

607 u1_imag = x0_imag - t1_imag 

608 v0_real = x2_real + t3_real 

609 v0_imag = x2_imag + t3_imag 

610 v1_real = x2_real - t3_real 

611 v1_imag = x2_imag - t3_imag 

612 

613 v0_tw_real = v0_real * tw2_real - v0_imag * tw2_imag 

614 v0_tw_imag = v0_real * tw2_imag + v0_imag * tw2_real 

615 w3_real = tw2_imag 

616 w3_imag = -tw2_real 

617 v1_tw_real = v1_real * w3_real - v1_imag * w3_imag 

618 v1_tw_imag = v1_real * w3_imag + v1_imag * w3_real 

619 

620 o0_real = u0_real + v0_tw_real 

621 o0_imag = u0_imag + v0_tw_imag 

622 o2_real = u0_real - v0_tw_real 

623 o2_imag = u0_imag - v0_tw_imag 

624 o1_real = u1_real + v1_tw_real 

625 o1_imag = u1_imag + v1_tw_imag 

626 o3_real = u1_real - v1_tw_real 

627 o3_imag = u1_imag - v1_tw_imag 

628 

629 m0 = pos < quarter 

630 m1 = (pos >= quarter) & (pos < half) 

631 m2 = (pos >= half) & (pos < three_quarter) 

632 out_real_val = tl.where( 

633 m0, o0_real, tl.where(m1, o1_real, tl.where(m2, o2_real, o3_real)) 

634 ) 

635 out_imag_val = tl.where( 

636 m0, o0_imag, tl.where(m1, o1_imag, tl.where(m2, o2_imag, o3_imag)) 

637 ) 

638 

639 out_ptrs_real = tle.gpu.local_ptr(smem_out_real, (idx,)) 

640 out_ptrs_imag = tle.gpu.local_ptr(smem_out_imag, (idx,)) 

641 tl.store(out_ptrs_real, out_real_val, mask=mask) 

642 tl.store(out_ptrs_imag, out_imag_val, mask=mask) 

643 tl.debug_barrier() 

644 

645 smem_in_real, smem_out_real = smem_out_real, smem_in_real 

646 smem_in_imag, smem_out_imag = smem_out_imag, smem_in_imag 

647 

648 out_real_ptrs = out_real + row * stride_out + offs 

649 out_imag_ptrs = out_imag + row * stride_out + offs 

650 smem_final_real_ptrs = tle.gpu.local_ptr(smem_in_real, (offs,)) 

651 smem_final_imag_ptrs = tle.gpu.local_ptr(smem_in_imag, (offs,)) 

652 out_vals_real = tl.load(smem_final_real_ptrs, mask=mask, other=0.0) 

653 out_vals_imag = tl.load(smem_final_imag_ptrs, mask=mask, other=0.0) 

654 tl.store(out_real_ptrs, out_vals_real, mask=mask) 

655 tl.store(out_imag_ptrs, out_vals_imag, mask=mask) 

656 

657 @triton.jit 

658 def fft_kernel_tle_reg( 

659 in_real, 

660 in_imag, 

661 bitrev, 

662 twiddle_real, 

663 twiddle_imag, 

664 out_real, 

665 out_imag, 

666 stride_in, 

667 stride_out, 

668 n_rows, 

669 N: tl.constexpr, 

670 LOG_N: tl.constexpr, 

671 ): 

672 pid = tl.program_id(0) 

673 row = pid 

674 offs = tl.arange(0, N) 

675 row_valid = row < n_rows 

676 mask = row_valid & (offs < N) 

677 

678 rev = tl.load(bitrev + offs, mask=offs < N, other=0) 

679 in_real_ptrs = in_real + row * stride_in + rev 

680 in_imag_ptrs = in_imag + row * stride_in + rev 

681 x_real = tl.load(in_real_ptrs, mask=mask, other=0.0) 

682 x_imag = tl.load(in_imag_ptrs, mask=mask, other=0.0) 

683 

684 if LOG_N % 2 == 1: 

685 m = 2 

686 half = 1 

687 idx = offs 

688 pos = idx & (m - 1) 

689 j = pos & (half - 1) 

690 base = idx - pos 

691 even_idx = base + j 

692 odd_idx = even_idx + half 

693 

694 u_real = tl.gather(x_real, even_idx, axis=0) 

695 u_imag = tl.gather(x_imag, even_idx, axis=0) 

696 v_real = tl.gather(x_real, odd_idx, axis=0) 

697 v_imag = tl.gather(x_imag, odd_idx, axis=0) 

698 

699 tw_real = tl.load(twiddle_real + j, mask=mask, other=1.0) 

700 tw_imag = tl.load(twiddle_imag + j, mask=mask, other=0.0) 

701 

702 v_tw_real = v_real * tw_real - v_imag * tw_imag 

703 v_tw_imag = v_real * tw_imag + v_imag * tw_real 

704 

705 add_mask = pos < half 

706 out_real_val = tl.where(add_mask, u_real + v_tw_real, u_real - v_tw_real) 

707 out_imag_val = tl.where(add_mask, u_imag + v_tw_imag, u_imag - v_tw_imag) 

708 x_real = out_real_val 

709 x_imag = out_imag_val 

710 

711 if LOG_N % 2 == 1: 

712 for r4 in tl.static_range((LOG_N - 1) // 2): 

713 stage_s = 2 + r4 * 2 

714 m = 1 << (stage_s + 1) 

715 quarter = m >> 2 

716 half = m >> 1 

717 three_quarter = quarter + half 

718 

719 idx = offs 

720 pos = idx & (m - 1) 

721 j = pos & (quarter - 1) 

722 base = idx - pos 

723 i0 = base + j 

724 i1 = i0 + quarter 

725 i2 = i1 + quarter 

726 i3 = i2 + quarter 

727 

728 x0_real = tl.gather(x_real, i0, axis=0) 

729 x0_imag = tl.gather(x_imag, i0, axis=0) 

730 x1_real = tl.gather(x_real, i1, axis=0) 

731 x1_imag = tl.gather(x_imag, i1, axis=0) 

732 x2_real = tl.gather(x_real, i2, axis=0) 

733 x2_imag = tl.gather(x_imag, i2, axis=0) 

734 x3_real = tl.gather(x_real, i3, axis=0) 

735 x3_imag = tl.gather(x_imag, i3, axis=0) 

736 

737 base_tw1 = (1 << (stage_s - 1)) - 1 

738 base_tw2 = (1 << stage_s) - 1 

739 tw1_idx = base_tw1 + j 

740 tw2_idx = base_tw2 + j 

741 tw1_real = tl.load(twiddle_real + tw1_idx, mask=mask, other=1.0) 

742 tw1_imag = tl.load(twiddle_imag + tw1_idx, mask=mask, other=0.0) 

743 tw2_real = tl.load(twiddle_real + tw2_idx, mask=mask, other=1.0) 

744 tw2_imag = tl.load(twiddle_imag + tw2_idx, mask=mask, other=0.0) 

745 

746 t1_real = x1_real * tw1_real - x1_imag * tw1_imag 

747 t1_imag = x1_real * tw1_imag + x1_imag * tw1_real 

748 t3_real = x3_real * tw1_real - x3_imag * tw1_imag 

749 t3_imag = x3_real * tw1_imag + x3_imag * tw1_real 

750 

751 u0_real = x0_real + t1_real 

752 u0_imag = x0_imag + t1_imag 

753 u1_real = x0_real - t1_real 

754 u1_imag = x0_imag - t1_imag 

755 v0_real = x2_real + t3_real 

756 v0_imag = x2_imag + t3_imag 

757 v1_real = x2_real - t3_real 

758 v1_imag = x2_imag - t3_imag 

759 

760 v0_tw_real = v0_real * tw2_real - v0_imag * tw2_imag 

761 v0_tw_imag = v0_real * tw2_imag + v0_imag * tw2_real 

762 w3_real = tw2_imag 

763 w3_imag = -tw2_real 

764 v1_tw_real = v1_real * w3_real - v1_imag * w3_imag 

765 v1_tw_imag = v1_real * w3_imag + v1_imag * w3_real 

766 

767 o0_real = u0_real + v0_tw_real 

768 o0_imag = u0_imag + v0_tw_imag 

769 o2_real = u0_real - v0_tw_real 

770 o2_imag = u0_imag - v0_tw_imag 

771 o1_real = u1_real + v1_tw_real 

772 o1_imag = u1_imag + v1_tw_imag 

773 o3_real = u1_real - v1_tw_real 

774 o3_imag = u1_imag - v1_tw_imag 

775 

776 m0 = pos < quarter 

777 m1 = (pos >= quarter) & (pos < half) 

778 m2 = (pos >= half) & (pos < three_quarter) 

779 out_real_val = tl.where( 

780 m0, o0_real, tl.where(m1, o1_real, tl.where(m2, o2_real, o3_real)) 

781 ) 

782 out_imag_val = tl.where( 

783 m0, o0_imag, tl.where(m1, o1_imag, tl.where(m2, o2_imag, o3_imag)) 

784 ) 

785 x_real = out_real_val 

786 x_imag = out_imag_val 

787 else: 

788 for r4 in tl.static_range(LOG_N // 2): 

789 stage_s = 1 + r4 * 2 

790 m = 1 << (stage_s + 1) 

791 quarter = m >> 2 

792 half = m >> 1 

793 three_quarter = quarter + half 

794 

795 idx = offs 

796 pos = idx & (m - 1) 

797 j = pos & (quarter - 1) 

798 base = idx - pos 

799 i0 = base + j 

800 i1 = i0 + quarter 

801 i2 = i1 + quarter 

802 i3 = i2 + quarter 

803 

804 x0_real = tl.gather(x_real, i0, axis=0) 

805 x0_imag = tl.gather(x_imag, i0, axis=0) 

806 x1_real = tl.gather(x_real, i1, axis=0) 

807 x1_imag = tl.gather(x_imag, i1, axis=0) 

808 x2_real = tl.gather(x_real, i2, axis=0) 

809 x2_imag = tl.gather(x_imag, i2, axis=0) 

810 x3_real = tl.gather(x_real, i3, axis=0) 

811 x3_imag = tl.gather(x_imag, i3, axis=0) 

812 

813 base_tw1 = (1 << (stage_s - 1)) - 1 

814 base_tw2 = (1 << stage_s) - 1 

815 tw1_idx = base_tw1 + j 

816 tw2_idx = base_tw2 + j 

817 tw1_real = tl.load(twiddle_real + tw1_idx, mask=mask, other=1.0) 

818 tw1_imag = tl.load(twiddle_imag + tw1_idx, mask=mask, other=0.0) 

819 tw2_real = tl.load(twiddle_real + tw2_idx, mask=mask, other=1.0) 

820 tw2_imag = tl.load(twiddle_imag + tw2_idx, mask=mask, other=0.0) 

821 

822 t1_real = x1_real * tw1_real - x1_imag * tw1_imag 

823 t1_imag = x1_real * tw1_imag + x1_imag * tw1_real 

824 t3_real = x3_real * tw1_real - x3_imag * tw1_imag 

825 t3_imag = x3_real * tw1_imag + x3_imag * tw1_real 

826 

827 u0_real = x0_real + t1_real 

828 u0_imag = x0_imag + t1_imag 

829 u1_real = x0_real - t1_real 

830 u1_imag = x0_imag - t1_imag 

831 v0_real = x2_real + t3_real 

832 v0_imag = x2_imag + t3_imag 

833 v1_real = x2_real - t3_real 

834 v1_imag = x2_imag - t3_imag 

835 

836 v0_tw_real = v0_real * tw2_real - v0_imag * tw2_imag 

837 v0_tw_imag = v0_real * tw2_imag + v0_imag * tw2_real 

838 w3_real = tw2_imag 

839 w3_imag = -tw2_real 

840 v1_tw_real = v1_real * w3_real - v1_imag * w3_imag 

841 v1_tw_imag = v1_real * w3_imag + v1_imag * w3_real 

842 

843 o0_real = u0_real + v0_tw_real 

844 o0_imag = u0_imag + v0_tw_imag 

845 o2_real = u0_real - v0_tw_real 

846 o2_imag = u0_imag - v0_tw_imag 

847 o1_real = u1_real + v1_tw_real 

848 o1_imag = u1_imag + v1_tw_imag 

849 o3_real = u1_real - v1_tw_real 

850 o3_imag = u1_imag - v1_tw_imag 

851 

852 m0 = pos < quarter 

853 m1 = (pos >= quarter) & (pos < half) 

854 m2 = (pos >= half) & (pos < three_quarter) 

855 out_real_val = tl.where( 

856 m0, o0_real, tl.where(m1, o1_real, tl.where(m2, o2_real, o3_real)) 

857 ) 

858 out_imag_val = tl.where( 

859 m0, o0_imag, tl.where(m1, o1_imag, tl.where(m2, o2_imag, o3_imag)) 

860 ) 

861 x_real = out_real_val 

862 x_imag = out_imag_val 

863 

864 out_real_ptrs = out_real + row * stride_out + offs 

865 out_imag_ptrs = out_imag + row * stride_out + offs 

866 tl.store(out_real_ptrs, x_real, mask=mask) 

867 tl.store(out_imag_ptrs, x_imag, mask=mask) 

868 

869 

870def fft(x: torch.Tensor) -> torch.Tensor: 

871 """ 

872 1D FFT with Triton and TLE (TLE Tutorial) 

873 ======================================= 

874 

875 This tutorial implements a simple 1D complex FFT over the last dimension of an 

876 (M, N) tensor and compares Triton vs TLE kernels against torch.fft.fft. If 

877 `cuda.tile` is available, it also runs a cuTile FFT kernel adapted from NVIDIA's 

878 cutile-python tests. 

879 

880 Notes 

881 ----- 

882 - N must be a power-of-two (<= 1024) for this tutorial implementation. 

883 - Complex values are represented as two float32 arrays (real/imag). 

884 - The kernels implement iterative Cooley-Tukey DIT with a bit-reversal copy. 

885 - Twiddle factors are precomputed on the host and read from global memory. 

886 - TLE uses a register-only path for small N to reduce shared-memory traffic. 

887 - cuTile path is optional and requires `cuda.tile` + `cupy`; it uses a 3-factor 

888 decomposition with precomputed DFT/twiddle tables. 

889 """ 

890 logger.debug("GEMS FFT") 

891 assert x.is_cuda, "input must be on CUDA" 

892 assert x.ndim == 2, "input must be 2D (M, N)" 

893 m, n = x.shape 

894 if not _is_power_of_two(n): 

895 raise ValueError(f"N={n} must be a power-of-two") 

896 if n > 1024: 

897 raise ValueError(f"N={n} too large for this kernel (max 1024)") 

898 

899 in_real, in_imag = _prepare_input(x) 

900 bitrev = _bitrev_indices(n, x.device) 

901 tw_real, tw_imag = _twiddle_tables(n, x.device) 

902 log_n = _log2(n) 

903 

904 with torch_device_fn.device(x.device): 

905 if HAS_TLE: 

906 out_real = torch.empty((m, n), device=x.device, dtype=torch.float32) 

907 out_imag = torch.empty((m, n), device=x.device, dtype=torch.float32) 

908 

909 grid = (m,) 

910 if n == _FFT_REG_THRESHOLD: 

911 fft_kernel_tle_reg[grid]( 

912 in_real, 

913 in_imag, 

914 bitrev, 

915 tw_real, 

916 tw_imag, 

917 out_real, 

918 out_imag, 

919 in_real.stride(0), 

920 out_real.stride(0), 

921 m, 

922 N=n, 

923 LOG_N=log_n, 

924 num_warps=4, 

925 num_stages=1, 

926 ) 

927 else: 

928 fft_kernel_tle[grid]( 

929 in_real, 

930 in_imag, 

931 bitrev, 

932 tw_real, 

933 tw_imag, 

934 out_real, 

935 out_imag, 

936 in_real.stride(0), 

937 out_real.stride(0), 

938 m, 

939 N=n, 

940 LOG_N=log_n, 

941 num_warps=4, 

942 num_stages=1, 

943 ) 

944 return torch.complex(out_real, out_imag) 

945 else: 

946 buf0_real = torch.empty((m, n), device=x.device, dtype=torch.float32) 

947 buf0_imag = torch.empty((m, n), device=x.device, dtype=torch.float32) 

948 buf1_real = torch.empty((m, n), device=x.device, dtype=torch.float32) 

949 buf1_imag = torch.empty((m, n), device=x.device, dtype=torch.float32) 

950 

951 grid = (m,) 

952 fft_kernel_triton[grid]( 

953 in_real, 

954 in_imag, 

955 bitrev, 

956 tw_real, 

957 tw_imag, 

958 buf0_real, 

959 buf0_imag, 

960 buf1_real, 

961 buf1_imag, 

962 in_real.stride(0), 

963 buf0_real.stride(0), 

964 m, 

965 N=n, 

966 LOG_N=log_n, 

967 num_warps=4, 

968 num_stages=1, 

969 ) 

970 

971 # Kernel swaps buf_a/buf_b after each stage write. 

972 # Total swaps = (log_n + 1) // 2 (1 radix-2 if odd, then radix-4 pairs). 

973 # Result lands in buf0 when total_swaps is even, buf1 when odd. 

974 total_swaps = (log_n + 1) // 2 

975 if total_swaps % 2 == 0: 

976 out_real = buf0_real 

977 out_imag = buf0_imag 

978 else: 

979 out_real = buf1_real 

980 out_imag = buf1_imag 

981 

982 return torch.complex(out_real, out_imag)