Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/hadamard_transform.py: 0%
74 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1"""Fast Hadamard Transform in Triton (KunlunXin).
3v0: Multi-pass butterfly via global memory, 1 kernel launch per stage.
4Simple baseline for correctness. Each butterfly stage reads from IN, writes to OUT.
5"""
7import math
9import torch
10import torch.nn.functional as F
11import triton
12import triton.language as tl
14MAX_GRID = 65535
17# ============================================================
18# Single butterfly stage kernel
19# ============================================================
22@triton.jit
23def _butterfly_stage(
24 IN_ptr,
25 OUT_ptr,
26 stride_row,
27 N_ROWS,
28 ROWS_PER_PROGRAM: tl.constexpr,
29 STRIDE_S: tl.constexpr,
30 DIM: tl.constexpr,
31):
32 """One butterfly stage: read from IN, write to OUT.
34 For each element at position i:
35 partner = i ^ STRIDE_S
36 if (i & STRIDE_S) == 0: out[i] = in[i] + in[partner]
37 else: out[i] = in[partner] - in[i]
38 """
39 pid = tl.program_id(0)
40 offsets = tl.arange(0, DIM)
42 for row_idx in tl.static_range(ROWS_PER_PROGRAM):
43 row_id = pid * ROWS_PER_PROGRAM + row_idx
44 if row_id < N_ROWS:
45 base = row_id * stride_row
47 x = tl.load(IN_ptr + base + offsets).to(tl.float32)
48 partner_offsets = offsets ^ STRIDE_S
49 x_partner = tl.load(IN_ptr + base + partner_offsets).to(tl.float32)
51 is_upper = (offsets & STRIDE_S) == 0
52 result = tl.where(is_upper, x + x_partner, x_partner - x)
54 tl.store(OUT_ptr + base + offsets, result)
57# ============================================================
58# Scale + cast kernel
59# ============================================================
62@triton.jit
63def _scale_cast(
64 IN_ptr,
65 OUT_ptr,
66 stride_in_row,
67 stride_out_row,
68 scale,
69 N_ROWS,
70 ROWS_PER_PROGRAM: tl.constexpr,
71 DIM: tl.constexpr,
72):
73 """Scale fp32 buffer and cast to output dtype."""
74 pid = tl.program_id(0)
75 offsets = tl.arange(0, DIM)
77 for row_idx in tl.static_range(ROWS_PER_PROGRAM):
78 row_id = pid * ROWS_PER_PROGRAM + row_idx
79 if row_id < N_ROWS:
80 x = tl.load(IN_ptr + row_id * stride_in_row + offsets)
81 tl.store(OUT_ptr + row_id * stride_out_row + offsets, x * scale)
84# ============================================================
85# Forward implementation
86# ============================================================
89def _hadamard_transform_fwd(x, scale):
90 orig_shape = x.shape
91 dim = x.shape[-1]
92 input_dtype = x.dtype
94 # Pad to next power of 2
95 log_dim = math.ceil(math.log2(dim)) if dim > 0 else 0
96 dim_padded = 1 << log_dim
97 if dim != dim_padded:
98 x = F.pad(x, (0, dim_padded - dim))
100 x_flat = x.reshape(-1, dim_padded).contiguous()
101 n_rows = x_flat.shape[0]
102 n_stages = log_dim # log2(dim_padded)
104 # Determine ROWS_PER_PROGRAM to stay within grid limit
105 rows_per_prog = 1
106 while (n_rows + rows_per_prog - 1) // rows_per_prog > MAX_GRID:
107 rows_per_prog *= 2
108 grid_size = (n_rows + rows_per_prog - 1) // rows_per_prog
110 # Allocate two fp32 scratch buffers for ping-pong
111 # .clone() is critical: for fp32 input, .float() is a no-op returning
112 # the same tensor, which would cause butterfly stages to overwrite the input
113 buf_a = x_flat.float().clone()
114 buf_b = torch.empty_like(buf_a)
116 stride_row = dim_padded
118 # Run butterfly stages
119 for s in range(n_stages):
120 stride_s = 1 << s
121 _butterfly_stage[(grid_size,)](
122 buf_a,
123 buf_b,
124 stride_row,
125 n_rows,
126 ROWS_PER_PROGRAM=rows_per_prog,
127 STRIDE_S=stride_s,
128 DIM=dim_padded,
129 )
130 buf_a, buf_b = buf_b, buf_a
132 # Result is in buf_a; scale and cast back
133 out = torch.empty(n_rows, dim_padded, dtype=input_dtype, device=x.device)
134 _scale_cast[(grid_size,)](
135 buf_a,
136 out,
137 stride_row,
138 dim_padded,
139 scale,
140 n_rows,
141 ROWS_PER_PROGRAM=rows_per_prog,
142 DIM=dim_padded,
143 )
145 if dim != dim_padded:
146 out = out[:, :dim]
147 return out.reshape(orig_shape)
150# ============================================================
151# Autograd wrapper
152# ============================================================
155class HadamardTransformFn(torch.autograd.Function):
156 @staticmethod
157 def forward(ctx, x, scale):
158 ctx._hadamard_transform_scale = scale
159 return _hadamard_transform_fwd(x, scale)
161 @staticmethod
162 def backward(ctx, grad_output):
163 # Hadamard matrix is symmetric: backward = forward with same scale
164 return (
165 _hadamard_transform_fwd(
166 grad_output.contiguous(), ctx._hadamard_transform_scale
167 ),
168 None,
169 )
172# ============================================================
173# Public API
174# ============================================================
177def hadamard_transform(x, scale=1.0):
178 """Fast Hadamard Transform.
180 Arguments:
181 x: (..., dim)
182 scale: float. Multiply the output by this number.
183 Returns:
184 out: (..., dim)
186 Multiply each row of x by the Hadamard transform matrix.
187 Equivalent to F.linear(x, torch.tensor(scipy.linalg.hadamard(dim))) * scale.
188 If dim is not a power of 2, we implicitly pad x with zero so that dim is
189 the next power of 2.
190 """
191 return HadamardTransformFn.apply(x, scale)
194# ============================================================
195# XXN variants (non-power-of-2 dims)
196# ============================================================
199def hadamard_transform_12N(x, scale=1.0):
200 """Hadamard transform for dim = 12 * 2^k (e.g. 12*512 = 6144)."""
201 return HadamardTransformFn.apply(x, scale)
204def hadamard_transform_20N(x, scale=1.0):
205 """Hadamard transform for dim = 20 * 2^k (e.g. 20*1024 = 20480)."""
206 return HadamardTransformFn.apply(x, scale)
209def hadamard_transform_28N(x, scale=1.0):
210 """Hadamard transform for dim = 28 * 2^k (e.g. 28*1024 = 28672)."""
211 return HadamardTransformFn.apply(x, scale)
214def hadamard_transform_40N(x, scale=1.0):
215 """Hadamard transform for dim = 40 * 2^k (e.g. 40*1024 = 40960)."""
216 return HadamardTransformFn.apply(x, scale)