Coverage for src/flag_gems/runtime/backend/_sunrise/ops/vdot.py: 0%

147 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6from torch import Tensor 

7 

8from flag_gems import runtime 

9from flag_gems.utils import libentry, tensor_wrapper 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14def _view_as_complex_ptpu_safe(x: torch.Tensor) -> torch.Tensor: 

15 """`torch.view_as_complex(x)` with a CPU bounce when x is on PTPU.""" 

16 try: 

17 return torch.view_as_complex(x) 

18 except NotImplementedError: 

19 if x.device.type != "ptpu": 

20 raise 

21 return torch.view_as_complex(x.cpu()).to(x.device) 

22 

23 

24@triton.jit 

25def compute_vdot( 

26 inp_real, inp_imag, other_real, other_imag, inp_is_conj, other_is_conj 

27): 

28 # # Given inp storage: [inp_real, inp_imag], other: [other_real, other_imag] 

29 

30 # # Case 1: inp_is_conj = False, other_is_conj = False 

31 # out_real = inp_real * other_real + inp_imag * other_imag 

32 # out_imag = inp_real * other_imag - inp_imag * other_real 

33 

34 # # Case 2: inp_is_conj = True, other_is_conj = False 

35 # out_real = inp_real * other_real - inp_imag * other_imag 

36 # out_imag = inp_real * other_imag + inp_imag * other_real 

37 

38 # # Case 3: inp_is_conj = False, other_is_conj = True 

39 # out_real = inp_real * other_real - inp_imag * other_imag 

40 # out_imag = -inp_real * other_imag - inp_imag * other_real 

41 

42 # # Case 4: inp_is_conj = True, other_is_conj = True 

43 # out_real = inp_real * other_real + inp_imag * other_imag 

44 # out_imag = inp_real * other_imag - inp_imag * other_real 

45 if not inp_is_conj and not other_is_conj: # Case 1 

46 out_real = tl.sum(inp_real * other_real + inp_imag * other_imag) 

47 out_imag = tl.sum(inp_real * other_imag - inp_imag * other_real) 

48 elif inp_is_conj and not other_is_conj: # Case 2 

49 out_real = tl.sum(inp_real * other_real - inp_imag * other_imag) 

50 out_imag = tl.sum(inp_real * other_imag + inp_imag * other_real) 

51 elif not inp_is_conj and other_is_conj: # Case 3 

52 out_real = tl.sum(inp_real * other_real - inp_imag * other_imag) 

53 out_imag = tl.sum(-inp_real * other_imag - inp_imag * other_real) 

54 else: # Case 4 

55 out_real = tl.sum(inp_real * other_real + inp_imag * other_imag) 

56 out_imag = tl.sum(-inp_real * other_imag + inp_imag * other_real) 

57 

58 return out_real, out_imag 

59 

60 

61# support old version triton which do not support tl.split 

62@libentry() 

63@triton.jit() 

64def vdot_kernel_complex( 

65 inp_ptr, 

66 other_ptr, 

67 out_ptr, 

68 n_elements, 

69 inp_is_conj: tl.constexpr, 

70 other_is_conj: tl.constexpr, 

71 inp_stride: tl.constexpr, 

72 other_stride: tl.constexpr, 

73 BLOCK_SIZE: tl.constexpr, 

74): 

75 pid = tl.program_id(0) 

76 num_progs = tl.num_programs(0) 

77 

78 grid_stride = num_progs * BLOCK_SIZE 

79 

80 acc_real = tl.zeros([], dtype=tl.float32) 

81 acc_imag = tl.zeros([], dtype=tl.float32) 

82 

83 for current_start in range(0, n_elements // 2, grid_stride): 

84 complex_idx = current_start + pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

85 mask = complex_idx < n_elements // 2 

86 

87 real_offset = complex_idx * 2 

88 

89 inp_real = tl.load(inp_ptr + real_offset * inp_stride, mask=mask, other=0.0) 

90 inp_imag = tl.load(inp_ptr + real_offset * inp_stride + 1, mask=mask, other=0.0) 

91 

92 other_real = tl.load( 

93 other_ptr + real_offset * other_stride, mask=mask, other=0.0 

94 ) 

95 other_imag = tl.load( 

96 other_ptr + real_offset * other_stride + 1, mask=mask, other=0.0 

97 ) 

98 

99 out_real, out_imag = compute_vdot( 

100 inp_real, inp_imag, other_real, other_imag, inp_is_conj, other_is_conj 

101 ) 

102 acc_real += out_real 

103 acc_imag += out_imag 

104 

105 temp_offset = pid * 2 

106 tl.store(out_ptr + temp_offset, acc_real) 

107 tl.store(out_ptr + temp_offset + 1, acc_imag) 

108 

109 

110@libentry() 

111@triton.jit() 

112def reduce_kernel_complex(input_ptr, out_ptr, n_blocks, BLOCK_SIZE: tl.constexpr): 

113 pid = tl.program_id(0) 

114 base_offset = tl.arange(0, BLOCK_SIZE) 

115 mask = base_offset < n_blocks 

116 

117 inp_real = tl.load(input_ptr + base_offset * 2, mask=mask, other=0.0) 

118 inp_imag = tl.load(input_ptr + base_offset * 2 + 1, mask=mask, other=0.0) 

119 final_out_real = tl.sum(inp_real) 

120 final_out_imag = tl.sum(inp_imag) 

121 if pid == 0: 

122 tl.store(out_ptr, final_out_real) 

123 tl.store(out_ptr + 1, final_out_imag) 

124 

125 

126# only support real number 

127@libentry() 

128@triton.heuristics(runtime.get_heuristic_config("vdot")) 

129@triton.jit() 

130def dot_kernel( 

131 inp_ptr, 

132 other_ptr, 

133 out_ptr, 

134 n_elements, 

135 inp_stride: tl.constexpr, 

136 other_stride: tl.constexpr, 

137 BLOCK_SIZE: tl.constexpr, 

138): 

139 pid = tl.program_id(0) 

140 num_progs = tl.num_programs(0) 

141 grid_stride = num_progs * BLOCK_SIZE 

142 

143 acc = tl.zeros([], dtype=tl.float32) 

144 

145 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

146 

147 for current_start in range(0, n_elements, grid_stride): 

148 cur_offsets = current_start + offsets 

149 mask = cur_offsets < n_elements 

150 

151 inp = tl.load(inp_ptr + inp_stride * cur_offsets, mask=mask, other=0.0).to( 

152 tl.float32 

153 ) 

154 other = tl.load( 

155 other_ptr + other_stride * cur_offsets, mask=mask, other=0.0 

156 ).to(tl.float32) 

157 

158 acc += tl.sum(inp * other) 

159 

160 tl.store(out_ptr + pid, acc) 

161 

162 

163@libentry() 

164@triton.jit() 

165def reduce_kernel( 

166 partial_sums_ptr, 

167 output_ptr, 

168 n_blocks, 

169 BLOCK_SIZE: tl.constexpr, 

170): 

171 offset = tl.arange(0, BLOCK_SIZE) 

172 mask = offset < n_blocks 

173 

174 partial_sums = tl.load(partial_sums_ptr + offset, mask=mask, other=0.0) 

175 final_sum = tl.sum(partial_sums) 

176 

177 if tl.program_id(0) == 0: 

178 tl.store(output_ptr, final_sum) 

179 

180 

181@libentry() 

182@triton.heuristics(runtime.get_heuristic_config("vdot")) 

183@triton.jit() 

184def dot_kernel_fp32( 

185 inp_ptr, 

186 other_ptr, 

187 out_ptr, 

188 n_elements, 

189 inp_stride: tl.constexpr, 

190 other_stride: tl.constexpr, 

191 BLOCK_SIZE: tl.constexpr, 

192): 

193 pid = tl.program_id(0) 

194 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

195 mask = offset < n_elements 

196 

197 inp = tl.load(inp_ptr + inp_stride * offset, mask=mask) 

198 other = tl.load(other_ptr + other_stride * offset, mask=mask) 

199 

200 out = tl.sum(inp * other) 

201 tl.atomic_add(out_ptr, out) 

202 

203 

204def vdot(input: Tensor, other: Tensor): 

205 logger.debug("GEMS VDOT") 

206 

207 assert ( 

208 input.dtype == other.dtype 

209 ), f"Input tensors must have the same dtype. Got {input.dtype} and {other.dtype}." 

210 assert ( 

211 input.ndim == 1 and other.ndim == 1 

212 ), f"Input tensors must be 1D. Got {input.ndim}D and {other.ndim}D." 

213 assert ( 

214 input.size() == other.size() 

215 ), f"Input tensors must have the same size. Got {input.size()} and {other.size()}." 

216 

217 inp = input 

218 inp_stride = inp.stride()[0] 

219 other_stride = other.stride()[0] 

220 

221 if inp.is_complex(): 

222 inp_is_conj = False 

223 other_is_conj = False 

224 

225 if inp.is_conj(): 

226 inp_is_conj = True 

227 inp = inp.conj() 

228 

229 if other.is_conj(): 

230 other_is_conj = True 

231 other = other.conj() 

232 

233 inp_real = tensor_wrapper.TypedPtr.reinterpret_tensor(inp, inp.dtype.to_real()) 

234 other_real = tensor_wrapper.TypedPtr.reinterpret_tensor( 

235 other, other.dtype.to_real() 

236 ) 

237 

238 n_elements = inp.numel() * 2 

239 n_complex = inp.numel() 

240 

241 block_size = runtime.get_heuristic_config("vdot")["BLOCK_SIZE"]( 

242 {"n_elements": n_elements} 

243 ) 

244 num_blocks = triton.cdiv(n_complex, block_size) 

245 

246 grid_size = min(num_blocks, 1024) 

247 

248 partial_real_sums = torch.empty( 

249 grid_size, dtype=inp_real.dtype, device=inp.device 

250 ) 

251 grid = (grid_size,) 

252 vdot_kernel_complex[grid]( 

253 inp_real, 

254 other_real, 

255 partial_real_sums, 

256 n_elements=n_elements, 

257 inp_is_conj=inp_is_conj, 

258 other_is_conj=other_is_conj, 

259 inp_stride=inp_stride, 

260 other_stride=other_stride, 

261 BLOCK_SIZE=block_size, 

262 ) 

263 output_real = torch.empty(2, dtype=inp_real.dtype, device=inp.device) 

264 reduce_kernel_complex[(1,)]( 

265 partial_real_sums, 

266 output_real, 

267 grid_size, 

268 BLOCK_SIZE=triton.next_power_of_2(grid_size), 

269 ) 

270 return _view_as_complex_ptpu_safe(output_real) 

271 elif inp.dtype == torch.float32: 

272 output = torch.zeros([], dtype=torch.float32, device=inp.device) 

273 n_elements = inp.numel() 

274 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

275 dot_kernel_fp32[grid]( 

276 inp, 

277 other, 

278 output, 

279 n_elements=n_elements, 

280 inp_stride=inp_stride, 

281 other_stride=other_stride, 

282 ) 

283 return output 

284 else: 

285 n_elements = inp.numel() 

286 block_size = runtime.get_heuristic_config("vdot")["BLOCK_SIZE"]( 

287 {"n_elements": n_elements} 

288 ) 

289 

290 num_blocks = triton.cdiv(n_elements, block_size) 

291 grid_size = min(num_blocks, 1024) 

292 

293 grid = (num_blocks,) 

294 partial_sums = torch.empty(grid_size, dtype=torch.float32, device=inp.device) 

295 dot_kernel[grid]( 

296 inp, 

297 other, 

298 partial_sums, 

299 n_elements=n_elements, 

300 inp_stride=inp_stride, 

301 other_stride=other_stride, 

302 BLOCK_SIZE=block_size, 

303 ) 

304 output = torch.empty([], dtype=input.dtype, device=inp.device) 

305 reduce_bs = min(triton.next_power_of_2(grid_size), 1024) 

306 reduce_kernel[(1,)]( 

307 partial_sums, 

308 output, 

309 num_blocks, 

310 BLOCK_SIZE=reduce_bs, 

311 ) 

312 return output