Coverage for src/flag_gems/runtime/backend/_ascend/ops/hadamard_transform.py: 0%

123 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1"""Fast Hadamard Transform in Triton (Ascend NPU). 

2 

3v1: Single-kernel fused butterfly with chained buffers + fused scale/cast. 

4All 7 butterfly stages + scale + dtype cast in one kernel launch. 

5Uses unique buffer for each stage to avoid NPU stale-read issues. 

6Eliminates 7 kernel launch overheads and the separate scale/cast kernel from v0. 

7""" 

8 

9import math 

10 

11import torch 

12import torch.nn.functional as F 

13import triton 

14import triton.language as tl 

15 

16MAX_GRID = 65535 

17 

18 

19# ============================================================ 

20# Fused 7-stage butterfly kernel (dim=128 specialized) 

21# Uses 6 scratch buffer segments (B0..B5) in a contiguous allocation. 

22# Chain: IN -> B0 -> B1 -> B2 -> B3 -> B4 -> B5 -> OUT 

23# ============================================================ 

24 

25 

26@triton.jit 

27def _fht_fused_7stage( 

28 IN_ptr, 

29 SCRATCH_ptr, 

30 OUT_ptr, 

31 stride_row, 

32 stride_out_row, 

33 seg_stride, 

34 scale, 

35 N_ROWS, 

36 ROWS_PER_PROGRAM: tl.constexpr, 

37 DIM: tl.constexpr, 

38 OUTPUT_BF16: tl.constexpr, 

39 OUTPUT_FP16: tl.constexpr, 

40): 

41 """Fused FHT for dim=128 (7 butterfly stages) + scale + cast. 

42 

43 SCRATCH_ptr points to a contiguous (6, batch, DIM) fp32 buffer. 

44 seg_stride = batch * DIM (distance between scratch segments). 

45 Chain: IN -> seg0 -> seg1 -> seg2 -> seg3 -> seg4 -> seg5 -> OUT 

46 """ 

47 pid = tl.program_id(0) 

48 offsets = tl.arange(0, DIM) 

49 

50 for row_idx in tl.static_range(ROWS_PER_PROGRAM): 

51 row_id = pid * ROWS_PER_PROGRAM + row_idx 

52 if row_id < N_ROWS: 

53 in_base = row_id * stride_row 

54 row_off = row_id * DIM # offset within each scratch segment 

55 

56 # Stage 0: IN -> B0 (stride=1) 

57 x = tl.load(IN_ptr + in_base + offsets) 

58 p = tl.load(IN_ptr + in_base + (offsets ^ 1)) 

59 r = tl.where((offsets & 1) == 0, x + p, p - x) 

60 tl.store(SCRATCH_ptr + row_off + offsets, r) 

61 

62 # Stage 1: B0 -> B1 (stride=2) 

63 b0_off = row_off 

64 b1_off = seg_stride + row_off 

65 x = tl.load(SCRATCH_ptr + b0_off + offsets) 

66 p = tl.load(SCRATCH_ptr + b0_off + (offsets ^ 2)) 

67 r = tl.where((offsets & 2) == 0, x + p, p - x) 

68 tl.store(SCRATCH_ptr + b1_off + offsets, r) 

69 

70 # Stage 2: B1 -> B2 (stride=4) 

71 b2_off = 2 * seg_stride + row_off 

72 x = tl.load(SCRATCH_ptr + b1_off + offsets) 

73 p = tl.load(SCRATCH_ptr + b1_off + (offsets ^ 4)) 

74 r = tl.where((offsets & 4) == 0, x + p, p - x) 

75 tl.store(SCRATCH_ptr + b2_off + offsets, r) 

76 

77 # Stage 3: B2 -> B3 (stride=8) 

78 b3_off = 3 * seg_stride + row_off 

79 x = tl.load(SCRATCH_ptr + b2_off + offsets) 

80 p = tl.load(SCRATCH_ptr + b2_off + (offsets ^ 8)) 

81 r = tl.where((offsets & 8) == 0, x + p, p - x) 

82 tl.store(SCRATCH_ptr + b3_off + offsets, r) 

83 

84 # Stage 4: B3 -> B4 (stride=16) 

85 b4_off = 4 * seg_stride + row_off 

86 x = tl.load(SCRATCH_ptr + b3_off + offsets) 

87 p = tl.load(SCRATCH_ptr + b3_off + (offsets ^ 16)) 

88 r = tl.where((offsets & 16) == 0, x + p, p - x) 

89 tl.store(SCRATCH_ptr + b4_off + offsets, r) 

90 

91 # Stage 5: B4 -> B5 (stride=32) 

92 b5_off = 5 * seg_stride + row_off 

93 x = tl.load(SCRATCH_ptr + b4_off + offsets) 

94 p = tl.load(SCRATCH_ptr + b4_off + (offsets ^ 32)) 

95 r = tl.where((offsets & 32) == 0, x + p, p - x) 

96 tl.store(SCRATCH_ptr + b5_off + offsets, r) 

97 

98 # Stage 6: B5 -> OUT (stride=64) + fused scale + cast 

99 x = tl.load(SCRATCH_ptr + b5_off + offsets) 

100 p = tl.load(SCRATCH_ptr + b5_off + (offsets ^ 64)) 

101 r = tl.where((offsets & 64) == 0, x + p, p - x) 

102 

103 r = r * scale 

104 out_base = row_id * stride_out_row 

105 if OUTPUT_BF16: 

106 tl.store(OUT_ptr + out_base + offsets, r.to(tl.bfloat16)) 

107 elif OUTPUT_FP16: 

108 tl.store(OUT_ptr + out_base + offsets, r.to(tl.float16)) 

109 else: 

110 tl.store(OUT_ptr + out_base + offsets, r) 

111 

112 

113# ============================================================ 

114# Generic fused butterfly kernel (any power-of-2 dim) 

115# ============================================================ 

116 

117 

118@triton.jit 

119def _fht_fused_generic( 

120 IN_ptr, 

121 SCRATCH_ptr, 

122 OUT_ptr, 

123 stride_row, 

124 stride_out_row, 

125 seg_stride, 

126 scale, 

127 N_ROWS, 

128 ROWS_PER_PROGRAM: tl.constexpr, 

129 DIM: tl.constexpr, 

130 LOG_N: tl.constexpr, 

131 OUTPUT_BF16: tl.constexpr, 

132 OUTPUT_FP16: tl.constexpr, 

133): 

134 """Generic fused FHT for any power-of-2 dim. 

135 

136 Uses chained scratch buffer segments. Each stage reads from one 

137 segment and writes to the next, avoiding NPU stale-read issues. 

138 """ 

139 pid = tl.program_id(0) 

140 offsets = tl.arange(0, DIM) 

141 

142 for row_idx in tl.static_range(ROWS_PER_PROGRAM): 

143 row_id = pid * ROWS_PER_PROGRAM + row_idx 

144 if row_id < N_ROWS: 

145 in_base = row_id * stride_row 

146 row_off = row_id * DIM 

147 

148 for s in tl.static_range(LOG_N): 

149 stride_s: tl.constexpr = 1 << s 

150 is_upper = (offsets & stride_s) == 0 

151 

152 if s == 0: 

153 # Read from input 

154 x = tl.load(IN_ptr + in_base + offsets) 

155 p = tl.load(IN_ptr + in_base + (offsets ^ stride_s)) 

156 else: 

157 src_off = (s - 1) * seg_stride + row_off 

158 x = tl.load(SCRATCH_ptr + src_off + offsets) 

159 p = tl.load(SCRATCH_ptr + src_off + (offsets ^ stride_s)) 

160 

161 r = tl.where(is_upper, x + p, p - x) 

162 

163 if s == LOG_N - 1: 

164 r = r * scale 

165 out_base = row_id * stride_out_row 

166 if OUTPUT_BF16: 

167 tl.store(OUT_ptr + out_base + offsets, r.to(tl.bfloat16)) 

168 elif OUTPUT_FP16: 

169 tl.store(OUT_ptr + out_base + offsets, r.to(tl.float16)) 

170 else: 

171 tl.store(OUT_ptr + out_base + offsets, r) 

172 else: 

173 dst_off = s * seg_stride + row_off 

174 tl.store(SCRATCH_ptr + dst_off + offsets, r) 

175 

176 

177# ============================================================ 

178# Core forward 

179# ============================================================ 

180 

181 

182def _hadamard_transform_fwd(x: torch.Tensor, scale: float) -> torch.Tensor: 

183 """Core forward: handles reshape, padding, kernel launch.""" 

184 assert x.dtype in ( 

185 torch.float32, 

186 torch.float16, 

187 torch.bfloat16, 

188 ), f"Unsupported dtype {x.dtype}" 

189 

190 orig_shape = x.shape 

191 dim = orig_shape[-1] 

192 input_dtype = x.dtype 

193 x_flat = x.reshape(-1, dim) 

194 batch = x_flat.shape[0] 

195 

196 # Pad dim to next power of 2 

197 log_n = math.ceil(math.log2(max(dim, 2))) 

198 dim_padded = 1 << log_n 

199 if dim != dim_padded: 

200 x_flat = F.pad(x_flat, (0, dim_padded - dim)) 

201 

202 # Input buffer in fp32 

203 inp_fp32 = x_flat.float() 

204 

205 # Scratch buffer: (log_n - 1) segments of (batch, dim_padded) in fp32 

206 # Stage s writes to segment s (0..log_n-2), last stage writes to output 

207 n_scratch = max(log_n - 1, 1) 

208 scratch = torch.empty( 

209 n_scratch, batch, dim_padded, dtype=torch.float32, device=x.device 

210 ) 

211 seg_stride = batch * dim_padded 

212 

213 # Grid calculation 

214 rows_per_program = max((batch + MAX_GRID - 1) // MAX_GRID, 1) 

215 grid_size = (batch + rows_per_program - 1) // rows_per_program 

216 

217 stride_row = dim_padded # contiguous 

218 

219 # Output buffer 

220 out = torch.empty(batch, dim_padded, dtype=input_dtype, device=x.device) 

221 

222 output_bf16 = input_dtype == torch.bfloat16 

223 output_fp16 = input_dtype == torch.float16 

224 

225 # Use specialized 7-stage kernel for dim=128, generic for others 

226 if log_n == 7: 

227 _fht_fused_7stage[(grid_size,)]( 

228 inp_fp32, 

229 scratch, 

230 out, 

231 stride_row, 

232 dim_padded, 

233 seg_stride, 

234 scale, 

235 N_ROWS=batch, 

236 ROWS_PER_PROGRAM=rows_per_program, 

237 DIM=dim_padded, 

238 OUTPUT_BF16=output_bf16, 

239 OUTPUT_FP16=output_fp16, 

240 ) 

241 else: 

242 _fht_fused_generic[(grid_size,)]( 

243 inp_fp32, 

244 scratch, 

245 out, 

246 stride_row, 

247 dim_padded, 

248 seg_stride, 

249 scale, 

250 N_ROWS=batch, 

251 ROWS_PER_PROGRAM=rows_per_program, 

252 DIM=dim_padded, 

253 LOG_N=log_n, 

254 OUTPUT_BF16=output_bf16, 

255 OUTPUT_FP16=output_fp16, 

256 ) 

257 

258 # Trim padding and restore shape 

259 if dim != dim_padded: 

260 out = out[:, :dim] 

261 return out.reshape(orig_shape) 

262 

263 

264# ============================================================ 

265# Autograd wrapper 

266# ============================================================ 

267 

268 

269class HadamardTransformFn(torch.autograd.Function): 

270 @staticmethod 

271 def forward(ctx, x, scale): 

272 ctx.save_for_backward(torch.tensor(scale)) 

273 return _hadamard_transform_fwd(x, scale) 

274 

275 @staticmethod 

276 def backward(ctx, grad_output): 

277 (scale_t,) = ctx.saved_tensors 

278 scale = scale_t.item() 

279 return _hadamard_transform_fwd(grad_output, scale), None 

280 

281 

282# ============================================================ 

283# Public API 

284# ============================================================ 

285 

286 

287def hadamard_transform(x, scale=1.0): 

288 """Fast Hadamard Transform. 

289 

290 x: (..., dim), device=npu, fp32/fp16/bf16 

291 out: (..., dim), same dtype 

292 

293 Equivalent to F.linear(x, torch.tensor(scipy.linalg.hadamard(dim))) * scale. 

294 If dim is not a power of 2, we implicitly pad x with zero so that dim is 

295 the next power of 2. 

296 """ 

297 return HadamardTransformFn.apply(x, scale)