Coverage for src/flag_gems/runtime/backend/_ascend/ops/cummin.py: 0%

260 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2import math 

3from typing import List, Tuple, Union 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12from flag_gems.utils.limits import get_dtype_max 

13 

14Tensor = torch.Tensor 

15 

16logger = logging.getLogger(__name__) 

17 

18 

19@triton.jit 

20def tl_cummin(input, index, axis=0): 

21 return tl.associative_scan( 

22 (input, index), axis, tle.minimum_with_index_tie_break_right 

23 ) 

24 

25 

26@triton.jit 

27def tl_min_tie_break_right(input, index, axis=None, keep_dims=False): 

28 return tl.reduce( 

29 (input, index), 

30 axis, 

31 tle.minimum_with_index_tie_break_right, 

32 keep_dims=keep_dims, 

33 ) 

34 

35 

36@libentry() 

37@triton.jit(do_not_specialize=["n_elements"]) 

38def add_base_min_kernel( 

39 out, 

40 out_indices, 

41 partial_min, 

42 partial_min_indices, 

43 n_elements, 

44 BLOCK_SIZE: tl.constexpr, 

45): 

46 pid = tle.program_id(0) 

47 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

48 mask = offset < n_elements 

49 

50 out_ptrs = out + offset 

51 out_indices_ptrs = out_indices + offset 

52 out_vals = tl.load(out_ptrs, mask=mask) 

53 out_indices = tl.load(out_indices_ptrs, mask=mask) 

54 

55 if pid > 0: 

56 partial_min_ptrs = partial_min + pid - 1 

57 last_part_min_via_min = tl.load(partial_min_ptrs) 

58 partial_min_indices_ptrs = partial_min_indices + pid - 1 

59 last_part_min_index_via_min = tl.load(partial_min_indices_ptrs) 

60 

61 use_cur = out_vals < last_part_min_via_min 

62 equal = out_vals == last_part_min_via_min 

63 cur_is_nan = out_vals != out_vals 

64 prev_is_nan = last_part_min_via_min != last_part_min_via_min 

65 use_cur |= cur_is_nan & ~prev_is_nan 

66 equal |= cur_is_nan & prev_is_nan 

67 use_cur |= equal & (out_indices > last_part_min_index_via_min) 

68 

69 final_vals = tl.where(use_cur, out_vals, last_part_min_via_min) 

70 final_indices = tl.where(use_cur, out_indices, last_part_min_index_via_min) 

71 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask) 

72 tl.store(out_indices_ptrs, final_indices, mask=mask) 

73 

74 

75@libentry() 

76@triton.jit(do_not_specialize=["n_elements"]) 

77def scan_part_min_kernel( 

78 inp, 

79 out, 

80 in_indices, 

81 out_indices, 

82 partial_min, 

83 partial_min_indices, 

84 n_elements, 

85 BLOCK_SIZE: tl.constexpr, 

86 NEED_PARTIAL: tl.constexpr, 

87 USE_OUT_INDICES: tl.constexpr, 

88): 

89 pid = tle.program_id(0) 

90 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

91 mask = offset < n_elements 

92 

93 max_value = get_dtype_max(inp.type.element_ty) 

94 inp_ptrs = inp + offset 

95 inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value) 

96 if ( 

97 tl.constexpr(inp_vals.dtype.is_int64()) 

98 or tl.constexpr(inp_vals.dtype.is_uint64()) 

99 ) or tl.constexpr(inp_vals.dtype.is_fp64()): 

100 inp_vals = inp_vals 

101 elif tl.constexpr(inp_vals.dtype.is_int()): 

102 inp_vals = inp_vals.to(tl.int32) 

103 else: 

104 inp_vals = inp_vals.to(tl.float32) 

105 if tl.constexpr(USE_OUT_INDICES): 

106 in_indices_ptrs = out_indices + offset 

107 in_indices_vals = tl.load(in_indices_ptrs, mask=mask) 

108 else: 

109 in_indices_vals = offset 

110 result, cummin_indices = tl_cummin(inp_vals, in_indices_vals, axis=0) 

111 

112 if tl.constexpr(NEED_PARTIAL): 

113 part_min_via_min, part_min_indices_via_min = tl_min_tie_break_right( 

114 inp_vals, in_indices_vals, axis=0 

115 ) 

116 if tl.constexpr(not USE_OUT_INDICES): 

117 part_min_indices_via_min = pid * BLOCK_SIZE + part_min_indices_via_min 

118 

119 out_ptrs = out + offset 

120 tl.store(out_ptrs, result, mask=mask) 

121 

122 out_indices_ptrs = out_indices + offset 

123 tl.store(out_indices_ptrs, cummin_indices, mask=mask) 

124 

125 if tl.constexpr(NEED_PARTIAL): 

126 partial_min_ptrs = partial_min + pid 

127 tl.store(partial_min_ptrs, part_min_via_min) 

128 

129 partial_min_indices_ptrs = partial_min_indices + pid 

130 tl.store(partial_min_indices_ptrs, part_min_indices_via_min) 

131 

132 

133def scan_then_fan_col(inp, out, out_indices, n_ele, dtype, use_out_indices=False): 

134 BLOCK_SIZE = 512 

135 if n_ele <= BLOCK_SIZE: 

136 BLOCK_SIZE = triton.next_power_of_2(n_ele) 

137 part_num = math.ceil(n_ele / BLOCK_SIZE) 

138 need_partial = True if part_num >= 2 else False 

139 if need_partial: 

140 partial_min = torch.empty(part_num, dtype=dtype, device=inp.device) 

141 partial_min_indices = torch.empty( 

142 part_num, dtype=torch.int64, device=inp.device 

143 ) 

144 else: 

145 partial_min = None 

146 partial_min_indices = None 

147 

148 grid = (part_num,) 

149 with torch_device_fn.device(inp.device): 

150 scan_part_min_kernel[grid]( 

151 inp, 

152 out, 

153 out_indices, 

154 out_indices, 

155 partial_min, 

156 partial_min_indices, 

157 n_ele, 

158 BLOCK_SIZE, 

159 need_partial, 

160 use_out_indices, 

161 ) 

162 

163 if part_num >= 2: 

164 scan_then_fan_col( 

165 partial_min, 

166 partial_min, 

167 partial_min_indices, 

168 part_num, 

169 dtype, 

170 use_out_indices=True, 

171 ) 

172 with torch_device_fn.device(inp.device): 

173 add_base_min_kernel[grid]( 

174 out, out_indices, partial_min, partial_min_indices, n_ele, BLOCK_SIZE 

175 ) 

176 

177 

178@libentry() 

179@triton.jit(do_not_specialize=["part_num"]) 

180def scan_part_min_abc_kernel( 

181 inp, 

182 out, 

183 in_indices, 

184 out_indices, 

185 partial_min, 

186 partial_min_indices, 

187 B, 

188 C, 

189 part_num, 

190 BLOCK_SIZE: tl.constexpr, 

191 NEED_PARTIAL: tl.constexpr, 

192 USE_OUT_INDICES: tl.constexpr, 

193): 

194 pid_a = tle.program_id(0) 

195 pid_b = tle.program_id(1) 

196 pid_c = tle.program_id(2) 

197 

198 a_idx = pid_a 

199 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

200 c_idx = pid_c 

201 

202 offset = a_idx * B * C + b_idx * C + c_idx 

203 base_part_offset = a_idx * part_num * C + c_idx 

204 part_offset = base_part_offset + pid_b * C 

205 

206 mask = b_idx < B 

207 inp_ptrs = inp + offset 

208 max_value = get_dtype_max(inp.type.element_ty) 

209 inp_vals = tl.load(inp_ptrs, mask=mask, other=max_value) 

210 if ( 

211 tl.constexpr(inp_vals.dtype.is_int64()) 

212 or tl.constexpr(inp_vals.dtype.is_uint64()) 

213 ) or tl.constexpr(inp_vals.dtype.is_fp64()): 

214 inp_vals = inp_vals 

215 elif tl.constexpr(inp_vals.dtype.is_int()): 

216 inp_vals = inp_vals.to(tl.int32) 

217 else: 

218 inp_vals = inp_vals.to(tl.float32) 

219 if tl.constexpr(USE_OUT_INDICES): 

220 in_indices_ptrs = out_indices + offset 

221 in_indices_vals = tl.load(in_indices_ptrs, mask=mask) 

222 else: 

223 in_indices_vals = b_idx 

224 result, cummin_indices = tl_cummin(inp_vals, in_indices_vals, axis=0) 

225 

226 if tl.constexpr(NEED_PARTIAL): 

227 part_min_via_min, part_min_indices_via_min = tl_min_tie_break_right( 

228 inp_vals, in_indices_vals, axis=0 

229 ) 

230 if tl.constexpr(not USE_OUT_INDICES): 

231 part_min_indices_via_min = pid_b * BLOCK_SIZE + part_min_indices_via_min 

232 

233 out_ptrs = out + offset 

234 tl.store(out_ptrs, result, mask=mask) 

235 

236 out_indices_ptrs = out_indices + offset 

237 tl.store(out_indices_ptrs, cummin_indices, mask=mask) 

238 

239 if tl.constexpr(NEED_PARTIAL): 

240 partial_min_ptrs = partial_min + part_offset 

241 tl.store(partial_min_ptrs, part_min_via_min) 

242 

243 partial_min_indices_ptrs = partial_min_indices + part_offset 

244 tl.store(partial_min_indices_ptrs, part_min_indices_via_min) 

245 

246 

247@libentry() 

248@triton.jit(do_not_specialize=["part_num"]) 

249def add_base_min_abc_kernel( 

250 out, 

251 out_indices, 

252 partial_min, 

253 partial_min_indices, 

254 B, 

255 C, 

256 part_num, 

257 BLOCK_SIZE: tl.constexpr, 

258): 

259 pid_a = tle.program_id(0) 

260 pid_b = tle.program_id(1) 

261 pid_c = tle.program_id(2) 

262 

263 a_idx = pid_a 

264 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

265 c_idx = pid_c 

266 

267 base_offset = a_idx * B * C + c_idx 

268 offset = base_offset + b_idx * C 

269 base_part_offset = a_idx * part_num * C + c_idx 

270 last_part_offset = base_part_offset + (pid_b - 1) * C 

271 

272 mask = b_idx < B 

273 out_ptrs = out + offset 

274 out_vals = tl.load(out_ptrs, mask=mask) 

275 out_indices_ptrs = out_indices + offset 

276 out_indices = tl.load(out_indices_ptrs, mask=mask) 

277 

278 if pid_b > 0: 

279 partial_min_ptrs = partial_min + last_part_offset 

280 last_part_min_via_min = tl.load(partial_min_ptrs) 

281 partial_min_index_ptrs = partial_min_indices + last_part_offset 

282 last_part_min_index_via_min = tl.load(partial_min_index_ptrs) 

283 

284 use_cur = out_vals < last_part_min_via_min 

285 equal = out_vals == last_part_min_via_min 

286 cur_is_nan = out_vals != out_vals 

287 prev_is_nan = last_part_min_via_min != last_part_min_via_min 

288 use_cur |= cur_is_nan & ~prev_is_nan 

289 equal |= cur_is_nan & prev_is_nan 

290 use_cur |= equal & (out_indices > last_part_min_index_via_min) 

291 

292 final_vals = tl.where(use_cur, out_vals, last_part_min_via_min) 

293 final_indices = tl.where(use_cur, out_indices, last_part_min_index_via_min) 

294 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask) 

295 tl.store(out_indices_ptrs, final_indices, mask=mask) 

296 

297 

298def scan_then_fan(inp, out, out_indices, A, B, C, dtype, use_out_indices=False): 

299 BLOCK_SIZE = 512 

300 if B <= BLOCK_SIZE: 

301 BLOCK_SIZE = triton.next_power_of_2(B) 

302 part_num = math.ceil(B / BLOCK_SIZE) 

303 need_partial = True if part_num >= 2 else False 

304 if need_partial: 

305 partial_min = torch.empty(A, part_num, C, dtype=dtype, device=inp.device) 

306 partial_min_indices = torch.empty( 

307 A, part_num, C, dtype=torch.int64, device=inp.device 

308 ) 

309 else: 

310 partial_min = None 

311 partial_min_indices = None 

312 

313 grid = (A, part_num, C) 

314 with torch_device_fn.device(inp.device): 

315 scan_part_min_abc_kernel[grid]( 

316 inp, 

317 out, 

318 out_indices, 

319 out_indices, 

320 partial_min, 

321 partial_min_indices, 

322 B, 

323 C, 

324 part_num, 

325 BLOCK_SIZE, 

326 need_partial, 

327 use_out_indices, 

328 ) 

329 

330 if part_num >= 2: 

331 scan_then_fan( 

332 partial_min, 

333 partial_min, 

334 partial_min_indices, 

335 A, 

336 part_num, 

337 C, 

338 dtype, 

339 use_out_indices=True, 

340 ) 

341 with torch_device_fn.device(inp.device): 

342 add_base_min_abc_kernel[grid]( 

343 out, 

344 out_indices, 

345 partial_min, 

346 partial_min_indices, 

347 B, 

348 C, 

349 part_num, 

350 BLOCK_SIZE, 

351 ) 

352 

353 

354@libentry() 

355@triton.jit() 

356def scan_part_min_abc_loop_kernel( 

357 inp, 

358 out, 

359 out_indices, 

360 B, 

361 C, 

362 loop_num, 

363 BLOCK_SIZE: tl.constexpr, 

364): 

365 pid_a = tle.program_id(0) 

366 pid_c = tle.program_id(1) 

367 

368 a_idx = pid_a 

369 c_idx = pid_c 

370 t_idx = tl.arange(0, BLOCK_SIZE) 

371 ac_offset = a_idx * B * C + c_idx 

372 

373 # init 

374 max_value = get_dtype_max(inp.type.element_ty) 

375 if tl.constexpr(inp.type.element_ty.is_fp16()) or tl.constexpr( 

376 inp.type.element_ty.is_bf16() 

377 ): 

378 compute_dtype = tl.float32 

379 elif tl.constexpr(inp.type.element_ty.is_int8()) or tl.constexpr( 

380 inp.type.element_ty.is_int16() 

381 ): 

382 compute_dtype = tl.int32 

383 else: 

384 compute_dtype = inp.type.element_ty 

385 

386 prev_min_val = tl.full([], max_value, dtype=compute_dtype) 

387 prev_min_val_idx = tl.full([], 0, dtype=tl.int64) 

388 last_mask = t_idx == (BLOCK_SIZE - 1) 

389 

390 for l_idx in tl.range(loop_num): 

391 b_idx = l_idx * BLOCK_SIZE + t_idx 

392 mask = b_idx < B 

393 offset = ac_offset + b_idx * C 

394 

395 inp_vals = tl.load(inp + offset, mask=mask, other=max_value) 

396 # Only promote if necessary 

397 if tl.constexpr(compute_dtype != inp.type.element_ty): 

398 vals = inp_vals.to(compute_dtype) 

399 else: 

400 vals = inp_vals 

401 idxs = b_idx 

402 

403 # cummin 

404 result, cummin_indices = tl_cummin(vals, idxs, axis=0) 

405 

406 # broadcast 

407 prev_min_val_b = tl.broadcast_to(prev_min_val, (BLOCK_SIZE,)) 

408 prev_min_val_idx_b = tl.broadcast_to(prev_min_val_idx, (BLOCK_SIZE,)) 

409 

410 # Handle NaN and tie-breaking logic 

411 if tl.constexpr(compute_dtype.is_floating()): 

412 # For floats: handle NaN propagation + tie-break right 

413 prev_is_nan = prev_min_val != prev_min_val 

414 result_is_nan = result != result 

415 prev_nan_mask = tl.broadcast_to(prev_is_nan, (BLOCK_SIZE,)) 

416 

417 use_result = result_is_nan | (~prev_nan_mask & (result <= prev_min_val_b)) 

418 else: 

419 # For integers: simple tie-break right 

420 use_result = result <= prev_min_val_b 

421 

422 final_vals = tl.where(use_result, result, prev_min_val_b) 

423 final_indices = tl.where(use_result, cummin_indices, prev_min_val_idx_b) 

424 

425 # update global min val and idx 

426 prev_min_val = tl.sum(tl.where(last_mask, final_vals, 0), axis=0) 

427 prev_min_val_idx = tl.sum(tl.where(last_mask, final_indices, 0), axis=0) 

428 

429 # store result 

430 tl.store(out + offset, final_vals.to(out.type.element_ty), mask=mask) 

431 tl.store(out_indices + offset, final_indices, mask=mask) 

432 

433 

434def scan_then_fan_loop(inp, out, out_indices, A, B, C, dtype): 

435 # TODO(all): tune on target board 

436 BLOCK_SIZE = 512 

437 if B < 1024 * 4: 

438 BLOCK_SIZE = triton.next_power_of_2(B) 

439 loop_num = math.ceil(B / BLOCK_SIZE) 

440 

441 grid = (A, C) 

442 with torch_device_fn.device(inp.device): 

443 scan_part_min_abc_loop_kernel[grid]( 

444 inp, 

445 out, 

446 out_indices, 

447 B, 

448 C, 

449 loop_num, 

450 BLOCK_SIZE, 

451 ) 

452 

453 

454def cummin( 

455 input: Tensor, 

456 dim: int, 

457 *, 

458 out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None, 

459) -> torch.return_types.cummin: 

460 logger.debug("GEMS_ASCEND CUMMIN") 

461 assert dim >= -input.ndim and dim < input.ndim, "Invalid dim" 

462 shape = input.shape 

463 dim = dim % input.ndim 

464 M = 1 

465 N = shape[dim] 

466 for i in range(dim): 

467 M *= shape[i] 

468 input = input.contiguous() 

469 K = input.numel() // M // N 

470 

471 dtype = input.dtype 

472 if dtype is torch.bool: 

473 dtype = torch.int64 

474 out = torch.empty_like(input, dtype=dtype) 

475 out_indices = torch.empty_like(input, dtype=torch.int64) 

476 

477 compute_dtype = out.dtype 

478 if input.dtype == torch.float16 or input.dtype == torch.bfloat16: 

479 compute_dtype = torch.float32 

480 

481 if M == 1 and K == 1: 

482 scan_then_fan_col(input, out, out_indices, N, compute_dtype) 

483 elif M * K <= 16: 

484 scan_then_fan(input, out, out_indices, M, N, K, compute_dtype) 

485 else: 

486 scan_then_fan_loop(input, out, out_indices, M, N, K, compute_dtype) 

487 return out, out_indices