Coverage for src/flag_gems/ops/hadamard_transform.py: 14%
265 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1"""Fast Hadamard Transform in Triton.
3Drop-in replacement for Dao-AILab/fast-hadamard-transform with identical interface:
4 - hadamard_transform(x, scale=1.0) with autograd support
5 - hadamard_transform_12N/20N/28N/40N(x, scale=1.0) for non-power-of-2 dims
6 - Input: (..., dim), fp32/fp16/bf16
7 - Output: (..., dim), same dtype as input
8 - Padding: to next multiple of 8 (matching CUDA impl)
9 - dim <= 32768 (standard), dim <= M*2^10 (XXN variants)
11Reference: https://github.com/Dao-AILab/fast-hadamard-transform
12"""
14import math
16import torch
17import torch.nn.functional as F
18import triton
19import triton.language as tl
21# ============================================================
22# Triton kernel — v43: remove evict_first from loads + warps=2 for dim=256
23# ============================================================
24# v35 best: dim=256 0.9302x (no evict_first on loads, warps=1)
25# v42: dim=256 0.8950x (evict_first on loads hurt — L2 thrashing)
26#
27# v43 strategy:
28# 1. Remove evict_first from all loads. v42 proved it hurts dim=256
29# (0.8950x vs v35's 0.9302x). The 256-element rows (512B fp16)
30# are small enough that L2 caching of nearby rows helps prefetch.
31# 2. Try num_warps=2 for dim=256 4-row ILP kernel. With 4 rows of
32# 256 elements each, the workload can benefit from 2-warp occupancy:
33# each warp handles the compute for its assigned instructions,
34# and the scheduler can overlap loads from one warp with compute
35# from the other. This targets the memory latency hiding gap.
36# 3. Keep evict_first on stores (write-once streaming pattern).
37# 4. Keep all other kernels unchanged from v42 baseline.
40# ============================================================
41# Butterfly stages
42# ============================================================
45@triton.jit
46def _butterfly_stage_1d(x, BLOCK_SIZE: tl.constexpr, STRIDE: tl.constexpr):
47 """One butterfly stage on a 1D vector."""
48 GRP: tl.constexpr = BLOCK_SIZE // (2 * STRIDE)
49 if STRIDE == 1:
50 x2 = tl.reshape(x, (GRP, 2))
51 a, b = tl.split(x2)
52 return tl.reshape(tl.join(a + b, a - b), (BLOCK_SIZE,))
53 else:
54 x3 = tl.reshape(x, (GRP, 2, STRIDE))
55 x3 = tl.permute(x3, (0, 2, 1))
56 a, b = tl.split(x3)
57 x3 = tl.join(a + b, a - b)
58 x3 = tl.permute(x3, (0, 2, 1))
59 return tl.reshape(x3, (BLOCK_SIZE,))
62@triton.jit
63def _butterfly_stage_2d(
64 x, ROWS: tl.constexpr, BLOCK_SIZE: tl.constexpr, STRIDE: tl.constexpr
65):
66 """One butterfly stage on a 2D (ROWS, BLOCK_SIZE) tensor."""
67 GRP: tl.constexpr = BLOCK_SIZE // (2 * STRIDE)
68 if STRIDE == 1:
69 x2 = tl.reshape(x, (ROWS, GRP, 2))
70 a, b = tl.split(x2)
71 return tl.reshape(tl.join(a + b, a - b), (ROWS, BLOCK_SIZE))
72 else:
73 x3 = tl.reshape(x, (ROWS, GRP, 2, STRIDE))
74 x3 = tl.permute(x3, (0, 1, 3, 2))
75 a, b = tl.split(x3)
76 x3 = tl.join(a + b, a - b)
77 x3 = tl.permute(x3, (0, 1, 3, 2))
78 return tl.reshape(x3, (ROWS, BLOCK_SIZE))
81# ============================================================
82# 4-row ILP 1D native kernel for dim=256 (8 hardcoded stages)
83# v43: remove evict_first from loads, keep on stores
84# ============================================================
87@triton.jit
88def _fht_kernel_256_4row_native(
89 X_ptr,
90 OUT_ptr,
91 stride_x_row,
92 stride_out_row,
93 N_ROWS,
94 SCALE: tl.constexpr,
95):
96 """FHT for dim=256, 4-row ILP: four independent 1D butterflies per program."""
97 pid = tl.program_id(0)
98 col_offs = tl.arange(0, 256)
100 row0 = pid * 4
101 row1 = row0 + 1
102 row2 = row0 + 2
103 row3 = row0 + 3
105 # Load all 4 rows (no evict_first: L2 caching helps for nearby rows)
106 x0 = tl.load(X_ptr + row0 * stride_x_row + col_offs)
107 x1 = tl.load(X_ptr + row1 * stride_x_row + col_offs, mask=row1 < N_ROWS, other=0.0)
108 x2 = tl.load(X_ptr + row2 * stride_x_row + col_offs, mask=row2 < N_ROWS, other=0.0)
109 x3 = tl.load(X_ptr + row3 * stride_x_row + col_offs, mask=row3 < N_ROWS, other=0.0)
111 # Interleaved hardcoded reversed butterfly stages for 4-way ILP
112 x0 = _butterfly_stage_1d(x0, 256, 128)
113 x1 = _butterfly_stage_1d(x1, 256, 128)
114 x2 = _butterfly_stage_1d(x2, 256, 128)
115 x3 = _butterfly_stage_1d(x3, 256, 128)
116 x0 = _butterfly_stage_1d(x0, 256, 64)
117 x1 = _butterfly_stage_1d(x1, 256, 64)
118 x2 = _butterfly_stage_1d(x2, 256, 64)
119 x3 = _butterfly_stage_1d(x3, 256, 64)
120 x0 = _butterfly_stage_1d(x0, 256, 32)
121 x1 = _butterfly_stage_1d(x1, 256, 32)
122 x2 = _butterfly_stage_1d(x2, 256, 32)
123 x3 = _butterfly_stage_1d(x3, 256, 32)
124 x0 = _butterfly_stage_1d(x0, 256, 16)
125 x1 = _butterfly_stage_1d(x1, 256, 16)
126 x2 = _butterfly_stage_1d(x2, 256, 16)
127 x3 = _butterfly_stage_1d(x3, 256, 16)
128 x0 = _butterfly_stage_1d(x0, 256, 8)
129 x1 = _butterfly_stage_1d(x1, 256, 8)
130 x2 = _butterfly_stage_1d(x2, 256, 8)
131 x3 = _butterfly_stage_1d(x3, 256, 8)
132 x0 = _butterfly_stage_1d(x0, 256, 4)
133 x1 = _butterfly_stage_1d(x1, 256, 4)
134 x2 = _butterfly_stage_1d(x2, 256, 4)
135 x3 = _butterfly_stage_1d(x3, 256, 4)
136 x0 = _butterfly_stage_1d(x0, 256, 2)
137 x1 = _butterfly_stage_1d(x1, 256, 2)
138 x2 = _butterfly_stage_1d(x2, 256, 2)
139 x3 = _butterfly_stage_1d(x3, 256, 2)
140 x0 = _butterfly_stage_1d(x0, 256, 1)
141 x1 = _butterfly_stage_1d(x1, 256, 1)
142 x2 = _butterfly_stage_1d(x2, 256, 1)
143 x3 = _butterfly_stage_1d(x3, 256, 1)
145 x0 = x0 * SCALE
146 x1 = x1 * SCALE
147 x2 = x2 * SCALE
148 x3 = x3 * SCALE
150 tl.store(
151 OUT_ptr + row0 * stride_out_row + col_offs, x0, eviction_policy="evict_first"
152 )
153 tl.store(
154 OUT_ptr + row1 * stride_out_row + col_offs,
155 x1,
156 mask=row1 < N_ROWS,
157 eviction_policy="evict_first",
158 )
159 tl.store(
160 OUT_ptr + row2 * stride_out_row + col_offs,
161 x2,
162 mask=row2 < N_ROWS,
163 eviction_policy="evict_first",
164 )
165 tl.store(
166 OUT_ptr + row3 * stride_out_row + col_offs,
167 x3,
168 mask=row3 < N_ROWS,
169 eviction_policy="evict_first",
170 )
173# ============================================================
174# Fallback: single-row 1D native kernel for dim=256
175# ============================================================
178@triton.jit
179def _fht_kernel_256_1d_native(
180 X_ptr,
181 OUT_ptr,
182 stride_x_row,
183 stride_out_row,
184 SCALE: tl.constexpr,
185):
186 """FHT for dim=256, 1D native fp16/bf16, 8 hardcoded reversed butterfly stages."""
187 pid = tl.program_id(0)
188 col_offs = tl.arange(0, 256)
190 x = tl.load(X_ptr + pid * stride_x_row + col_offs)
192 # Reversed butterfly: stride 128, 64, 32, 16, 8, 4, 2, 1
193 x = _butterfly_stage_1d(x, 256, 128)
194 x = _butterfly_stage_1d(x, 256, 64)
195 x = _butterfly_stage_1d(x, 256, 32)
196 x = _butterfly_stage_1d(x, 256, 16)
197 x = _butterfly_stage_1d(x, 256, 8)
198 x = _butterfly_stage_1d(x, 256, 4)
199 x = _butterfly_stage_1d(x, 256, 2)
200 x = _butterfly_stage_1d(x, 256, 1)
202 x = x * SCALE
203 tl.store(
204 OUT_ptr + pid * stride_out_row + col_offs, x, eviction_policy="evict_first"
205 )
208# ============================================================
209# 1D hardcoded native kernel for dim=512 (9 stages)
210# Restored from v31/v35: single-row (best: 1.1193x in v35)
211# ============================================================
214@triton.jit
215def _fht_kernel_512_1d_native(
216 X_ptr,
217 OUT_ptr,
218 stride_x_row,
219 stride_out_row,
220 SCALE: tl.constexpr,
221):
222 """FHT for dim=512, 1D native fp16/bf16, 9 hardcoded reversed butterfly stages."""
223 pid = tl.program_id(0)
224 col_offs = tl.arange(0, 512)
226 x = tl.load(X_ptr + pid * stride_x_row + col_offs)
228 # Reversed butterfly: stride 256, 128, 64, 32, 16, 8, 4, 2, 1
229 x = _butterfly_stage_1d(x, 512, 256)
230 x = _butterfly_stage_1d(x, 512, 128)
231 x = _butterfly_stage_1d(x, 512, 64)
232 x = _butterfly_stage_1d(x, 512, 32)
233 x = _butterfly_stage_1d(x, 512, 16)
234 x = _butterfly_stage_1d(x, 512, 8)
235 x = _butterfly_stage_1d(x, 512, 4)
236 x = _butterfly_stage_1d(x, 512, 2)
237 x = _butterfly_stage_1d(x, 512, 1)
239 x = x * SCALE
240 tl.store(
241 OUT_ptr + pid * stride_out_row + col_offs, x, eviction_policy="evict_first"
242 )
245# ============================================================
246# Generic 1D native-dtype butterfly kernel (for other small dims)
247# ============================================================
250@triton.jit
251def _fht_kernel_1d_native(
252 X_ptr,
253 OUT_ptr,
254 stride_x_row,
255 stride_out_row,
256 DIM: tl.constexpr,
257 LOG_N: tl.constexpr,
258 BLOCK_SIZE: tl.constexpr,
259 SCALE: tl.constexpr,
260):
261 """FHT butterfly — 1D single-row, native fp16/bf16, reversed stage order."""
262 pid = tl.program_id(0)
263 col_offs = tl.arange(0, BLOCK_SIZE)
265 x = tl.load(X_ptr + pid * stride_x_row + col_offs)
267 # Reversed butterfly: stride N/2, N/4, ..., 2, 1
268 for s_rev in tl.static_range(LOG_N):
269 x = _butterfly_stage_1d(x, BLOCK_SIZE, 1 << (LOG_N - 1 - s_rev))
271 x = x * SCALE
272 tl.store(
273 OUT_ptr + pid * stride_out_row + col_offs, x, eviction_policy="evict_first"
274 )
277# ============================================================
278# 2D native-dtype butterfly kernel (for dim=1024 with fp16/bf16)
279# ============================================================
282@triton.jit
283def _fht_kernel_2d_native(
284 X_ptr,
285 OUT_ptr,
286 stride_x_row,
287 stride_out_row,
288 N_ROWS,
289 DIM: tl.constexpr,
290 LOG_N: tl.constexpr,
291 BLOCK_SIZE: tl.constexpr,
292 ROWS_PER_PROGRAM: tl.constexpr,
293 SCALE: tl.constexpr,
294):
295 """FHT butterfly — 2D batch, reversed stage order, native fp16/bf16 compute."""
296 pid = tl.program_id(0)
297 col_offs = tl.arange(0, BLOCK_SIZE)
298 row_offs = tl.arange(0, ROWS_PER_PROGRAM)
300 base_row = pid * ROWS_PER_PROGRAM
301 row_ids = base_row + row_offs
302 row_mask = row_ids < N_ROWS
304 in_ptrs = X_ptr + row_ids[:, None] * stride_x_row + col_offs[None, :]
305 out_ptrs = OUT_ptr + row_ids[:, None] * stride_out_row + col_offs[None, :]
306 load_mask = row_mask[:, None]
308 x = tl.load(in_ptrs, mask=load_mask, other=0.0)
310 # Reversed butterfly: stride N/2, N/4, ..., 2, 1
311 for s_rev in tl.static_range(LOG_N):
312 x = _butterfly_stage_2d(
313 x, ROWS_PER_PROGRAM, BLOCK_SIZE, 1 << (LOG_N - 1 - s_rev)
314 )
316 x = x * SCALE
317 tl.store(out_ptrs, x, mask=load_mask, eviction_policy="evict_first")
320# ============================================================
321# 1D butterfly kernel (fp32, for fp32 inputs)
322# ============================================================
325@triton.jit
326def _fht_kernel_1d(
327 X_ptr,
328 OUT_ptr,
329 scale,
330 stride_x_row,
331 stride_out_row,
332 DIM: tl.constexpr,
333 LOG_N: tl.constexpr,
334 BLOCK_SIZE: tl.constexpr,
335 INPUT_IS_FP16: tl.constexpr,
336 INPUT_IS_BF16: tl.constexpr,
337):
338 """FHT butterfly — 1D single-row kernel, reversed stage order."""
339 pid = tl.program_id(0)
340 col_offs = tl.arange(0, BLOCK_SIZE)
342 in_ptr = X_ptr + pid * stride_x_row + col_offs
343 out_ptr = OUT_ptr + pid * stride_out_row + col_offs
345 x = tl.load(in_ptr).to(tl.float32)
347 for s_rev in tl.static_range(LOG_N):
348 x = _butterfly_stage_1d(x, BLOCK_SIZE, 1 << (LOG_N - 1 - s_rev))
350 x = x * scale
352 if INPUT_IS_FP16:
353 tl.store(out_ptr, x.to(tl.float16), eviction_policy="evict_first")
354 elif INPUT_IS_BF16:
355 tl.store(out_ptr, x.to(tl.bfloat16), eviction_policy="evict_first")
356 else:
357 tl.store(out_ptr, x, eviction_policy="evict_first")
360# ============================================================
361# 2D butterfly kernel (fp32, for dim>=1024 and fp32 inputs)
362# ============================================================
365@triton.jit
366def _fht_kernel_2d(
367 X_ptr,
368 OUT_ptr,
369 scale,
370 stride_x_row,
371 stride_out_row,
372 N_ROWS,
373 DIM: tl.constexpr,
374 LOG_N: tl.constexpr,
375 BLOCK_SIZE: tl.constexpr,
376 ROWS_PER_PROGRAM: tl.constexpr,
377 INPUT_IS_FP16: tl.constexpr,
378 INPUT_IS_BF16: tl.constexpr,
379):
380 """FHT butterfly — 2D batch, reversed stage order, fp32 compute."""
381 pid = tl.program_id(0)
382 col_offs = tl.arange(0, BLOCK_SIZE)
383 row_offs = tl.arange(0, ROWS_PER_PROGRAM)
385 base_row = pid * ROWS_PER_PROGRAM
386 row_ids = base_row + row_offs
387 row_mask = row_ids < N_ROWS
389 in_ptrs = X_ptr + row_ids[:, None] * stride_x_row + col_offs[None, :]
390 out_ptrs = OUT_ptr + row_ids[:, None] * stride_out_row + col_offs[None, :]
391 load_mask = row_mask[:, None]
393 x = tl.load(in_ptrs, mask=load_mask, other=0.0).to(tl.float32)
395 for s_rev in tl.static_range(LOG_N):
396 x = _butterfly_stage_2d(
397 x, ROWS_PER_PROGRAM, BLOCK_SIZE, 1 << (LOG_N - 1 - s_rev)
398 )
400 x = x * scale
402 if INPUT_IS_FP16:
403 tl.store(
404 out_ptrs, x.to(tl.float16), mask=load_mask, eviction_policy="evict_first"
405 )
406 elif INPUT_IS_BF16:
407 tl.store(
408 out_ptrs, x.to(tl.bfloat16), mask=load_mask, eviction_policy="evict_first"
409 )
410 else:
411 tl.store(out_ptrs, x, mask=load_mask, eviction_policy="evict_first")
414# ============================================================
415# Precomputed lookup tables for fast dispatch
416# ============================================================
418# Power-of-2 dims that are multiples of 8, up to 65536
419_POW2_DIMS = frozenset(1 << k for k in range(3, 17)) # 8, 16, ..., 65536
422# ============================================================
423# Core forward
424# ============================================================
427def _hadamard_transform_fwd(x: torch.Tensor, scale: float) -> torch.Tensor:
428 """Core forward: handles reshape, padding, kernel launch."""
429 shapes_og = x.shape
430 dim_og = x.shape[-1]
431 input_dtype = x.dtype
432 x_flat = x.reshape(-1, dim_og)
433 if x_flat.stride(-1) != 1:
434 x_flat = x_flat.contiguous()
435 batch_size = x_flat.shape[0]
437 # Fast path for power-of-2 dims (no padding needed)
438 if dim_og in _POW2_DIMS:
439 n = dim_og
440 log_n = n.bit_length() - 1
441 # Allocate output directly with explicit args (faster than empty_like)
442 out = torch.empty(batch_size, n, dtype=input_dtype, device=x_flat.device)
443 stride_x = x_flat.stride(0)
444 stride_out = n # out is freshly allocated, always contiguous
446 _launch_kernel(
447 x_flat, out, scale, input_dtype, batch_size, n, log_n, stride_x, stride_out
448 )
450 return out.reshape(shapes_og)
452 # General path: handle padding
453 assert input_dtype in (
454 torch.float32,
455 torch.float16,
456 torch.bfloat16,
457 ), f"hadamard_transform not implemented for input type '{input_dtype}'"
458 assert x.is_cuda, "hadamard_transform requires CUDA tensor"
460 # Pad to multiple of 8 (matching CUDA implementation)
461 needs_pad = dim_og % 8 != 0
462 if needs_pad:
463 x_flat = F.pad(x_flat, (0, 8 - dim_og % 8))
464 dim = x_flat.shape[1]
466 assert (
467 dim % 8 == 0
468 ), "fast_hadamard_transform only supports hidden dimension divisible by 8 for now"
469 assert (
470 dim <= 65536
471 ), "fast_hadamard_transform only supports hidden dimension at most 65536 for now"
473 # For butterfly we need next power of 2
474 log_n = math.ceil(math.log2(dim)) if dim > 1 else 1
475 n = 1 << log_n
477 # If dim (multiple of 8) is not a power of 2, pad further for the kernel
478 if n != dim:
479 x_flat = F.pad(x_flat, (0, n - dim))
481 out = torch.empty(batch_size, n, dtype=input_dtype, device=x_flat.device)
482 stride_x = x_flat.stride(0)
483 stride_out = n
485 _launch_kernel(
486 x_flat, out, scale, input_dtype, batch_size, n, log_n, stride_x, stride_out
487 )
489 # Trim padding back to original dim
490 if n != dim_og:
491 out = out[:, :dim_og]
492 return out.reshape(shapes_og)
495def _launch_kernel(
496 x, out, scale, input_dtype, batch_size, n, log_n, stride_x, stride_out
497):
498 """Dispatch to the appropriate kernel. Separated for fast-path sharing."""
499 # Dispatch strategy (v43):
500 # - dim=256, fp16/bf16: 4-row ILP native (warps=2) — test 2-warp occupancy
501 # - dim=512, fp16/bf16: 1D single-row native (warps=1) — v35 best
502 # - other dim<=128, fp16/bf16: generic 1D native
503 # - dim=1024, fp16/bf16: 2D native batched (rows=2, warps=4)
504 # - dim<=512, fp32: fp32 1D kernel
505 # - dim>=1024, fp32 or dim>=2048: fp32 2D kernel
506 if n <= 1024 and input_dtype in (torch.float16, torch.bfloat16):
507 if n == 256:
508 if batch_size >= 4:
509 _fht_kernel_256_4row_native[((batch_size + 3) // 4,)](
510 x,
511 out,
512 stride_x_row=stride_x,
513 stride_out_row=stride_out,
514 N_ROWS=batch_size,
515 SCALE=scale,
516 num_warps=2,
517 num_stages=1,
518 )
519 else:
520 _fht_kernel_256_1d_native[(batch_size,)](
521 x,
522 out,
523 stride_x_row=stride_x,
524 stride_out_row=stride_out,
525 SCALE=scale,
526 num_warps=2,
527 num_stages=1,
528 )
529 elif n == 512:
530 # Single-row 1D hardcoded: v35 achieved 1.1193x (best)
531 _fht_kernel_512_1d_native[(batch_size,)](
532 x,
533 out,
534 stride_x_row=stride_x,
535 stride_out_row=stride_out,
536 SCALE=scale,
537 num_warps=1,
538 num_stages=1,
539 )
540 elif n <= 128:
541 _fht_kernel_1d_native[(batch_size,)](
542 x,
543 out,
544 stride_x_row=stride_x,
545 stride_out_row=stride_out,
546 DIM=n,
547 LOG_N=log_n,
548 BLOCK_SIZE=n,
549 SCALE=scale,
550 num_warps=1,
551 num_stages=1,
552 )
553 else:
554 # dim=1024: 2D native with 2 rows/program
555 rows_per_program = 2
556 n_programs = (batch_size + rows_per_program - 1) // rows_per_program
557 _fht_kernel_2d_native[(n_programs,)](
558 x,
559 out,
560 stride_x_row=stride_x,
561 stride_out_row=stride_out,
562 N_ROWS=batch_size,
563 DIM=n,
564 LOG_N=log_n,
565 BLOCK_SIZE=n,
566 ROWS_PER_PROGRAM=rows_per_program,
567 SCALE=scale,
568 num_warps=4,
569 num_stages=1,
570 )
571 elif n <= 512:
572 # fp32 1D kernel
573 _fht_kernel_1d[(batch_size,)](
574 x,
575 out,
576 scale,
577 stride_x_row=stride_x,
578 stride_out_row=stride_out,
579 DIM=n,
580 LOG_N=log_n,
581 BLOCK_SIZE=n,
582 INPUT_IS_FP16=(input_dtype == torch.float16),
583 INPUT_IS_BF16=(input_dtype == torch.bfloat16),
584 num_warps=1,
585 num_stages=1,
586 )
587 else:
588 # fp32 2D butterfly for fp32 inputs and large dims
589 if n <= 32:
590 num_warps = 1
591 rows_per_program = 64
592 elif n <= 64:
593 num_warps = 1
594 rows_per_program = 64
595 elif n <= 128:
596 num_warps = 1
597 rows_per_program = 32
598 elif n <= 256:
599 num_warps = 1
600 rows_per_program = 16
601 elif n <= 1024:
602 num_warps = 4
603 rows_per_program = 2
604 elif n <= 4096:
605 num_warps = 4
606 rows_per_program = 1
607 else:
608 num_warps = 8
609 rows_per_program = 1
611 n_programs = (batch_size + rows_per_program - 1) // rows_per_program
612 _fht_kernel_2d[(n_programs,)](
613 x,
614 out,
615 scale,
616 stride_x_row=stride_x,
617 stride_out_row=stride_out,
618 N_ROWS=batch_size,
619 DIM=n,
620 LOG_N=log_n,
621 BLOCK_SIZE=n,
622 ROWS_PER_PROGRAM=rows_per_program,
623 INPUT_IS_FP16=(input_dtype == torch.float16),
624 INPUT_IS_BF16=(input_dtype == torch.bfloat16),
625 num_warps=num_warps,
626 num_stages=1,
627 )
630# ============================================================
631# Autograd Function
632# ============================================================
635class HadamardTransformFn(torch.autograd.Function):
636 @staticmethod
637 def forward(ctx, x, scale=1.0):
638 ctx._hadamard_transform_scale = scale
639 return _hadamard_transform_fwd(x, scale)
641 @staticmethod
642 def backward(ctx, dout):
643 # Hadamard matrix is symmetric: backward = forward with same scale
644 return _hadamard_transform_fwd(dout, ctx._hadamard_transform_scale), None
647# ============================================================
648# Public API
649# ============================================================
652def hadamard_transform(x, scale=1.0):
653 """
654 Arguments:
655 x: (..., dim)
656 scale: float. Multiply the output by this number.
657 Returns:
658 out: (..., dim)
660 Multiply each row of x by the Hadamard transform matrix.
661 Equivalent to F.linear(x, torch.tensor(scipy.linalg.hadamard(dim))) * scale.
662 If dim is not a power of 2, we implicitly pad x with zero so that dim is
663 the next power of 2.
664 """
665 return HadamardTransformFn.apply(x, scale)
668# ============================================================
669# XXN variants (non-power-of-2 dims)
670#
671# Dao-AILab decomposes dim = M * 2^k, applying a small M×M
672# Hadamard-like matrix then a standard 2^k FHT.
673# For now these use the standard FHT with implicit zero-padding
674# to the next power of 2, which is correct but not optimal.
675# TODO: implement proper M×N decomposition for better efficiency.
676# ============================================================
679def hadamard_transform_12N(x, scale=1.0):
680 """Hadamard transform for dim = 12 * 2^k (e.g. 12*512 = 6144)."""
681 return HadamardTransformFn.apply(x, scale)
684def hadamard_transform_20N(x, scale=1.0):
685 """Hadamard transform for dim = 20 * 2^k (e.g. 20*1024 = 20480)."""
686 return HadamardTransformFn.apply(x, scale)
689def hadamard_transform_28N(x, scale=1.0):
690 """Hadamard transform for dim = 28 * 2^k (e.g. 28*1024 = 28672)."""
691 return HadamardTransformFn.apply(x, scale)
694def hadamard_transform_40N(x, scale=1.0):
695 """Hadamard transform for dim = 40 * 2^k (e.g. 40*1024 = 40960)."""
696 return HadamardTransformFn.apply(x, scale)