Coverage for src/flag_gems/ops/sort.py: 41%

211 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.ops.topk import _get_finfo_val, _get_iinfo_val, argsort 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14def unwrap_if_constexpr(o): 

15 return o.value if isinstance(o, tl.constexpr) else o 

16 

17 

18@tl.constexpr 

19def get_int_t(num_bits: tl.constexpr, signed: tl.constexpr) -> tl.dtype: 

20 num_bits = unwrap_if_constexpr(num_bits) 

21 signed = unwrap_if_constexpr(signed) 

22 return tl.core.get_int_dtype(num_bits, signed) 

23 

24 

25@tl.constexpr 

26def one_zeros(num_bits: tl.constexpr) -> int: 

27 num_bits = unwrap_if_constexpr(num_bits) 

28 return 1 << (num_bits - 1) 

29 

30 

31@tl.constexpr 

32def zero_ones(num_bits: tl.constexpr) -> int: 

33 num_bits = unwrap_if_constexpr(num_bits) 

34 return (1 << (num_bits - 1)) - 1 

35 

36 

37@triton.jit 

38def uint_to_uint(x, descending: tl.constexpr = False): 

39 out = ~x if descending else x 

40 return out 

41 

42 

43@triton.jit 

44def int_to_uint(x, descending: tl.constexpr = False): 

45 num_bits: tl.constexpr = x.dtype.primitive_bitwidth 

46 udtype = get_int_t(num_bits, False) 

47 ux = tl.cast(x, udtype, bitcast=True) 

48 if descending: 

49 # 0111111....1 

50 bit_mask: tl.constexpr = zero_ones(num_bits) 

51 bit_mask_tensor = tl.full((), value=bit_mask, dtype=udtype) 

52 out = ux ^ bit_mask_tensor 

53 else: 

54 # 1000000...0 

55 sign_bit_mask: tl.constexpr = one_zeros(num_bits) 

56 sign_bit_mask_tensor = tl.full((), value=sign_bit_mask, dtype=udtype) 

57 out = ux ^ sign_bit_mask_tensor 

58 return out 

59 

60 

61@triton.jit 

62def floating_to_uint(x, descending: tl.constexpr = False): 

63 num_bits: tl.constexpr = x.dtype.primitive_bitwidth 

64 sdtype = get_int_t(num_bits, True) 

65 udtype = get_int_t(num_bits, False) 

66 sx = x.to(sdtype, bitcast=True) 

67 ux = x.to(udtype, bitcast=True) 

68 

69 sign_bit_mask_v: tl.constexpr = one_zeros(num_bits) 

70 sign_bit_mask = tl.full((), value=sign_bit_mask_v, dtype=udtype) 

71 # mind the dtype, right_shift for signed is arithmetic right shift 

72 # Fix for triton 3.1 or else `sx >> rshift_bits` is promoted to int32 

73 rshift_bits = tl.full((), value=num_bits - 1, dtype=sdtype) 

74 mask = sign_bit_mask | (sx >> rshift_bits).to(udtype, bitcast=True) 

75 tl.static_assert(mask.dtype == udtype, "type mismatch") 

76 # 1000000000...0 for positive 

77 # 1111111111...1 for negative 

78 if descending: 

79 out = ux ^ (~mask) 

80 else: 

81 out = ux ^ mask 

82 return out.to(udtype, bitcast=True) 

83 

84 

85@triton.jit 

86def convert_to_uint_preverse_order(x: tl.tensor, descending: tl.constexpr = False): 

87 if x.dtype.is_floating(): 

88 out = floating_to_uint(x, descending) 

89 elif x.dtype.is_int_signed(): 

90 out = int_to_uint(x, descending) 

91 elif x.dtype.is_int_unsigned(): 

92 out = uint_to_uint(x, descending) 

93 return out 

94 

95 

96@triton.jit 

97def compute_global_hist_kernel( 

98 arr_ptr, 

99 out_ptr, 

100 num_passes, 

101 m, 

102 n, 

103 tiles_n_per_cta, 

104 TILE_N: tl.constexpr, 

105 TILE_R: tl.constexpr, 

106 num_bits_per_pass: tl.constexpr, 

107 descending: tl.constexpr, 

108): 

109 # arr_ptr: (m, n) 

110 # out_ptr: (m, n_passes, r), where r = 2 ** k_bits is the number of bins 

111 pid = tl.program_id(0) 

112 pid_n = pid // m 

113 pid_m = pid % m 

114 

115 r: tl.constexpr = 2**num_bits_per_pass 

116 bfe_mask: tl.constexpr = (1 << num_bits_per_pass) - 1 # a.k.a. 2 ** k_bits - 1 

117 CTA_TILE_N: tl.constexpr = TILE_N * tiles_n_per_cta 

118 cta_n_start = CTA_TILE_N * pid_n 

119 cta_n_end = tl.minimum(cta_n_start + CTA_TILE_N, n) 

120 

121 for p in range(0, num_passes): # parallel 

122 bit_offset = p * num_bits_per_pass 

123 for r_start in range(0, r, TILE_R): # parallel 

124 bin_indices = r_start + tl.arange(0, TILE_R) 

125 acc = tl.zeros((TILE_R, TILE_N), dtype=tl.int64) 

126 for n_start in range(cta_n_start, cta_n_end, TILE_N): # sequantial 

127 n_offsets = n_start + tl.arange(0, TILE_N) # (TILE_N, ) 

128 mask = n_offsets < cta_n_end 

129 arr = tl.load(arr_ptr + pid_m * n + n_offsets, mask=mask) 

130 arr = convert_to_uint_preverse_order(arr, descending) 

131 key = (arr >> bit_offset) & bfe_mask # (TILE_N, ) 

132 matches = tl.where( 

133 mask, (bin_indices[:, None] == key), False 

134 ) # (TILE_R, TILE_N) 

135 acc += matches 

136 local_sum = tl.sum(acc, axis=1) 

137 tl.atomic_add( 

138 out_ptr + pid_m * num_passes * r + p * r + bin_indices, 

139 local_sum, 

140 sem="relaxed", 

141 ) 

142 

143 

144@triton.jit 

145def sweep( 

146 arr_ptr, 

147 associate_arr_ptr, # inputs: (key & value) 

148 out_ptr, 

149 associate_out_ptr, # outputs: (key & value) 

150 excumsum_bins_ptr, 

151 status_ptr, # aux input and status 

152 n_passes, 

153 pass_id, 

154 bit_offset, 

155 m, 

156 N, 

157 OUT_N, 

158 TILE_N: tl.constexpr, 

159 TILE_R: tl.constexpr, 

160 k_bits: tl.constexpr, 

161 descending: tl.constexpr, 

162): 

163 # r: num_bins = 2 ** k_bits 

164 # OUT_N: grid_n = cdiv(N, ) 

165 

166 # arr_ptr: (m, N) 

167 # out_ptr: (m, N) 

168 # excumsum_bins_ptr: (m, n_passes, r) 

169 # flag_ptr: (m, r, OUT_N) 

170 

171 # grid: (m, grid_r, grid_n) 

172 

173 # load data 

174 pid = tl.program_id(0) 

175 pid_m = pid % m 

176 pid_n = pid // m 

177 pid_r = tl.program_id(1) 

178 

179 # bit masks 

180 aggregate_mask: tl.constexpr = 1 << 30 

181 inclusive_prefix_mask: tl.constexpr = 1 << 31 

182 v_mask: tl.constexpr = (1 << 30) - 1 

183 bfe_mask: tl.constexpr = (1 << k_bits) - 1 # a.k.a. 2 ** k_bits - 1 

184 

185 # initialize flag to zero-local sum is not ready 

186 r: tl.constexpr = 2**k_bits 

187 cta_r_start = pid_r * TILE_R 

188 cta_r_end = tl.minimum(cta_r_start + TILE_R, r) 

189 

190 # cumsum for a bin_index 

191 n_offsets = pid_n * TILE_N + tl.arange(0, TILE_N) # (TILE_N, ) 

192 mask = n_offsets < N 

193 arr = tl.load(arr_ptr + pid_m * N + n_offsets, mask=mask) 

194 arr_u = convert_to_uint_preverse_order(arr, descending) 

195 key = (arr_u >> bit_offset) & bfe_mask # (TILE_N, ) 

196 if associate_arr_ptr is not None: 

197 associate_arr = tl.load(associate_arr_ptr + pid_m * N + n_offsets, mask=mask) 

198 # since triton can only use scalar as condition, loop by bin_index 

199 # status must be pre zero-initialized, or else we have to initialize it 

200 for bin_index in range(cta_r_start, cta_r_end): 

201 matches = tl.where(mask, key == bin_index, False) # (TILE_N, ) bool 

202 # cta level cumsum per bin 

203 # CAUTION: tl.sum in triton 3.2 does not promote type 

204 local_sum = tl.sum(matches.to(tl.uint32), axis=0) 

205 pack0 = aggregate_mask | local_sum 

206 status_offset = pid_m * (r * OUT_N) + bin_index * OUT_N + pid_n 

207 tl.store(status_ptr + status_offset, pack0, cache_modifier=".cg") 

208 

209 # decoupled lookback 

210 exclusive_prefix = tl.zeros((), dtype=tl.uint32) 

211 i_lookback = pid_n - 1 

212 while i_lookback >= 0: 

213 flag_offset_i = pid_m * (r * OUT_N) + bin_index * OUT_N + i_lookback 

214 pack1 = tl.load(status_ptr + flag_offset_i, volatile=True) # uin32 

215 while pack1 == 0: 

216 pack1 = tl.load(status_ptr + flag_offset_i, volatile=True) 

217 exclusive_prefix += pack1 & v_mask 

218 if (pack1 & aggregate_mask) == aggregate_mask: 

219 i_lookback -= 1 

220 else: 

221 i_lookback = -1 

222 pack2 = inclusive_prefix_mask | (exclusive_prefix + local_sum) 

223 tl.store(status_ptr + status_offset, pack2, cache_modifier=".cg") 

224 

225 local_ex_cumsum = ( 

226 tl.cumsum(matches.to(tl.uint32), axis=0) - matches 

227 ) # (TILE_N, ) 

228 ex_cumsum_in_bin = ( 

229 exclusive_prefix + local_ex_cumsum 

230 ) # global ex_cumsum_in_bin (TILE_N, ) 

231 

232 # ex_cumsum_bins (m, n_passes, r) 

233 ex_cumsum_bins = tl.load( 

234 excumsum_bins_ptr + pid_m * (n_passes * r) + pass_id * r + bin_index 

235 ) # scalar 

236 pos = ex_cumsum_bins + ex_cumsum_in_bin # (TILE_N, ) 

237 

238 # scatter 

239 tl.store(out_ptr + pid_m * N + pos, arr, mask=matches) 

240 if associate_arr_ptr is not None: 

241 tl.store(associate_out_ptr + pid_m * N + pos, associate_arr, mask=matches) 

242 

243 

244def radix_sort(arr, k_bits=8, descending=False): 

245 n = arr.shape[-1] 

246 m = arr.numel() // n 

247 assert n < (1 << 30), "we have not implemented 2**30 per launch" 

248 dtype = arr.dtype 

249 num_bits = 1 if dtype == torch.bool else (arr.itemsize * 8) 

250 

251 TILE_N = 1024 

252 tiles_n_per_cta = 8 

253 CTA_TILE_N = tiles_n_per_cta * TILE_N 

254 

255 num_bins = 2**k_bits 

256 n_passes = triton.cdiv(num_bits, k_bits) 

257 TILE_R = 16 

258 

259 grid_n = triton.cdiv(n, CTA_TILE_N) 

260 grid_for_global_hist = (m * grid_n, 1, 1) 

261 

262 with torch_device_fn.device(arr.device): 

263 global_hist = torch.zeros( 

264 (m, n_passes, num_bins), device=arr.device, dtype=torch.int32 

265 ) 

266 compute_global_hist_kernel[grid_for_global_hist]( 

267 arr, 

268 global_hist, 

269 n_passes, 

270 m, 

271 n, 

272 tiles_n_per_cta, 

273 TILE_N, 

274 TILE_R, 

275 k_bits, 

276 descending, 

277 ) 

278 ex_cumsum_bins = torch.cumsum(global_hist, -1) - global_hist 

279 ex_cumsum_bins = ex_cumsum_bins.to(torch.uint32) 

280 

281 # sort 

282 arr_in = torch.clone(arr) 

283 indices_in = ( 

284 torch.arange(0, n, dtype=torch.int64, device=arr_in.device) 

285 .broadcast_to(arr.shape) 

286 .contiguous() 

287 ) 

288 arr_out = torch.empty_like(arr) 

289 indices_out = torch.empty_like(indices_in) 

290 

291 TILE_R = 8 

292 grid_r = triton.cdiv(num_bins, TILE_R) 

293 TILE_N = 2048 

294 grid_n = triton.cdiv(n, TILE_N) 

295 grid_for_sweep = (m * grid_n, grid_r) 

296 

297 status = torch.empty( 

298 (m, num_bins, grid_n), device=arr.device, dtype=torch.uint32 

299 ) 

300 

301 for i in range(0, n_passes): 

302 bit_offset = i * k_bits 

303 status.zero_() 

304 sweep[grid_for_sweep]( 

305 arr_in, 

306 indices_in, 

307 arr_out, 

308 indices_out, 

309 ex_cumsum_bins, 

310 status, 

311 n_passes, 

312 i, 

313 bit_offset, 

314 m, 

315 n, 

316 grid_n, 

317 TILE_N, 

318 TILE_R, 

319 k_bits, 

320 descending, 

321 ) 

322 # print(f"< sorted last {bit_offset + k_bits:>2d} bits: {arr_out}") 

323 arr_in, arr_out = arr_out, arr_in 

324 indices_in, indices_out = indices_out, indices_in 

325 

326 return arr_in, indices_in 

327 

328 

329@libentry() 

330@triton.jit() 

331def sort_kernel( 

332 in_ptr, 

333 out_ptr, 

334 out_index_ptr, 

335 N: tl.constexpr, 

336 BLOCK_SIZE: tl.constexpr, 

337 DESCENDING: tl.constexpr, 

338 IS_FLOAT: tl.constexpr, 

339): 

340 cols = tl.arange(0, BLOCK_SIZE) 

341 mask = cols < N 

342 offset = tl.program_id(0) * N + cols 

343 in_ptr += offset 

344 out_ptr += offset 

345 out_index_ptr += offset 

346 

347 if IS_FLOAT: 

348 mask_val = _get_finfo_val(in_ptr.dtype.element_ty, return_max=not DESCENDING) 

349 in_val = tl.load(in_ptr, mask=mask, other=mask_val) 

350 else: 

351 mask_val = _get_iinfo_val(in_ptr.dtype.element_ty, return_max=not DESCENDING) 

352 in_val = tl.load(in_ptr, mask=mask, other=mask_val) 

353 

354 index_val = tl.arange(0, BLOCK_SIZE) 

355 

356 sorted_in_val, sorted_index_val = argsort( 

357 in_val, index_val, 0, descending=DESCENDING 

358 ) 

359 tl.store(out_ptr, sorted_in_val, mask=mask) 

360 tl.store(out_index_ptr, sorted_index_val, mask=mask) 

361 

362 

363def sort(inp, dim=-1, descending=False): 

364 # We only implement stable radix sort here 

365 logger.debug("GEMS SORT") 

366 return sort_stable(inp, stable=False, dim=dim, descending=descending) 

367 

368 

369def sort_stable(inp, *, stable, dim=-1, descending=False): 

370 logger.debug("GEMS SORT.STABLE") 

371 # We only implement stable radix sort here 

372 _ = stable 

373 sort_elem_cnt = inp.shape[dim] 

374 if sort_elem_cnt == 1: 

375 return inp, torch.zeros_like(inp, dtype=torch.int64) 

376 

377 if dim < 0: 

378 dim = dim + inp.ndim 

379 if dim != inp.ndim - 1: 

380 inp = torch.movedim(inp, dim, -1).contiguous() 

381 else: 

382 inp = inp.contiguous() 

383 

384 dtype = inp.dtype 

385 num_bits_per_pass = 1 if dtype == torch.bool else 4 

386 out, out_index = radix_sort(inp, num_bits_per_pass, descending) 

387 

388 if dim != inp.ndim - 1: 

389 out = torch.movedim(out, -1, dim) 

390 out_index = torch.movedim(out_index, -1, dim) 

391 return out, out_index