Coverage for src/flag_gems/runtime/backend/_ascend/ops/hadamard_transform.py: 0%
123 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1"""Fast Hadamard Transform in Triton (Ascend NPU).
3v1: Single-kernel fused butterfly with chained buffers + fused scale/cast.
4All 7 butterfly stages + scale + dtype cast in one kernel launch.
5Uses unique buffer for each stage to avoid NPU stale-read issues.
6Eliminates 7 kernel launch overheads and the separate scale/cast kernel from v0.
7"""
9import math
11import torch
12import torch.nn.functional as F
13import triton
14import triton.language as tl
16MAX_GRID = 65535
19# ============================================================
20# Fused 7-stage butterfly kernel (dim=128 specialized)
21# Uses 6 scratch buffer segments (B0..B5) in a contiguous allocation.
22# Chain: IN -> B0 -> B1 -> B2 -> B3 -> B4 -> B5 -> OUT
23# ============================================================
26@triton.jit
27def _fht_fused_7stage(
28 IN_ptr,
29 SCRATCH_ptr,
30 OUT_ptr,
31 stride_row,
32 stride_out_row,
33 seg_stride,
34 scale,
35 N_ROWS,
36 ROWS_PER_PROGRAM: tl.constexpr,
37 DIM: tl.constexpr,
38 OUTPUT_BF16: tl.constexpr,
39 OUTPUT_FP16: tl.constexpr,
40):
41 """Fused FHT for dim=128 (7 butterfly stages) + scale + cast.
43 SCRATCH_ptr points to a contiguous (6, batch, DIM) fp32 buffer.
44 seg_stride = batch * DIM (distance between scratch segments).
45 Chain: IN -> seg0 -> seg1 -> seg2 -> seg3 -> seg4 -> seg5 -> OUT
46 """
47 pid = tl.program_id(0)
48 offsets = tl.arange(0, DIM)
50 for row_idx in tl.static_range(ROWS_PER_PROGRAM):
51 row_id = pid * ROWS_PER_PROGRAM + row_idx
52 if row_id < N_ROWS:
53 in_base = row_id * stride_row
54 row_off = row_id * DIM # offset within each scratch segment
56 # Stage 0: IN -> B0 (stride=1)
57 x = tl.load(IN_ptr + in_base + offsets)
58 p = tl.load(IN_ptr + in_base + (offsets ^ 1))
59 r = tl.where((offsets & 1) == 0, x + p, p - x)
60 tl.store(SCRATCH_ptr + row_off + offsets, r)
62 # Stage 1: B0 -> B1 (stride=2)
63 b0_off = row_off
64 b1_off = seg_stride + row_off
65 x = tl.load(SCRATCH_ptr + b0_off + offsets)
66 p = tl.load(SCRATCH_ptr + b0_off + (offsets ^ 2))
67 r = tl.where((offsets & 2) == 0, x + p, p - x)
68 tl.store(SCRATCH_ptr + b1_off + offsets, r)
70 # Stage 2: B1 -> B2 (stride=4)
71 b2_off = 2 * seg_stride + row_off
72 x = tl.load(SCRATCH_ptr + b1_off + offsets)
73 p = tl.load(SCRATCH_ptr + b1_off + (offsets ^ 4))
74 r = tl.where((offsets & 4) == 0, x + p, p - x)
75 tl.store(SCRATCH_ptr + b2_off + offsets, r)
77 # Stage 3: B2 -> B3 (stride=8)
78 b3_off = 3 * seg_stride + row_off
79 x = tl.load(SCRATCH_ptr + b2_off + offsets)
80 p = tl.load(SCRATCH_ptr + b2_off + (offsets ^ 8))
81 r = tl.where((offsets & 8) == 0, x + p, p - x)
82 tl.store(SCRATCH_ptr + b3_off + offsets, r)
84 # Stage 4: B3 -> B4 (stride=16)
85 b4_off = 4 * seg_stride + row_off
86 x = tl.load(SCRATCH_ptr + b3_off + offsets)
87 p = tl.load(SCRATCH_ptr + b3_off + (offsets ^ 16))
88 r = tl.where((offsets & 16) == 0, x + p, p - x)
89 tl.store(SCRATCH_ptr + b4_off + offsets, r)
91 # Stage 5: B4 -> B5 (stride=32)
92 b5_off = 5 * seg_stride + row_off
93 x = tl.load(SCRATCH_ptr + b4_off + offsets)
94 p = tl.load(SCRATCH_ptr + b4_off + (offsets ^ 32))
95 r = tl.where((offsets & 32) == 0, x + p, p - x)
96 tl.store(SCRATCH_ptr + b5_off + offsets, r)
98 # Stage 6: B5 -> OUT (stride=64) + fused scale + cast
99 x = tl.load(SCRATCH_ptr + b5_off + offsets)
100 p = tl.load(SCRATCH_ptr + b5_off + (offsets ^ 64))
101 r = tl.where((offsets & 64) == 0, x + p, p - x)
103 r = r * scale
104 out_base = row_id * stride_out_row
105 if OUTPUT_BF16:
106 tl.store(OUT_ptr + out_base + offsets, r.to(tl.bfloat16))
107 elif OUTPUT_FP16:
108 tl.store(OUT_ptr + out_base + offsets, r.to(tl.float16))
109 else:
110 tl.store(OUT_ptr + out_base + offsets, r)
113# ============================================================
114# Generic fused butterfly kernel (any power-of-2 dim)
115# ============================================================
118@triton.jit
119def _fht_fused_generic(
120 IN_ptr,
121 SCRATCH_ptr,
122 OUT_ptr,
123 stride_row,
124 stride_out_row,
125 seg_stride,
126 scale,
127 N_ROWS,
128 ROWS_PER_PROGRAM: tl.constexpr,
129 DIM: tl.constexpr,
130 LOG_N: tl.constexpr,
131 OUTPUT_BF16: tl.constexpr,
132 OUTPUT_FP16: tl.constexpr,
133):
134 """Generic fused FHT for any power-of-2 dim.
136 Uses chained scratch buffer segments. Each stage reads from one
137 segment and writes to the next, avoiding NPU stale-read issues.
138 """
139 pid = tl.program_id(0)
140 offsets = tl.arange(0, DIM)
142 for row_idx in tl.static_range(ROWS_PER_PROGRAM):
143 row_id = pid * ROWS_PER_PROGRAM + row_idx
144 if row_id < N_ROWS:
145 in_base = row_id * stride_row
146 row_off = row_id * DIM
148 for s in tl.static_range(LOG_N):
149 stride_s: tl.constexpr = 1 << s
150 is_upper = (offsets & stride_s) == 0
152 if s == 0:
153 # Read from input
154 x = tl.load(IN_ptr + in_base + offsets)
155 p = tl.load(IN_ptr + in_base + (offsets ^ stride_s))
156 else:
157 src_off = (s - 1) * seg_stride + row_off
158 x = tl.load(SCRATCH_ptr + src_off + offsets)
159 p = tl.load(SCRATCH_ptr + src_off + (offsets ^ stride_s))
161 r = tl.where(is_upper, x + p, p - x)
163 if s == LOG_N - 1:
164 r = r * scale
165 out_base = row_id * stride_out_row
166 if OUTPUT_BF16:
167 tl.store(OUT_ptr + out_base + offsets, r.to(tl.bfloat16))
168 elif OUTPUT_FP16:
169 tl.store(OUT_ptr + out_base + offsets, r.to(tl.float16))
170 else:
171 tl.store(OUT_ptr + out_base + offsets, r)
172 else:
173 dst_off = s * seg_stride + row_off
174 tl.store(SCRATCH_ptr + dst_off + offsets, r)
177# ============================================================
178# Core forward
179# ============================================================
182def _hadamard_transform_fwd(x: torch.Tensor, scale: float) -> torch.Tensor:
183 """Core forward: handles reshape, padding, kernel launch."""
184 assert x.dtype in (
185 torch.float32,
186 torch.float16,
187 torch.bfloat16,
188 ), f"Unsupported dtype {x.dtype}"
190 orig_shape = x.shape
191 dim = orig_shape[-1]
192 input_dtype = x.dtype
193 x_flat = x.reshape(-1, dim)
194 batch = x_flat.shape[0]
196 # Pad dim to next power of 2
197 log_n = math.ceil(math.log2(max(dim, 2)))
198 dim_padded = 1 << log_n
199 if dim != dim_padded:
200 x_flat = F.pad(x_flat, (0, dim_padded - dim))
202 # Input buffer in fp32
203 inp_fp32 = x_flat.float()
205 # Scratch buffer: (log_n - 1) segments of (batch, dim_padded) in fp32
206 # Stage s writes to segment s (0..log_n-2), last stage writes to output
207 n_scratch = max(log_n - 1, 1)
208 scratch = torch.empty(
209 n_scratch, batch, dim_padded, dtype=torch.float32, device=x.device
210 )
211 seg_stride = batch * dim_padded
213 # Grid calculation
214 rows_per_program = max((batch + MAX_GRID - 1) // MAX_GRID, 1)
215 grid_size = (batch + rows_per_program - 1) // rows_per_program
217 stride_row = dim_padded # contiguous
219 # Output buffer
220 out = torch.empty(batch, dim_padded, dtype=input_dtype, device=x.device)
222 output_bf16 = input_dtype == torch.bfloat16
223 output_fp16 = input_dtype == torch.float16
225 # Use specialized 7-stage kernel for dim=128, generic for others
226 if log_n == 7:
227 _fht_fused_7stage[(grid_size,)](
228 inp_fp32,
229 scratch,
230 out,
231 stride_row,
232 dim_padded,
233 seg_stride,
234 scale,
235 N_ROWS=batch,
236 ROWS_PER_PROGRAM=rows_per_program,
237 DIM=dim_padded,
238 OUTPUT_BF16=output_bf16,
239 OUTPUT_FP16=output_fp16,
240 )
241 else:
242 _fht_fused_generic[(grid_size,)](
243 inp_fp32,
244 scratch,
245 out,
246 stride_row,
247 dim_padded,
248 seg_stride,
249 scale,
250 N_ROWS=batch,
251 ROWS_PER_PROGRAM=rows_per_program,
252 DIM=dim_padded,
253 LOG_N=log_n,
254 OUTPUT_BF16=output_bf16,
255 OUTPUT_FP16=output_fp16,
256 )
258 # Trim padding and restore shape
259 if dim != dim_padded:
260 out = out[:, :dim]
261 return out.reshape(orig_shape)
264# ============================================================
265# Autograd wrapper
266# ============================================================
269class HadamardTransformFn(torch.autograd.Function):
270 @staticmethod
271 def forward(ctx, x, scale):
272 ctx.save_for_backward(torch.tensor(scale))
273 return _hadamard_transform_fwd(x, scale)
275 @staticmethod
276 def backward(ctx, grad_output):
277 (scale_t,) = ctx.saved_tensors
278 scale = scale_t.item()
279 return _hadamard_transform_fwd(grad_output, scale), None
282# ============================================================
283# Public API
284# ============================================================
287def hadamard_transform(x, scale=1.0):
288 """Fast Hadamard Transform.
290 x: (..., dim), device=npu, fp32/fp16/bf16
291 out: (..., dim), same dtype
293 Equivalent to F.linear(x, torch.tensor(scipy.linalg.hadamard(dim))) * scale.
294 If dim is not a power of 2, we implicitly pad x with zero so that dim is
295 the next power of 2.
296 """
297 return HadamardTransformFn.apply(x, scale)