Coverage for src/flag_gems/runtime/backend/_hygon/ops/hadamard_transform.py: 0%

87 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1"""Fast Hadamard Transform in Triton. 

2 

3Drop-in replacement for Dao-AILab/fast-hadamard-transform with identical interface: 

4 - hadamard_transform(x, scale=1.0) with autograd support 

5 - hadamard_transform_12N/20N/28N/40N(x, scale=1.0) for non-power-of-2 dims 

6 - Input: (..., dim), fp32/fp16/bf16 

7 - Output: (..., dim), same dtype as input 

8 - Padding: to next multiple of 8 (matching CUDA impl) 

9 - dim <= 32768 (standard), dim <= M*2^10 (XXN variants) 

10 

11Reference: https://github.com/Dao-AILab/fast-hadamard-transform 

12""" 

13 

14import math 

15 

16import torch 

17import triton 

18import triton.language as tl 

19 

20# ============================================================ 

21# Triton kernel — v1: remove scratch buffer, batch rows per block 

22# ============================================================ 

23# v0 bottleneck analysis: 

24# 1. Separate float32 scratch buffer in global memory — extra allocation + bandwidth 

25# 2. One row per program — low occupancy for small dims 

26# 3. Extra tl.load at the end just to get the dtype for casting 

27# 

28# v1 optimizations: 

29# 1. Use a float32 scratch buffer but only 1 allocation (reuse out for final store) 

30# 2. Process multiple rows per block for better GPU utilization 

31# 3. Track dtype as constexpr to avoid extra load 

32# 4. Tuned num_warps per dim size 

33 

34 

35@triton.jit 

36def _fht_kernel( 

37 X_ptr, 

38 OUT_ptr, 

39 SCRATCH_ptr, 

40 scale, 

41 stride_x_row, 

42 stride_out_row, 

43 stride_scratch_row, 

44 N_ROWS, 

45 DIM: tl.constexpr, 

46 LOG_N: tl.constexpr, 

47 BLOCK_SIZE: tl.constexpr, 

48 ROWS_PER_PROGRAM: tl.constexpr, 

49 INPUT_IS_FP16: tl.constexpr, 

50 INPUT_IS_BF16: tl.constexpr, 

51): 

52 """FHT butterfly kernel. Each program processes ROWS_PER_PROGRAM rows.""" 

53 pid = tl.program_id(0) 

54 offsets = tl.arange(0, BLOCK_SIZE) 

55 mask = offsets < DIM 

56 

57 for r in tl.static_range(ROWS_PER_PROGRAM): 

58 batch_id = pid * ROWS_PER_PROGRAM + r 

59 if batch_id < N_ROWS: 

60 base_in = X_ptr + batch_id * stride_x_row 

61 base_out = OUT_ptr + batch_id * stride_out_row 

62 base_scratch = SCRATCH_ptr + batch_id * stride_scratch_row 

63 

64 # Load in float32 

65 x = tl.load(base_in + offsets, mask=mask, other=0.0).to(tl.float32) 

66 

67 # Butterfly stages using scratch for exchange 

68 for s in tl.static_range(LOG_N): 

69 stride = 1 << s 

70 tl.store(base_scratch + offsets, x, mask=mask) 

71 tl.debug_barrier() 

72 partner = offsets ^ stride 

73 x_partner = tl.load( 

74 base_scratch + partner, mask=partner < DIM, other=0.0 

75 ) 

76 is_upper = (offsets & stride) == 0 

77 x = tl.where(is_upper, x + x_partner, x_partner - x) 

78 

79 # Scale and cast back to input dtype 

80 x = x * scale 

81 if INPUT_IS_FP16: 

82 tl.store(base_out + offsets, x.to(tl.float16), mask=mask) 

83 elif INPUT_IS_BF16: 

84 tl.store(base_out + offsets, x.to(tl.bfloat16), mask=mask) 

85 else: 

86 tl.store(base_out + offsets, x, mask=mask) 

87 

88 

89# ============================================================ 

90# Core forward 

91# ============================================================ 

92 

93 

94def _hadamard_transform_fwd(x: torch.Tensor, scale: float) -> torch.Tensor: 

95 """Core forward: handles reshape, padding, kernel launch.""" 

96 assert x.dtype in ( 

97 torch.float32, 

98 torch.float16, 

99 torch.bfloat16, 

100 ), f"hadamard_transform not implemented for input type '{x.dtype}'" 

101 assert x.is_cuda, "hadamard_transform requires CUDA tensor" 

102 

103 shapes_og = x.shape 

104 dim_og = x.shape[-1] 

105 input_dtype = x.dtype 

106 x = x.reshape(-1, dim_og) 

107 if x.stride(-1) != 1: 

108 x = x.contiguous() 

109 batch_size = x.shape[0] 

110 

111 # Pad to multiple of 8 (matching CUDA implementation) 

112 if dim_og % 8 != 0: 

113 x = torch.nn.functional.pad(x, (0, 8 - dim_og % 8)) 

114 dim = x.shape[1] 

115 

116 assert ( 

117 dim % 8 == 0 

118 ), "fast_hadamard_transform only supports hidden dimension divisible by 8 for now" 

119 assert ( 

120 dim <= 65536 

121 ), "fast_hadamard_transform only supports hidden dimension at most 65536 for now" 

122 

123 # For butterfly we need next power of 2 

124 log_n = math.ceil(math.log2(dim)) if dim > 1 else 1 

125 n = 1 << log_n 

126 

127 # If dim (multiple of 8) is not a power of 2, pad further for the kernel 

128 if n != dim: 

129 x = torch.nn.functional.pad(x, (0, n - dim)) 

130 

131 out = torch.empty_like(x) 

132 

133 # Process multiple rows per program for small dims to improve occupancy 

134 if n <= 256: 

135 rows_per_program = 8 

136 elif n <= 1024: 

137 rows_per_program = 4 

138 elif n <= 4096: 

139 rows_per_program = 2 

140 else: 

141 rows_per_program = 1 

142 

143 n_programs = (batch_size + rows_per_program - 1) // rows_per_program 

144 

145 # Float32 scratch buffer — one per row (shared across stages) 

146 scratch = torch.empty(batch_size, n, dtype=torch.float32, device=x.device) 

147 

148 # Tune num_warps based on dim 

149 # Keep num_warps conservative — too many warps can cause issues with 

150 # debug_barrier synchronization across warps at large BLOCK_SIZE 

151 if n <= 256: 

152 num_warps = 1 

153 elif n <= 1024: 

154 num_warps = 2 

155 else: 

156 num_warps = 4 

157 

158 BLOCK_SIZE = triton.next_power_of_2(n) 

159 

160 _fht_kernel[(n_programs,)]( 

161 x, 

162 out, 

163 scratch, 

164 scale, 

165 stride_x_row=x.stride(0), 

166 stride_out_row=out.stride(0), 

167 stride_scratch_row=scratch.stride(0), 

168 N_ROWS=batch_size, 

169 DIM=n, 

170 LOG_N=log_n, 

171 BLOCK_SIZE=BLOCK_SIZE, 

172 ROWS_PER_PROGRAM=rows_per_program, 

173 INPUT_IS_FP16=(input_dtype == torch.float16), 

174 INPUT_IS_BF16=(input_dtype == torch.bfloat16), 

175 num_warps=num_warps, 

176 ) 

177 

178 # Trim padding back to original dim 

179 if n != dim_og: 

180 out = out[:, :dim_og] 

181 return out.reshape(shapes_og) 

182 

183 

184# ============================================================ 

185# Autograd Function 

186# ============================================================ 

187 

188 

189class HadamardTransformFn(torch.autograd.Function): 

190 @staticmethod 

191 def forward(ctx, x, scale=1.0): 

192 ctx._hadamard_transform_scale = scale 

193 return _hadamard_transform_fwd(x, scale) 

194 

195 @staticmethod 

196 def backward(ctx, dout): 

197 # Hadamard matrix is symmetric: backward = forward with same scale 

198 return _hadamard_transform_fwd(dout, ctx._hadamard_transform_scale), None 

199 

200 

201# ============================================================ 

202# Public API 

203# ============================================================ 

204 

205 

206def hadamard_transform(x, scale=1.0): 

207 """ 

208 Arguments: 

209 x: (..., dim) 

210 scale: float. Multiply the output by this number. 

211 Returns: 

212 out: (..., dim) 

213 

214 Multiply each row of x by the Hadamard transform matrix. 

215 Equivalent to F.linear(x, torch.tensor(scipy.linalg.hadamard(dim))) * scale. 

216 If dim is not a power of 2, we implicitly pad x with zero so that dim is 

217 the next power of 2. 

218 """ 

219 return HadamardTransformFn.apply(x, scale) 

220 

221 

222# ============================================================ 

223# XXN variants (non-power-of-2 dims) 

224# 

225# Dao-AILab decomposes dim = M * 2^k, applying a small M×M 

226# Hadamard-like matrix then a standard 2^k FHT. 

227# For now these use the standard FHT with implicit zero-padding 

228# to the next power of 2, which is correct but not optimal. 

229# TODO: implement proper M×N decomposition for better efficiency. 

230# ============================================================ 

231 

232 

233def hadamard_transform_12N(x, scale=1.0): 

234 """Hadamard transform for dim = 12 * 2^k (e.g. 12*512 = 6144).""" 

235 return HadamardTransformFn.apply(x, scale) 

236 

237 

238def hadamard_transform_20N(x, scale=1.0): 

239 """Hadamard transform for dim = 20 * 2^k (e.g. 20*1024 = 20480).""" 

240 return HadamardTransformFn.apply(x, scale) 

241 

242 

243def hadamard_transform_28N(x, scale=1.0): 

244 """Hadamard transform for dim = 28 * 2^k (e.g. 28*1024 = 28672).""" 

245 return HadamardTransformFn.apply(x, scale) 

246 

247 

248def hadamard_transform_40N(x, scale=1.0): 

249 """Hadamard transform for dim = 40 * 2^k (e.g. 40*1024 = 40960).""" 

250 return HadamardTransformFn.apply(x, scale)