Coverage for src/flag_gems/ops/hadamard_transform.py: 14%

265 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1"""Fast Hadamard Transform in Triton. 

2 

3Drop-in replacement for Dao-AILab/fast-hadamard-transform with identical interface: 

4 - hadamard_transform(x, scale=1.0) with autograd support 

5 - hadamard_transform_12N/20N/28N/40N(x, scale=1.0) for non-power-of-2 dims 

6 - Input: (..., dim), fp32/fp16/bf16 

7 - Output: (..., dim), same dtype as input 

8 - Padding: to next multiple of 8 (matching CUDA impl) 

9 - dim <= 32768 (standard), dim <= M*2^10 (XXN variants) 

10 

11Reference: https://github.com/Dao-AILab/fast-hadamard-transform 

12""" 

13 

14import math 

15 

16import torch 

17import torch.nn.functional as F 

18import triton 

19import triton.language as tl 

20 

21# ============================================================ 

22# Triton kernel — v43: remove evict_first from loads + warps=2 for dim=256 

23# ============================================================ 

24# v35 best: dim=256 0.9302x (no evict_first on loads, warps=1) 

25# v42: dim=256 0.8950x (evict_first on loads hurt — L2 thrashing) 

26# 

27# v43 strategy: 

28# 1. Remove evict_first from all loads. v42 proved it hurts dim=256 

29# (0.8950x vs v35's 0.9302x). The 256-element rows (512B fp16) 

30# are small enough that L2 caching of nearby rows helps prefetch. 

31# 2. Try num_warps=2 for dim=256 4-row ILP kernel. With 4 rows of 

32# 256 elements each, the workload can benefit from 2-warp occupancy: 

33# each warp handles the compute for its assigned instructions, 

34# and the scheduler can overlap loads from one warp with compute 

35# from the other. This targets the memory latency hiding gap. 

36# 3. Keep evict_first on stores (write-once streaming pattern). 

37# 4. Keep all other kernels unchanged from v42 baseline. 

38 

39 

40# ============================================================ 

41# Butterfly stages 

42# ============================================================ 

43 

44 

45@triton.jit 

46def _butterfly_stage_1d(x, BLOCK_SIZE: tl.constexpr, STRIDE: tl.constexpr): 

47 """One butterfly stage on a 1D vector.""" 

48 GRP: tl.constexpr = BLOCK_SIZE // (2 * STRIDE) 

49 if STRIDE == 1: 

50 x2 = tl.reshape(x, (GRP, 2)) 

51 a, b = tl.split(x2) 

52 return tl.reshape(tl.join(a + b, a - b), (BLOCK_SIZE,)) 

53 else: 

54 x3 = tl.reshape(x, (GRP, 2, STRIDE)) 

55 x3 = tl.permute(x3, (0, 2, 1)) 

56 a, b = tl.split(x3) 

57 x3 = tl.join(a + b, a - b) 

58 x3 = tl.permute(x3, (0, 2, 1)) 

59 return tl.reshape(x3, (BLOCK_SIZE,)) 

60 

61 

62@triton.jit 

63def _butterfly_stage_2d( 

64 x, ROWS: tl.constexpr, BLOCK_SIZE: tl.constexpr, STRIDE: tl.constexpr 

65): 

66 """One butterfly stage on a 2D (ROWS, BLOCK_SIZE) tensor.""" 

67 GRP: tl.constexpr = BLOCK_SIZE // (2 * STRIDE) 

68 if STRIDE == 1: 

69 x2 = tl.reshape(x, (ROWS, GRP, 2)) 

70 a, b = tl.split(x2) 

71 return tl.reshape(tl.join(a + b, a - b), (ROWS, BLOCK_SIZE)) 

72 else: 

73 x3 = tl.reshape(x, (ROWS, GRP, 2, STRIDE)) 

74 x3 = tl.permute(x3, (0, 1, 3, 2)) 

75 a, b = tl.split(x3) 

76 x3 = tl.join(a + b, a - b) 

77 x3 = tl.permute(x3, (0, 1, 3, 2)) 

78 return tl.reshape(x3, (ROWS, BLOCK_SIZE)) 

79 

80 

81# ============================================================ 

82# 4-row ILP 1D native kernel for dim=256 (8 hardcoded stages) 

83# v43: remove evict_first from loads, keep on stores 

84# ============================================================ 

85 

86 

87@triton.jit 

88def _fht_kernel_256_4row_native( 

89 X_ptr, 

90 OUT_ptr, 

91 stride_x_row, 

92 stride_out_row, 

93 N_ROWS, 

94 SCALE: tl.constexpr, 

95): 

96 """FHT for dim=256, 4-row ILP: four independent 1D butterflies per program.""" 

97 pid = tl.program_id(0) 

98 col_offs = tl.arange(0, 256) 

99 

100 row0 = pid * 4 

101 row1 = row0 + 1 

102 row2 = row0 + 2 

103 row3 = row0 + 3 

104 

105 # Load all 4 rows (no evict_first: L2 caching helps for nearby rows) 

106 x0 = tl.load(X_ptr + row0 * stride_x_row + col_offs) 

107 x1 = tl.load(X_ptr + row1 * stride_x_row + col_offs, mask=row1 < N_ROWS, other=0.0) 

108 x2 = tl.load(X_ptr + row2 * stride_x_row + col_offs, mask=row2 < N_ROWS, other=0.0) 

109 x3 = tl.load(X_ptr + row3 * stride_x_row + col_offs, mask=row3 < N_ROWS, other=0.0) 

110 

111 # Interleaved hardcoded reversed butterfly stages for 4-way ILP 

112 x0 = _butterfly_stage_1d(x0, 256, 128) 

113 x1 = _butterfly_stage_1d(x1, 256, 128) 

114 x2 = _butterfly_stage_1d(x2, 256, 128) 

115 x3 = _butterfly_stage_1d(x3, 256, 128) 

116 x0 = _butterfly_stage_1d(x0, 256, 64) 

117 x1 = _butterfly_stage_1d(x1, 256, 64) 

118 x2 = _butterfly_stage_1d(x2, 256, 64) 

119 x3 = _butterfly_stage_1d(x3, 256, 64) 

120 x0 = _butterfly_stage_1d(x0, 256, 32) 

121 x1 = _butterfly_stage_1d(x1, 256, 32) 

122 x2 = _butterfly_stage_1d(x2, 256, 32) 

123 x3 = _butterfly_stage_1d(x3, 256, 32) 

124 x0 = _butterfly_stage_1d(x0, 256, 16) 

125 x1 = _butterfly_stage_1d(x1, 256, 16) 

126 x2 = _butterfly_stage_1d(x2, 256, 16) 

127 x3 = _butterfly_stage_1d(x3, 256, 16) 

128 x0 = _butterfly_stage_1d(x0, 256, 8) 

129 x1 = _butterfly_stage_1d(x1, 256, 8) 

130 x2 = _butterfly_stage_1d(x2, 256, 8) 

131 x3 = _butterfly_stage_1d(x3, 256, 8) 

132 x0 = _butterfly_stage_1d(x0, 256, 4) 

133 x1 = _butterfly_stage_1d(x1, 256, 4) 

134 x2 = _butterfly_stage_1d(x2, 256, 4) 

135 x3 = _butterfly_stage_1d(x3, 256, 4) 

136 x0 = _butterfly_stage_1d(x0, 256, 2) 

137 x1 = _butterfly_stage_1d(x1, 256, 2) 

138 x2 = _butterfly_stage_1d(x2, 256, 2) 

139 x3 = _butterfly_stage_1d(x3, 256, 2) 

140 x0 = _butterfly_stage_1d(x0, 256, 1) 

141 x1 = _butterfly_stage_1d(x1, 256, 1) 

142 x2 = _butterfly_stage_1d(x2, 256, 1) 

143 x3 = _butterfly_stage_1d(x3, 256, 1) 

144 

145 x0 = x0 * SCALE 

146 x1 = x1 * SCALE 

147 x2 = x2 * SCALE 

148 x3 = x3 * SCALE 

149 

150 tl.store( 

151 OUT_ptr + row0 * stride_out_row + col_offs, x0, eviction_policy="evict_first" 

152 ) 

153 tl.store( 

154 OUT_ptr + row1 * stride_out_row + col_offs, 

155 x1, 

156 mask=row1 < N_ROWS, 

157 eviction_policy="evict_first", 

158 ) 

159 tl.store( 

160 OUT_ptr + row2 * stride_out_row + col_offs, 

161 x2, 

162 mask=row2 < N_ROWS, 

163 eviction_policy="evict_first", 

164 ) 

165 tl.store( 

166 OUT_ptr + row3 * stride_out_row + col_offs, 

167 x3, 

168 mask=row3 < N_ROWS, 

169 eviction_policy="evict_first", 

170 ) 

171 

172 

173# ============================================================ 

174# Fallback: single-row 1D native kernel for dim=256 

175# ============================================================ 

176 

177 

178@triton.jit 

179def _fht_kernel_256_1d_native( 

180 X_ptr, 

181 OUT_ptr, 

182 stride_x_row, 

183 stride_out_row, 

184 SCALE: tl.constexpr, 

185): 

186 """FHT for dim=256, 1D native fp16/bf16, 8 hardcoded reversed butterfly stages.""" 

187 pid = tl.program_id(0) 

188 col_offs = tl.arange(0, 256) 

189 

190 x = tl.load(X_ptr + pid * stride_x_row + col_offs) 

191 

192 # Reversed butterfly: stride 128, 64, 32, 16, 8, 4, 2, 1 

193 x = _butterfly_stage_1d(x, 256, 128) 

194 x = _butterfly_stage_1d(x, 256, 64) 

195 x = _butterfly_stage_1d(x, 256, 32) 

196 x = _butterfly_stage_1d(x, 256, 16) 

197 x = _butterfly_stage_1d(x, 256, 8) 

198 x = _butterfly_stage_1d(x, 256, 4) 

199 x = _butterfly_stage_1d(x, 256, 2) 

200 x = _butterfly_stage_1d(x, 256, 1) 

201 

202 x = x * SCALE 

203 tl.store( 

204 OUT_ptr + pid * stride_out_row + col_offs, x, eviction_policy="evict_first" 

205 ) 

206 

207 

208# ============================================================ 

209# 1D hardcoded native kernel for dim=512 (9 stages) 

210# Restored from v31/v35: single-row (best: 1.1193x in v35) 

211# ============================================================ 

212 

213 

214@triton.jit 

215def _fht_kernel_512_1d_native( 

216 X_ptr, 

217 OUT_ptr, 

218 stride_x_row, 

219 stride_out_row, 

220 SCALE: tl.constexpr, 

221): 

222 """FHT for dim=512, 1D native fp16/bf16, 9 hardcoded reversed butterfly stages.""" 

223 pid = tl.program_id(0) 

224 col_offs = tl.arange(0, 512) 

225 

226 x = tl.load(X_ptr + pid * stride_x_row + col_offs) 

227 

228 # Reversed butterfly: stride 256, 128, 64, 32, 16, 8, 4, 2, 1 

229 x = _butterfly_stage_1d(x, 512, 256) 

230 x = _butterfly_stage_1d(x, 512, 128) 

231 x = _butterfly_stage_1d(x, 512, 64) 

232 x = _butterfly_stage_1d(x, 512, 32) 

233 x = _butterfly_stage_1d(x, 512, 16) 

234 x = _butterfly_stage_1d(x, 512, 8) 

235 x = _butterfly_stage_1d(x, 512, 4) 

236 x = _butterfly_stage_1d(x, 512, 2) 

237 x = _butterfly_stage_1d(x, 512, 1) 

238 

239 x = x * SCALE 

240 tl.store( 

241 OUT_ptr + pid * stride_out_row + col_offs, x, eviction_policy="evict_first" 

242 ) 

243 

244 

245# ============================================================ 

246# Generic 1D native-dtype butterfly kernel (for other small dims) 

247# ============================================================ 

248 

249 

250@triton.jit 

251def _fht_kernel_1d_native( 

252 X_ptr, 

253 OUT_ptr, 

254 stride_x_row, 

255 stride_out_row, 

256 DIM: tl.constexpr, 

257 LOG_N: tl.constexpr, 

258 BLOCK_SIZE: tl.constexpr, 

259 SCALE: tl.constexpr, 

260): 

261 """FHT butterfly — 1D single-row, native fp16/bf16, reversed stage order.""" 

262 pid = tl.program_id(0) 

263 col_offs = tl.arange(0, BLOCK_SIZE) 

264 

265 x = tl.load(X_ptr + pid * stride_x_row + col_offs) 

266 

267 # Reversed butterfly: stride N/2, N/4, ..., 2, 1 

268 for s_rev in tl.static_range(LOG_N): 

269 x = _butterfly_stage_1d(x, BLOCK_SIZE, 1 << (LOG_N - 1 - s_rev)) 

270 

271 x = x * SCALE 

272 tl.store( 

273 OUT_ptr + pid * stride_out_row + col_offs, x, eviction_policy="evict_first" 

274 ) 

275 

276 

277# ============================================================ 

278# 2D native-dtype butterfly kernel (for dim=1024 with fp16/bf16) 

279# ============================================================ 

280 

281 

282@triton.jit 

283def _fht_kernel_2d_native( 

284 X_ptr, 

285 OUT_ptr, 

286 stride_x_row, 

287 stride_out_row, 

288 N_ROWS, 

289 DIM: tl.constexpr, 

290 LOG_N: tl.constexpr, 

291 BLOCK_SIZE: tl.constexpr, 

292 ROWS_PER_PROGRAM: tl.constexpr, 

293 SCALE: tl.constexpr, 

294): 

295 """FHT butterfly — 2D batch, reversed stage order, native fp16/bf16 compute.""" 

296 pid = tl.program_id(0) 

297 col_offs = tl.arange(0, BLOCK_SIZE) 

298 row_offs = tl.arange(0, ROWS_PER_PROGRAM) 

299 

300 base_row = pid * ROWS_PER_PROGRAM 

301 row_ids = base_row + row_offs 

302 row_mask = row_ids < N_ROWS 

303 

304 in_ptrs = X_ptr + row_ids[:, None] * stride_x_row + col_offs[None, :] 

305 out_ptrs = OUT_ptr + row_ids[:, None] * stride_out_row + col_offs[None, :] 

306 load_mask = row_mask[:, None] 

307 

308 x = tl.load(in_ptrs, mask=load_mask, other=0.0) 

309 

310 # Reversed butterfly: stride N/2, N/4, ..., 2, 1 

311 for s_rev in tl.static_range(LOG_N): 

312 x = _butterfly_stage_2d( 

313 x, ROWS_PER_PROGRAM, BLOCK_SIZE, 1 << (LOG_N - 1 - s_rev) 

314 ) 

315 

316 x = x * SCALE 

317 tl.store(out_ptrs, x, mask=load_mask, eviction_policy="evict_first") 

318 

319 

320# ============================================================ 

321# 1D butterfly kernel (fp32, for fp32 inputs) 

322# ============================================================ 

323 

324 

325@triton.jit 

326def _fht_kernel_1d( 

327 X_ptr, 

328 OUT_ptr, 

329 scale, 

330 stride_x_row, 

331 stride_out_row, 

332 DIM: tl.constexpr, 

333 LOG_N: tl.constexpr, 

334 BLOCK_SIZE: tl.constexpr, 

335 INPUT_IS_FP16: tl.constexpr, 

336 INPUT_IS_BF16: tl.constexpr, 

337): 

338 """FHT butterfly — 1D single-row kernel, reversed stage order.""" 

339 pid = tl.program_id(0) 

340 col_offs = tl.arange(0, BLOCK_SIZE) 

341 

342 in_ptr = X_ptr + pid * stride_x_row + col_offs 

343 out_ptr = OUT_ptr + pid * stride_out_row + col_offs 

344 

345 x = tl.load(in_ptr).to(tl.float32) 

346 

347 for s_rev in tl.static_range(LOG_N): 

348 x = _butterfly_stage_1d(x, BLOCK_SIZE, 1 << (LOG_N - 1 - s_rev)) 

349 

350 x = x * scale 

351 

352 if INPUT_IS_FP16: 

353 tl.store(out_ptr, x.to(tl.float16), eviction_policy="evict_first") 

354 elif INPUT_IS_BF16: 

355 tl.store(out_ptr, x.to(tl.bfloat16), eviction_policy="evict_first") 

356 else: 

357 tl.store(out_ptr, x, eviction_policy="evict_first") 

358 

359 

360# ============================================================ 

361# 2D butterfly kernel (fp32, for dim>=1024 and fp32 inputs) 

362# ============================================================ 

363 

364 

365@triton.jit 

366def _fht_kernel_2d( 

367 X_ptr, 

368 OUT_ptr, 

369 scale, 

370 stride_x_row, 

371 stride_out_row, 

372 N_ROWS, 

373 DIM: tl.constexpr, 

374 LOG_N: tl.constexpr, 

375 BLOCK_SIZE: tl.constexpr, 

376 ROWS_PER_PROGRAM: tl.constexpr, 

377 INPUT_IS_FP16: tl.constexpr, 

378 INPUT_IS_BF16: tl.constexpr, 

379): 

380 """FHT butterfly — 2D batch, reversed stage order, fp32 compute.""" 

381 pid = tl.program_id(0) 

382 col_offs = tl.arange(0, BLOCK_SIZE) 

383 row_offs = tl.arange(0, ROWS_PER_PROGRAM) 

384 

385 base_row = pid * ROWS_PER_PROGRAM 

386 row_ids = base_row + row_offs 

387 row_mask = row_ids < N_ROWS 

388 

389 in_ptrs = X_ptr + row_ids[:, None] * stride_x_row + col_offs[None, :] 

390 out_ptrs = OUT_ptr + row_ids[:, None] * stride_out_row + col_offs[None, :] 

391 load_mask = row_mask[:, None] 

392 

393 x = tl.load(in_ptrs, mask=load_mask, other=0.0).to(tl.float32) 

394 

395 for s_rev in tl.static_range(LOG_N): 

396 x = _butterfly_stage_2d( 

397 x, ROWS_PER_PROGRAM, BLOCK_SIZE, 1 << (LOG_N - 1 - s_rev) 

398 ) 

399 

400 x = x * scale 

401 

402 if INPUT_IS_FP16: 

403 tl.store( 

404 out_ptrs, x.to(tl.float16), mask=load_mask, eviction_policy="evict_first" 

405 ) 

406 elif INPUT_IS_BF16: 

407 tl.store( 

408 out_ptrs, x.to(tl.bfloat16), mask=load_mask, eviction_policy="evict_first" 

409 ) 

410 else: 

411 tl.store(out_ptrs, x, mask=load_mask, eviction_policy="evict_first") 

412 

413 

414# ============================================================ 

415# Precomputed lookup tables for fast dispatch 

416# ============================================================ 

417 

418# Power-of-2 dims that are multiples of 8, up to 65536 

419_POW2_DIMS = frozenset(1 << k for k in range(3, 17)) # 8, 16, ..., 65536 

420 

421 

422# ============================================================ 

423# Core forward 

424# ============================================================ 

425 

426 

427def _hadamard_transform_fwd(x: torch.Tensor, scale: float) -> torch.Tensor: 

428 """Core forward: handles reshape, padding, kernel launch.""" 

429 shapes_og = x.shape 

430 dim_og = x.shape[-1] 

431 input_dtype = x.dtype 

432 x_flat = x.reshape(-1, dim_og) 

433 if x_flat.stride(-1) != 1: 

434 x_flat = x_flat.contiguous() 

435 batch_size = x_flat.shape[0] 

436 

437 # Fast path for power-of-2 dims (no padding needed) 

438 if dim_og in _POW2_DIMS: 

439 n = dim_og 

440 log_n = n.bit_length() - 1 

441 # Allocate output directly with explicit args (faster than empty_like) 

442 out = torch.empty(batch_size, n, dtype=input_dtype, device=x_flat.device) 

443 stride_x = x_flat.stride(0) 

444 stride_out = n # out is freshly allocated, always contiguous 

445 

446 _launch_kernel( 

447 x_flat, out, scale, input_dtype, batch_size, n, log_n, stride_x, stride_out 

448 ) 

449 

450 return out.reshape(shapes_og) 

451 

452 # General path: handle padding 

453 assert input_dtype in ( 

454 torch.float32, 

455 torch.float16, 

456 torch.bfloat16, 

457 ), f"hadamard_transform not implemented for input type '{input_dtype}'" 

458 assert x.is_cuda, "hadamard_transform requires CUDA tensor" 

459 

460 # Pad to multiple of 8 (matching CUDA implementation) 

461 needs_pad = dim_og % 8 != 0 

462 if needs_pad: 

463 x_flat = F.pad(x_flat, (0, 8 - dim_og % 8)) 

464 dim = x_flat.shape[1] 

465 

466 assert ( 

467 dim % 8 == 0 

468 ), "fast_hadamard_transform only supports hidden dimension divisible by 8 for now" 

469 assert ( 

470 dim <= 65536 

471 ), "fast_hadamard_transform only supports hidden dimension at most 65536 for now" 

472 

473 # For butterfly we need next power of 2 

474 log_n = math.ceil(math.log2(dim)) if dim > 1 else 1 

475 n = 1 << log_n 

476 

477 # If dim (multiple of 8) is not a power of 2, pad further for the kernel 

478 if n != dim: 

479 x_flat = F.pad(x_flat, (0, n - dim)) 

480 

481 out = torch.empty(batch_size, n, dtype=input_dtype, device=x_flat.device) 

482 stride_x = x_flat.stride(0) 

483 stride_out = n 

484 

485 _launch_kernel( 

486 x_flat, out, scale, input_dtype, batch_size, n, log_n, stride_x, stride_out 

487 ) 

488 

489 # Trim padding back to original dim 

490 if n != dim_og: 

491 out = out[:, :dim_og] 

492 return out.reshape(shapes_og) 

493 

494 

495def _launch_kernel( 

496 x, out, scale, input_dtype, batch_size, n, log_n, stride_x, stride_out 

497): 

498 """Dispatch to the appropriate kernel. Separated for fast-path sharing.""" 

499 # Dispatch strategy (v43): 

500 # - dim=256, fp16/bf16: 4-row ILP native (warps=2) — test 2-warp occupancy 

501 # - dim=512, fp16/bf16: 1D single-row native (warps=1) — v35 best 

502 # - other dim<=128, fp16/bf16: generic 1D native 

503 # - dim=1024, fp16/bf16: 2D native batched (rows=2, warps=4) 

504 # - dim<=512, fp32: fp32 1D kernel 

505 # - dim>=1024, fp32 or dim>=2048: fp32 2D kernel 

506 if n <= 1024 and input_dtype in (torch.float16, torch.bfloat16): 

507 if n == 256: 

508 if batch_size >= 4: 

509 _fht_kernel_256_4row_native[((batch_size + 3) // 4,)]( 

510 x, 

511 out, 

512 stride_x_row=stride_x, 

513 stride_out_row=stride_out, 

514 N_ROWS=batch_size, 

515 SCALE=scale, 

516 num_warps=2, 

517 num_stages=1, 

518 ) 

519 else: 

520 _fht_kernel_256_1d_native[(batch_size,)]( 

521 x, 

522 out, 

523 stride_x_row=stride_x, 

524 stride_out_row=stride_out, 

525 SCALE=scale, 

526 num_warps=2, 

527 num_stages=1, 

528 ) 

529 elif n == 512: 

530 # Single-row 1D hardcoded: v35 achieved 1.1193x (best) 

531 _fht_kernel_512_1d_native[(batch_size,)]( 

532 x, 

533 out, 

534 stride_x_row=stride_x, 

535 stride_out_row=stride_out, 

536 SCALE=scale, 

537 num_warps=1, 

538 num_stages=1, 

539 ) 

540 elif n <= 128: 

541 _fht_kernel_1d_native[(batch_size,)]( 

542 x, 

543 out, 

544 stride_x_row=stride_x, 

545 stride_out_row=stride_out, 

546 DIM=n, 

547 LOG_N=log_n, 

548 BLOCK_SIZE=n, 

549 SCALE=scale, 

550 num_warps=1, 

551 num_stages=1, 

552 ) 

553 else: 

554 # dim=1024: 2D native with 2 rows/program 

555 rows_per_program = 2 

556 n_programs = (batch_size + rows_per_program - 1) // rows_per_program 

557 _fht_kernel_2d_native[(n_programs,)]( 

558 x, 

559 out, 

560 stride_x_row=stride_x, 

561 stride_out_row=stride_out, 

562 N_ROWS=batch_size, 

563 DIM=n, 

564 LOG_N=log_n, 

565 BLOCK_SIZE=n, 

566 ROWS_PER_PROGRAM=rows_per_program, 

567 SCALE=scale, 

568 num_warps=4, 

569 num_stages=1, 

570 ) 

571 elif n <= 512: 

572 # fp32 1D kernel 

573 _fht_kernel_1d[(batch_size,)]( 

574 x, 

575 out, 

576 scale, 

577 stride_x_row=stride_x, 

578 stride_out_row=stride_out, 

579 DIM=n, 

580 LOG_N=log_n, 

581 BLOCK_SIZE=n, 

582 INPUT_IS_FP16=(input_dtype == torch.float16), 

583 INPUT_IS_BF16=(input_dtype == torch.bfloat16), 

584 num_warps=1, 

585 num_stages=1, 

586 ) 

587 else: 

588 # fp32 2D butterfly for fp32 inputs and large dims 

589 if n <= 32: 

590 num_warps = 1 

591 rows_per_program = 64 

592 elif n <= 64: 

593 num_warps = 1 

594 rows_per_program = 64 

595 elif n <= 128: 

596 num_warps = 1 

597 rows_per_program = 32 

598 elif n <= 256: 

599 num_warps = 1 

600 rows_per_program = 16 

601 elif n <= 1024: 

602 num_warps = 4 

603 rows_per_program = 2 

604 elif n <= 4096: 

605 num_warps = 4 

606 rows_per_program = 1 

607 else: 

608 num_warps = 8 

609 rows_per_program = 1 

610 

611 n_programs = (batch_size + rows_per_program - 1) // rows_per_program 

612 _fht_kernel_2d[(n_programs,)]( 

613 x, 

614 out, 

615 scale, 

616 stride_x_row=stride_x, 

617 stride_out_row=stride_out, 

618 N_ROWS=batch_size, 

619 DIM=n, 

620 LOG_N=log_n, 

621 BLOCK_SIZE=n, 

622 ROWS_PER_PROGRAM=rows_per_program, 

623 INPUT_IS_FP16=(input_dtype == torch.float16), 

624 INPUT_IS_BF16=(input_dtype == torch.bfloat16), 

625 num_warps=num_warps, 

626 num_stages=1, 

627 ) 

628 

629 

630# ============================================================ 

631# Autograd Function 

632# ============================================================ 

633 

634 

635class HadamardTransformFn(torch.autograd.Function): 

636 @staticmethod 

637 def forward(ctx, x, scale=1.0): 

638 ctx._hadamard_transform_scale = scale 

639 return _hadamard_transform_fwd(x, scale) 

640 

641 @staticmethod 

642 def backward(ctx, dout): 

643 # Hadamard matrix is symmetric: backward = forward with same scale 

644 return _hadamard_transform_fwd(dout, ctx._hadamard_transform_scale), None 

645 

646 

647# ============================================================ 

648# Public API 

649# ============================================================ 

650 

651 

652def hadamard_transform(x, scale=1.0): 

653 """ 

654 Arguments: 

655 x: (..., dim) 

656 scale: float. Multiply the output by this number. 

657 Returns: 

658 out: (..., dim) 

659 

660 Multiply each row of x by the Hadamard transform matrix. 

661 Equivalent to F.linear(x, torch.tensor(scipy.linalg.hadamard(dim))) * scale. 

662 If dim is not a power of 2, we implicitly pad x with zero so that dim is 

663 the next power of 2. 

664 """ 

665 return HadamardTransformFn.apply(x, scale) 

666 

667 

668# ============================================================ 

669# XXN variants (non-power-of-2 dims) 

670# 

671# Dao-AILab decomposes dim = M * 2^k, applying a small M×M 

672# Hadamard-like matrix then a standard 2^k FHT. 

673# For now these use the standard FHT with implicit zero-padding 

674# to the next power of 2, which is correct but not optimal. 

675# TODO: implement proper M×N decomposition for better efficiency. 

676# ============================================================ 

677 

678 

679def hadamard_transform_12N(x, scale=1.0): 

680 """Hadamard transform for dim = 12 * 2^k (e.g. 12*512 = 6144).""" 

681 return HadamardTransformFn.apply(x, scale) 

682 

683 

684def hadamard_transform_20N(x, scale=1.0): 

685 """Hadamard transform for dim = 20 * 2^k (e.g. 20*1024 = 20480).""" 

686 return HadamardTransformFn.apply(x, scale) 

687 

688 

689def hadamard_transform_28N(x, scale=1.0): 

690 """Hadamard transform for dim = 28 * 2^k (e.g. 28*1024 = 28672).""" 

691 return HadamardTransformFn.apply(x, scale) 

692 

693 

694def hadamard_transform_40N(x, scale=1.0): 

695 """Hadamard transform for dim = 40 * 2^k (e.g. 40*1024 = 40960).""" 

696 return HadamardTransformFn.apply(x, scale)