Coverage for src/flag_gems/runtime/backend/_sunrise/ops/sort.py: 0%

242 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7import flag_gems 

8from flag_gems.ops.topk import _get_finfo_val, _get_iinfo_val, argsort 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11 

12logger = logging.getLogger(__name__) 

13 

14 

15def next_power_of_2(n: int) -> int: 

16 n -= 1 

17 n |= n >> 1 

18 n |= n >> 2 

19 n |= n >> 4 

20 n |= n >> 8 

21 n |= n >> 16 

22 n |= n >> 32 

23 n += 1 

24 return n 

25 

26 

27def unwrap_if_constexpr(o): 

28 return o.value if isinstance(o, tl.constexpr) else o 

29 

30 

31@tl.constexpr 

32def get_int_t(num_bits: tl.constexpr, signed: tl.constexpr) -> tl.dtype: 

33 num_bits = unwrap_if_constexpr(num_bits) 

34 signed = unwrap_if_constexpr(signed) 

35 return tl.core.get_int_dtype(num_bits, signed) 

36 

37 

38@tl.constexpr 

39def one_zeros(num_bits: tl.constexpr) -> int: 

40 num_bits = unwrap_if_constexpr(num_bits) 

41 return 1 << (num_bits - 1) 

42 

43 

44@tl.constexpr 

45def zero_ones(num_bits: tl.constexpr) -> int: 

46 num_bits = unwrap_if_constexpr(num_bits) 

47 return (1 << (num_bits - 1)) - 1 

48 

49 

50@triton.jit 

51def uint_to_uint(x, descending: tl.constexpr = False): 

52 out = ~x if descending else x 

53 return out 

54 

55 

56@triton.jit 

57def int_to_uint(x, descending: tl.constexpr = False): 

58 num_bits: tl.constexpr = x.dtype.primitive_bitwidth 

59 udtype = get_int_t(num_bits, False) 

60 ux = tl.cast(x, udtype, bitcast=True) 

61 if descending: 

62 # 0111111....1 

63 bit_mask: tl.constexpr = zero_ones(num_bits) 

64 bit_mask_tensor = tl.full((), value=bit_mask, dtype=udtype) 

65 out = ux ^ bit_mask_tensor 

66 else: 

67 # 1000000...0 

68 sign_bit_mask: tl.constexpr = one_zeros(num_bits) 

69 sign_bit_mask_tensor = tl.full((), value=sign_bit_mask, dtype=udtype) 

70 out = ux ^ sign_bit_mask_tensor 

71 return out 

72 

73 

74@triton.jit 

75def floating_to_uint(x, descending: tl.constexpr = False): 

76 num_bits: tl.constexpr = x.dtype.primitive_bitwidth 

77 sdtype = get_int_t(num_bits, True) 

78 udtype = get_int_t(num_bits, False) 

79 sx = x.to(sdtype, bitcast=True) 

80 ux = x.to(udtype, bitcast=True) 

81 

82 sign_bit_mask_v: tl.constexpr = one_zeros(num_bits) 

83 sign_bit_mask = tl.full((), value=sign_bit_mask_v, dtype=udtype) 

84 # mind the dtype, right_shift for signed is arithmetic right shift 

85 # Fix for triton 3.1 or else `sx >> rshift_bits` is promoted to int32 

86 rshift_bits = tl.full((), value=num_bits - 1, dtype=sdtype) 

87 mask = sign_bit_mask | (sx >> rshift_bits).to(udtype, bitcast=True) 

88 tl.static_assert(mask.dtype == udtype, "type mismatch") 

89 # 1000000000...0 for positive 

90 # 1111111111...1 for negative 

91 if descending: 

92 out = ux ^ (~mask) 

93 else: 

94 out = ux ^ mask 

95 return out.to(udtype, bitcast=True) 

96 

97 

98@triton.jit 

99def convert_to_uint_preverse_order(x: tl.tensor, descending: tl.constexpr = False): 

100 if x.dtype.is_floating(): 

101 out = floating_to_uint(x, descending) 

102 elif x.dtype.is_int_signed(): 

103 out = int_to_uint(x, descending) 

104 elif x.dtype.is_int_unsigned(): 

105 out = uint_to_uint(x, descending) 

106 return out 

107 

108 

109@triton.jit 

110def compute_global_hist_kernel( 

111 arr_ptr, 

112 out_ptr, 

113 num_passes, 

114 m, 

115 n, 

116 tiles_n_per_cta, 

117 TILE_N: tl.constexpr, 

118 TILE_R: tl.constexpr, 

119 num_bits_per_pass: tl.constexpr, 

120 descending: tl.constexpr, 

121 USE_UINT16: tl.constexpr, 

122): 

123 # arr_ptr: (m, n) 

124 # out_ptr: (m, n_passes, r), where r = 2 ** k_bits is the number of bins 

125 pid = tl.program_id(0) 

126 pid_n = pid // m 

127 pid_m = pid % m 

128 

129 r: tl.constexpr = 2**num_bits_per_pass 

130 bfe_mask: tl.constexpr = (1 << num_bits_per_pass) - 1 # a.k.a. 2 ** k_bits - 1 

131 CTA_TILE_N: tl.constexpr = TILE_N * tiles_n_per_cta 

132 cta_n_start = CTA_TILE_N * pid_n 

133 dsize = cta_n_start + CTA_TILE_N 

134 cta_n_end = tl.where(dsize < n, dsize, n) 

135 

136 arr_partial_ptr = arr_ptr + pid_m * n 

137 range_tile_r = tl.arange(0, TILE_R) 

138 range_tile_n = tl.arange(0, TILE_N) 

139 acc_type = tl.int32 

140 if tl.constexpr(USE_UINT16): 

141 acc_type = tl.uint16 

142 for p in range(0, num_passes): # parallel 

143 bit_offset = p * num_bits_per_pass 

144 for r_start in range(0, r, TILE_R): # parallel 

145 bin_indices = r_start + range_tile_r 

146 acc = tl.zeros((TILE_R, TILE_N), dtype=acc_type) 

147 for n_start in range(cta_n_start, cta_n_end, TILE_N): # sequantial 

148 n_offsets = n_start + range_tile_n # (TILE_N, ) 

149 mask = n_offsets < cta_n_end 

150 arr = tl.load(arr_partial_ptr + n_offsets, mask=mask) 

151 arr = convert_to_uint_preverse_order(arr, descending) 

152 key = (arr >> bit_offset) & bfe_mask # (TILE_N, ) 

153 matches = tl.where( 

154 mask, (bin_indices[:, None] == key).to(acc_type), 0 

155 ) # (TILE_R, TILE_N) 

156 acc += matches 

157 local_sum = tl.sum(acc, axis=1).to(tl.int32) 

158 tl.atomic_add( 

159 out_ptr + (pid_m * num_passes + p) * r + bin_indices, 

160 local_sum, 

161 sem="relaxed", 

162 ) 

163 

164 

165@triton.jit 

166def sweep( 

167 arr_ptr, 

168 associate_arr_ptr, # inputs: (key & value) 

169 out_ptr, 

170 associate_out_ptr, # outputs: (key & value) 

171 excumsum_bins_ptr, 

172 status_ptr, # aux input and status 

173 n_passes, 

174 pass_id, 

175 bit_offset, 

176 m, 

177 N, 

178 OUT_N, 

179 TILE_N: tl.constexpr, 

180 TILE_R: tl.constexpr, 

181 k_bits: tl.constexpr, 

182 descending: tl.constexpr, 

183 USE_UINT16: tl.constexpr, 

184): 

185 # r: num_bins = 2 ** k_bits 

186 # OUT_N: grid_n = cdiv(N, ) 

187 

188 # arr_ptr: (m, N) 

189 # out_ptr: (m, N) 

190 # excumsum_bins_ptr: (m, n_passes, r) 

191 # flag_ptr: (m, r, OUT_N) 

192 

193 # grid: (m, grid_r, grid_n) 

194 

195 # load data 

196 pid = tl.program_id(0) 

197 pid_m = pid % m 

198 pid_n = pid // m 

199 pid_r = tl.program_id(1) 

200 

201 # bit masks 

202 aggregate_mask: tl.constexpr = 1 << 30 

203 inclusive_prefix_mask: tl.constexpr = 1 << 31 

204 v_mask: tl.constexpr = (1 << 30) - 1 

205 bfe_mask: tl.constexpr = (1 << k_bits) - 1 # a.k.a. 2 ** k_bits - 1 

206 

207 # initialize flag to zero-local sum is not ready 

208 r: tl.constexpr = 2**k_bits 

209 cta_r_start = pid_r * TILE_R 

210 cta_r_end = tl.minimum(cta_r_start + TILE_R, r) 

211 

212 # cumsum for a bin_index 

213 n_offsets = pid_n * TILE_N + tl.arange(0, TILE_N) # (TILE_N, ) 

214 mask = n_offsets < N 

215 arr = tl.load(arr_ptr + pid_m * N + n_offsets, mask=mask) 

216 arr_u = convert_to_uint_preverse_order(arr, descending) 

217 key = (arr_u >> bit_offset) & bfe_mask # (TILE_N, ) 

218 if associate_arr_ptr is not None: 

219 associate_arr = tl.load(associate_arr_ptr + pid_m * N + n_offsets, mask=mask) 

220 

221 dt = tl.uint32 

222 if tl.constexpr(USE_UINT16): 

223 dt = tl.uint16 

224 

225 # since triton can only use scalar as condition, loop by bin_index 

226 # status must be pre zero-initialized, or else we have to initialize it 

227 for bin_index in range(cta_r_start, cta_r_end): 

228 matches = tl.where(mask, key == bin_index, False) # (TILE_N, ) bool 

229 # cta level cumsum per bin 

230 # CAUTION: tl.sum in triton 3.2 does not promote type 

231 local_sum = tl.sum(matches.to(dtype=dt), axis=0).to(tl.uint32) 

232 pack0 = aggregate_mask | local_sum 

233 status_offset = (pid_m * r + bin_index) * OUT_N + pid_n 

234 tl.store(status_ptr + status_offset, pack0, cache_modifier=".cg") 

235 

236 # decoupled lookback 

237 exclusive_prefix = tl.zeros((), dtype=tl.uint32) 

238 i_lookback = pid_n - 1 

239 while i_lookback >= 0: 

240 flag_offset_i = pid_m * (r * OUT_N) + bin_index * OUT_N + i_lookback 

241 pack1 = tl.load(status_ptr + flag_offset_i, volatile=True) # uin32 

242 while pack1 == 0: 

243 pack1 = tl.load(status_ptr + flag_offset_i, volatile=True) 

244 exclusive_prefix += pack1 & v_mask 

245 i_lookback = ((pack1 & aggregate_mask) == aggregate_mask) * i_lookback - 1 

246 pack2 = inclusive_prefix_mask | (exclusive_prefix + local_sum) 

247 tl.store(status_ptr + status_offset, pack2, cache_modifier=".cg") 

248 

249 local_ex_cumsum = ( 

250 tl.cumsum(matches.to(dt), axis=0).to(tl.uint32) - matches 

251 ) # (TILE_N, ) 

252 ex_cumsum_in_bin = ( 

253 exclusive_prefix + local_ex_cumsum 

254 ) # global ex_cumsum_in_bin (TILE_N, ) 

255 

256 # ex_cumsum_bins (m, n_passes, r) 

257 ex_cumsum_bins = tl.load( 

258 excumsum_bins_ptr + (pid_m * n_passes + pass_id) * r + bin_index 

259 ) # scalar 

260 pos = ex_cumsum_bins + ex_cumsum_in_bin # (TILE_N, ) 

261 

262 # scatter 

263 tl.store(out_ptr + pid_m * N + pos, arr, mask=matches) 

264 if associate_arr_ptr is not None: 

265 tl.store(associate_out_ptr + pid_m * N + pos, associate_arr, mask=matches) 

266 

267 

268def radix_sort(arr, k_bits=8, descending=False): 

269 n = arr.shape[-1] 

270 m = arr.numel() // n 

271 assert n < (1 << 30), "we have not implemented 2**30 per launch" 

272 dtype = arr.dtype 

273 num_bits = 1 if dtype == torch.bool else (arr.itemsize * 8) 

274 

275 TILE_N = 128 

276 if m > 2048 and n < 512: 

277 TILE_N = 512 

278 

279 tiles_n_per_cta = 8 

280 CTA_TILE_N = tiles_n_per_cta * TILE_N 

281 

282 num_bins = 1 << k_bits 

283 n_passes = triton.cdiv(num_bits, k_bits) 

284 TILE_R = 4 

285 

286 grid_n = triton.cdiv(n, CTA_TILE_N) 

287 grid_for_global_hist = (m * grid_n, 1, 1) 

288 

289 USE_UINT16 = False 

290 

291 with torch_device_fn.device(arr.device): 

292 global_hist = torch.zeros( 

293 (m, n_passes, num_bins), device=arr.device, dtype=torch.int32 

294 ) 

295 compute_global_hist_kernel[grid_for_global_hist]( 

296 arr, 

297 global_hist, 

298 n_passes, 

299 m, 

300 n, 

301 tiles_n_per_cta, 

302 TILE_N, 

303 TILE_R, 

304 k_bits, 

305 descending, 

306 USE_UINT16, 

307 ) 

308 ex_cumsum_bins = flag_gems.sub( 

309 flag_gems.cumsum(global_hist, -1), global_hist 

310 ) # [DIPU] cumsum结果错误 

311 ex_cumsum_bins = ex_cumsum_bins.to(torch.uint32) 

312 

313 # sort 

314 # arr_in = torch.clone(arr) 

315 indices_in = ( 

316 torch.arange(0, n, dtype=torch.int32, device=arr.device) 

317 .broadcast_to(arr.shape) 

318 .contiguous() 

319 ) 

320 arr_out = torch.empty_like(arr) 

321 indices_out = torch.empty( 

322 indices_in.shape, device=indices_in.device 

323 ).as_strided(indices_in.shape, indices_in.stride()) 

324 

325 TILE_R = 2 if n > 2048 else num_bins 

326 grid_r = triton.cdiv(num_bins, TILE_R) 

327 TILE_N = 2048 

328 if n > 32768: 

329 TILE_N = 2048 

330 elif m > 2048 and n <= 128: 

331 TILE_N = 128 

332 elif m < 32 and n > 8096: 

333 TILE_N = 256 

334 elif m < 32 and n > 2048: 

335 TILE_N = 256 

336 grid_n = triton.cdiv(n, TILE_N) 

337 grid_for_sweep = (m * grid_n, grid_r) 

338 

339 USE_UINT16 = n <= 4096 

340 

341 status = torch.empty( 

342 (m, num_bins, grid_n), device=arr.device, dtype=torch.uint32 

343 ) 

344 

345 for i in range(0, n_passes): 

346 bit_offset = i * k_bits 

347 status = status.zero_() 

348 sweep[grid_for_sweep]( 

349 arr, 

350 indices_in, 

351 arr_out, 

352 indices_out, 

353 ex_cumsum_bins, 

354 status, 

355 n_passes, 

356 i, 

357 bit_offset, 

358 m, 

359 n, 

360 grid_n, 

361 TILE_N, 

362 TILE_R, 

363 k_bits, 

364 descending, 

365 USE_UINT16, 

366 ) 

367 # print(f"< sorted last {bit_offset + k_bits:>2d} bits: {arr_out}") 

368 arr, arr_out = arr_out, arr 

369 indices_in, indices_out = indices_out, indices_in 

370 

371 return arr, indices_in 

372 

373 

374@libentry() 

375@triton.jit() 

376def sort_kernel( 

377 in_ptr, 

378 out_ptr, 

379 out_index_ptr, 

380 N: tl.constexpr, 

381 BLOCK_SIZE: tl.constexpr, 

382 DESCENDING: tl.constexpr, 

383 IS_FLOAT: tl.constexpr, 

384): 

385 cols = tl.arange(0, BLOCK_SIZE) 

386 mask = cols < N 

387 offset = tl.program_id(0) * N + cols 

388 in_ptr += offset 

389 out_ptr += offset 

390 out_index_ptr += offset 

391 

392 if IS_FLOAT: 

393 mask_val = _get_finfo_val(in_ptr.dtype.element_ty, return_max=not DESCENDING) 

394 in_val = tl.load(in_ptr, mask=mask, other=mask_val) 

395 else: 

396 mask_val = _get_iinfo_val(in_ptr.dtype.element_ty, return_max=not DESCENDING) 

397 in_val = tl.load(in_ptr, mask=mask, other=mask_val) 

398 

399 index_val = tl.arange(0, BLOCK_SIZE) 

400 

401 sorted_in_val, sorted_index_val = argsort( 

402 in_val, index_val, 0, descending=DESCENDING 

403 ) 

404 tl.store(out_ptr, sorted_in_val, mask=mask) 

405 tl.store(out_index_ptr, sorted_index_val, mask=mask) 

406 

407 

408def sort(inp, dim=-1, descending=False): 

409 # We only implement stable radix sort here 

410 logger.debug("GEMS SORT") 

411 return sort_stable(inp, stable=False, dim=dim, descending=descending) 

412 

413 

414def sort_stable(inp, *, stable, dim=-1, descending=False): 

415 device = inp.device 

416 logger.debug("GEMS SORT.STABLE") 

417 # We only implement stable radix sort here 

418 _ = stable 

419 sort_elem_cnt = inp.shape[dim] 

420 if sort_elem_cnt == 1: 

421 return inp, torch.zeros(inp.shape, dtype=torch.int64, device=inp.device) 

422 

423 if dim < 0: 

424 dim = dim + inp.ndim 

425 if dim != inp.ndim - 1: 

426 inp = torch.movedim(inp.cpu(), dim, -1).contiguous().to(device=device) 

427 else: 

428 inp = inp.contiguous() 

429 

430 dtype = inp.dtype 

431 num_bits_per_pass = 1 if dtype == torch.bool else 2 

432 out, out_index = radix_sort(inp, num_bits_per_pass, descending) 

433 

434 if dim != inp.ndim - 1: 

435 out = torch.movedim(out, -1, dim) 

436 out_index = torch.movedim(out_index, -1, dim) 

437 # [sunrise fix] 殷文达反馈 -> “返回 return out, out_index.to(torch.int64) 应该是返回了内部mem,返回之后,内部的mem被冲掉了,数据没了” 

438 return out.clone(), out_index.to(torch.int64).clone()