Coverage for src/flag_gems/runtime/backend/_arm/ops/topk.py: 0%

195 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2import math 

3 

4import numpy as np 

5import torch 

6import triton 

7import triton.language as tl 

8import triton.language.core as core 

9 

10try: 

11 # TODO: Triton 2.1 does not implement _log2. 

12 # Remove the try-catch block once all vendors upgrade to a newer version of Triton. 

13 from triton.language.standard import _log2, zeros_like 

14except ImportError: 

15 pass 

16 

17from flag_gems.utils import triton_lang_extension as tle 

18from flag_gems.utils.limits import get_dtype_max, get_dtype_min 

19 

20logger = logging.getLogger(__name__) 

21_MIN_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).min) 

22_MAX_FLOAT32_VAL = tl.constexpr(torch.finfo(torch.float32).max) 

23_MIN_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).min) 

24_MAX_FLOAT16_VAL = tl.constexpr(torch.finfo(torch.float16).max) 

25_MIN_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).min) 

26_MAX_BFLOAT16_VAL = tl.constexpr(torch.finfo(torch.bfloat16).max) 

27_MIN_INT8_VAL = tl.constexpr(torch.iinfo(torch.int8).min) 

28_MAX_INT8_VAL = tl.constexpr(torch.iinfo(torch.int8).max) 

29_MIN_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).min) 

30_MAX_INT16_VAL = tl.constexpr(torch.iinfo(torch.int16).max) 

31_MIN_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).min) 

32_MAX_INT32_VAL = tl.constexpr(torch.iinfo(torch.int32).max) 

33_MIN_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).min) 

34_MAX_INT64_VAL = tl.constexpr(torch.iinfo(torch.int64).max) 

35 

36 

37@triton.jit 

38def _get_finfo_val( 

39 dtype, 

40 return_max, 

41): 

42 if dtype is tl.float32: 

43 if return_max: 

44 return _MAX_FLOAT32_VAL 

45 else: 

46 return _MIN_FLOAT32_VAL 

47 elif dtype is tl.float16: 

48 if return_max: 

49 return _MAX_FLOAT16_VAL 

50 else: 

51 return _MIN_FLOAT16_VAL 

52 elif dtype is tl.bfloat16: 

53 if return_max: 

54 return _MAX_BFLOAT16_VAL 

55 else: 

56 return _MIN_BFLOAT16_VAL 

57 

58 

59@triton.jit 

60def _get_iinfo_val( 

61 dtype, 

62 return_max, 

63): 

64 if return_max: 

65 return get_dtype_max(dtype) 

66 else: 

67 return get_dtype_min(dtype) 

68 

69 

70# @libentry() 

71@triton.jit 

72def topk_stage1_kernel( 

73 y_ptr, 

74 index_ptr, 

75 x_ptr, 

76 k, 

77 N: tl.constexpr, 

78 CHUNK_SIZE: tl.constexpr, 

79 DESCENDING: tl.constexpr, 

80): 

81 cur_batch = tle.program_id(0) 

82 cur_chunk_idx = tle.program_id(1) 

83 chunk_num = tle.num_programs(1) 

84 

85 y_ptr += cur_batch * chunk_num * k + cur_chunk_idx * k 

86 index_ptr += cur_batch * chunk_num * k + cur_chunk_idx * k 

87 

88 chunk_offset = cur_chunk_idx * CHUNK_SIZE 

89 x_ptr += cur_batch * N + chunk_offset 

90 

91 cols = tl.arange(0, CHUNK_SIZE) 

92 mask = (chunk_offset + cols) < N 

93 

94 mask_val = _get_finfo_val(x_ptr.dtype.element_ty, return_max=not DESCENDING) 

95 x_val = tl.load(x_ptr + cols, mask=mask, other=mask_val).to(tl.float32) 

96 for k_idx in range(k): 

97 if DESCENDING: 

98 chunk_select_val = tl.max(x_val) 

99 chunk_select_idx = tl.argmax(x_val, axis=0) 

100 else: 

101 chunk_select_val = tl.min(x_val) 

102 chunk_select_idx = tl.argmin(x_val, axis=0) 

103 

104 tl.store(y_ptr + k_idx, chunk_select_val) 

105 tl.store(index_ptr + k_idx, chunk_select_idx + chunk_offset) 

106 

107 if DESCENDING: 

108 x_val = tl.where( 

109 cols == chunk_select_idx, 

110 _get_finfo_val(tl.float32, return_max=False), 

111 x_val, 

112 ) 

113 else: 

114 x_val = tl.where( 

115 cols == chunk_select_idx, 

116 _get_finfo_val(tl.float32, return_max=True), 

117 x_val, 

118 ) 

119 

120 

121""" 

122Note(Zhengzekang): 

123Refer from triton2.2 official `sort` implementation: 

124https://github.com/triton-lang/triton/blob/release/2.2.x/python/triton/language/standard.py#L392-L404 

125Just add indices to sort with values. 

126""" 

127 

128 

129@triton.jit 

130def _compare_and_swap(x, ids, flip, i: core.constexpr, n_dims: core.constexpr): 

131 n_outer: core.constexpr = x.numel >> n_dims 

132 shape: core.constexpr = [n_outer * 2**i, 2, 2 ** (n_dims - i - 1)] 

133 

134 # tl.device_print("shape is: ", shape) 

135 y = core.reshape(x, shape) 

136 y_idx = core.reshape(ids, shape) 

137 

138 # slice left/right with 'stride' 2**(n_dims - i - 1) 

139 mask = core.arange(0, 2)[None, :, None] 

140 left = core.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape).to(x.dtype) 

141 right = core.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape).to(x.dtype) 

142 left = core.reshape(left, x.shape) 

143 right = core.reshape(right, x.shape) 

144 

145 left_idx = core.broadcast_to(tl.sum(y_idx * (1 - mask), 1)[:, None, :], shape).to( 

146 ids.dtype 

147 ) 

148 right_idx = core.broadcast_to(tl.sum(y_idx * mask, 1)[:, None, :], shape).to( 

149 ids.dtype 

150 ) 

151 left_idx = core.reshape(left_idx, ids.shape) 

152 right_idx = core.reshape(right_idx, ids.shape) 

153 

154 # actual compare-and-swap 

155 if core.constexpr(x.dtype.primitive_bitwidth) == 8: 

156 idtype = core.int8 

157 elif core.constexpr(x.dtype.primitive_bitwidth) == 16: 

158 idtype = core.int16 

159 elif core.constexpr(x.dtype.primitive_bitwidth) == 32: 

160 idtype = core.int32 

161 elif core.constexpr(x.dtype.primitive_bitwidth) == 64: 

162 idtype = core.int64 

163 else: 

164 raise ValueError("Unsupported dtype") 

165 

166 ileft = left.to(idtype, bitcast=True) 

167 iright = right.to(idtype, bitcast=True) 

168 ix = x.to(idtype, bitcast=True) 

169 

170 cond = (left > right) ^ flip 

171 ret = ix ^ core.where(cond, ileft ^ iright, zeros_like(ix)) 

172 

173 if core.constexpr(ids.dtype.primitive_bitwidth) == 8: 

174 idx_dtype = core.int8 

175 elif core.constexpr(ids.dtype.primitive_bitwidth) == 16: 

176 idx_dtype = core.int16 

177 elif core.constexpr(ids.dtype.primitive_bitwidth) == 32: 

178 idx_dtype = core.int32 

179 elif core.constexpr(ids.dtype.primitive_bitwidth) == 64: 

180 idx_dtype = core.int64 

181 else: 

182 raise ValueError("Unsupported dtype") 

183 

184 ileft_idx = left_idx.to(idx_dtype, bitcast=True) 

185 iright_idx = right_idx.to(idx_dtype, bitcast=True) 

186 ix_idx = ids.to(idx_dtype, bitcast=True) 

187 ret_idx = ix_idx ^ core.where(cond, ileft_idx ^ iright_idx, zeros_like(ix_idx)) 

188 

189 return ret.to(x.dtype, bitcast=True), ret_idx.to(ids.dtype, bitcast=True) 

190 

191 

192@triton.jit 

193def _bitonic_merge( 

194 x, ids, stage: core.constexpr, order: core.constexpr, n_dims: core.constexpr 

195): 

196 """ 

197 order_type 0 == ascending 

198 order_type 1 == descending 

199 order_type 2 == alternating 

200 """ 

201 n_outer: core.constexpr = x.numel >> n_dims 

202 core.static_assert(stage <= n_dims) 

203 # flip denotes whether to re-arrange sub-sequences of elements in ascending or 

204 # descending order. 

205 # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage 

206 # if flip = 00110011... then all the elements will be re-arranged alternatingly (with 

207 # a stride of 2) at this stage 

208 if order == 2: 

209 shape: core.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2**stage] 

210 flip = core.reshape( 

211 core.broadcast_to(core.arange(0, 2)[None, :, None], shape), x.shape 

212 ) 

213 else: 

214 flip = order 

215 # perform `stage` rounds of `compare-and-swap` 

216 for i in core.static_range(stage): 

217 x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims) 

218 return x, ids 

219 

220 

221@triton.jit 

222def argsort(x, ids, dim: tl.constexpr, descending: core.constexpr): 

223 # handle default dimension or check that it is the most minor dim 

224 _dim: core.constexpr = dim 

225 n_dims: core.constexpr = _log2(x.shape[_dim]) 

226 for i in core.static_range(1, n_dims + 1): 

227 x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims) 

228 return x, ids 

229 

230 

231# @libentry() 

232@triton.jit 

233def topk_stage2_kernel( 

234 y_ptr, 

235 index_ptr, 

236 chunk_x, 

237 chunk_index, 

238 sort_dim: tl.constexpr, 

239 k: tl.constexpr, 

240 N: tl.constexpr, 

241 BLOCK_SIZE: tl.constexpr, 

242 DESCENDING: tl.constexpr, 

243): 

244 cur_batch = tle.program_id(0) 

245 chunk_x += cur_batch * N 

246 chunk_index += cur_batch * N 

247 y_ptr += cur_batch * k 

248 index_ptr += cur_batch * k 

249 

250 cols = tl.arange(0, BLOCK_SIZE) 

251 mask = cols < N 

252 

253 mask_val = _get_finfo_val(chunk_x.dtype.element_ty, return_max=not DESCENDING) 

254 mask_index_val = _MIN_INT32_VAL if DESCENDING else _MAX_INT32_VAL 

255 

256 chunk_x_val = tl.load(chunk_x + cols, mask=mask, other=mask_val).to(tl.float32) 

257 chunk_index_val = tl.load(chunk_index + cols, mask=mask, other=mask_index_val).to( 

258 tl.int32 

259 ) 

260 

261 sorted_chunk_x, sorted_chunk_index = argsort( 

262 chunk_x_val, chunk_index_val, 0, descending=DESCENDING 

263 ) 

264 tl.store(y_ptr + cols, sorted_chunk_x, mask=cols < k) 

265 tl.store(index_ptr + cols, sorted_chunk_index, mask=cols < k) 

266 

267 

268def topk(x, k, dim=-1, largest=True, sorted=True): 

269 logger.debug("GEMS TOPK") 

270 if dim < 0: 

271 dim = dim + x.ndim 

272 

273 if x.device.type == "cpu": 

274 # CPU Triton topk kernel is unstable for some shapes/dtypes. Use a 

275 # deterministic host fallback for correctness. 

276 x_np = x.detach().cpu().to(torch.float32).numpy() 

277 if largest: 

278 part = np.argpartition(x_np, x_np.shape[dim] - k, axis=dim) 

279 idx = np.take( 

280 part, indices=range(x_np.shape[dim] - k, x_np.shape[dim]), axis=dim 

281 ) 

282 vals = np.take_along_axis(x_np, idx, axis=dim) 

283 if sorted: 

284 order = np.flip(np.argsort(vals, axis=dim), axis=dim) 

285 idx = np.take_along_axis(idx, order, axis=dim) 

286 vals = np.take_along_axis(vals, order, axis=dim) 

287 else: 

288 part = np.argpartition(x_np, k - 1, axis=dim) 

289 idx = np.take(part, indices=range(k), axis=dim) 

290 vals = np.take_along_axis(x_np, idx, axis=dim) 

291 if sorted: 

292 order = np.argsort(vals, axis=dim) 

293 idx = np.take_along_axis(idx, order, axis=dim) 

294 vals = np.take_along_axis(vals, order, axis=dim) 

295 vals_t = torch.from_numpy(vals).to(device=x.device, dtype=x.dtype) 

296 idx_t = torch.from_numpy(idx.astype(np.int64, copy=False)).to(device=x.device) 

297 return vals_t, idx_t 

298 

299 # If dim equals to last dim, we set it to -1. 

300 assert dim == x.ndim - 1, "Currently only support topk in last dimension" 

301 # assert sorted, "Currently only support sorted == True" 

302 

303 descending = True 

304 if not largest: 

305 descending = False 

306 

307 topk_elem_cnt = x.shape[dim] 

308 batch_size = math.prod(x.shape) // topk_elem_cnt 

309 

310 # Note(Zhengzekang): Maybe we should add a heuristic search in selecting a proper chunk size. 

311 if topk_elem_cnt < 1024: 

312 chunk_size = 256 

313 else: 

314 chunk_size = 1024 

315 

316 # Note(Zhengzekang): We should promise chunk_size is larger than k. 

317 if chunk_size < k: 

318 chunk_size = triton.next_power_of_2(k) 

319 

320 chunk_num = triton.cdiv(topk_elem_cnt, chunk_size) 

321 

322 stage1_out = torch.empty(batch_size * chunk_num * k, device=x.device, dtype=x.dtype) 

323 stage1_out_idx = torch.empty( 

324 batch_size * chunk_num * k, device=x.device, dtype=torch.int64 

325 ) 

326 

327 out_shape = x.shape[:-1] + (k,) 

328 stage2_out = torch.empty(out_shape, device=x.device, dtype=x.dtype) 

329 stage2_out_idx = torch.empty(out_shape, device=x.device, dtype=torch.int64) 

330 

331 # with torch_device_fn.device(x.device): 

332 topk_stage1_kernel[ 

333 batch_size, 

334 chunk_num, 

335 ]( 

336 stage1_out, # pointer to the output 

337 stage1_out_idx, # pointer to the output 

338 x, # pointer to the input 

339 k, 

340 topk_elem_cnt, 

341 chunk_size, 

342 descending, 

343 ) 

344 stage2_elem_cnt = chunk_num * k 

345 BLOCK_SIZE = triton.next_power_of_2(stage2_elem_cnt) 

346 

347 # with torch_device_fn.device(x.device): 

348 topk_stage2_kernel[batch_size,]( 

349 stage2_out, 

350 stage2_out_idx, 

351 stage1_out, 

352 stage1_out_idx, 

353 dim, 

354 k, 

355 stage2_elem_cnt, 

356 BLOCK_SIZE, 

357 descending, 

358 ) 

359 

360 return (stage2_out, stage2_out_idx)