Coverage for src/flag_gems/runtime/backend/_sunrise/ops/topk.py: 0%

179 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7import triton.language.core as core 

8 

9try: 

10 # TODO: Triton 2.1 does not implement _log2. 

11 # Remove the try-catch block once all vendors upgrade to a newer version of Triton. 

12 from triton.language.standard import _log2 

13except ImportError: 

14 pass 

15 

16from flag_gems.runtime import torch_device_fn 

17from flag_gems.utils import libentry 

18from flag_gems.utils import triton_lang_extension as ext 

19from flag_gems.utils.limits import get_dtype_max, get_dtype_min 

20 

21logger = logging.getLogger(__name__) 

22_MIN_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).min) 

23_MAX_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).max) 

24_MIN_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).min) 

25_MAX_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).max) 

26_MIN_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).min) 

27_MAX_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).max) 

28_MIN_INT8_VAL = tl.constexpr(torch.iinfo(torch.int8).min) 

29_MAX_INT8_VAL = tl.constexpr(torch.iinfo(torch.int8).max) 

30_MIN_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).min) 

31_MAX_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).max) 

32_MIN_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).min) 

33_MAX_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).max) 

34_MIN_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).min) 

35_MAX_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).max) 

36 

37 

38@triton.jit 

39def _get_finfo_val( 

40 dtype, 

41 return_max, 

42): 

43 if dtype is tl.float32: 

44 if return_max: 

45 return _MAX_FLOAT32_VAL 

46 else: 

47 return _MIN_FLOAT32_VAL 

48 elif dtype is tl.float16: 

49 if return_max: 

50 return _MAX_FLOAT16_VAL 

51 else: 

52 return _MIN_FLOAT16_VAL 

53 elif dtype is tl.bfloat16: 

54 if return_max: 

55 return _MAX_BFLOAT16_VAL 

56 else: 

57 return _MIN_BFLOAT16_VAL 

58 

59 

60@triton.jit 

61def _get_iinfo_val( 

62 dtype, 

63 return_max, 

64): 

65 if return_max: 

66 return get_dtype_max(dtype) 

67 else: 

68 return get_dtype_min(dtype) 

69 

70 

71@libentry() 

72@triton.jit 

73def topk_stage1_kernel( 

74 y_ptr, 

75 index_ptr, 

76 x_ptr, 

77 k, 

78 N: tl.constexpr, 

79 CHUNK_SIZE: tl.constexpr, 

80 DESCENDING: tl.constexpr, 

81): 

82 cur_batch = ext.program_id(0) 

83 cur_chunk_idx = ext.program_id(1) 

84 chunk_num = ext.num_programs(1) 

85 

86 y_ptr += cur_batch * chunk_num * k + cur_chunk_idx * k 

87 index_ptr += cur_batch * chunk_num * k + cur_chunk_idx * k 

88 

89 chunk_offset = cur_chunk_idx * CHUNK_SIZE 

90 x_ptr += cur_batch * N + chunk_offset 

91 

92 cols = tl.arange(0, CHUNK_SIZE) 

93 mask = (chunk_offset + cols) < N 

94 

95 mask_val = _get_finfo_val(x_ptr.dtype.element_ty, return_max=not DESCENDING) 

96 x_val = tl.load(x_ptr + cols, mask=mask, other=mask_val).to(tl.float32) 

97 for k_idx in range(k): 

98 if DESCENDING: 

99 chunk_select_val = tl.max(x_val) 

100 chunk_select_idx = tl.argmax(x_val, axis=0) 

101 else: 

102 chunk_select_val = tl.min(x_val) 

103 chunk_select_idx = tl.argmin(x_val, axis=0) 

104 

105 tl.store(y_ptr + k_idx, chunk_select_val) 

106 tl.store(index_ptr + k_idx, chunk_select_idx + chunk_offset) 

107 

108 if DESCENDING: 

109 x_val = tl.where( 

110 cols == chunk_select_idx, 

111 _get_finfo_val(tl.float32, return_max=False), 

112 x_val, 

113 ) 

114 else: 

115 x_val = tl.where( 

116 cols == chunk_select_idx, 

117 _get_finfo_val(tl.float32, return_max=True), 

118 x_val, 

119 ) 

120 

121 

122""" 

123Note(Zhengzekang): 

124Refer from triton2.2 official `sort` implementation: 

125https://github.com/triton-lang/triton/blob/release/2.2.x/python/triton/language/standard.py#L392-L404 

126Just add indices to sort with values. 

127""" 

128 

129 

130@triton.jit 

131def _compare_and_swap(x, ids, flip, i: core.constexpr, n_dims: core.constexpr): 

132 n_outer: core.constexpr = x.numel >> n_dims 

133 shape: core.constexpr = [n_outer * 2**i, 2, 2 ** (n_dims - i - 1)] 

134 

135 # tl.device_print("shape is: ", shape) 

136 y = core.reshape(x, shape) 

137 y_idx = core.reshape(ids, shape) 

138 

139 # slice left/right with 'stride' 2**(n_dims - i - 1) 

140 mask = core.arange(0, 2)[None, :, None] 

141 left = core.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape).to(x.dtype) 

142 right = core.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape).to(x.dtype) 

143 left = core.reshape(left, x.shape) 

144 right = core.reshape(right, x.shape) 

145 

146 left_idx = core.broadcast_to(tl.sum(y_idx * (1 - mask), 1)[:, None, :], shape).to( 

147 ids.dtype 

148 ) 

149 right_idx = core.broadcast_to(tl.sum(y_idx * mask, 1)[:, None, :], shape).to( 

150 ids.dtype 

151 ) 

152 left_idx = core.reshape(left_idx, ids.shape) 

153 right_idx = core.reshape(right_idx, ids.shape) 

154 

155 # actual compare-and-swap 

156 # is_right indicator: 0 for left, 1 for right element in each pair. 

157 is_right = core.reshape( 

158 core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape 

159 ) 

160 

161 # Paired value: for left (is_right=0), the paired is right; 

162 # for right (is_right=1), the paired is left. 

163 paired_val = core.where(is_right, left, right) 

164 paired_idx = core.where(is_right, left_idx, right_idx) 

165 

166 # Conditional swap following the official Triton pattern: 

167 # swap if (current > paired) differs from (flip ^ is_right). 

168 flip_right = (flip ^ is_right) != 0 

169 cond = (x > paired_val) != flip_right 

170 x = core.where(cond, paired_val, x) 

171 ids = core.where(cond, paired_idx, ids) 

172 

173 return x, ids 

174 

175 

176@triton.jit 

177def _bitonic_merge( 

178 x, ids, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr 

179): 

180 """ 

181 order_type 0 == ascending 

182 order_type 1 == descending 

183 order_type 2 == alternating 

184 """ 

185 n_outer: core.constexpr = x.numel >> n_dims 

186 core.static_assert(stage <= n_dims) 

187 # flip denotes whether to re-arrange sub-sequences of elements in ascending or 

188 # descending order. 

189 # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage 

190 # if flip = 00110011... then all the elements will be re-arranged alternatingly (with 

191 # a stride of 2) at this stage 

192 if order == 2: 

193 shape: core.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2**stage] 

194 flip = core.reshape( 

195 core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape 

196 ) 

197 else: 

198 flip = order 

199 # perform `stage` rounds of `compare-and-swap` 

200 for i in core.static_range(stage): 

201 x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims) 

202 return x, ids 

203 

204 

205@triton.jit 

206def argsort(x, ids, dim: tl.constexpr, descending: core.constexpr): 

207 # handle default dimension or check that it is the most minor dim 

208 _dim: core.constexpr = dim 

209 n_dims: core.constexpr = _log2(x.shape[_dim]) 

210 for i in core.static_range(1, n_dims + 1): 

211 x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims) 

212 return x, ids 

213 

214 

215@libentry() 

216@triton.jit 

217def topk_stage2_kernel( 

218 y_ptr, 

219 index_ptr, 

220 chunk_x, 

221 chunk_index, 

222 sort_dim: tl.constexpr, 

223 k: tl.constexpr, 

224 N: tl.constexpr, 

225 BLOCK_SIZE: tl.constexpr, 

226 DESCENDING: tl.constexpr, 

227): 

228 cur_batch = ext.program_id(0) 

229 chunk_x += cur_batch * N 

230 chunk_index += cur_batch * N 

231 y_ptr += cur_batch * k 

232 index_ptr += cur_batch * k 

233 

234 cols = tl.arange(0, BLOCK_SIZE) 

235 mask = cols < N 

236 

237 mask_val = _get_finfo_val(chunk_x.dtype.element_ty, return_max=not DESCENDING) 

238 mask_index_val = _MIN_INT32_VAL if DESCENDING else _MAX_INT32_VAL 

239 

240 chunk_x_val = tl.load(chunk_x + cols, mask=mask, other=mask_val).to(tl.float32) 

241 chunk_index_val = tl.load(chunk_index + cols, mask=mask, other=mask_index_val).to( 

242 tl.int32 

243 ) 

244 

245 sorted_chunk_x, sorted_chunk_index = argsort( 

246 chunk_x_val, chunk_index_val, 0, descending=DESCENDING 

247 ) 

248 tl.store(y_ptr + cols, sorted_chunk_x, mask=cols < k) 

249 tl.store(index_ptr + cols, sorted_chunk_index, mask=cols < k) 

250 

251 

252def topk(x, k, dim=-1, largest=True, sorted=True): 

253 logger.debug("GEMS TOPK") 

254 # If dim equals to last dim, we set it to -1. 

255 if dim < 0: 

256 dim = dim + x.ndim 

257 

258 assert dim == x.ndim - 1, "Currently only support topk in last dimension" 

259 # assert sorted, "Currently only support sorted == True" 

260 

261 # Early return for k=0 to avoid Triton kernel compilation error. 

262 # Triton's tl.arange(0, BLOCK_SIZE) requires BLOCK_SIZE > 0. 

263 # When k=0, stage2_elem_cnt becomes 0, leading to BLOCK_SIZE=0. 

264 if k == 0: 

265 out_shape = list(x.shape[:-1]) + [0] 

266 return ( 

267 torch.empty(out_shape, device=x.device, dtype=x.dtype), 

268 torch.empty(out_shape, device=x.device, dtype=torch.int64), 

269 ) 

270 

271 descending = True 

272 if not largest: 

273 descending = False 

274 

275 topk_elem_cnt = x.shape[dim] 

276 batch_size = math.prod(x.shape) // topk_elem_cnt 

277 

278 # Note(Zhengzekang): Maybe we should add a heuristic search in selecting a proper chunk size. 

279 if topk_elem_cnt < 1024: 

280 chunk_size = 256 

281 else: 

282 chunk_size = 1024 

283 

284 # Note(Zhengzekang): We should promise chunk_size is larger than k. 

285 if chunk_size < k: 

286 chunk_size = triton.next_power_of_2(k) 

287 

288 chunk_num = triton.cdiv(topk_elem_cnt, chunk_size) 

289 

290 stage1_out = torch.empty(batch_size * chunk_num * k, device=x.device, dtype=x.dtype) 

291 stage1_out_idx = torch.empty( 

292 batch_size * chunk_num * k, device=x.device, dtype=torch.int64 

293 ) 

294 

295 out_shape = x.shape[:-1] + (k,) 

296 stage2_out = torch.empty(out_shape, device=x.device, dtype=x.dtype) 

297 stage2_out_idx = torch.empty(out_shape, device=x.device, dtype=torch.int64) 

298 

299 with torch_device_fn.device(x.device): 

300 topk_stage1_kernel[ 

301 batch_size, 

302 chunk_num, 

303 ]( 

304 stage1_out, # pointer to the output 

305 stage1_out_idx, # pointer to the output 

306 x, # pointer to the input 

307 k, 

308 topk_elem_cnt, 

309 chunk_size, 

310 descending, 

311 ) 

312 stage2_elem_cnt = chunk_num * k 

313 

314 candidate_vals = stage1_out.view(batch_size, stage2_elem_cnt) 

315 candidate_indices = stage1_out_idx.view(batch_size, stage2_elem_cnt) 

316 # [sunrise fix] hits incorrect results once the stage2 bitonic sort spills 

317 # into the multi-warp path (BLOCK_SIZE >= 512). Reduce the candidate set 

318 # with additional stage1 passes until the final sort stays within 256 lanes. 

319 """ 

320 1. topk_stage2_kernel设置num_warps=1可以绕过 ptpu 后端 multi-warp reduction 的共享内存 path bug, 

321 根因是 ptpu 后端在 ReduceOpToLLVM.cpp 的 cross-warp reduction 路径存在共享内存线性化偏移计算问题 

322 — 官方 tl.sort() 在 N≥512 时也有同样的错误。 

323 2. 问题不只是“通用 inter-warp reduce lowering”这一处;至少在 topk 的完整 bitonic sort 路径里,还有别的 multi-warp 交互在出错。 

324 最可能的下一步不是继续硬改通用 ReduceOp,而是针对 topk_stage2_kernel 的某个具体 stage 做精确复现, 

325 直接盯 _compare_and_swap 后几轮的 TTIR/LLVM IR 

326 """ 

327 safe_stage2_elem_cnt = 256 

328 reduction_chunk_size = max(256, triton.next_power_of_2(k + 1)) 

329 while ( 

330 k <= safe_stage2_elem_cnt 

331 and stage2_elem_cnt > safe_stage2_elem_cnt 

332 and triton.next_power_of_2(stage2_elem_cnt) > safe_stage2_elem_cnt 

333 ): 

334 round_chunk_size = min(stage2_elem_cnt, reduction_chunk_size) 

335 round_chunk_num = triton.cdiv(stage2_elem_cnt, round_chunk_size) 

336 reduced_elem_cnt = round_chunk_num * k 

337 

338 reduced_vals = torch.empty( 

339 batch_size * reduced_elem_cnt, device=x.device, dtype=x.dtype 

340 ) 

341 reduced_local_indices = torch.empty( 

342 batch_size * reduced_elem_cnt, device=x.device, dtype=torch.int64 

343 ) 

344 

345 with torch_device_fn.device(x.device): 

346 topk_stage1_kernel[ 

347 batch_size, 

348 round_chunk_num, 

349 ]( 

350 reduced_vals, 

351 reduced_local_indices, 

352 candidate_vals, 

353 k, 

354 stage2_elem_cnt, 

355 round_chunk_size, 

356 descending, 

357 ) 

358 

359 candidate_indices = torch.gather( 

360 candidate_indices, 

361 1, 

362 reduced_local_indices.view(batch_size, reduced_elem_cnt).to(torch.int64), 

363 ).contiguous() 

364 candidate_vals = reduced_vals.view(batch_size, reduced_elem_cnt) 

365 stage2_elem_cnt = reduced_elem_cnt 

366 

367 BLOCK_SIZE = triton.next_power_of_2(stage2_elem_cnt) 

368 

369 with torch_device_fn.device(x.device): 

370 topk_stage2_kernel[batch_size,]( 

371 stage2_out, 

372 stage2_out_idx, 

373 candidate_vals, 

374 candidate_indices, 

375 dim, 

376 k, 

377 stage2_elem_cnt, 

378 BLOCK_SIZE, 

379 descending, 

380 ) 

381 

382 return (stage2_out, stage2_out_idx)