Coverage for src/flag_gems/ops/hadamard_transform.py: 10%

412 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1"""Fast Hadamard Transform in Triton. 

2 

3Drop-in replacement for Dao-AILab/fast-hadamard-transform with identical interface: 

4 - hadamard_transform(x, scale=1.0) with autograd support 

5 - hadamard_transform_12N/20N/28N/40N(x, scale=1.0) for non-power-of-2 dims 

6 - Input: (..., dim), fp32/fp16/bf16 

7 - Output: (..., dim), same dtype as input 

8 - Padding: to next multiple of 8 (matching CUDA impl) 

9 - dim <= 32768 (standard), dim <= M*2^10 (XXN variants) 

10 

11Reference: https://github.com/Dao-AILab/fast-hadamard-transform 

12""" 

13 

14import math 

15 

16import torch 

17import torch.nn.functional as F 

18import triton 

19import triton.language as tl 

20 

21# ============================================================ 

22# Triton kernel — v43: remove evict_first from loads + warps=2 for dim=256 

23# ============================================================ 

24# v35 best: dim=256 0.9302x (no evict_first on loads, warps=1) 

25# v42: dim=256 0.8950x (evict_first on loads hurt — L2 thrashing) 

26# 

27# v43 strategy: 

28# 1. Remove evict_first from all loads. v42 proved it hurts dim=256 

29# (0.8950x vs v35's 0.9302x). The 256-element rows (512B fp16) 

30# are small enough that L2 caching of nearby rows helps prefetch. 

31# 2. Try num_warps=2 for dim=256 4-row ILP kernel. With 4 rows of 

32# 256 elements each, the workload can benefit from 2-warp occupancy: 

33# each warp handles the compute for its assigned instructions, 

34# and the scheduler can overlap loads from one warp with compute 

35# from the other. This targets the memory latency hiding gap. 

36# 3. Keep evict_first on stores (write-once streaming pattern). 

37# 4. Keep all other kernels unchanged from v42 baseline. 

38 

39 

40# ============================================================ 

41# Butterfly stages 

42# ============================================================ 

43 

44 

45@triton.jit 

46def _butterfly_stage_1d(x, BLOCK_SIZE: tl.constexpr, STRIDE: tl.constexpr): 

47 """One butterfly stage on a 1D vector.""" 

48 GRP: tl.constexpr = BLOCK_SIZE // (2 * STRIDE) 

49 if STRIDE == 1: 

50 x2 = tl.reshape(x, (GRP, 2)) 

51 a, b = tl.split(x2) 

52 return tl.reshape(tl.join(a + b, a - b), (BLOCK_SIZE,)) 

53 else: 

54 x3 = tl.reshape(x, (GRP, 2, STRIDE)) 

55 x3 = tl.permute(x3, (0, 2, 1)) 

56 a, b = tl.split(x3) 

57 x3 = tl.join(a + b, a - b) 

58 x3 = tl.permute(x3, (0, 2, 1)) 

59 return tl.reshape(x3, (BLOCK_SIZE,)) 

60 

61 

62@triton.jit 

63def _butterfly_stage_2d( 

64 x, ROWS: tl.constexpr, BLOCK_SIZE: tl.constexpr, STRIDE: tl.constexpr 

65): 

66 """One butterfly stage on a 2D (ROWS, BLOCK_SIZE) tensor.""" 

67 GRP: tl.constexpr = BLOCK_SIZE // (2 * STRIDE) 

68 if STRIDE == 1: 

69 x2 = tl.reshape(x, (ROWS, GRP, 2)) 

70 a, b = tl.split(x2) 

71 return tl.reshape(tl.join(a + b, a - b), (ROWS, BLOCK_SIZE)) 

72 else: 

73 x3 = tl.reshape(x, (ROWS, GRP, 2, STRIDE)) 

74 x3 = tl.permute(x3, (0, 1, 3, 2)) 

75 a, b = tl.split(x3) 

76 x3 = tl.join(a + b, a - b) 

77 x3 = tl.permute(x3, (0, 1, 3, 2)) 

78 return tl.reshape(x3, (ROWS, BLOCK_SIZE)) 

79 

80 

81# ============================================================ 

82# 4-row ILP 1D native kernel for dim=256 (8 hardcoded stages) 

83# v43: remove evict_first from loads, keep on stores 

84# ============================================================ 

85 

86 

87@triton.jit 

88def _fht_kernel_256_4row_native( 

89 X_ptr, 

90 OUT_ptr, 

91 stride_x_row, 

92 stride_out_row, 

93 N_ROWS, 

94 SCALE: tl.constexpr, 

95): 

96 """FHT for dim=256, 4-row ILP: four independent 1D butterflies per program.""" 

97 pid = tl.program_id(0) 

98 col_offs = tl.arange(0, 256) 

99 

100 row0 = pid * 4 

101 row1 = row0 + 1 

102 row2 = row0 + 2 

103 row3 = row0 + 3 

104 

105 # Load all 4 rows (no evict_first: L2 caching helps for nearby rows) 

106 x0 = tl.load(X_ptr + row0 * stride_x_row + col_offs) 

107 x1 = tl.load(X_ptr + row1 * stride_x_row + col_offs, mask=row1 < N_ROWS, other=0.0) 

108 x2 = tl.load(X_ptr + row2 * stride_x_row + col_offs, mask=row2 < N_ROWS, other=0.0) 

109 x3 = tl.load(X_ptr + row3 * stride_x_row + col_offs, mask=row3 < N_ROWS, other=0.0) 

110 

111 # Interleaved hardcoded reversed butterfly stages for 4-way ILP 

112 x0 = _butterfly_stage_1d(x0, 256, 128) 

113 x1 = _butterfly_stage_1d(x1, 256, 128) 

114 x2 = _butterfly_stage_1d(x2, 256, 128) 

115 x3 = _butterfly_stage_1d(x3, 256, 128) 

116 x0 = _butterfly_stage_1d(x0, 256, 64) 

117 x1 = _butterfly_stage_1d(x1, 256, 64) 

118 x2 = _butterfly_stage_1d(x2, 256, 64) 

119 x3 = _butterfly_stage_1d(x3, 256, 64) 

120 x0 = _butterfly_stage_1d(x0, 256, 32) 

121 x1 = _butterfly_stage_1d(x1, 256, 32) 

122 x2 = _butterfly_stage_1d(x2, 256, 32) 

123 x3 = _butterfly_stage_1d(x3, 256, 32) 

124 x0 = _butterfly_stage_1d(x0, 256, 16) 

125 x1 = _butterfly_stage_1d(x1, 256, 16) 

126 x2 = _butterfly_stage_1d(x2, 256, 16) 

127 x3 = _butterfly_stage_1d(x3, 256, 16) 

128 x0 = _butterfly_stage_1d(x0, 256, 8) 

129 x1 = _butterfly_stage_1d(x1, 256, 8) 

130 x2 = _butterfly_stage_1d(x2, 256, 8) 

131 x3 = _butterfly_stage_1d(x3, 256, 8) 

132 x0 = _butterfly_stage_1d(x0, 256, 4) 

133 x1 = _butterfly_stage_1d(x1, 256, 4) 

134 x2 = _butterfly_stage_1d(x2, 256, 4) 

135 x3 = _butterfly_stage_1d(x3, 256, 4) 

136 x0 = _butterfly_stage_1d(x0, 256, 2) 

137 x1 = _butterfly_stage_1d(x1, 256, 2) 

138 x2 = _butterfly_stage_1d(x2, 256, 2) 

139 x3 = _butterfly_stage_1d(x3, 256, 2) 

140 x0 = _butterfly_stage_1d(x0, 256, 1) 

141 x1 = _butterfly_stage_1d(x1, 256, 1) 

142 x2 = _butterfly_stage_1d(x2, 256, 1) 

143 x3 = _butterfly_stage_1d(x3, 256, 1) 

144 

145 x0 = x0 * SCALE 

146 x1 = x1 * SCALE 

147 x2 = x2 * SCALE 

148 x3 = x3 * SCALE 

149 

150 tl.store( 

151 OUT_ptr + row0 * stride_out_row + col_offs, x0, eviction_policy="evict_first" 

152 ) 

153 tl.store( 

154 OUT_ptr + row1 * stride_out_row + col_offs, 

155 x1, 

156 mask=row1 < N_ROWS, 

157 eviction_policy="evict_first", 

158 ) 

159 tl.store( 

160 OUT_ptr + row2 * stride_out_row + col_offs, 

161 x2, 

162 mask=row2 < N_ROWS, 

163 eviction_policy="evict_first", 

164 ) 

165 tl.store( 

166 OUT_ptr + row3 * stride_out_row + col_offs, 

167 x3, 

168 mask=row3 < N_ROWS, 

169 eviction_policy="evict_first", 

170 ) 

171 

172 

173# ============================================================ 

174# Fallback: single-row 1D native kernel for dim=256 

175# ============================================================ 

176 

177 

178@triton.jit 

179def _fht_kernel_256_1d_native( 

180 X_ptr, 

181 OUT_ptr, 

182 stride_x_row, 

183 stride_out_row, 

184 SCALE: tl.constexpr, 

185): 

186 """FHT for dim=256, 1D native fp16/bf16, 8 hardcoded reversed butterfly stages.""" 

187 pid = tl.program_id(0) 

188 col_offs = tl.arange(0, 256) 

189 

190 x = tl.load(X_ptr + pid * stride_x_row + col_offs) 

191 

192 # Reversed butterfly: stride 128, 64, 32, 16, 8, 4, 2, 1 

193 x = _butterfly_stage_1d(x, 256, 128) 

194 x = _butterfly_stage_1d(x, 256, 64) 

195 x = _butterfly_stage_1d(x, 256, 32) 

196 x = _butterfly_stage_1d(x, 256, 16) 

197 x = _butterfly_stage_1d(x, 256, 8) 

198 x = _butterfly_stage_1d(x, 256, 4) 

199 x = _butterfly_stage_1d(x, 256, 2) 

200 x = _butterfly_stage_1d(x, 256, 1) 

201 

202 x = x * SCALE 

203 tl.store( 

204 OUT_ptr + pid * stride_out_row + col_offs, x, eviction_policy="evict_first" 

205 ) 

206 

207 

208# ============================================================ 

209# 1D hardcoded native kernel for dim=512 (9 stages) 

210# Restored from v31/v35: single-row (best: 1.1193x in v35) 

211# ============================================================ 

212 

213 

214@triton.jit 

215def _fht_kernel_512_1d_native( 

216 X_ptr, 

217 OUT_ptr, 

218 stride_x_row, 

219 stride_out_row, 

220 SCALE: tl.constexpr, 

221): 

222 """FHT for dim=512, 1D native fp16/bf16, 9 hardcoded reversed butterfly stages.""" 

223 pid = tl.program_id(0) 

224 col_offs = tl.arange(0, 512) 

225 

226 x = tl.load(X_ptr + pid * stride_x_row + col_offs) 

227 

228 # Reversed butterfly: stride 256, 128, 64, 32, 16, 8, 4, 2, 1 

229 x = _butterfly_stage_1d(x, 512, 256) 

230 x = _butterfly_stage_1d(x, 512, 128) 

231 x = _butterfly_stage_1d(x, 512, 64) 

232 x = _butterfly_stage_1d(x, 512, 32) 

233 x = _butterfly_stage_1d(x, 512, 16) 

234 x = _butterfly_stage_1d(x, 512, 8) 

235 x = _butterfly_stage_1d(x, 512, 4) 

236 x = _butterfly_stage_1d(x, 512, 2) 

237 x = _butterfly_stage_1d(x, 512, 1) 

238 

239 x = x * SCALE 

240 tl.store( 

241 OUT_ptr + pid * stride_out_row + col_offs, x, eviction_policy="evict_first" 

242 ) 

243 

244 

245# ============================================================ 

246# Generic 1D native-dtype butterfly kernel (for other small dims) 

247# ============================================================ 

248 

249 

250@triton.jit 

251def _fht_kernel_1d_native( 

252 X_ptr, 

253 OUT_ptr, 

254 stride_x_row, 

255 stride_out_row, 

256 DIM: tl.constexpr, 

257 LOG_N: tl.constexpr, 

258 BLOCK_SIZE: tl.constexpr, 

259 SCALE: tl.constexpr, 

260): 

261 """FHT butterfly — 1D single-row, native fp16/bf16, reversed stage order.""" 

262 pid = tl.program_id(0) 

263 col_offs = tl.arange(0, BLOCK_SIZE) 

264 

265 x = tl.load(X_ptr + pid * stride_x_row + col_offs) 

266 

267 # Reversed butterfly: stride N/2, N/4, ..., 2, 1 

268 for s_rev in tl.static_range(LOG_N): 

269 x = _butterfly_stage_1d(x, BLOCK_SIZE, 1 << (LOG_N - 1 - s_rev)) 

270 

271 x = x * SCALE 

272 tl.store( 

273 OUT_ptr + pid * stride_out_row + col_offs, x, eviction_policy="evict_first" 

274 ) 

275 

276 

277# ============================================================ 

278# 2D native-dtype butterfly kernel (for dim=1024 with fp16/bf16) 

279# ============================================================ 

280 

281 

282@triton.jit 

283def _fht_kernel_2d_native( 

284 X_ptr, 

285 OUT_ptr, 

286 stride_x_row, 

287 stride_out_row, 

288 N_ROWS, 

289 DIM: tl.constexpr, 

290 LOG_N: tl.constexpr, 

291 BLOCK_SIZE: tl.constexpr, 

292 ROWS_PER_PROGRAM: tl.constexpr, 

293 SCALE: tl.constexpr, 

294): 

295 """FHT butterfly — 2D batch, reversed stage order, native fp16/bf16 compute.""" 

296 pid = tl.program_id(0) 

297 col_offs = tl.arange(0, BLOCK_SIZE) 

298 row_offs = tl.arange(0, ROWS_PER_PROGRAM) 

299 

300 base_row = pid * ROWS_PER_PROGRAM 

301 row_ids = base_row + row_offs 

302 row_mask = row_ids < N_ROWS 

303 

304 in_ptrs = X_ptr + row_ids[:, None] * stride_x_row + col_offs[None, :] 

305 out_ptrs = OUT_ptr + row_ids[:, None] * stride_out_row + col_offs[None, :] 

306 load_mask = row_mask[:, None] 

307 

308 x = tl.load(in_ptrs, mask=load_mask, other=0.0) 

309 

310 # Reversed butterfly: stride N/2, N/4, ..., 2, 1 

311 for s_rev in tl.static_range(LOG_N): 

312 x = _butterfly_stage_2d( 

313 x, ROWS_PER_PROGRAM, BLOCK_SIZE, 1 << (LOG_N - 1 - s_rev) 

314 ) 

315 

316 x = x * SCALE 

317 tl.store(out_ptrs, x, mask=load_mask, eviction_policy="evict_first") 

318 

319 

320# ============================================================ 

321# 1D butterfly kernel (fp32, for fp32 inputs) 

322# ============================================================ 

323 

324 

325@triton.jit 

326def _fht_kernel_1d( 

327 X_ptr, 

328 OUT_ptr, 

329 scale, 

330 stride_x_row, 

331 stride_out_row, 

332 DIM: tl.constexpr, 

333 LOG_N: tl.constexpr, 

334 BLOCK_SIZE: tl.constexpr, 

335 INPUT_IS_FP16: tl.constexpr, 

336 INPUT_IS_BF16: tl.constexpr, 

337): 

338 """FHT butterfly — 1D single-row kernel, reversed stage order.""" 

339 pid = tl.program_id(0) 

340 col_offs = tl.arange(0, BLOCK_SIZE) 

341 

342 in_ptr = X_ptr + pid * stride_x_row + col_offs 

343 out_ptr = OUT_ptr + pid * stride_out_row + col_offs 

344 

345 x = tl.load(in_ptr).to(tl.float32) 

346 

347 for s_rev in tl.static_range(LOG_N): 

348 x = _butterfly_stage_1d(x, BLOCK_SIZE, 1 << (LOG_N - 1 - s_rev)) 

349 

350 x = x * scale 

351 

352 if INPUT_IS_FP16: 

353 tl.store(out_ptr, x.to(tl.float16), eviction_policy="evict_first") 

354 elif INPUT_IS_BF16: 

355 tl.store(out_ptr, x.to(tl.bfloat16), eviction_policy="evict_first") 

356 else: 

357 tl.store(out_ptr, x, eviction_policy="evict_first") 

358 

359 

360# ============================================================ 

361# 2D butterfly kernel (fp32, for dim>=1024 and fp32 inputs) 

362# ============================================================ 

363 

364 

365@triton.jit 

366def _fht_kernel_2d( 

367 X_ptr, 

368 OUT_ptr, 

369 scale, 

370 stride_x_row, 

371 stride_out_row, 

372 N_ROWS, 

373 DIM: tl.constexpr, 

374 LOG_N: tl.constexpr, 

375 BLOCK_SIZE: tl.constexpr, 

376 ROWS_PER_PROGRAM: tl.constexpr, 

377 INPUT_IS_FP16: tl.constexpr, 

378 INPUT_IS_BF16: tl.constexpr, 

379): 

380 """FHT butterfly — 2D batch, reversed stage order, fp32 compute.""" 

381 pid = tl.program_id(0) 

382 col_offs = tl.arange(0, BLOCK_SIZE) 

383 row_offs = tl.arange(0, ROWS_PER_PROGRAM) 

384 

385 base_row = pid * ROWS_PER_PROGRAM 

386 row_ids = base_row + row_offs 

387 row_mask = row_ids < N_ROWS 

388 

389 in_ptrs = X_ptr + row_ids[:, None] * stride_x_row + col_offs[None, :] 

390 out_ptrs = OUT_ptr + row_ids[:, None] * stride_out_row + col_offs[None, :] 

391 load_mask = row_mask[:, None] 

392 

393 x = tl.load(in_ptrs, mask=load_mask, other=0.0).to(tl.float32) 

394 

395 for s_rev in tl.static_range(LOG_N): 

396 x = _butterfly_stage_2d( 

397 x, ROWS_PER_PROGRAM, BLOCK_SIZE, 1 << (LOG_N - 1 - s_rev) 

398 ) 

399 

400 x = x * scale 

401 

402 if INPUT_IS_FP16: 

403 tl.store( 

404 out_ptrs, x.to(tl.float16), mask=load_mask, eviction_policy="evict_first" 

405 ) 

406 elif INPUT_IS_BF16: 

407 tl.store( 

408 out_ptrs, x.to(tl.bfloat16), mask=load_mask, eviction_policy="evict_first" 

409 ) 

410 else: 

411 tl.store(out_ptrs, x, mask=load_mask, eviction_policy="evict_first") 

412 

413 

414# ============================================================ 

415# M×N fused kernels: H_M column transform + FHT in registers 

416# No intermediate DRAM write, no padding to next power of 2. 

417# ============================================================ 

418 

419 

420@triton.jit 

421def _h3_fht_kernel( 

422 X_ptr, 

423 OUT_ptr, 

424 stride_batch, 

425 stride_row, 

426 SCALE: tl.constexpr, 

427 IS_FP16: tl.constexpr, 

428 IS_BF16: tl.constexpr, 

429 N_COLS: tl.constexpr, 

430 LOG_N: tl.constexpr, 

431): 

432 pid = tl.program_id(0) 

433 offs = tl.arange(0, N_COLS) 

434 base = pid * stride_batch 

435 a = tl.load(X_ptr + base + 0 * stride_row + offs).to(tl.float32) 

436 b = tl.load(X_ptr + base + 1 * stride_row + offs).to(tl.float32) 

437 c = tl.load(X_ptr + base + 2 * stride_row + offs).to(tl.float32) 

438 y0 = a + b + c 

439 y1 = a - b + c 

440 y2 = a + b - c 

441 for s_rev in tl.static_range(LOG_N): 

442 y0 = _butterfly_stage_1d(y0, N_COLS, 1 << (LOG_N - 1 - s_rev)) 

443 y1 = _butterfly_stage_1d(y1, N_COLS, 1 << (LOG_N - 1 - s_rev)) 

444 y2 = _butterfly_stage_1d(y2, N_COLS, 1 << (LOG_N - 1 - s_rev)) 

445 y0 = y0 * SCALE 

446 y1 = y1 * SCALE 

447 y2 = y2 * SCALE 

448 if IS_FP16: 

449 y0 = y0.to(tl.float16) 

450 y1 = y1.to(tl.float16) 

451 y2 = y2.to(tl.float16) 

452 elif IS_BF16: 

453 y0 = y0.to(tl.bfloat16) 

454 y1 = y1.to(tl.bfloat16) 

455 y2 = y2.to(tl.bfloat16) 

456 tl.store(OUT_ptr + base + 0 * stride_row + offs, y0, eviction_policy="evict_first") 

457 tl.store(OUT_ptr + base + 1 * stride_row + offs, y1, eviction_policy="evict_first") 

458 tl.store(OUT_ptr + base + 2 * stride_row + offs, y2, eviction_policy="evict_first") 

459 

460 

461@triton.jit 

462def _h5_fht_kernel( 

463 X_ptr, 

464 OUT_ptr, 

465 stride_batch, 

466 stride_row, 

467 SCALE: tl.constexpr, 

468 IS_FP16: tl.constexpr, 

469 IS_BF16: tl.constexpr, 

470 N_COLS: tl.constexpr, 

471 LOG_N: tl.constexpr, 

472): 

473 pid = tl.program_id(0) 

474 offs = tl.arange(0, N_COLS) 

475 base = pid * stride_batch 

476 a = tl.load(X_ptr + base + 0 * stride_row + offs).to(tl.float32) 

477 b = tl.load(X_ptr + base + 1 * stride_row + offs).to(tl.float32) 

478 c = tl.load(X_ptr + base + 2 * stride_row + offs).to(tl.float32) 

479 d = tl.load(X_ptr + base + 3 * stride_row + offs).to(tl.float32) 

480 e = tl.load(X_ptr + base + 4 * stride_row + offs).to(tl.float32) 

481 y0 = a + b + c + d + e 

482 y1 = a - b + c - d + e 

483 y2 = a + b - c + d - e 

484 y3 = a - b - c - d - e 

485 y4 = a + b + c - d - e 

486 for s_rev in tl.static_range(LOG_N): 

487 y0 = _butterfly_stage_1d(y0, N_COLS, 1 << (LOG_N - 1 - s_rev)) 

488 y1 = _butterfly_stage_1d(y1, N_COLS, 1 << (LOG_N - 1 - s_rev)) 

489 y2 = _butterfly_stage_1d(y2, N_COLS, 1 << (LOG_N - 1 - s_rev)) 

490 y3 = _butterfly_stage_1d(y3, N_COLS, 1 << (LOG_N - 1 - s_rev)) 

491 y4 = _butterfly_stage_1d(y4, N_COLS, 1 << (LOG_N - 1 - s_rev)) 

492 y0 = y0 * SCALE 

493 y1 = y1 * SCALE 

494 y2 = y2 * SCALE 

495 y3 = y3 * SCALE 

496 y4 = y4 * SCALE 

497 if IS_FP16: 

498 y0 = y0.to(tl.float16) 

499 y1 = y1.to(tl.float16) 

500 y2 = y2.to(tl.float16) 

501 y3 = y3.to(tl.float16) 

502 y4 = y4.to(tl.float16) 

503 elif IS_BF16: 

504 y0 = y0.to(tl.bfloat16) 

505 y1 = y1.to(tl.bfloat16) 

506 y2 = y2.to(tl.bfloat16) 

507 y3 = y3.to(tl.bfloat16) 

508 y4 = y4.to(tl.bfloat16) 

509 tl.store(OUT_ptr + base + 0 * stride_row + offs, y0, eviction_policy="evict_first") 

510 tl.store(OUT_ptr + base + 1 * stride_row + offs, y1, eviction_policy="evict_first") 

511 tl.store(OUT_ptr + base + 2 * stride_row + offs, y2, eviction_policy="evict_first") 

512 tl.store(OUT_ptr + base + 3 * stride_row + offs, y3, eviction_policy="evict_first") 

513 tl.store(OUT_ptr + base + 4 * stride_row + offs, y4, eviction_policy="evict_first") 

514 

515 

516@triton.jit 

517def _h7_fht_kernel( 

518 X_ptr, 

519 OUT_ptr, 

520 stride_batch, 

521 stride_row, 

522 SCALE: tl.constexpr, 

523 IS_FP16: tl.constexpr, 

524 IS_BF16: tl.constexpr, 

525 N_COLS: tl.constexpr, 

526 LOG_N: tl.constexpr, 

527): 

528 pid = tl.program_id(0) 

529 offs = tl.arange(0, N_COLS) 

530 base = pid * stride_batch 

531 a = tl.load(X_ptr + base + 0 * stride_row + offs).to(tl.float32) 

532 b = tl.load(X_ptr + base + 1 * stride_row + offs).to(tl.float32) 

533 c = tl.load(X_ptr + base + 2 * stride_row + offs).to(tl.float32) 

534 d = tl.load(X_ptr + base + 3 * stride_row + offs).to(tl.float32) 

535 e = tl.load(X_ptr + base + 4 * stride_row + offs).to(tl.float32) 

536 f = tl.load(X_ptr + base + 5 * stride_row + offs).to(tl.float32) 

537 g = tl.load(X_ptr + base + 6 * stride_row + offs).to(tl.float32) 

538 y0 = a + b + c + d + e + f + g 

539 y1 = a - b + c - d + e - f + g 

540 y2 = a + b - c + d - e + f - g 

541 y3 = a - b - c - d - e - f - g 

542 y4 = a + b + c - d - e - f - g 

543 y5 = a - b + c + d - e + f + g 

544 y6 = a + b - c - d + e + f - g 

545 for s_rev in tl.static_range(LOG_N): 

546 y0 = _butterfly_stage_1d(y0, N_COLS, 1 << (LOG_N - 1 - s_rev)) 

547 y1 = _butterfly_stage_1d(y1, N_COLS, 1 << (LOG_N - 1 - s_rev)) 

548 y2 = _butterfly_stage_1d(y2, N_COLS, 1 << (LOG_N - 1 - s_rev)) 

549 y3 = _butterfly_stage_1d(y3, N_COLS, 1 << (LOG_N - 1 - s_rev)) 

550 y4 = _butterfly_stage_1d(y4, N_COLS, 1 << (LOG_N - 1 - s_rev)) 

551 y5 = _butterfly_stage_1d(y5, N_COLS, 1 << (LOG_N - 1 - s_rev)) 

552 y6 = _butterfly_stage_1d(y6, N_COLS, 1 << (LOG_N - 1 - s_rev)) 

553 y0 = y0 * SCALE 

554 y1 = y1 * SCALE 

555 y2 = y2 * SCALE 

556 y3 = y3 * SCALE 

557 y4 = y4 * SCALE 

558 y5 = y5 * SCALE 

559 y6 = y6 * SCALE 

560 if IS_FP16: 

561 y0 = y0.to(tl.float16) 

562 y1 = y1.to(tl.float16) 

563 y2 = y2.to(tl.float16) 

564 y3 = y3.to(tl.float16) 

565 y4 = y4.to(tl.float16) 

566 y5 = y5.to(tl.float16) 

567 y6 = y6.to(tl.float16) 

568 elif IS_BF16: 

569 y0 = y0.to(tl.bfloat16) 

570 y1 = y1.to(tl.bfloat16) 

571 y2 = y2.to(tl.bfloat16) 

572 y3 = y3.to(tl.bfloat16) 

573 y4 = y4.to(tl.bfloat16) 

574 y5 = y5.to(tl.bfloat16) 

575 y6 = y6.to(tl.bfloat16) 

576 tl.store(OUT_ptr + base + 0 * stride_row + offs, y0, eviction_policy="evict_first") 

577 tl.store(OUT_ptr + base + 1 * stride_row + offs, y1, eviction_policy="evict_first") 

578 tl.store(OUT_ptr + base + 2 * stride_row + offs, y2, eviction_policy="evict_first") 

579 tl.store(OUT_ptr + base + 3 * stride_row + offs, y3, eviction_policy="evict_first") 

580 tl.store(OUT_ptr + base + 4 * stride_row + offs, y4, eviction_policy="evict_first") 

581 tl.store(OUT_ptr + base + 5 * stride_row + offs, y5, eviction_policy="evict_first") 

582 tl.store(OUT_ptr + base + 6 * stride_row + offs, y6, eviction_policy="evict_first") 

583 

584 

585def _launch_mn_fused_kernel(x: torch.Tensor, M: int, scale: float) -> torch.Tensor: 

586 """Launch the appropriate H_M fused kernel for dim = M * 2^k.""" 

587 *leading, dim = x.shape 

588 batch = x.numel() // dim 

589 n_cols = dim // M 

590 log_n = n_cols.bit_length() - 1 

591 dtype = x.dtype 

592 xm = x.reshape(batch, M, n_cols).contiguous() 

593 out = torch.empty_like(xm) 

594 num_warps = 2 if n_cols <= 1024 else (4 if n_cols <= 2048 else 8) 

595 kwargs = dict( 

596 SCALE=scale, 

597 IS_FP16=(dtype == torch.float16), 

598 IS_BF16=(dtype == torch.bfloat16), 

599 N_COLS=n_cols, 

600 LOG_N=log_n, 

601 num_warps=num_warps, 

602 num_stages=1, 

603 ) 

604 if M == 3: 

605 _h3_fht_kernel[(batch,)](xm, out, xm.stride(0), xm.stride(1), **kwargs) 

606 elif M == 5: 

607 _h5_fht_kernel[(batch,)](xm, out, xm.stride(0), xm.stride(1), **kwargs) 

608 elif M == 7: 

609 _h7_fht_kernel[(batch,)](xm, out, xm.stride(0), xm.stride(1), **kwargs) 

610 else: 

611 raise ValueError(f"Unsupported M={M}") 

612 return out.reshape(*leading, dim) 

613 

614 

615# ============================================================ 

616# Precomputed lookup tables for fast dispatch 

617# ============================================================ 

618 

619# Power-of-2 dims that are multiples of 8, up to 65536 

620_POW2_DIMS = frozenset(1 << k for k in range(3, 17)) # 8, 16, ..., 65536 

621 

622 

623# ============================================================ 

624# Core forward 

625# ============================================================ 

626 

627 

628def _hadamard_transform_fwd(x: torch.Tensor, scale: float) -> torch.Tensor: 

629 """Core forward: handles reshape, padding, kernel launch.""" 

630 shapes_og = x.shape 

631 dim_og = x.shape[-1] 

632 input_dtype = x.dtype 

633 x_flat = x.reshape(-1, dim_og) 

634 if x_flat.stride(-1) != 1: 

635 x_flat = x_flat.contiguous() 

636 batch_size = x_flat.shape[0] 

637 

638 # Fast path for power-of-2 dims (no padding needed) 

639 if dim_og in _POW2_DIMS: 

640 n = dim_og 

641 log_n = n.bit_length() - 1 

642 # Allocate output directly with explicit args (faster than empty_like) 

643 out = torch.empty(batch_size, n, dtype=input_dtype, device=x_flat.device) 

644 stride_x = x_flat.stride(0) 

645 stride_out = n # out is freshly allocated, always contiguous 

646 

647 _launch_kernel( 

648 x_flat, out, scale, input_dtype, batch_size, n, log_n, stride_x, stride_out 

649 ) 

650 

651 return out.reshape(shapes_og) 

652 

653 # General path: handle padding 

654 assert input_dtype in ( 

655 torch.float32, 

656 torch.float16, 

657 torch.bfloat16, 

658 ), f"hadamard_transform not implemented for input type '{input_dtype}'" 

659 assert x.is_cuda, "hadamard_transform requires CUDA tensor" 

660 

661 # Pad to multiple of 8 (matching CUDA implementation) 

662 needs_pad = dim_og % 8 != 0 

663 if needs_pad: 

664 x_flat = F.pad(x_flat, (0, 8 - dim_og % 8)) 

665 dim = x_flat.shape[1] 

666 

667 assert ( 

668 dim % 8 == 0 

669 ), "fast_hadamard_transform only supports hidden dimension divisible by 8 for now" 

670 assert ( 

671 dim <= 65536 

672 ), "fast_hadamard_transform only supports hidden dimension at most 65536 for now" 

673 

674 # For butterfly we need next power of 2 

675 log_n = math.ceil(math.log2(dim)) if dim > 1 else 1 

676 n = 1 << log_n 

677 

678 # If dim (multiple of 8) is not a power of 2, pad further for the kernel 

679 if n != dim: 

680 x_flat = F.pad(x_flat, (0, n - dim)) 

681 

682 out = torch.empty(batch_size, n, dtype=input_dtype, device=x_flat.device) 

683 stride_x = x_flat.stride(0) 

684 stride_out = n 

685 

686 _launch_kernel( 

687 x_flat, out, scale, input_dtype, batch_size, n, log_n, stride_x, stride_out 

688 ) 

689 

690 # Trim padding back to original dim 

691 if n != dim_og: 

692 out = out[:, :dim_og] 

693 return out.reshape(shapes_og) 

694 

695 

696def _launch_kernel( 

697 x, out, scale, input_dtype, batch_size, n, log_n, stride_x, stride_out 

698): 

699 """Dispatch to the appropriate kernel. Separated for fast-path sharing.""" 

700 # Dispatch strategy (v43): 

701 # - dim=256, fp16/bf16: 4-row ILP native (warps=2) — test 2-warp occupancy 

702 # - dim=512, fp16/bf16: 1D single-row native (warps=1) — v35 best 

703 # - other dim<=128, fp16/bf16: generic 1D native 

704 # - dim=1024, fp16/bf16: 2D native batched (rows=2, warps=4) 

705 # - dim<=512, fp32: fp32 1D kernel 

706 # - dim>=1024, fp32 or dim>=2048: fp32 2D kernel 

707 if n <= 1024 and input_dtype in (torch.float16, torch.bfloat16): 

708 if n == 256: 

709 if batch_size >= 4: 

710 _fht_kernel_256_4row_native[((batch_size + 3) // 4,)]( 

711 x, 

712 out, 

713 stride_x_row=stride_x, 

714 stride_out_row=stride_out, 

715 N_ROWS=batch_size, 

716 SCALE=scale, 

717 num_warps=2, 

718 num_stages=1, 

719 ) 

720 else: 

721 _fht_kernel_256_1d_native[(batch_size,)]( 

722 x, 

723 out, 

724 stride_x_row=stride_x, 

725 stride_out_row=stride_out, 

726 SCALE=scale, 

727 num_warps=2, 

728 num_stages=1, 

729 ) 

730 elif n == 512: 

731 # Single-row 1D hardcoded: v35 achieved 1.1193x (best) 

732 _fht_kernel_512_1d_native[(batch_size,)]( 

733 x, 

734 out, 

735 stride_x_row=stride_x, 

736 stride_out_row=stride_out, 

737 SCALE=scale, 

738 num_warps=1, 

739 num_stages=1, 

740 ) 

741 elif n <= 128: 

742 _fht_kernel_1d_native[(batch_size,)]( 

743 x, 

744 out, 

745 stride_x_row=stride_x, 

746 stride_out_row=stride_out, 

747 DIM=n, 

748 LOG_N=log_n, 

749 BLOCK_SIZE=n, 

750 SCALE=scale, 

751 num_warps=1, 

752 num_stages=1, 

753 ) 

754 else: 

755 # dim=1024: 2D native with 2 rows/program 

756 rows_per_program = 2 

757 n_programs = (batch_size + rows_per_program - 1) // rows_per_program 

758 _fht_kernel_2d_native[(n_programs,)]( 

759 x, 

760 out, 

761 stride_x_row=stride_x, 

762 stride_out_row=stride_out, 

763 N_ROWS=batch_size, 

764 DIM=n, 

765 LOG_N=log_n, 

766 BLOCK_SIZE=n, 

767 ROWS_PER_PROGRAM=rows_per_program, 

768 SCALE=scale, 

769 num_warps=4, 

770 num_stages=1, 

771 ) 

772 elif n <= 512: 

773 # fp32 1D kernel 

774 _fht_kernel_1d[(batch_size,)]( 

775 x, 

776 out, 

777 scale, 

778 stride_x_row=stride_x, 

779 stride_out_row=stride_out, 

780 DIM=n, 

781 LOG_N=log_n, 

782 BLOCK_SIZE=n, 

783 INPUT_IS_FP16=(input_dtype == torch.float16), 

784 INPUT_IS_BF16=(input_dtype == torch.bfloat16), 

785 num_warps=1, 

786 num_stages=1, 

787 ) 

788 else: 

789 # fp32 2D butterfly for fp32 inputs and large dims 

790 if n <= 32: 

791 num_warps = 1 

792 rows_per_program = 64 

793 elif n <= 64: 

794 num_warps = 1 

795 rows_per_program = 64 

796 elif n <= 128: 

797 num_warps = 1 

798 rows_per_program = 32 

799 elif n <= 256: 

800 num_warps = 1 

801 rows_per_program = 16 

802 elif n <= 1024: 

803 num_warps = 4 

804 rows_per_program = 2 

805 elif n <= 4096: 

806 num_warps = 4 

807 rows_per_program = 1 

808 else: 

809 num_warps = 8 

810 rows_per_program = 1 

811 

812 n_programs = (batch_size + rows_per_program - 1) // rows_per_program 

813 _fht_kernel_2d[(n_programs,)]( 

814 x, 

815 out, 

816 scale, 

817 stride_x_row=stride_x, 

818 stride_out_row=stride_out, 

819 N_ROWS=batch_size, 

820 DIM=n, 

821 LOG_N=log_n, 

822 BLOCK_SIZE=n, 

823 ROWS_PER_PROGRAM=rows_per_program, 

824 INPUT_IS_FP16=(input_dtype == torch.float16), 

825 INPUT_IS_BF16=(input_dtype == torch.bfloat16), 

826 num_warps=num_warps, 

827 num_stages=1, 

828 ) 

829 

830 

831# ============================================================ 

832# Autograd Function 

833# ============================================================ 

834 

835 

836class HadamardTransformFn(torch.autograd.Function): 

837 @staticmethod 

838 def forward(ctx, x, scale=1.0): 

839 ctx._hadamard_transform_scale = scale 

840 return _hadamard_transform_fwd(x, scale) 

841 

842 @staticmethod 

843 def backward(ctx, dout): 

844 # Hadamard matrix is symmetric: backward = forward with same scale 

845 return _hadamard_transform_fwd(dout, ctx._hadamard_transform_scale), None 

846 

847 

848# ============================================================ 

849# Public API 

850# ============================================================ 

851 

852 

853def hadamard_transform(x, scale=1.0): 

854 """ 

855 Arguments: 

856 x: (..., dim) 

857 scale: float. Multiply the output by this number. 

858 Returns: 

859 out: (..., dim) 

860 

861 Multiply each row of x by the Hadamard transform matrix. 

862 Equivalent to F.linear(x, torch.tensor(scipy.linalg.hadamard(dim))) * scale. 

863 If dim is not a power of 2, we implicitly pad x with zero so that dim is 

864 the next power of 2. 

865 """ 

866 return HadamardTransformFn.apply(x, scale) 

867 

868 

869# ============================================================ 

870# XXN variants (non-power-of-2 dims) 

871# 

872# Decomposes dim = M * 2^k via H_M ⊗ H_{2^k}: 

873# 1. Reshape to (batch, M, 2^k) 

874# 2. Apply H_M column transform + FHT in a single fused kernel 

875# No padding to next power of 2, no intermediate DRAM write. 

876# ============================================================ 

877 

878 

879def hadamard_transform_12N(x, scale=1.0): 

880 """Hadamard transform for dim = 3 * 2^k (e.g. 1536, 3072, 6144, 12288).""" 

881 return _launch_mn_fused_kernel(x, M=3, scale=scale) 

882 

883 

884def hadamard_transform_20N(x, scale=1.0): 

885 """Hadamard transform for dim = 5 * 2^k (e.g. 5120, 10240, 20480).""" 

886 return _launch_mn_fused_kernel(x, M=5, scale=scale) 

887 

888 

889def hadamard_transform_28N(x, scale=1.0): 

890 """Hadamard transform for dim = 7 * 2^k (e.g. 7168, 14336, 28672).""" 

891 return _launch_mn_fused_kernel(x, M=7, scale=scale) 

892 

893 

894def hadamard_transform_40N(x, scale=1.0): 

895 """Hadamard transform for dim = 5 * 2^k (e.g. 10240, 20480, 40960).""" 

896 return _launch_mn_fused_kernel(x, M=5, scale=scale)