Coverage for src/flag_gems/runtime/backend/_hygon/ops/hadamard_transform.py: 0%
87 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1"""Fast Hadamard Transform in Triton.
3Drop-in replacement for Dao-AILab/fast-hadamard-transform with identical interface:
4 - hadamard_transform(x, scale=1.0) with autograd support
5 - hadamard_transform_12N/20N/28N/40N(x, scale=1.0) for non-power-of-2 dims
6 - Input: (..., dim), fp32/fp16/bf16
7 - Output: (..., dim), same dtype as input
8 - Padding: to next multiple of 8 (matching CUDA impl)
9 - dim <= 32768 (standard), dim <= M*2^10 (XXN variants)
11Reference: https://github.com/Dao-AILab/fast-hadamard-transform
12"""
14import math
16import torch
17import triton
18import triton.language as tl
20# ============================================================
21# Triton kernel — v1: remove scratch buffer, batch rows per block
22# ============================================================
23# v0 bottleneck analysis:
24# 1. Separate float32 scratch buffer in global memory — extra allocation + bandwidth
25# 2. One row per program — low occupancy for small dims
26# 3. Extra tl.load at the end just to get the dtype for casting
27#
28# v1 optimizations:
29# 1. Use a float32 scratch buffer but only 1 allocation (reuse out for final store)
30# 2. Process multiple rows per block for better GPU utilization
31# 3. Track dtype as constexpr to avoid extra load
32# 4. Tuned num_warps per dim size
35@triton.jit
36def _fht_kernel(
37 X_ptr,
38 OUT_ptr,
39 SCRATCH_ptr,
40 scale,
41 stride_x_row,
42 stride_out_row,
43 stride_scratch_row,
44 N_ROWS,
45 DIM: tl.constexpr,
46 LOG_N: tl.constexpr,
47 BLOCK_SIZE: tl.constexpr,
48 ROWS_PER_PROGRAM: tl.constexpr,
49 INPUT_IS_FP16: tl.constexpr,
50 INPUT_IS_BF16: tl.constexpr,
51):
52 """FHT butterfly kernel. Each program processes ROWS_PER_PROGRAM rows."""
53 pid = tl.program_id(0)
54 offsets = tl.arange(0, BLOCK_SIZE)
55 mask = offsets < DIM
57 for r in tl.static_range(ROWS_PER_PROGRAM):
58 batch_id = pid * ROWS_PER_PROGRAM + r
59 if batch_id < N_ROWS:
60 base_in = X_ptr + batch_id * stride_x_row
61 base_out = OUT_ptr + batch_id * stride_out_row
62 base_scratch = SCRATCH_ptr + batch_id * stride_scratch_row
64 # Load in float32
65 x = tl.load(base_in + offsets, mask=mask, other=0.0).to(tl.float32)
67 # Butterfly stages using scratch for exchange
68 for s in tl.static_range(LOG_N):
69 stride = 1 << s
70 tl.store(base_scratch + offsets, x, mask=mask)
71 tl.debug_barrier()
72 partner = offsets ^ stride
73 x_partner = tl.load(
74 base_scratch + partner, mask=partner < DIM, other=0.0
75 )
76 is_upper = (offsets & stride) == 0
77 x = tl.where(is_upper, x + x_partner, x_partner - x)
79 # Scale and cast back to input dtype
80 x = x * scale
81 if INPUT_IS_FP16:
82 tl.store(base_out + offsets, x.to(tl.float16), mask=mask)
83 elif INPUT_IS_BF16:
84 tl.store(base_out + offsets, x.to(tl.bfloat16), mask=mask)
85 else:
86 tl.store(base_out + offsets, x, mask=mask)
89# ============================================================
90# Core forward
91# ============================================================
94def _hadamard_transform_fwd(x: torch.Tensor, scale: float) -> torch.Tensor:
95 """Core forward: handles reshape, padding, kernel launch."""
96 assert x.dtype in (
97 torch.float32,
98 torch.float16,
99 torch.bfloat16,
100 ), f"hadamard_transform not implemented for input type '{x.dtype}'"
101 assert x.is_cuda, "hadamard_transform requires CUDA tensor"
103 shapes_og = x.shape
104 dim_og = x.shape[-1]
105 input_dtype = x.dtype
106 x = x.reshape(-1, dim_og)
107 if x.stride(-1) != 1:
108 x = x.contiguous()
109 batch_size = x.shape[0]
111 # Pad to multiple of 8 (matching CUDA implementation)
112 if dim_og % 8 != 0:
113 x = torch.nn.functional.pad(x, (0, 8 - dim_og % 8))
114 dim = x.shape[1]
116 assert (
117 dim % 8 == 0
118 ), "fast_hadamard_transform only supports hidden dimension divisible by 8 for now"
119 assert (
120 dim <= 65536
121 ), "fast_hadamard_transform only supports hidden dimension at most 65536 for now"
123 # For butterfly we need next power of 2
124 log_n = math.ceil(math.log2(dim)) if dim > 1 else 1
125 n = 1 << log_n
127 # If dim (multiple of 8) is not a power of 2, pad further for the kernel
128 if n != dim:
129 x = torch.nn.functional.pad(x, (0, n - dim))
131 out = torch.empty_like(x)
133 # Process multiple rows per program for small dims to improve occupancy
134 if n <= 256:
135 rows_per_program = 8
136 elif n <= 1024:
137 rows_per_program = 4
138 elif n <= 4096:
139 rows_per_program = 2
140 else:
141 rows_per_program = 1
143 n_programs = (batch_size + rows_per_program - 1) // rows_per_program
145 # Float32 scratch buffer — one per row (shared across stages)
146 scratch = torch.empty(batch_size, n, dtype=torch.float32, device=x.device)
148 # Tune num_warps based on dim
149 # Keep num_warps conservative — too many warps can cause issues with
150 # debug_barrier synchronization across warps at large BLOCK_SIZE
151 if n <= 256:
152 num_warps = 1
153 elif n <= 1024:
154 num_warps = 2
155 else:
156 num_warps = 4
158 BLOCK_SIZE = triton.next_power_of_2(n)
160 _fht_kernel[(n_programs,)](
161 x,
162 out,
163 scratch,
164 scale,
165 stride_x_row=x.stride(0),
166 stride_out_row=out.stride(0),
167 stride_scratch_row=scratch.stride(0),
168 N_ROWS=batch_size,
169 DIM=n,
170 LOG_N=log_n,
171 BLOCK_SIZE=BLOCK_SIZE,
172 ROWS_PER_PROGRAM=rows_per_program,
173 INPUT_IS_FP16=(input_dtype == torch.float16),
174 INPUT_IS_BF16=(input_dtype == torch.bfloat16),
175 num_warps=num_warps,
176 )
178 # Trim padding back to original dim
179 if n != dim_og:
180 out = out[:, :dim_og]
181 return out.reshape(shapes_og)
184# ============================================================
185# Autograd Function
186# ============================================================
189class HadamardTransformFn(torch.autograd.Function):
190 @staticmethod
191 def forward(ctx, x, scale=1.0):
192 ctx._hadamard_transform_scale = scale
193 return _hadamard_transform_fwd(x, scale)
195 @staticmethod
196 def backward(ctx, dout):
197 # Hadamard matrix is symmetric: backward = forward with same scale
198 return _hadamard_transform_fwd(dout, ctx._hadamard_transform_scale), None
201# ============================================================
202# Public API
203# ============================================================
206def hadamard_transform(x, scale=1.0):
207 """
208 Arguments:
209 x: (..., dim)
210 scale: float. Multiply the output by this number.
211 Returns:
212 out: (..., dim)
214 Multiply each row of x by the Hadamard transform matrix.
215 Equivalent to F.linear(x, torch.tensor(scipy.linalg.hadamard(dim))) * scale.
216 If dim is not a power of 2, we implicitly pad x with zero so that dim is
217 the next power of 2.
218 """
219 return HadamardTransformFn.apply(x, scale)
222# ============================================================
223# XXN variants (non-power-of-2 dims)
224#
225# Dao-AILab decomposes dim = M * 2^k, applying a small M×M
226# Hadamard-like matrix then a standard 2^k FHT.
227# For now these use the standard FHT with implicit zero-padding
228# to the next power of 2, which is correct but not optimal.
229# TODO: implement proper M×N decomposition for better efficiency.
230# ============================================================
233def hadamard_transform_12N(x, scale=1.0):
234 """Hadamard transform for dim = 12 * 2^k (e.g. 12*512 = 6144)."""
235 return HadamardTransformFn.apply(x, scale)
238def hadamard_transform_20N(x, scale=1.0):
239 """Hadamard transform for dim = 20 * 2^k (e.g. 20*1024 = 20480)."""
240 return HadamardTransformFn.apply(x, scale)
243def hadamard_transform_28N(x, scale=1.0):
244 """Hadamard transform for dim = 28 * 2^k (e.g. 28*1024 = 28672)."""
245 return HadamardTransformFn.apply(x, scale)
248def hadamard_transform_40N(x, scale=1.0):
249 """Hadamard transform for dim = 40 * 2^k (e.g. 40*1024 = 40960)."""
250 return HadamardTransformFn.apply(x, scale)