Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/cummin.py: 0%

242 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import logging 

2import math 

3from typing import List, Tuple, Union 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as ext 

12from flag_gems.utils.limits import get_dtype_max 

13 

14Tensor = torch.Tensor 

15 

16logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

17 

18 

19@triton.jit 

20def tl_cummin(input, index, axis=0): 

21 return tl.associative_scan( 

22 (input, index), axis, ext.minimum_with_index_tie_break_right 

23 ) 

24 

25 

26@triton.jit 

27def tl_min_tie_break_right(input, index, axis=None, keep_dims=False): 

28 return tl.reduce( 

29 (input, index), 

30 axis, 

31 ext.minimum_with_index_tie_break_right, 

32 keep_dims=keep_dims, 

33 ) 

34 

35 

36@libentry() 

37@triton.jit(do_not_specialize=["n_elements"]) 

38def add_base_min_kernel( 

39 out, 

40 out_indices, 

41 partial_min, 

42 partial_min_indices, 

43 n_elements, 

44 BLOCK_SIZE: tl.constexpr, 

45): 

46 pid = ext.program_id(0) 

47 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

48 mask = offset < n_elements 

49 

50 out_ptrs = out + offset 

51 out_indices_ptrs = out_indices + offset 

52 out_vals = tl.load(out_ptrs, mask=mask) 

53 out_indices = tl.load(out_indices_ptrs, mask=mask) 

54 

55 if pid > 0: 

56 partial_min_ptrs = partial_min + pid - 1 

57 last_part_min_via_min = tl.load(partial_min_ptrs) 

58 partial_min_indices_ptrs = partial_min_indices + pid - 1 

59 last_part_min_index_via_min = tl.load(partial_min_indices_ptrs) 

60 

61 final_vals = tl.minimum(out_vals, last_part_min_via_min) 

62 final_indices = tl.where( 

63 out_vals <= last_part_min_via_min, out_indices, last_part_min_index_via_min 

64 ) 

65 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask) 

66 tl.store(out_indices_ptrs, final_indices, mask=mask) 

67 

68 

69@libentry() 

70@triton.jit(do_not_specialize=["n_elements"]) 

71def scan_part_min_kernel( 

72 inp, 

73 out, 

74 in_indices, 

75 out_indices, 

76 partial_min, 

77 partial_min_indices, 

78 n_elements, 

79 BLOCK_SIZE: tl.constexpr, 

80 NEED_PARTIAL: tl.constexpr, 

81 USE_OUT_INDICES: tl.constexpr, 

82): 

83 pid = ext.program_id(0) 

84 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

85 mask = offset < n_elements 

86 

87 max_value = get_dtype_max(inp.type.element_ty) 

88 inp_ptrs = inp + offset 

89 inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value) 

90 if ( 

91 tl.constexpr(inp_vals.dtype.is_int64()) 

92 or tl.constexpr(inp_vals.dtype.is_uint64()) 

93 ) or tl.constexpr(inp_vals.dtype.is_fp64()): 

94 inp_vals = inp_vals 

95 elif tl.constexpr(inp_vals.dtype.is_int()): 

96 inp_vals = inp_vals.to(tl.int32) 

97 else: 

98 inp_vals = inp_vals.to(tl.float32) 

99 if tl.constexpr(USE_OUT_INDICES): 

100 in_indices_ptrs = out_indices + offset 

101 in_indices_vals = tl.load(in_indices_ptrs, mask=mask) 

102 else: 

103 in_indices_vals = offset 

104 result, cummin_indices = tl_cummin(inp_vals, in_indices_vals, axis=0) 

105 

106 if tl.constexpr(NEED_PARTIAL): 

107 part_min_via_min, part_min_indices_via_min = tl_min_tie_break_right( 

108 inp_vals, in_indices_vals, axis=0 

109 ) 

110 

111 out_ptrs = out + offset 

112 tl.store(out_ptrs, result, mask=mask) 

113 

114 out_indices_ptrs = out_indices + offset 

115 tl.store(out_indices_ptrs, cummin_indices, mask=mask) 

116 

117 if tl.constexpr(NEED_PARTIAL): 

118 partial_min_ptrs = partial_min + pid 

119 tl.store(partial_min_ptrs, part_min_via_min) 

120 

121 partial_min_indices_ptrs = partial_min_indices + pid 

122 tl.store(partial_min_indices_ptrs, part_min_indices_via_min) 

123 

124 

125def scan_then_fan_col(inp, out, out_indices, n_ele, dtype, use_out_indices=False): 

126 BLOCK_SIZE = 1024 

127 if n_ele <= 1024 * 4: 

128 BLOCK_SIZE = triton.next_power_of_2(n_ele) 

129 part_num = math.ceil(n_ele / BLOCK_SIZE) 

130 need_partial = True if part_num >= 2 else False 

131 if need_partial: 

132 partial_min = torch.empty(part_num, dtype=dtype, device=inp.device) 

133 partial_min_indices = torch.empty( 

134 part_num, dtype=torch.int64, device=inp.device 

135 ) 

136 else: 

137 partial_min = None 

138 partial_min_indices = None 

139 

140 grid = (part_num,) 

141 with torch_device_fn.device(inp.device): 

142 scan_part_min_kernel[grid]( 

143 inp, 

144 out, 

145 out_indices, 

146 out_indices, 

147 partial_min, 

148 partial_min_indices, 

149 n_ele, 

150 BLOCK_SIZE, 

151 need_partial, 

152 use_out_indices, 

153 ) 

154 

155 if part_num >= 2: 

156 scan_then_fan_col( 

157 partial_min, 

158 partial_min, 

159 partial_min_indices, 

160 part_num, 

161 dtype, 

162 use_out_indices=True, 

163 ) 

164 with torch_device_fn.device(inp.device): 

165 add_base_min_kernel[grid]( 

166 out, out_indices, partial_min, partial_min_indices, n_ele, BLOCK_SIZE 

167 ) 

168 

169 

170@libentry() 

171@triton.jit(do_not_specialize=["part_num"]) 

172def scan_part_min_abc_kernel( 

173 inp, 

174 out, 

175 in_indices, 

176 out_indices, 

177 partial_min, 

178 partial_min_indices, 

179 B, 

180 C, 

181 part_num, 

182 BLOCK_SIZE: tl.constexpr, 

183 NEED_PARTIAL: tl.constexpr, 

184 USE_OUT_INDICES: tl.constexpr, 

185): 

186 pid_a = ext.program_id(0) 

187 pid_b = ext.program_id(1) 

188 pid_c = ext.program_id(2) 

189 

190 a_idx = pid_a 

191 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

192 c_idx = pid_c 

193 

194 offset = a_idx * B * C + b_idx * C + c_idx 

195 base_part_offset = a_idx * part_num * C + c_idx 

196 part_offset = base_part_offset + pid_b * C 

197 

198 mask = b_idx < B 

199 inp_ptrs = inp + offset 

200 max_value = get_dtype_max(inp.type.element_ty) 

201 inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value) 

202 if ( 

203 tl.constexpr(inp_vals.dtype.is_int64()) 

204 or tl.constexpr(inp_vals.dtype.is_uint64()) 

205 ) or tl.constexpr(inp_vals.dtype.is_fp64()): 

206 inp_vals = inp_vals 

207 elif tl.constexpr(inp_vals.dtype.is_int()): 

208 inp_vals = inp_vals.to(tl.int32) 

209 else: 

210 inp_vals = inp_vals.to(tl.float32) 

211 if tl.constexpr(USE_OUT_INDICES): 

212 in_indices_ptrs = out_indices + offset 

213 in_indices_vals = tl.load(in_indices_ptrs, mask=mask) 

214 else: 

215 in_indices_vals = b_idx 

216 result, cummin_indices = tl_cummin(inp_vals, in_indices_vals, axis=0) 

217 

218 if tl.constexpr(NEED_PARTIAL): 

219 part_min_via_min, part_min_indices_via_min = tl_min_tie_break_right( 

220 inp_vals, in_indices_vals, axis=0 

221 ) 

222 

223 out_ptrs = out + offset 

224 tl.store(out_ptrs, result, mask=mask) 

225 

226 out_indices_ptrs = out_indices + offset 

227 tl.store(out_indices_ptrs, cummin_indices, mask=mask) 

228 

229 if tl.constexpr(NEED_PARTIAL): 

230 partial_min_ptrs = partial_min + part_offset 

231 tl.store(partial_min_ptrs, part_min_via_min) 

232 

233 partial_min_indices_ptrs = partial_min_indices + part_offset 

234 tl.store(partial_min_indices_ptrs, part_min_indices_via_min) 

235 

236 

237@libentry() 

238@triton.jit(do_not_specialize=["part_num"]) 

239def add_base_min_abc_kernel( 

240 out, 

241 out_indices, 

242 partial_min, 

243 partial_min_indices, 

244 B, 

245 C, 

246 part_num, 

247 BLOCK_SIZE: tl.constexpr, 

248): 

249 pid_a = ext.program_id(0) 

250 pid_b = ext.program_id(1) 

251 pid_c = ext.program_id(2) 

252 

253 a_idx = pid_a 

254 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

255 c_idx = pid_c 

256 

257 base_offset = a_idx * B * C + c_idx 

258 offset = base_offset + b_idx * C 

259 base_part_offset = a_idx * part_num * C + c_idx 

260 last_part_offset = base_part_offset + (pid_b - 1) * C 

261 

262 mask = b_idx < B 

263 out_ptrs = out + offset 

264 out_vals = tl.load(out_ptrs, mask=mask) 

265 out_indices_ptrs = out_indices + offset 

266 out_indices = tl.load(out_indices_ptrs, mask=mask) 

267 

268 if pid_b > 0: 

269 partial_min_ptrs = partial_min + last_part_offset 

270 last_part_min_via_min = tl.load(partial_min_ptrs) 

271 partial_min_index_ptrs = partial_min_indices + last_part_offset 

272 last_part_min_index_via_min = tl.load(partial_min_index_ptrs) 

273 

274 final_vals = tl.minimum(out_vals, last_part_min_via_min) 

275 final_indices = tl.where( 

276 out_vals <= last_part_min_via_min, out_indices, last_part_min_index_via_min 

277 ) 

278 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask) 

279 tl.store(out_indices_ptrs, final_indices, mask=mask) 

280 

281 

282def scan_then_fan(inp, out, out_indices, A, B, C, dtype, use_out_indices=False): 

283 BLOCK_SIZE = 1024 

284 if B <= 1024 * 4: 

285 BLOCK_SIZE = triton.next_power_of_2(B) 

286 part_num = math.ceil(B / BLOCK_SIZE) 

287 need_partial = True if part_num >= 2 else False 

288 if need_partial: 

289 partial_min = torch.empty(A, part_num, C, dtype=dtype, device=inp.device) 

290 partial_min_indices = torch.empty( 

291 A, part_num, C, dtype=torch.int64, device=inp.device 

292 ) 

293 else: 

294 partial_min = None 

295 partial_min_indices = None 

296 

297 grid = (A, part_num, C) 

298 with torch_device_fn.device(inp.device): 

299 scan_part_min_abc_kernel[grid]( 

300 inp, 

301 out, 

302 out_indices, 

303 out_indices, 

304 partial_min, 

305 partial_min_indices, 

306 B, 

307 C, 

308 part_num, 

309 BLOCK_SIZE, 

310 need_partial, 

311 use_out_indices, 

312 ) 

313 

314 if part_num >= 2: 

315 scan_then_fan( 

316 partial_min, 

317 partial_min, 

318 partial_min_indices, 

319 A, 

320 part_num, 

321 C, 

322 dtype, 

323 use_out_indices=True, 

324 ) 

325 with torch_device_fn.device(inp.device): 

326 add_base_min_abc_kernel[grid]( 

327 out, 

328 out_indices, 

329 partial_min, 

330 partial_min_indices, 

331 B, 

332 C, 

333 part_num, 

334 BLOCK_SIZE, 

335 ) 

336 

337 

338@libentry() 

339@triton.jit() 

340def scan_part_min_abc_loop_kernel( 

341 inp, 

342 out, 

343 out_indices, 

344 B, 

345 C, 

346 loop_num, 

347 BLOCK_SIZE: tl.constexpr, 

348): 

349 pid_a = ext.program_id(0) 

350 pid_c = ext.program_id(1) 

351 

352 a_idx = pid_a 

353 c_idx = pid_c 

354 t_idx = tl.arange(0, BLOCK_SIZE) 

355 ac_offset = a_idx * B * C + c_idx 

356 

357 max_value = get_dtype_max(inp.type.element_ty) 

358 if tl.constexpr(inp.type.element_ty.is_fp16()) or tl.constexpr( 

359 inp.type.element_ty.is_bf16() 

360 ): 

361 compute_dtype = tl.float32 

362 elif tl.constexpr(inp.type.element_ty.is_int8()) or tl.constexpr( 

363 inp.type.element_ty.is_int16() 

364 ): 

365 compute_dtype = tl.int32 

366 else: 

367 compute_dtype = inp.type.element_ty 

368 

369 prev_min_val = tl.full([], max_value, dtype=compute_dtype) 

370 prev_min_val_idx = tl.full([], 0, dtype=tl.int64) 

371 last_mask = t_idx == (BLOCK_SIZE - 1) 

372 

373 for l_idx in tl.range(loop_num): 

374 b_idx = l_idx * BLOCK_SIZE + t_idx 

375 mask = b_idx < B 

376 offset = ac_offset + b_idx * C 

377 

378 inp_vals = tl.load(inp + offset, mask=mask, other=max_value) 

379 if tl.constexpr(compute_dtype != inp.type.element_ty): 

380 vals = inp_vals.to(compute_dtype) 

381 else: 

382 vals = inp_vals 

383 idxs = b_idx 

384 

385 result, cummin_indices = tl_cummin(vals, idxs, axis=0) 

386 

387 prev_min_val_b = tl.broadcast_to(prev_min_val, (BLOCK_SIZE,)) 

388 prev_min_val_idx_b = tl.broadcast_to(prev_min_val_idx, (BLOCK_SIZE,)) 

389 

390 if tl.constexpr(compute_dtype.is_floating()): 

391 prev_is_nan = prev_min_val != prev_min_val 

392 result_is_nan = result != result 

393 prev_nan_mask = tl.broadcast_to(prev_is_nan, (BLOCK_SIZE,)) 

394 use_result = result_is_nan | (~prev_nan_mask & (result <= prev_min_val_b)) 

395 else: 

396 use_result = result <= prev_min_val_b 

397 

398 final_vals = tl.where(use_result, result, prev_min_val_b) 

399 final_indices = tl.where(use_result, cummin_indices, prev_min_val_idx_b) 

400 

401 prev_min_val = tl.sum(tl.where(last_mask, final_vals, 0), axis=0) 

402 prev_min_val_idx = tl.sum(tl.where(last_mask, final_indices, 0), axis=0) 

403 

404 tl.store(out + offset, final_vals.to(out.type.element_ty), mask=mask) 

405 tl.store(out_indices + offset, final_indices, mask=mask) 

406 

407 

408def scan_then_fan_loop(inp, out, out_indices, A, B, C, dtype): 

409 BLOCK_SIZE = 1024 

410 if B < 1024 * 4: 

411 BLOCK_SIZE = triton.next_power_of_2(B) 

412 loop_num = math.ceil(B / BLOCK_SIZE) 

413 

414 grid = (A, C) 

415 with torch_device_fn.device(inp.device): 

416 scan_part_min_abc_loop_kernel[grid]( 

417 inp, 

418 out, 

419 out_indices, 

420 B, 

421 C, 

422 loop_num, 

423 BLOCK_SIZE, 

424 is_use_mask_zero=True, 

425 ) 

426 

427 

428def cummin( 

429 input: Tensor, 

430 dim: int, 

431 *, 

432 out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None, 

433) -> torch.return_types.cummin: 

434 logger.debug("GEMS_KUNLUNXIN cummin") 

435 assert dim >= -input.ndim and dim < input.ndim, "Invalid dim" 

436 shape = input.shape 

437 dim = dim % input.ndim 

438 M = 1 

439 N = shape[dim] 

440 for i in range(dim): 

441 M *= shape[i] 

442 input = input.contiguous() 

443 K = input.numel() // M // N 

444 

445 dtype = input.dtype 

446 if dtype is torch.bool: 

447 dtype = torch.int64 

448 out = torch.empty_like(input, dtype=dtype) 

449 out_indices = torch.empty_like(input, dtype=torch.int64) 

450 

451 compute_dtype = out.dtype 

452 if input.dtype == torch.float16 or input.dtype == torch.bfloat16: 

453 compute_dtype = torch.float32 

454 

455 if M == 1 and K == 1: 

456 scan_then_fan_col(input, out, out_indices, N, compute_dtype) 

457 elif M * K <= 16: 

458 scan_then_fan(input, out, out_indices, M, N, K, compute_dtype) 

459 else: 

460 scan_then_fan_loop(input, out, out_indices, M, N, K, compute_dtype) 

461 return torch.return_types.cummin((out, out_indices))