Coverage for src/flag_gems/ops/hadamard_transform.py: 10%
412 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1"""Fast Hadamard Transform in Triton.
3Drop-in replacement for Dao-AILab/fast-hadamard-transform with identical interface:
4 - hadamard_transform(x, scale=1.0) with autograd support
5 - hadamard_transform_12N/20N/28N/40N(x, scale=1.0) for non-power-of-2 dims
6 - Input: (..., dim), fp32/fp16/bf16
7 - Output: (..., dim), same dtype as input
8 - Padding: to next multiple of 8 (matching CUDA impl)
9 - dim <= 32768 (standard), dim <= M*2^10 (XXN variants)
11Reference: https://github.com/Dao-AILab/fast-hadamard-transform
12"""
14import math
16import torch
17import torch.nn.functional as F
18import triton
19import triton.language as tl
21# ============================================================
22# Triton kernel — v43: remove evict_first from loads + warps=2 for dim=256
23# ============================================================
24# v35 best: dim=256 0.9302x (no evict_first on loads, warps=1)
25# v42: dim=256 0.8950x (evict_first on loads hurt — L2 thrashing)
26#
27# v43 strategy:
28# 1. Remove evict_first from all loads. v42 proved it hurts dim=256
29# (0.8950x vs v35's 0.9302x). The 256-element rows (512B fp16)
30# are small enough that L2 caching of nearby rows helps prefetch.
31# 2. Try num_warps=2 for dim=256 4-row ILP kernel. With 4 rows of
32# 256 elements each, the workload can benefit from 2-warp occupancy:
33# each warp handles the compute for its assigned instructions,
34# and the scheduler can overlap loads from one warp with compute
35# from the other. This targets the memory latency hiding gap.
36# 3. Keep evict_first on stores (write-once streaming pattern).
37# 4. Keep all other kernels unchanged from v42 baseline.
40# ============================================================
41# Butterfly stages
42# ============================================================
45@triton.jit
46def _butterfly_stage_1d(x, BLOCK_SIZE: tl.constexpr, STRIDE: tl.constexpr):
47 """One butterfly stage on a 1D vector."""
48 GRP: tl.constexpr = BLOCK_SIZE // (2 * STRIDE)
49 if STRIDE == 1:
50 x2 = tl.reshape(x, (GRP, 2))
51 a, b = tl.split(x2)
52 return tl.reshape(tl.join(a + b, a - b), (BLOCK_SIZE,))
53 else:
54 x3 = tl.reshape(x, (GRP, 2, STRIDE))
55 x3 = tl.permute(x3, (0, 2, 1))
56 a, b = tl.split(x3)
57 x3 = tl.join(a + b, a - b)
58 x3 = tl.permute(x3, (0, 2, 1))
59 return tl.reshape(x3, (BLOCK_SIZE,))
62@triton.jit
63def _butterfly_stage_2d(
64 x, ROWS: tl.constexpr, BLOCK_SIZE: tl.constexpr, STRIDE: tl.constexpr
65):
66 """One butterfly stage on a 2D (ROWS, BLOCK_SIZE) tensor."""
67 GRP: tl.constexpr = BLOCK_SIZE // (2 * STRIDE)
68 if STRIDE == 1:
69 x2 = tl.reshape(x, (ROWS, GRP, 2))
70 a, b = tl.split(x2)
71 return tl.reshape(tl.join(a + b, a - b), (ROWS, BLOCK_SIZE))
72 else:
73 x3 = tl.reshape(x, (ROWS, GRP, 2, STRIDE))
74 x3 = tl.permute(x3, (0, 1, 3, 2))
75 a, b = tl.split(x3)
76 x3 = tl.join(a + b, a - b)
77 x3 = tl.permute(x3, (0, 1, 3, 2))
78 return tl.reshape(x3, (ROWS, BLOCK_SIZE))
81# ============================================================
82# 4-row ILP 1D native kernel for dim=256 (8 hardcoded stages)
83# v43: remove evict_first from loads, keep on stores
84# ============================================================
87@triton.jit
88def _fht_kernel_256_4row_native(
89 X_ptr,
90 OUT_ptr,
91 stride_x_row,
92 stride_out_row,
93 N_ROWS,
94 SCALE: tl.constexpr,
95):
96 """FHT for dim=256, 4-row ILP: four independent 1D butterflies per program."""
97 pid = tl.program_id(0)
98 col_offs = tl.arange(0, 256)
100 row0 = pid * 4
101 row1 = row0 + 1
102 row2 = row0 + 2
103 row3 = row0 + 3
105 # Load all 4 rows (no evict_first: L2 caching helps for nearby rows)
106 x0 = tl.load(X_ptr + row0 * stride_x_row + col_offs)
107 x1 = tl.load(X_ptr + row1 * stride_x_row + col_offs, mask=row1 < N_ROWS, other=0.0)
108 x2 = tl.load(X_ptr + row2 * stride_x_row + col_offs, mask=row2 < N_ROWS, other=0.0)
109 x3 = tl.load(X_ptr + row3 * stride_x_row + col_offs, mask=row3 < N_ROWS, other=0.0)
111 # Interleaved hardcoded reversed butterfly stages for 4-way ILP
112 x0 = _butterfly_stage_1d(x0, 256, 128)
113 x1 = _butterfly_stage_1d(x1, 256, 128)
114 x2 = _butterfly_stage_1d(x2, 256, 128)
115 x3 = _butterfly_stage_1d(x3, 256, 128)
116 x0 = _butterfly_stage_1d(x0, 256, 64)
117 x1 = _butterfly_stage_1d(x1, 256, 64)
118 x2 = _butterfly_stage_1d(x2, 256, 64)
119 x3 = _butterfly_stage_1d(x3, 256, 64)
120 x0 = _butterfly_stage_1d(x0, 256, 32)
121 x1 = _butterfly_stage_1d(x1, 256, 32)
122 x2 = _butterfly_stage_1d(x2, 256, 32)
123 x3 = _butterfly_stage_1d(x3, 256, 32)
124 x0 = _butterfly_stage_1d(x0, 256, 16)
125 x1 = _butterfly_stage_1d(x1, 256, 16)
126 x2 = _butterfly_stage_1d(x2, 256, 16)
127 x3 = _butterfly_stage_1d(x3, 256, 16)
128 x0 = _butterfly_stage_1d(x0, 256, 8)
129 x1 = _butterfly_stage_1d(x1, 256, 8)
130 x2 = _butterfly_stage_1d(x2, 256, 8)
131 x3 = _butterfly_stage_1d(x3, 256, 8)
132 x0 = _butterfly_stage_1d(x0, 256, 4)
133 x1 = _butterfly_stage_1d(x1, 256, 4)
134 x2 = _butterfly_stage_1d(x2, 256, 4)
135 x3 = _butterfly_stage_1d(x3, 256, 4)
136 x0 = _butterfly_stage_1d(x0, 256, 2)
137 x1 = _butterfly_stage_1d(x1, 256, 2)
138 x2 = _butterfly_stage_1d(x2, 256, 2)
139 x3 = _butterfly_stage_1d(x3, 256, 2)
140 x0 = _butterfly_stage_1d(x0, 256, 1)
141 x1 = _butterfly_stage_1d(x1, 256, 1)
142 x2 = _butterfly_stage_1d(x2, 256, 1)
143 x3 = _butterfly_stage_1d(x3, 256, 1)
145 x0 = x0 * SCALE
146 x1 = x1 * SCALE
147 x2 = x2 * SCALE
148 x3 = x3 * SCALE
150 tl.store(
151 OUT_ptr + row0 * stride_out_row + col_offs, x0, eviction_policy="evict_first"
152 )
153 tl.store(
154 OUT_ptr + row1 * stride_out_row + col_offs,
155 x1,
156 mask=row1 < N_ROWS,
157 eviction_policy="evict_first",
158 )
159 tl.store(
160 OUT_ptr + row2 * stride_out_row + col_offs,
161 x2,
162 mask=row2 < N_ROWS,
163 eviction_policy="evict_first",
164 )
165 tl.store(
166 OUT_ptr + row3 * stride_out_row + col_offs,
167 x3,
168 mask=row3 < N_ROWS,
169 eviction_policy="evict_first",
170 )
173# ============================================================
174# Fallback: single-row 1D native kernel for dim=256
175# ============================================================
178@triton.jit
179def _fht_kernel_256_1d_native(
180 X_ptr,
181 OUT_ptr,
182 stride_x_row,
183 stride_out_row,
184 SCALE: tl.constexpr,
185):
186 """FHT for dim=256, 1D native fp16/bf16, 8 hardcoded reversed butterfly stages."""
187 pid = tl.program_id(0)
188 col_offs = tl.arange(0, 256)
190 x = tl.load(X_ptr + pid * stride_x_row + col_offs)
192 # Reversed butterfly: stride 128, 64, 32, 16, 8, 4, 2, 1
193 x = _butterfly_stage_1d(x, 256, 128)
194 x = _butterfly_stage_1d(x, 256, 64)
195 x = _butterfly_stage_1d(x, 256, 32)
196 x = _butterfly_stage_1d(x, 256, 16)
197 x = _butterfly_stage_1d(x, 256, 8)
198 x = _butterfly_stage_1d(x, 256, 4)
199 x = _butterfly_stage_1d(x, 256, 2)
200 x = _butterfly_stage_1d(x, 256, 1)
202 x = x * SCALE
203 tl.store(
204 OUT_ptr + pid * stride_out_row + col_offs, x, eviction_policy="evict_first"
205 )
208# ============================================================
209# 1D hardcoded native kernel for dim=512 (9 stages)
210# Restored from v31/v35: single-row (best: 1.1193x in v35)
211# ============================================================
214@triton.jit
215def _fht_kernel_512_1d_native(
216 X_ptr,
217 OUT_ptr,
218 stride_x_row,
219 stride_out_row,
220 SCALE: tl.constexpr,
221):
222 """FHT for dim=512, 1D native fp16/bf16, 9 hardcoded reversed butterfly stages."""
223 pid = tl.program_id(0)
224 col_offs = tl.arange(0, 512)
226 x = tl.load(X_ptr + pid * stride_x_row + col_offs)
228 # Reversed butterfly: stride 256, 128, 64, 32, 16, 8, 4, 2, 1
229 x = _butterfly_stage_1d(x, 512, 256)
230 x = _butterfly_stage_1d(x, 512, 128)
231 x = _butterfly_stage_1d(x, 512, 64)
232 x = _butterfly_stage_1d(x, 512, 32)
233 x = _butterfly_stage_1d(x, 512, 16)
234 x = _butterfly_stage_1d(x, 512, 8)
235 x = _butterfly_stage_1d(x, 512, 4)
236 x = _butterfly_stage_1d(x, 512, 2)
237 x = _butterfly_stage_1d(x, 512, 1)
239 x = x * SCALE
240 tl.store(
241 OUT_ptr + pid * stride_out_row + col_offs, x, eviction_policy="evict_first"
242 )
245# ============================================================
246# Generic 1D native-dtype butterfly kernel (for other small dims)
247# ============================================================
250@triton.jit
251def _fht_kernel_1d_native(
252 X_ptr,
253 OUT_ptr,
254 stride_x_row,
255 stride_out_row,
256 DIM: tl.constexpr,
257 LOG_N: tl.constexpr,
258 BLOCK_SIZE: tl.constexpr,
259 SCALE: tl.constexpr,
260):
261 """FHT butterfly — 1D single-row, native fp16/bf16, reversed stage order."""
262 pid = tl.program_id(0)
263 col_offs = tl.arange(0, BLOCK_SIZE)
265 x = tl.load(X_ptr + pid * stride_x_row + col_offs)
267 # Reversed butterfly: stride N/2, N/4, ..., 2, 1
268 for s_rev in tl.static_range(LOG_N):
269 x = _butterfly_stage_1d(x, BLOCK_SIZE, 1 << (LOG_N - 1 - s_rev))
271 x = x * SCALE
272 tl.store(
273 OUT_ptr + pid * stride_out_row + col_offs, x, eviction_policy="evict_first"
274 )
277# ============================================================
278# 2D native-dtype butterfly kernel (for dim=1024 with fp16/bf16)
279# ============================================================
282@triton.jit
283def _fht_kernel_2d_native(
284 X_ptr,
285 OUT_ptr,
286 stride_x_row,
287 stride_out_row,
288 N_ROWS,
289 DIM: tl.constexpr,
290 LOG_N: tl.constexpr,
291 BLOCK_SIZE: tl.constexpr,
292 ROWS_PER_PROGRAM: tl.constexpr,
293 SCALE: tl.constexpr,
294):
295 """FHT butterfly — 2D batch, reversed stage order, native fp16/bf16 compute."""
296 pid = tl.program_id(0)
297 col_offs = tl.arange(0, BLOCK_SIZE)
298 row_offs = tl.arange(0, ROWS_PER_PROGRAM)
300 base_row = pid * ROWS_PER_PROGRAM
301 row_ids = base_row + row_offs
302 row_mask = row_ids < N_ROWS
304 in_ptrs = X_ptr + row_ids[:, None] * stride_x_row + col_offs[None, :]
305 out_ptrs = OUT_ptr + row_ids[:, None] * stride_out_row + col_offs[None, :]
306 load_mask = row_mask[:, None]
308 x = tl.load(in_ptrs, mask=load_mask, other=0.0)
310 # Reversed butterfly: stride N/2, N/4, ..., 2, 1
311 for s_rev in tl.static_range(LOG_N):
312 x = _butterfly_stage_2d(
313 x, ROWS_PER_PROGRAM, BLOCK_SIZE, 1 << (LOG_N - 1 - s_rev)
314 )
316 x = x * SCALE
317 tl.store(out_ptrs, x, mask=load_mask, eviction_policy="evict_first")
320# ============================================================
321# 1D butterfly kernel (fp32, for fp32 inputs)
322# ============================================================
325@triton.jit
326def _fht_kernel_1d(
327 X_ptr,
328 OUT_ptr,
329 scale,
330 stride_x_row,
331 stride_out_row,
332 DIM: tl.constexpr,
333 LOG_N: tl.constexpr,
334 BLOCK_SIZE: tl.constexpr,
335 INPUT_IS_FP16: tl.constexpr,
336 INPUT_IS_BF16: tl.constexpr,
337):
338 """FHT butterfly — 1D single-row kernel, reversed stage order."""
339 pid = tl.program_id(0)
340 col_offs = tl.arange(0, BLOCK_SIZE)
342 in_ptr = X_ptr + pid * stride_x_row + col_offs
343 out_ptr = OUT_ptr + pid * stride_out_row + col_offs
345 x = tl.load(in_ptr).to(tl.float32)
347 for s_rev in tl.static_range(LOG_N):
348 x = _butterfly_stage_1d(x, BLOCK_SIZE, 1 << (LOG_N - 1 - s_rev))
350 x = x * scale
352 if INPUT_IS_FP16:
353 tl.store(out_ptr, x.to(tl.float16), eviction_policy="evict_first")
354 elif INPUT_IS_BF16:
355 tl.store(out_ptr, x.to(tl.bfloat16), eviction_policy="evict_first")
356 else:
357 tl.store(out_ptr, x, eviction_policy="evict_first")
360# ============================================================
361# 2D butterfly kernel (fp32, for dim>=1024 and fp32 inputs)
362# ============================================================
365@triton.jit
366def _fht_kernel_2d(
367 X_ptr,
368 OUT_ptr,
369 scale,
370 stride_x_row,
371 stride_out_row,
372 N_ROWS,
373 DIM: tl.constexpr,
374 LOG_N: tl.constexpr,
375 BLOCK_SIZE: tl.constexpr,
376 ROWS_PER_PROGRAM: tl.constexpr,
377 INPUT_IS_FP16: tl.constexpr,
378 INPUT_IS_BF16: tl.constexpr,
379):
380 """FHT butterfly — 2D batch, reversed stage order, fp32 compute."""
381 pid = tl.program_id(0)
382 col_offs = tl.arange(0, BLOCK_SIZE)
383 row_offs = tl.arange(0, ROWS_PER_PROGRAM)
385 base_row = pid * ROWS_PER_PROGRAM
386 row_ids = base_row + row_offs
387 row_mask = row_ids < N_ROWS
389 in_ptrs = X_ptr + row_ids[:, None] * stride_x_row + col_offs[None, :]
390 out_ptrs = OUT_ptr + row_ids[:, None] * stride_out_row + col_offs[None, :]
391 load_mask = row_mask[:, None]
393 x = tl.load(in_ptrs, mask=load_mask, other=0.0).to(tl.float32)
395 for s_rev in tl.static_range(LOG_N):
396 x = _butterfly_stage_2d(
397 x, ROWS_PER_PROGRAM, BLOCK_SIZE, 1 << (LOG_N - 1 - s_rev)
398 )
400 x = x * scale
402 if INPUT_IS_FP16:
403 tl.store(
404 out_ptrs, x.to(tl.float16), mask=load_mask, eviction_policy="evict_first"
405 )
406 elif INPUT_IS_BF16:
407 tl.store(
408 out_ptrs, x.to(tl.bfloat16), mask=load_mask, eviction_policy="evict_first"
409 )
410 else:
411 tl.store(out_ptrs, x, mask=load_mask, eviction_policy="evict_first")
414# ============================================================
415# M×N fused kernels: H_M column transform + FHT in registers
416# No intermediate DRAM write, no padding to next power of 2.
417# ============================================================
420@triton.jit
421def _h3_fht_kernel(
422 X_ptr,
423 OUT_ptr,
424 stride_batch,
425 stride_row,
426 SCALE: tl.constexpr,
427 IS_FP16: tl.constexpr,
428 IS_BF16: tl.constexpr,
429 N_COLS: tl.constexpr,
430 LOG_N: tl.constexpr,
431):
432 pid = tl.program_id(0)
433 offs = tl.arange(0, N_COLS)
434 base = pid * stride_batch
435 a = tl.load(X_ptr + base + 0 * stride_row + offs).to(tl.float32)
436 b = tl.load(X_ptr + base + 1 * stride_row + offs).to(tl.float32)
437 c = tl.load(X_ptr + base + 2 * stride_row + offs).to(tl.float32)
438 y0 = a + b + c
439 y1 = a - b + c
440 y2 = a + b - c
441 for s_rev in tl.static_range(LOG_N):
442 y0 = _butterfly_stage_1d(y0, N_COLS, 1 << (LOG_N - 1 - s_rev))
443 y1 = _butterfly_stage_1d(y1, N_COLS, 1 << (LOG_N - 1 - s_rev))
444 y2 = _butterfly_stage_1d(y2, N_COLS, 1 << (LOG_N - 1 - s_rev))
445 y0 = y0 * SCALE
446 y1 = y1 * SCALE
447 y2 = y2 * SCALE
448 if IS_FP16:
449 y0 = y0.to(tl.float16)
450 y1 = y1.to(tl.float16)
451 y2 = y2.to(tl.float16)
452 elif IS_BF16:
453 y0 = y0.to(tl.bfloat16)
454 y1 = y1.to(tl.bfloat16)
455 y2 = y2.to(tl.bfloat16)
456 tl.store(OUT_ptr + base + 0 * stride_row + offs, y0, eviction_policy="evict_first")
457 tl.store(OUT_ptr + base + 1 * stride_row + offs, y1, eviction_policy="evict_first")
458 tl.store(OUT_ptr + base + 2 * stride_row + offs, y2, eviction_policy="evict_first")
461@triton.jit
462def _h5_fht_kernel(
463 X_ptr,
464 OUT_ptr,
465 stride_batch,
466 stride_row,
467 SCALE: tl.constexpr,
468 IS_FP16: tl.constexpr,
469 IS_BF16: tl.constexpr,
470 N_COLS: tl.constexpr,
471 LOG_N: tl.constexpr,
472):
473 pid = tl.program_id(0)
474 offs = tl.arange(0, N_COLS)
475 base = pid * stride_batch
476 a = tl.load(X_ptr + base + 0 * stride_row + offs).to(tl.float32)
477 b = tl.load(X_ptr + base + 1 * stride_row + offs).to(tl.float32)
478 c = tl.load(X_ptr + base + 2 * stride_row + offs).to(tl.float32)
479 d = tl.load(X_ptr + base + 3 * stride_row + offs).to(tl.float32)
480 e = tl.load(X_ptr + base + 4 * stride_row + offs).to(tl.float32)
481 y0 = a + b + c + d + e
482 y1 = a - b + c - d + e
483 y2 = a + b - c + d - e
484 y3 = a - b - c - d - e
485 y4 = a + b + c - d - e
486 for s_rev in tl.static_range(LOG_N):
487 y0 = _butterfly_stage_1d(y0, N_COLS, 1 << (LOG_N - 1 - s_rev))
488 y1 = _butterfly_stage_1d(y1, N_COLS, 1 << (LOG_N - 1 - s_rev))
489 y2 = _butterfly_stage_1d(y2, N_COLS, 1 << (LOG_N - 1 - s_rev))
490 y3 = _butterfly_stage_1d(y3, N_COLS, 1 << (LOG_N - 1 - s_rev))
491 y4 = _butterfly_stage_1d(y4, N_COLS, 1 << (LOG_N - 1 - s_rev))
492 y0 = y0 * SCALE
493 y1 = y1 * SCALE
494 y2 = y2 * SCALE
495 y3 = y3 * SCALE
496 y4 = y4 * SCALE
497 if IS_FP16:
498 y0 = y0.to(tl.float16)
499 y1 = y1.to(tl.float16)
500 y2 = y2.to(tl.float16)
501 y3 = y3.to(tl.float16)
502 y4 = y4.to(tl.float16)
503 elif IS_BF16:
504 y0 = y0.to(tl.bfloat16)
505 y1 = y1.to(tl.bfloat16)
506 y2 = y2.to(tl.bfloat16)
507 y3 = y3.to(tl.bfloat16)
508 y4 = y4.to(tl.bfloat16)
509 tl.store(OUT_ptr + base + 0 * stride_row + offs, y0, eviction_policy="evict_first")
510 tl.store(OUT_ptr + base + 1 * stride_row + offs, y1, eviction_policy="evict_first")
511 tl.store(OUT_ptr + base + 2 * stride_row + offs, y2, eviction_policy="evict_first")
512 tl.store(OUT_ptr + base + 3 * stride_row + offs, y3, eviction_policy="evict_first")
513 tl.store(OUT_ptr + base + 4 * stride_row + offs, y4, eviction_policy="evict_first")
516@triton.jit
517def _h7_fht_kernel(
518 X_ptr,
519 OUT_ptr,
520 stride_batch,
521 stride_row,
522 SCALE: tl.constexpr,
523 IS_FP16: tl.constexpr,
524 IS_BF16: tl.constexpr,
525 N_COLS: tl.constexpr,
526 LOG_N: tl.constexpr,
527):
528 pid = tl.program_id(0)
529 offs = tl.arange(0, N_COLS)
530 base = pid * stride_batch
531 a = tl.load(X_ptr + base + 0 * stride_row + offs).to(tl.float32)
532 b = tl.load(X_ptr + base + 1 * stride_row + offs).to(tl.float32)
533 c = tl.load(X_ptr + base + 2 * stride_row + offs).to(tl.float32)
534 d = tl.load(X_ptr + base + 3 * stride_row + offs).to(tl.float32)
535 e = tl.load(X_ptr + base + 4 * stride_row + offs).to(tl.float32)
536 f = tl.load(X_ptr + base + 5 * stride_row + offs).to(tl.float32)
537 g = tl.load(X_ptr + base + 6 * stride_row + offs).to(tl.float32)
538 y0 = a + b + c + d + e + f + g
539 y1 = a - b + c - d + e - f + g
540 y2 = a + b - c + d - e + f - g
541 y3 = a - b - c - d - e - f - g
542 y4 = a + b + c - d - e - f - g
543 y5 = a - b + c + d - e + f + g
544 y6 = a + b - c - d + e + f - g
545 for s_rev in tl.static_range(LOG_N):
546 y0 = _butterfly_stage_1d(y0, N_COLS, 1 << (LOG_N - 1 - s_rev))
547 y1 = _butterfly_stage_1d(y1, N_COLS, 1 << (LOG_N - 1 - s_rev))
548 y2 = _butterfly_stage_1d(y2, N_COLS, 1 << (LOG_N - 1 - s_rev))
549 y3 = _butterfly_stage_1d(y3, N_COLS, 1 << (LOG_N - 1 - s_rev))
550 y4 = _butterfly_stage_1d(y4, N_COLS, 1 << (LOG_N - 1 - s_rev))
551 y5 = _butterfly_stage_1d(y5, N_COLS, 1 << (LOG_N - 1 - s_rev))
552 y6 = _butterfly_stage_1d(y6, N_COLS, 1 << (LOG_N - 1 - s_rev))
553 y0 = y0 * SCALE
554 y1 = y1 * SCALE
555 y2 = y2 * SCALE
556 y3 = y3 * SCALE
557 y4 = y4 * SCALE
558 y5 = y5 * SCALE
559 y6 = y6 * SCALE
560 if IS_FP16:
561 y0 = y0.to(tl.float16)
562 y1 = y1.to(tl.float16)
563 y2 = y2.to(tl.float16)
564 y3 = y3.to(tl.float16)
565 y4 = y4.to(tl.float16)
566 y5 = y5.to(tl.float16)
567 y6 = y6.to(tl.float16)
568 elif IS_BF16:
569 y0 = y0.to(tl.bfloat16)
570 y1 = y1.to(tl.bfloat16)
571 y2 = y2.to(tl.bfloat16)
572 y3 = y3.to(tl.bfloat16)
573 y4 = y4.to(tl.bfloat16)
574 y5 = y5.to(tl.bfloat16)
575 y6 = y6.to(tl.bfloat16)
576 tl.store(OUT_ptr + base + 0 * stride_row + offs, y0, eviction_policy="evict_first")
577 tl.store(OUT_ptr + base + 1 * stride_row + offs, y1, eviction_policy="evict_first")
578 tl.store(OUT_ptr + base + 2 * stride_row + offs, y2, eviction_policy="evict_first")
579 tl.store(OUT_ptr + base + 3 * stride_row + offs, y3, eviction_policy="evict_first")
580 tl.store(OUT_ptr + base + 4 * stride_row + offs, y4, eviction_policy="evict_first")
581 tl.store(OUT_ptr + base + 5 * stride_row + offs, y5, eviction_policy="evict_first")
582 tl.store(OUT_ptr + base + 6 * stride_row + offs, y6, eviction_policy="evict_first")
585def _launch_mn_fused_kernel(x: torch.Tensor, M: int, scale: float) -> torch.Tensor:
586 """Launch the appropriate H_M fused kernel for dim = M * 2^k."""
587 *leading, dim = x.shape
588 batch = x.numel() // dim
589 n_cols = dim // M
590 log_n = n_cols.bit_length() - 1
591 dtype = x.dtype
592 xm = x.reshape(batch, M, n_cols).contiguous()
593 out = torch.empty_like(xm)
594 num_warps = 2 if n_cols <= 1024 else (4 if n_cols <= 2048 else 8)
595 kwargs = dict(
596 SCALE=scale,
597 IS_FP16=(dtype == torch.float16),
598 IS_BF16=(dtype == torch.bfloat16),
599 N_COLS=n_cols,
600 LOG_N=log_n,
601 num_warps=num_warps,
602 num_stages=1,
603 )
604 if M == 3:
605 _h3_fht_kernel[(batch,)](xm, out, xm.stride(0), xm.stride(1), **kwargs)
606 elif M == 5:
607 _h5_fht_kernel[(batch,)](xm, out, xm.stride(0), xm.stride(1), **kwargs)
608 elif M == 7:
609 _h7_fht_kernel[(batch,)](xm, out, xm.stride(0), xm.stride(1), **kwargs)
610 else:
611 raise ValueError(f"Unsupported M={M}")
612 return out.reshape(*leading, dim)
615# ============================================================
616# Precomputed lookup tables for fast dispatch
617# ============================================================
619# Power-of-2 dims that are multiples of 8, up to 65536
620_POW2_DIMS = frozenset(1 << k for k in range(3, 17)) # 8, 16, ..., 65536
623# ============================================================
624# Core forward
625# ============================================================
628def _hadamard_transform_fwd(x: torch.Tensor, scale: float) -> torch.Tensor:
629 """Core forward: handles reshape, padding, kernel launch."""
630 shapes_og = x.shape
631 dim_og = x.shape[-1]
632 input_dtype = x.dtype
633 x_flat = x.reshape(-1, dim_og)
634 if x_flat.stride(-1) != 1:
635 x_flat = x_flat.contiguous()
636 batch_size = x_flat.shape[0]
638 # Fast path for power-of-2 dims (no padding needed)
639 if dim_og in _POW2_DIMS:
640 n = dim_og
641 log_n = n.bit_length() - 1
642 # Allocate output directly with explicit args (faster than empty_like)
643 out = torch.empty(batch_size, n, dtype=input_dtype, device=x_flat.device)
644 stride_x = x_flat.stride(0)
645 stride_out = n # out is freshly allocated, always contiguous
647 _launch_kernel(
648 x_flat, out, scale, input_dtype, batch_size, n, log_n, stride_x, stride_out
649 )
651 return out.reshape(shapes_og)
653 # General path: handle padding
654 assert input_dtype in (
655 torch.float32,
656 torch.float16,
657 torch.bfloat16,
658 ), f"hadamard_transform not implemented for input type '{input_dtype}'"
659 assert x.is_cuda, "hadamard_transform requires CUDA tensor"
661 # Pad to multiple of 8 (matching CUDA implementation)
662 needs_pad = dim_og % 8 != 0
663 if needs_pad:
664 x_flat = F.pad(x_flat, (0, 8 - dim_og % 8))
665 dim = x_flat.shape[1]
667 assert (
668 dim % 8 == 0
669 ), "fast_hadamard_transform only supports hidden dimension divisible by 8 for now"
670 assert (
671 dim <= 65536
672 ), "fast_hadamard_transform only supports hidden dimension at most 65536 for now"
674 # For butterfly we need next power of 2
675 log_n = math.ceil(math.log2(dim)) if dim > 1 else 1
676 n = 1 << log_n
678 # If dim (multiple of 8) is not a power of 2, pad further for the kernel
679 if n != dim:
680 x_flat = F.pad(x_flat, (0, n - dim))
682 out = torch.empty(batch_size, n, dtype=input_dtype, device=x_flat.device)
683 stride_x = x_flat.stride(0)
684 stride_out = n
686 _launch_kernel(
687 x_flat, out, scale, input_dtype, batch_size, n, log_n, stride_x, stride_out
688 )
690 # Trim padding back to original dim
691 if n != dim_og:
692 out = out[:, :dim_og]
693 return out.reshape(shapes_og)
696def _launch_kernel(
697 x, out, scale, input_dtype, batch_size, n, log_n, stride_x, stride_out
698):
699 """Dispatch to the appropriate kernel. Separated for fast-path sharing."""
700 # Dispatch strategy (v43):
701 # - dim=256, fp16/bf16: 4-row ILP native (warps=2) — test 2-warp occupancy
702 # - dim=512, fp16/bf16: 1D single-row native (warps=1) — v35 best
703 # - other dim<=128, fp16/bf16: generic 1D native
704 # - dim=1024, fp16/bf16: 2D native batched (rows=2, warps=4)
705 # - dim<=512, fp32: fp32 1D kernel
706 # - dim>=1024, fp32 or dim>=2048: fp32 2D kernel
707 if n <= 1024 and input_dtype in (torch.float16, torch.bfloat16):
708 if n == 256:
709 if batch_size >= 4:
710 _fht_kernel_256_4row_native[((batch_size + 3) // 4,)](
711 x,
712 out,
713 stride_x_row=stride_x,
714 stride_out_row=stride_out,
715 N_ROWS=batch_size,
716 SCALE=scale,
717 num_warps=2,
718 num_stages=1,
719 )
720 else:
721 _fht_kernel_256_1d_native[(batch_size,)](
722 x,
723 out,
724 stride_x_row=stride_x,
725 stride_out_row=stride_out,
726 SCALE=scale,
727 num_warps=2,
728 num_stages=1,
729 )
730 elif n == 512:
731 # Single-row 1D hardcoded: v35 achieved 1.1193x (best)
732 _fht_kernel_512_1d_native[(batch_size,)](
733 x,
734 out,
735 stride_x_row=stride_x,
736 stride_out_row=stride_out,
737 SCALE=scale,
738 num_warps=1,
739 num_stages=1,
740 )
741 elif n <= 128:
742 _fht_kernel_1d_native[(batch_size,)](
743 x,
744 out,
745 stride_x_row=stride_x,
746 stride_out_row=stride_out,
747 DIM=n,
748 LOG_N=log_n,
749 BLOCK_SIZE=n,
750 SCALE=scale,
751 num_warps=1,
752 num_stages=1,
753 )
754 else:
755 # dim=1024: 2D native with 2 rows/program
756 rows_per_program = 2
757 n_programs = (batch_size + rows_per_program - 1) // rows_per_program
758 _fht_kernel_2d_native[(n_programs,)](
759 x,
760 out,
761 stride_x_row=stride_x,
762 stride_out_row=stride_out,
763 N_ROWS=batch_size,
764 DIM=n,
765 LOG_N=log_n,
766 BLOCK_SIZE=n,
767 ROWS_PER_PROGRAM=rows_per_program,
768 SCALE=scale,
769 num_warps=4,
770 num_stages=1,
771 )
772 elif n <= 512:
773 # fp32 1D kernel
774 _fht_kernel_1d[(batch_size,)](
775 x,
776 out,
777 scale,
778 stride_x_row=stride_x,
779 stride_out_row=stride_out,
780 DIM=n,
781 LOG_N=log_n,
782 BLOCK_SIZE=n,
783 INPUT_IS_FP16=(input_dtype == torch.float16),
784 INPUT_IS_BF16=(input_dtype == torch.bfloat16),
785 num_warps=1,
786 num_stages=1,
787 )
788 else:
789 # fp32 2D butterfly for fp32 inputs and large dims
790 if n <= 32:
791 num_warps = 1
792 rows_per_program = 64
793 elif n <= 64:
794 num_warps = 1
795 rows_per_program = 64
796 elif n <= 128:
797 num_warps = 1
798 rows_per_program = 32
799 elif n <= 256:
800 num_warps = 1
801 rows_per_program = 16
802 elif n <= 1024:
803 num_warps = 4
804 rows_per_program = 2
805 elif n <= 4096:
806 num_warps = 4
807 rows_per_program = 1
808 else:
809 num_warps = 8
810 rows_per_program = 1
812 n_programs = (batch_size + rows_per_program - 1) // rows_per_program
813 _fht_kernel_2d[(n_programs,)](
814 x,
815 out,
816 scale,
817 stride_x_row=stride_x,
818 stride_out_row=stride_out,
819 N_ROWS=batch_size,
820 DIM=n,
821 LOG_N=log_n,
822 BLOCK_SIZE=n,
823 ROWS_PER_PROGRAM=rows_per_program,
824 INPUT_IS_FP16=(input_dtype == torch.float16),
825 INPUT_IS_BF16=(input_dtype == torch.bfloat16),
826 num_warps=num_warps,
827 num_stages=1,
828 )
831# ============================================================
832# Autograd Function
833# ============================================================
836class HadamardTransformFn(torch.autograd.Function):
837 @staticmethod
838 def forward(ctx, x, scale=1.0):
839 ctx._hadamard_transform_scale = scale
840 return _hadamard_transform_fwd(x, scale)
842 @staticmethod
843 def backward(ctx, dout):
844 # Hadamard matrix is symmetric: backward = forward with same scale
845 return _hadamard_transform_fwd(dout, ctx._hadamard_transform_scale), None
848# ============================================================
849# Public API
850# ============================================================
853def hadamard_transform(x, scale=1.0):
854 """
855 Arguments:
856 x: (..., dim)
857 scale: float. Multiply the output by this number.
858 Returns:
859 out: (..., dim)
861 Multiply each row of x by the Hadamard transform matrix.
862 Equivalent to F.linear(x, torch.tensor(scipy.linalg.hadamard(dim))) * scale.
863 If dim is not a power of 2, we implicitly pad x with zero so that dim is
864 the next power of 2.
865 """
866 return HadamardTransformFn.apply(x, scale)
869# ============================================================
870# XXN variants (non-power-of-2 dims)
871#
872# Decomposes dim = M * 2^k via H_M ⊗ H_{2^k}:
873# 1. Reshape to (batch, M, 2^k)
874# 2. Apply H_M column transform + FHT in a single fused kernel
875# No padding to next power of 2, no intermediate DRAM write.
876# ============================================================
879def hadamard_transform_12N(x, scale=1.0):
880 """Hadamard transform for dim = 3 * 2^k (e.g. 1536, 3072, 6144, 12288)."""
881 return _launch_mn_fused_kernel(x, M=3, scale=scale)
884def hadamard_transform_20N(x, scale=1.0):
885 """Hadamard transform for dim = 5 * 2^k (e.g. 5120, 10240, 20480)."""
886 return _launch_mn_fused_kernel(x, M=5, scale=scale)
889def hadamard_transform_28N(x, scale=1.0):
890 """Hadamard transform for dim = 7 * 2^k (e.g. 7168, 14336, 28672)."""
891 return _launch_mn_fused_kernel(x, M=7, scale=scale)
894def hadamard_transform_40N(x, scale=1.0):
895 """Hadamard transform for dim = 5 * 2^k (e.g. 10240, 20480, 40960)."""
896 return _launch_mn_fused_kernel(x, M=5, scale=scale)