Coverage for src/flag_gems/runtime/backend/_ascend/ops/cummax.py: 0%

260 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2import math 

3from typing import List, Tuple, Union 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems.runtime import torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as tle 

12from flag_gems.utils.limits import get_dtype_min 

13 

14Tensor = torch.Tensor 

15 

16logger = logging.getLogger(__name__) 

17 

18 

19@triton.jit 

20def tl_cummax(input, index, axis=0): 

21 return tl.associative_scan( 

22 (input, index), axis, tle.maximum_with_index_tie_break_right 

23 ) 

24 

25 

26@triton.jit 

27def tl_max_tie_break_right(input, index, axis=None, keep_dims=False): 

28 return tl.reduce( 

29 (input, index), 

30 axis, 

31 tle.maximum_with_index_tie_break_right, 

32 keep_dims=keep_dims, 

33 ) 

34 

35 

36@libentry() 

37@triton.jit(do_not_specialize=["n_elements"]) 

38def add_base_max_kernel( 

39 out, 

40 out_indices, 

41 partial_max, 

42 partial_max_indices, 

43 n_elements, 

44 BLOCK_SIZE: tl.constexpr, 

45): 

46 pid = tle.program_id(0) 

47 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

48 mask = offset < n_elements 

49 

50 out_ptrs = out + offset 

51 out_indices_ptrs = out_indices + offset 

52 out_vals = tl.load(out_ptrs, mask=mask) 

53 out_indices = tl.load(out_indices_ptrs, mask=mask) 

54 

55 if pid > 0: 

56 partial_max_ptrs = partial_max + pid - 1 

57 last_part_max_via_max = tl.load(partial_max_ptrs) 

58 partial_max_indices_ptrs = partial_max_indices + pid - 1 

59 last_part_max_index_via_max = tl.load(partial_max_indices_ptrs) 

60 

61 # NaN-aware maximum (same semantics as maximum_with_index_tie_break_right) 

62 use_cur = out_vals > last_part_max_via_max 

63 equal = out_vals == last_part_max_via_max 

64 cur_is_nan = out_vals != out_vals 

65 prev_is_nan = last_part_max_via_max != last_part_max_via_max 

66 use_cur |= cur_is_nan & ~prev_is_nan 

67 equal |= cur_is_nan & prev_is_nan 

68 use_cur |= equal & (out_indices > last_part_max_index_via_max) 

69 

70 final_vals = tl.where(use_cur, out_vals, last_part_max_via_max) 

71 final_indices = tl.where(use_cur, out_indices, last_part_max_index_via_max) 

72 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask) 

73 tl.store(out_indices_ptrs, final_indices, mask=mask) 

74 

75 

76@libentry() 

77@triton.jit(do_not_specialize=["n_elements"]) 

78def scan_part_max_kernel( 

79 inp, 

80 out, 

81 in_indices, 

82 out_indices, 

83 partial_max, 

84 partial_max_indices, 

85 n_elements, 

86 BLOCK_SIZE: tl.constexpr, 

87 NEED_PARTIAL: tl.constexpr, 

88 USE_OUT_INDICES: tl.constexpr, 

89): 

90 pid = tle.program_id(0) 

91 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

92 mask = offset < n_elements 

93 

94 min_value = get_dtype_min(inp.type.element_ty) 

95 inp_ptrs = inp + offset 

96 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value) 

97 if ( 

98 tl.constexpr(inp_vals.dtype.is_int64()) 

99 or tl.constexpr(inp_vals.dtype.is_uint64()) 

100 ) or tl.constexpr(inp_vals.dtype.is_fp64()): 

101 inp_vals = inp_vals 

102 elif tl.constexpr(inp_vals.dtype.is_int()): 

103 inp_vals = inp_vals.to(tl.int32) 

104 else: 

105 inp_vals = inp_vals.to(tl.float32) 

106 if tl.constexpr(USE_OUT_INDICES): 

107 in_indices_ptrs = out_indices + offset 

108 in_indices_vals = tl.load(in_indices_ptrs, mask=mask) 

109 else: 

110 in_indices_vals = offset 

111 result, cummax_indices = tl_cummax(inp_vals, in_indices_vals, axis=0) 

112 

113 if tl.constexpr(NEED_PARTIAL): 

114 part_max_via_max, part_max_indices_via_max = tl_max_tie_break_right( 

115 inp_vals, in_indices_vals, axis=0 

116 ) 

117 if tl.constexpr(not USE_OUT_INDICES): 

118 part_max_indices_via_max = pid * BLOCK_SIZE + part_max_indices_via_max 

119 

120 out_ptrs = out + offset 

121 tl.store(out_ptrs, result, mask=mask) 

122 

123 out_indices_ptrs = out_indices + offset 

124 tl.store(out_indices_ptrs, cummax_indices, mask=mask) 

125 

126 if tl.constexpr(NEED_PARTIAL): 

127 partial_max_ptrs = partial_max + pid 

128 tl.store(partial_max_ptrs, part_max_via_max) 

129 

130 partial_max_indices_ptrs = partial_max_indices + pid 

131 tl.store(partial_max_indices_ptrs, part_max_indices_via_max) 

132 

133 

134def scan_then_fan_col(inp, out, out_indices, n_ele, dtype, use_out_indices=False): 

135 BLOCK_SIZE = 512 

136 if n_ele <= BLOCK_SIZE: 

137 BLOCK_SIZE = triton.next_power_of_2(n_ele) 

138 part_num = math.ceil(n_ele / BLOCK_SIZE) 

139 need_partial = True if part_num >= 2 else False 

140 if need_partial: 

141 partial_max = torch.empty(part_num, dtype=dtype, device=inp.device) 

142 partial_max_indices = torch.empty( 

143 part_num, dtype=torch.int64, device=inp.device 

144 ) 

145 else: 

146 partial_max = None 

147 partial_max_indices = None 

148 

149 grid = (part_num,) 

150 with torch_device_fn.device(inp.device): 

151 scan_part_max_kernel[grid]( 

152 inp, 

153 out, 

154 out_indices, 

155 out_indices, 

156 partial_max, 

157 partial_max_indices, 

158 n_ele, 

159 BLOCK_SIZE, 

160 need_partial, 

161 use_out_indices, 

162 ) 

163 

164 if part_num >= 2: 

165 scan_then_fan_col( 

166 partial_max, 

167 partial_max, 

168 partial_max_indices, 

169 part_num, 

170 dtype, 

171 use_out_indices=True, 

172 ) 

173 with torch_device_fn.device(inp.device): 

174 add_base_max_kernel[grid]( 

175 out, out_indices, partial_max, partial_max_indices, n_ele, BLOCK_SIZE 

176 ) 

177 

178 

179@libentry() 

180@triton.jit(do_not_specialize=["part_num"]) 

181def scan_part_max_abc_kernel( 

182 inp, 

183 out, 

184 in_indices, 

185 out_indices, 

186 partial_max, 

187 partial_max_indices, 

188 B, 

189 C, 

190 part_num, 

191 BLOCK_SIZE: tl.constexpr, 

192 NEED_PARTIAL: tl.constexpr, 

193 USE_OUT_INDICES: tl.constexpr, 

194): 

195 pid_a = tle.program_id(0) 

196 pid_b = tle.program_id(1) 

197 pid_c = tle.program_id(2) 

198 

199 a_idx = pid_a 

200 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

201 c_idx = pid_c 

202 

203 offset = a_idx * B * C + b_idx * C + c_idx 

204 base_part_offset = a_idx * part_num * C + c_idx 

205 part_offset = base_part_offset + pid_b * C 

206 

207 mask = b_idx < B 

208 inp_ptrs = inp + offset 

209 min_value = get_dtype_min(inp.type.element_ty) 

210 inp_vals = tl.load(inp_ptrs, mask=mask, other=min_value) 

211 if ( 

212 tl.constexpr(inp_vals.dtype.is_int64()) 

213 or tl.constexpr(inp_vals.dtype.is_uint64()) 

214 ) or tl.constexpr(inp_vals.dtype.is_fp64()): 

215 inp_vals = inp_vals 

216 elif tl.constexpr(inp_vals.dtype.is_int()): 

217 inp_vals = inp_vals.to(tl.int32) 

218 else: 

219 inp_vals = inp_vals.to(tl.float32) 

220 if tl.constexpr(USE_OUT_INDICES): 

221 in_indices_ptrs = out_indices + offset 

222 in_indices_vals = tl.load(in_indices_ptrs, mask=mask) 

223 else: 

224 in_indices_vals = b_idx 

225 result, cummax_indices = tl_cummax(inp_vals, in_indices_vals, axis=0) 

226 

227 if tl.constexpr(NEED_PARTIAL): 

228 part_max_via_max, part_max_indices_via_max = tl_max_tie_break_right( 

229 inp_vals, in_indices_vals, axis=0 

230 ) 

231 if tl.constexpr(not USE_OUT_INDICES): 

232 part_max_indices_via_max = pid_b * BLOCK_SIZE + part_max_indices_via_max 

233 

234 out_ptrs = out + offset 

235 tl.store(out_ptrs, result, mask=mask) 

236 

237 out_indices_ptrs = out_indices + offset 

238 tl.store(out_indices_ptrs, cummax_indices, mask=mask) 

239 

240 if tl.constexpr(NEED_PARTIAL): 

241 partial_max_ptrs = partial_max + part_offset 

242 tl.store(partial_max_ptrs, part_max_via_max) 

243 

244 partial_max_indices_ptrs = partial_max_indices + part_offset 

245 tl.store(partial_max_indices_ptrs, part_max_indices_via_max) 

246 

247 

248@libentry() 

249@triton.jit(do_not_specialize=["part_num"]) 

250def add_base_max_abc_kernel( 

251 out, 

252 out_indices, 

253 partial_max, 

254 partial_max_indices, 

255 B, 

256 C, 

257 part_num, 

258 BLOCK_SIZE: tl.constexpr, 

259): 

260 pid_a = tle.program_id(0) 

261 pid_b = tle.program_id(1) 

262 pid_c = tle.program_id(2) 

263 

264 a_idx = pid_a 

265 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

266 c_idx = pid_c 

267 

268 base_offset = a_idx * B * C + c_idx 

269 offset = base_offset + b_idx * C 

270 base_part_offset = a_idx * part_num * C + c_idx 

271 last_part_offset = base_part_offset + (pid_b - 1) * C 

272 

273 mask = b_idx < B 

274 out_ptrs = out + offset 

275 out_vals = tl.load(out_ptrs, mask=mask) 

276 out_indices_ptrs = out_indices + offset 

277 out_indices = tl.load(out_indices_ptrs, mask=mask) 

278 

279 if pid_b > 0: 

280 partial_max_ptrs = partial_max + last_part_offset 

281 last_part_max_via_max = tl.load(partial_max_ptrs) 

282 partial_max_index_ptrs = partial_max_indices + last_part_offset 

283 last_part_max_index_via_max = tl.load(partial_max_index_ptrs) 

284 

285 use_cur = out_vals > last_part_max_via_max 

286 equal = out_vals == last_part_max_via_max 

287 cur_is_nan = out_vals != out_vals 

288 prev_is_nan = last_part_max_via_max != last_part_max_via_max 

289 use_cur |= cur_is_nan & ~prev_is_nan 

290 equal |= cur_is_nan & prev_is_nan 

291 use_cur |= equal & (out_indices > last_part_max_index_via_max) 

292 

293 final_vals = tl.where(use_cur, out_vals, last_part_max_via_max) 

294 final_indices = tl.where(use_cur, out_indices, last_part_max_index_via_max) 

295 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask) 

296 tl.store(out_indices_ptrs, final_indices, mask=mask) 

297 

298 

299def scan_then_fan(inp, out, out_indices, A, B, C, dtype, use_out_indices=False): 

300 BLOCK_SIZE = 512 

301 if B <= BLOCK_SIZE: 

302 BLOCK_SIZE = triton.next_power_of_2(B) 

303 part_num = math.ceil(B / BLOCK_SIZE) 

304 need_partial = True if part_num >= 2 else False 

305 if need_partial: 

306 partial_max = torch.empty(A, part_num, C, dtype=dtype, device=inp.device) 

307 partial_max_indices = torch.empty( 

308 A, part_num, C, dtype=torch.int64, device=inp.device 

309 ) 

310 else: 

311 partial_max = None 

312 partial_max_indices = None 

313 

314 grid = (A, part_num, C) 

315 with torch_device_fn.device(inp.device): 

316 scan_part_max_abc_kernel[grid]( 

317 inp, 

318 out, 

319 out_indices, 

320 out_indices, 

321 partial_max, 

322 partial_max_indices, 

323 B, 

324 C, 

325 part_num, 

326 BLOCK_SIZE, 

327 need_partial, 

328 use_out_indices, 

329 ) 

330 

331 if part_num >= 2: 

332 scan_then_fan( 

333 partial_max, 

334 partial_max, 

335 partial_max_indices, 

336 A, 

337 part_num, 

338 C, 

339 dtype, 

340 use_out_indices=True, 

341 ) 

342 with torch_device_fn.device(inp.device): 

343 add_base_max_abc_kernel[grid]( 

344 out, 

345 out_indices, 

346 partial_max, 

347 partial_max_indices, 

348 B, 

349 C, 

350 part_num, 

351 BLOCK_SIZE, 

352 ) 

353 

354 

355@libentry() 

356@triton.jit() 

357def scan_part_max_abc_loop_kernel( 

358 inp, 

359 out, 

360 out_indices, 

361 B, 

362 C, 

363 loop_num, 

364 BLOCK_SIZE: tl.constexpr, 

365): 

366 pid_a = tle.program_id(0) 

367 pid_c = tle.program_id(1) 

368 

369 a_idx = pid_a 

370 c_idx = pid_c 

371 t_idx = tl.arange(0, BLOCK_SIZE) 

372 ac_offset = a_idx * B * C + c_idx 

373 

374 # init, promote low precision types 

375 min_value = get_dtype_min(inp.type.element_ty) 

376 if tl.constexpr(inp.type.element_ty.is_fp16()) or tl.constexpr( 

377 inp.type.element_ty.is_bf16() 

378 ): 

379 compute_dtype = tl.float32 

380 elif tl.constexpr(inp.type.element_ty.is_int8()) or tl.constexpr( 

381 inp.type.element_ty.is_int16() 

382 ): 

383 compute_dtype = tl.int32 

384 else: 

385 compute_dtype = inp.type.element_ty 

386 

387 prev_max_val = tl.full([], min_value, dtype=compute_dtype) 

388 prev_max_val_idx = tl.full([], 0, dtype=tl.int64) 

389 last_mask = t_idx == (BLOCK_SIZE - 1) 

390 

391 for l_idx in tl.range(loop_num): 

392 b_idx = l_idx * BLOCK_SIZE + t_idx 

393 mask = b_idx < B 

394 offset = ac_offset + b_idx * C 

395 

396 inp_vals = tl.load(inp + offset, mask=mask, other=min_value) 

397 # Only promote if necessary 

398 if tl.constexpr(compute_dtype != inp.type.element_ty): 

399 vals = inp_vals.to(compute_dtype) 

400 else: 

401 vals = inp_vals 

402 idxs = b_idx 

403 

404 # cummax 

405 result, cummax_indices = tl_cummax(vals, idxs, axis=0) 

406 

407 # broadcast 

408 prev_max_val_b = tl.broadcast_to(prev_max_val, (BLOCK_SIZE,)) 

409 prev_max_val_idx_b = tl.broadcast_to(prev_max_val_idx, (BLOCK_SIZE,)) 

410 

411 # Handle NaN and tie-breaking logic 

412 if tl.constexpr(compute_dtype.is_floating()): 

413 # For floats: handle NaN propagation + tie-break right 

414 prev_is_nan = prev_max_val != prev_max_val 

415 result_is_nan = result != result 

416 prev_nan_mask = tl.broadcast_to(prev_is_nan, (BLOCK_SIZE,)) 

417 

418 use_result = result_is_nan | (~prev_nan_mask & (result >= prev_max_val_b)) 

419 else: 

420 # For integers: simple tie-break right 

421 use_result = result >= prev_max_val_b 

422 

423 final_vals = tl.where(use_result, result, prev_max_val_b) 

424 final_indices = tl.where(use_result, cummax_indices, prev_max_val_idx_b) 

425 

426 # update global max val and idx 

427 prev_max_val = tl.sum(tl.where(last_mask, final_vals, 0), axis=0) 

428 prev_max_val_idx = tl.sum(tl.where(last_mask, final_indices, 0), axis=0) 

429 

430 # store result 

431 tl.store(out + offset, final_vals.to(out.type.element_ty), mask=mask) 

432 tl.store(out_indices + offset, final_indices, mask=mask) 

433 

434 

435def scan_then_fan_loop(inp, out, out_indices, A, B, C, dtype): 

436 BLOCK_SIZE = 512 

437 if B < BLOCK_SIZE: 

438 BLOCK_SIZE = triton.next_power_of_2(B) 

439 loop_num = math.ceil(B / BLOCK_SIZE) 

440 

441 grid = (A, C) 

442 with torch_device_fn.device(inp.device): 

443 scan_part_max_abc_loop_kernel[grid]( 

444 inp, 

445 out, 

446 out_indices, 

447 B, 

448 C, 

449 loop_num, 

450 BLOCK_SIZE, 

451 ) 

452 

453 

454def cummax( 

455 input: Tensor, 

456 dim: int, 

457 *, 

458 out: Union[Tensor, Tuple[Tensor, ...], List[Tensor], None] = None, 

459) -> torch.return_types.cummax: 

460 logger.debug("GEMS_ASCEND CUMMAX") 

461 assert dim >= -input.ndim and dim < input.ndim, "Invalid dim" 

462 shape = input.shape 

463 dim = dim % input.ndim 

464 M = 1 

465 N = shape[dim] 

466 for i in range(dim): 

467 M *= shape[i] 

468 input = input.contiguous() 

469 K = input.numel() // M // N 

470 

471 dtype = input.dtype 

472 if dtype is torch.bool: 

473 dtype = torch.int64 

474 out = torch.empty_like(input, dtype=dtype) 

475 out_indices = torch.empty_like(input, dtype=torch.int64) 

476 

477 compute_dtype = out.dtype 

478 if input.dtype == torch.float16 or input.dtype == torch.bfloat16: 

479 compute_dtype = torch.float32 

480 

481 if M == 1 and K == 1: 

482 scan_then_fan_col(input, out, out_indices, N, compute_dtype) 

483 elif M * K <= 16: 

484 scan_then_fan(input, out, out_indices, M, N, K, compute_dtype) 

485 else: 

486 scan_then_fan_loop(input, out, out_indices, M, N, K, compute_dtype) 

487 return out, out_indices