Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/hadamard_transform.py: 0%

74 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1"""Fast Hadamard Transform in Triton (KunlunXin). 

2 

3v0: Multi-pass butterfly via global memory, 1 kernel launch per stage. 

4Simple baseline for correctness. Each butterfly stage reads from IN, writes to OUT. 

5""" 

6 

7import math 

8 

9import torch 

10import torch.nn.functional as F 

11import triton 

12import triton.language as tl 

13 

14MAX_GRID = 65535 

15 

16 

17# ============================================================ 

18# Single butterfly stage kernel 

19# ============================================================ 

20 

21 

22@triton.jit 

23def _butterfly_stage( 

24 IN_ptr, 

25 OUT_ptr, 

26 stride_row, 

27 N_ROWS, 

28 ROWS_PER_PROGRAM: tl.constexpr, 

29 STRIDE_S: tl.constexpr, 

30 DIM: tl.constexpr, 

31): 

32 """One butterfly stage: read from IN, write to OUT. 

33 

34 For each element at position i: 

35 partner = i ^ STRIDE_S 

36 if (i & STRIDE_S) == 0: out[i] = in[i] + in[partner] 

37 else: out[i] = in[partner] - in[i] 

38 """ 

39 pid = tl.program_id(0) 

40 offsets = tl.arange(0, DIM) 

41 

42 for row_idx in tl.static_range(ROWS_PER_PROGRAM): 

43 row_id = pid * ROWS_PER_PROGRAM + row_idx 

44 if row_id < N_ROWS: 

45 base = row_id * stride_row 

46 

47 x = tl.load(IN_ptr + base + offsets).to(tl.float32) 

48 partner_offsets = offsets ^ STRIDE_S 

49 x_partner = tl.load(IN_ptr + base + partner_offsets).to(tl.float32) 

50 

51 is_upper = (offsets & STRIDE_S) == 0 

52 result = tl.where(is_upper, x + x_partner, x_partner - x) 

53 

54 tl.store(OUT_ptr + base + offsets, result) 

55 

56 

57# ============================================================ 

58# Scale + cast kernel 

59# ============================================================ 

60 

61 

62@triton.jit 

63def _scale_cast( 

64 IN_ptr, 

65 OUT_ptr, 

66 stride_in_row, 

67 stride_out_row, 

68 scale, 

69 N_ROWS, 

70 ROWS_PER_PROGRAM: tl.constexpr, 

71 DIM: tl.constexpr, 

72): 

73 """Scale fp32 buffer and cast to output dtype.""" 

74 pid = tl.program_id(0) 

75 offsets = tl.arange(0, DIM) 

76 

77 for row_idx in tl.static_range(ROWS_PER_PROGRAM): 

78 row_id = pid * ROWS_PER_PROGRAM + row_idx 

79 if row_id < N_ROWS: 

80 x = tl.load(IN_ptr + row_id * stride_in_row + offsets) 

81 tl.store(OUT_ptr + row_id * stride_out_row + offsets, x * scale) 

82 

83 

84# ============================================================ 

85# Forward implementation 

86# ============================================================ 

87 

88 

89def _hadamard_transform_fwd(x, scale): 

90 orig_shape = x.shape 

91 dim = x.shape[-1] 

92 input_dtype = x.dtype 

93 

94 # Pad to next power of 2 

95 log_dim = math.ceil(math.log2(dim)) if dim > 0 else 0 

96 dim_padded = 1 << log_dim 

97 if dim != dim_padded: 

98 x = F.pad(x, (0, dim_padded - dim)) 

99 

100 x_flat = x.reshape(-1, dim_padded).contiguous() 

101 n_rows = x_flat.shape[0] 

102 n_stages = log_dim # log2(dim_padded) 

103 

104 # Determine ROWS_PER_PROGRAM to stay within grid limit 

105 rows_per_prog = 1 

106 while (n_rows + rows_per_prog - 1) // rows_per_prog > MAX_GRID: 

107 rows_per_prog *= 2 

108 grid_size = (n_rows + rows_per_prog - 1) // rows_per_prog 

109 

110 # Allocate two fp32 scratch buffers for ping-pong 

111 # .clone() is critical: for fp32 input, .float() is a no-op returning 

112 # the same tensor, which would cause butterfly stages to overwrite the input 

113 buf_a = x_flat.float().clone() 

114 buf_b = torch.empty_like(buf_a) 

115 

116 stride_row = dim_padded 

117 

118 # Run butterfly stages 

119 for s in range(n_stages): 

120 stride_s = 1 << s 

121 _butterfly_stage[(grid_size,)]( 

122 buf_a, 

123 buf_b, 

124 stride_row, 

125 n_rows, 

126 ROWS_PER_PROGRAM=rows_per_prog, 

127 STRIDE_S=stride_s, 

128 DIM=dim_padded, 

129 ) 

130 buf_a, buf_b = buf_b, buf_a 

131 

132 # Result is in buf_a; scale and cast back 

133 out = torch.empty(n_rows, dim_padded, dtype=input_dtype, device=x.device) 

134 _scale_cast[(grid_size,)]( 

135 buf_a, 

136 out, 

137 stride_row, 

138 dim_padded, 

139 scale, 

140 n_rows, 

141 ROWS_PER_PROGRAM=rows_per_prog, 

142 DIM=dim_padded, 

143 ) 

144 

145 if dim != dim_padded: 

146 out = out[:, :dim] 

147 return out.reshape(orig_shape) 

148 

149 

150# ============================================================ 

151# Autograd wrapper 

152# ============================================================ 

153 

154 

155class HadamardTransformFn(torch.autograd.Function): 

156 @staticmethod 

157 def forward(ctx, x, scale): 

158 ctx._hadamard_transform_scale = scale 

159 return _hadamard_transform_fwd(x, scale) 

160 

161 @staticmethod 

162 def backward(ctx, grad_output): 

163 # Hadamard matrix is symmetric: backward = forward with same scale 

164 return ( 

165 _hadamard_transform_fwd( 

166 grad_output.contiguous(), ctx._hadamard_transform_scale 

167 ), 

168 None, 

169 ) 

170 

171 

172# ============================================================ 

173# Public API 

174# ============================================================ 

175 

176 

177def hadamard_transform(x, scale=1.0): 

178 """Fast Hadamard Transform. 

179 

180 Arguments: 

181 x: (..., dim) 

182 scale: float. Multiply the output by this number. 

183 Returns: 

184 out: (..., dim) 

185 

186 Multiply each row of x by the Hadamard transform matrix. 

187 Equivalent to F.linear(x, torch.tensor(scipy.linalg.hadamard(dim))) * scale. 

188 If dim is not a power of 2, we implicitly pad x with zero so that dim is 

189 the next power of 2. 

190 """ 

191 return HadamardTransformFn.apply(x, scale) 

192 

193 

194# ============================================================ 

195# XXN variants (non-power-of-2 dims) 

196# ============================================================ 

197 

198 

199def hadamard_transform_12N(x, scale=1.0): 

200 """Hadamard transform for dim = 12 * 2^k (e.g. 12*512 = 6144).""" 

201 return HadamardTransformFn.apply(x, scale) 

202 

203 

204def hadamard_transform_20N(x, scale=1.0): 

205 """Hadamard transform for dim = 20 * 2^k (e.g. 20*1024 = 20480).""" 

206 return HadamardTransformFn.apply(x, scale) 

207 

208 

209def hadamard_transform_28N(x, scale=1.0): 

210 """Hadamard transform for dim = 28 * 2^k (e.g. 28*1024 = 28672).""" 

211 return HadamardTransformFn.apply(x, scale) 

212 

213 

214def hadamard_transform_40N(x, scale=1.0): 

215 """Hadamard transform for dim = 40 * 2^k (e.g. 40*1024 = 40960).""" 

216 return HadamardTransformFn.apply(x, scale)