Coverage for src/flag_gems/runtime/backend/_ascend/ops/cumsum.py: 0%

258 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7import triton.runtime.driver as driver 

8 

9from flag_gems.runtime import device, torch_device_fn 

10from flag_gems.utils import libentry 

11from flag_gems.utils import triton_lang_extension as ext 

12 

13logger = logging.getLogger(f'flag_gems.runtime._ascend.ops.{__name__.split(".")[-1]}') 

14 

15 

16def get_npu_properties(): 

17 device = torch.npu.current_device() 

18 return driver.active.utils.get_device_properties(device) 

19 

20 

21device = device.name 

22 

23 

24@libentry() 

25@triton.jit(do_not_specialize=["n_elements", "part_num"]) 

26def scan_part_sum_kernel( 

27 inp, 

28 out, 

29 partial_sum, 

30 n_elements, 

31 part_num, 

32 BLOCK_SIZE: tl.constexpr, 

33): 

34 pid = ext.program_id(0) 

35 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

36 mask = offset < n_elements 

37 

38 inp_ptrs = inp + offset 

39 inp_vals = tl.load(inp_ptrs, mask=mask) 

40 if ( 

41 tl.constexpr(inp_vals.dtype.is_int64()) 

42 or tl.constexpr(inp_vals.dtype.is_uint64()) 

43 ) or tl.constexpr(inp_vals.dtype.is_fp64()): 

44 inp_vals = inp_vals 

45 elif tl.constexpr(inp_vals.dtype.is_int()): 

46 inp_vals = inp_vals.to(tl.int32) 

47 else: 

48 inp_vals = inp_vals.to(tl.float32) 

49 result = tl.cumsum(inp_vals, axis=0) 

50 

51 part_sum_via_sum = tl.sum(inp_vals) 

52 

53 out_ptrs = out + offset 

54 tl.store(out_ptrs, result, mask=mask) 

55 

56 partial_sum_ptrs = partial_sum + pid 

57 tl.store(partial_sum_ptrs, part_sum_via_sum) 

58 

59 

60@libentry() 

61@triton.jit(do_not_specialize=["n_elements", "part_num"]) 

62def add_base_sum_kernel( 

63 out, 

64 partial_sum, 

65 n_elements, 

66 part_num, 

67 BLOCK_SIZE: tl.constexpr, 

68): 

69 pid = ext.program_id(0) 

70 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

71 mask = offset < n_elements 

72 

73 out_ptrs = out + offset 

74 out_vals = tl.load(out_ptrs, mask=mask) 

75 

76 if pid > 0: 

77 partial_sum_ptrs = partial_sum + pid - 1 

78 last_part_sum_via_sum = tl.load(partial_sum_ptrs) 

79 

80 final_vals = out_vals + last_part_sum_via_sum 

81 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask) 

82 

83 

84@libentry() 

85@triton.jit(do_not_specialize=["part_num"]) 

86def scan_part_sum_abc_kernel( 

87 inp, 

88 out, 

89 partial_sum, 

90 B, 

91 C, 

92 part_num, 

93 BLOCK_SIZE: tl.constexpr, 

94): 

95 pid_a = ext.program_id(0) 

96 pid_b = ext.program_id(1) 

97 pid_c = ext.program_id(2) 

98 

99 a_idx = pid_a 

100 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

101 c_idx = pid_c 

102 

103 offset = a_idx * B * C + b_idx * C + c_idx 

104 base_part_offset = a_idx * part_num * C + c_idx 

105 part_offset = base_part_offset + pid_b * C 

106 

107 mask = b_idx < B 

108 inp_ptrs = inp + offset 

109 inp_vals = tl.load(inp_ptrs, mask=mask) 

110 if ( 

111 tl.constexpr(inp_vals.dtype.is_int64()) 

112 or tl.constexpr(inp_vals.dtype.is_uint64()) 

113 ) or tl.constexpr(inp_vals.dtype.is_fp64()): 

114 inp_vals = inp_vals 

115 elif tl.constexpr(inp_vals.dtype.is_int()): 

116 inp_vals = inp_vals.to(tl.int32) 

117 else: 

118 inp_vals = inp_vals.to(tl.float32) 

119 result = tl.cumsum(inp_vals, axis=0) 

120 

121 part_sum_via_sum = tl.sum(inp_vals) 

122 

123 out_ptrs = out + offset 

124 tl.store(out_ptrs, result, mask=mask) 

125 

126 partial_sum_ptrs = partial_sum + part_offset 

127 tl.store(partial_sum_ptrs, part_sum_via_sum) 

128 

129 

130@libentry() 

131@triton.jit(do_not_specialize=["part_num"]) 

132def add_base_sum_abc_kernel( 

133 out, 

134 partial_sum, 

135 B, 

136 C, 

137 part_num, 

138 BLOCK_SIZE: tl.constexpr, 

139): 

140 pid_a = ext.program_id(0) 

141 pid_b = ext.program_id(1) 

142 pid_c = ext.program_id(2) 

143 

144 a_idx = pid_a 

145 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

146 c_idx = pid_c 

147 

148 base_offset = a_idx * B * C + c_idx 

149 offset = base_offset + b_idx * C 

150 base_part_offset = a_idx * part_num * C + c_idx 

151 last_part_offset = base_part_offset + (pid_b - 1) * C 

152 

153 mask = b_idx < B 

154 out_ptrs = out + offset 

155 out_vals = tl.load(out_ptrs, mask=mask) 

156 

157 if pid_b > 0: 

158 partial_sum_ptrs = partial_sum + last_part_offset 

159 last_part_sum_via_sum = tl.load(partial_sum_ptrs) 

160 

161 final_vals = out_vals + last_part_sum_via_sum 

162 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask) 

163 

164 

165def scan_then_fan_col(inp, out, n_ele, dtype): 

166 BLOCK_SIZE = 1024 

167 if n_ele <= 1024 * 4: 

168 BLOCK_SIZE = triton.next_power_of_2(n_ele) 

169 part_num = math.ceil(n_ele / BLOCK_SIZE) 

170 partial_sum = torch.empty(part_num, dtype=dtype, device=inp.device) 

171 

172 grid = (part_num,) 

173 with torch_device_fn.device(inp.device): 

174 scan_part_sum_kernel[grid](inp, out, partial_sum, n_ele, part_num, BLOCK_SIZE) 

175 

176 if part_num >= 2: 

177 scan_then_fan_col(partial_sum, partial_sum, part_num, dtype) 

178 with torch_device_fn.device(inp.device): 

179 add_base_sum_kernel[grid](out, partial_sum, n_ele, part_num, BLOCK_SIZE) 

180 

181 

182def scan_then_fan(inp, out, A, B, C, dtype): 

183 BLOCK_SIZE = 1024 

184 if B <= 1024 * 4: 

185 BLOCK_SIZE = triton.next_power_of_2(B) 

186 part_num = math.ceil(B / BLOCK_SIZE) 

187 partial_sum = torch.empty(A, part_num, C, dtype=dtype, device=inp.device) 

188 

189 grid = (A, part_num, C) 

190 with torch_device_fn.device(inp.device): 

191 scan_part_sum_abc_kernel[grid]( 

192 inp, out, partial_sum, B, C, part_num, BLOCK_SIZE 

193 ) 

194 

195 if part_num >= 2: 

196 scan_then_fan(partial_sum, partial_sum, A, part_num, C, dtype) 

197 with torch_device_fn.device(inp.device): 

198 add_base_sum_abc_kernel[grid](out, partial_sum, B, C, part_num, BLOCK_SIZE) 

199 

200 

201def cumsum(inp, dim=1, *, dtype=None): 

202 logger.debug("GEMS_ASCEND CUMSUM") 

203 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

204 shape = inp.shape 

205 dim = dim % inp.ndim 

206 M = 1 

207 N = shape[dim] 

208 for i in range(dim): 

209 M *= shape[i] 

210 inp = inp.contiguous() 

211 K = inp.numel() // M // N 

212 

213 if dtype is None: 

214 dtype = inp.dtype 

215 if dtype is torch.bool: 

216 dtype = torch.int64 

217 if inp.dtype in (torch.int8, torch.int16, torch.int32, torch.uint8): 

218 dtype = torch.int64 

219 out = torch.empty_like(inp, dtype=dtype) 

220 

221 compute_dtype = out.dtype 

222 if inp.dtype == torch.float16 or inp.dtype == torch.bfloat16: 

223 compute_dtype = torch.float32 

224 

225 if M == 1 and K == 1: 

226 scan_then_fan_col(inp, out, N, compute_dtype) 

227 else: 

228 scan_then_fan(inp, out, M, N, K, compute_dtype) 

229 return out 

230 

231 

232@libentry() 

233@triton.jit(do_not_specialize=["K"]) 

234def normed_cumsum_kernel(inp, out, K, BLOCK: tl.constexpr): 

235 row_start = ext.program_id(0) * K 

236 row_off = tl.arange(0, BLOCK) 

237 x = tl.load(inp + row_start + row_off, mask=row_off < K, other=0) 

238 if x.dtype.is_fp16(): 

239 x = x.to(tl.float32) 

240 y_sum = tl.sum(x, 0) 

241 y = tl.cumsum(x, 0) 

242 y = y / y_sum 

243 tl.store(out + row_start + row_off, y, mask=row_off < K) 

244 

245 

246@libentry() 

247@triton.jit( 

248 do_not_specialize=[ 

249 "r", 

250 "t", 

251 "R", 

252 "K", 

253 "r_stride", 

254 "out_r_stride", 

255 ] 

256) 

257def block_cumsum_kernel( 

258 inp, 

259 out, 

260 sums, 

261 r, 

262 t, 

263 R, 

264 K, 

265 r_stride, 

266 k_stride, 

267 out_r_stride, 

268 out_k_stride, 

269 OUTPUT_SUMS: tl.constexpr, 

270 NORMALIZE: tl.constexpr, 

271 HAS_OUT_LAYOUT: tl.constexpr, 

272 TILE: tl.constexpr, 

273): 

274 # One CTA processes a (r, t*tile) chunk 

275 # rows = [ grid.y, grid.y + r ) 

276 # cols = [ grid.x * t * tile, (grid.x + 1) * t * tile ) 

277 gridx = ext.program_id(0).to(tl.int64) 

278 gridy = ext.program_id(1).to(tl.int64) 

279 n_chunks = ext.num_programs(0) 

280 

281 for row in range(gridy * r, min((gridy + 1) * r, R)): 

282 curr_cumsum = tl.zeros((1,), tl.float32) 

283 row_offset = row * r_stride 

284 cols_base = gridx * t * TILE + tl.arange(0, TILE) 

285 for ti in range(0, t): 

286 cols = cols_base + ti * TILE 

287 cols_offset = cols * k_stride 

288 x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0) 

289 if x.dtype.is_fp16() | x.dtype.is_bf16(): 

290 x = x.to(tl.float32) 

291 tile_sum = tl.sum(x, 0)[None] 

292 tile_cumsum = tl.cumsum(x, 0) + curr_cumsum 

293 curr_cumsum += tile_sum 

294 if HAS_OUT_LAYOUT: 

295 cols_offset = cols * out_k_stride 

296 row_offset = row * out_r_stride 

297 tl.store(out + row_offset + cols_offset, tile_cumsum, mask=cols < K) 

298 if OUTPUT_SUMS: 

299 tl.store(sums + row * n_chunks + gridx[None], curr_cumsum) 

300 if NORMALIZE: 

301 cols_base = gridx * t * TILE + tl.arange(0, TILE) 

302 for ti in range(0, t): 

303 cols = cols_base + ti * TILE 

304 cols_offset = cols * k_stride 

305 if HAS_OUT_LAYOUT: 

306 cols_offset = cols * out_k_stride 

307 row_offset = row * out_r_stride 

308 x = tl.load(out + row_offset + cols_offset, mask=cols < K, other=0) 

309 if x.dtype.is_fp16() | x.dtype.is_bf16(): 

310 x = x.to(tl.float32) 

311 x = x / curr_cumsum 

312 tl.store(out + row_offset + cols_offset, x, mask=cols < K) 

313 

314 

315@libentry() 

316@triton.jit( 

317 do_not_specialize=[ 

318 "r", 

319 "t", 

320 "R", 

321 "K", 

322 "r_stride", 

323 "out_r_stride", 

324 ] 

325) 

326def block_update_kernel( 

327 inp, 

328 base, 

329 rscale_ptr, 

330 out, 

331 r, 

332 t, 

333 R, 

334 K, 

335 r_stride, 

336 k_stride, 

337 out_r_stride, 

338 out_k_stride, 

339 rscale_stride, 

340 HAS_OUT_LAYOUT: tl.constexpr, 

341 TILE: tl.constexpr, 

342): 

343 # One CTA processes a (r, t*tile) chunk 

344 # rows = [ grid.y, grid.y + r ) 

345 # cols = [ grid.x * t * tile, (grid.x + 1) * t * tile ) 

346 gridx = ext.program_id(0).to(tl.int64) 

347 gridy = ext.program_id(1).to(tl.int64) 

348 n_gridx = ext.num_programs(1) 

349 

350 base += gridy * n_gridx + gridx 

351 rscale_ptr += gridy * rscale_stride 

352 

353 for row in range(gridy, min(gridy + r, R)): 

354 d = tl.load(base) 

355 rscale = tl.load(rscale_ptr) 

356 base += gridx 

357 rscale_ptr += rscale_stride 

358 row_offset = row * r_stride 

359 cols = gridx * t * TILE + tl.arange(0, TILE) 

360 for _ in range(0, t): 

361 cols_offset = cols * k_stride 

362 x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0) 

363 x += d 

364 x /= rscale 

365 if HAS_OUT_LAYOUT: 

366 cols_offset = cols * out_k_stride 

367 row_offset = row * out_r_stride 

368 tl.store(out + row_offset + cols_offset, x, mask=cols < K) 

369 cols += TILE 

370 

371 

372GRID_Y_LIMIT = 65535 

373 

374 

375def normed_cumsum(inp, dim=-1): 

376 logger.debug("GEMS_ASCEND NORMED_CUMSUM") 

377 assert inp.dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64) 

378 dim = dim % inp.ndim 

379 N = inp.numel() 

380 K = inp.size(dim) 

381 # inp = inp.contiguous() 

382 # First and last dims are easier to handle, but transpose the middle dim to the last 

383 ranked_dims = sorted(range(inp.ndim), key=lambda i: inp.stride(i), reverse=True) 

384 is_mid_dim = dim not in (ranked_dims[0], ranked_dims[-1]) 

385 if is_mid_dim: 

386 inp = inp.transpose(dim, -1).contiguous() 

387 dim = -1 

388 out = torch.empty_like(inp) 

389 with torch_device_fn.device(inp.device.index): 

390 # Pass one, scan a (batch, n_tiles * TILE) sized block within each cta 

391 num_sms = get_npu_properties()["num_vectorcore"] 

392 TILE = 2048 

393 # Each row is split into n_chunks of chunks where each chunk is compised of 

394 # n_tiles of tiles. Different chunks are assigned to different ctas. 

395 n_rows = N // K 

396 n_chunks = min(triton.cdiv(num_sms, n_rows), triton.cdiv(K, TILE)) 

397 n_tiles = triton.cdiv(triton.cdiv(K, TILE), n_chunks) 

398 k_stride = inp.stride(dim) 

399 r_stride = inp.size(dim) if k_stride == 1 else 1 

400 if n_rows > GRID_Y_LIMIT: 

401 batch = triton.cdiv(n_rows, GRID_Y_LIMIT) 

402 n_batch = triton.cdiv(n_rows, batch) 

403 else: 

404 batch = 1 

405 n_batch = n_rows 

406 

407 grid = (n_chunks, n_batch) 

408 if n_chunks == 1: 

409 block_cumsum_kernel[grid]( 

410 inp, 

411 out, 

412 0, 

413 batch, 

414 n_tiles, 

415 n_rows, 

416 K, 

417 r_stride, 

418 k_stride, 

419 r_stride, 

420 k_stride, 

421 OUTPUT_SUMS=False, 

422 NORMALIZE=True, 

423 HAS_OUT_LAYOUT=False, 

424 TILE=TILE, 

425 ) 

426 return out 

427 

428 if inp.dtype != torch.float64: 

429 acc_dtype = torch.float32 

430 sums = torch.empty((n_rows, n_chunks), dtype=acc_dtype, device=device.name) 

431 cumsums = torch.empty_like(sums) 

432 block_cumsum_kernel[grid]( 

433 inp, 

434 out, 

435 sums, 

436 batch, 

437 n_tiles, 

438 n_rows, 

439 K, 

440 r_stride, 

441 k_stride, 

442 r_stride, 

443 k_stride, 

444 OUTPUT_SUMS=True, 

445 NORMALIZE=False, 

446 HAS_OUT_LAYOUT=False, 

447 TILE=TILE, 

448 ) 

449 # Pass two, scan partial cumsums 

450 block_cumsum_kernel[(1, n_batch)]( 

451 sums, 

452 cumsums, 

453 0, 

454 batch, 

455 1, 

456 n_rows, 

457 n_chunks, 

458 n_chunks, 

459 1, 

460 n_chunks, 

461 1, 

462 OUTPUT_SUMS=False, 

463 NORMALIZE=False, 

464 HAS_OUT_LAYOUT=True, 

465 TILE=TILE, 

466 ) 

467 # logger.debug(sums) 

468 rscale = cumsums[..., -1] 

469 block_update_kernel[grid]( 

470 out, 

471 cumsums - sums, 

472 rscale, 

473 out, 

474 batch, 

475 n_tiles, 

476 n_rows, 

477 K, 

478 r_stride, 

479 k_stride, 

480 r_stride, 

481 k_stride, 

482 n_chunks, 

483 HAS_OUT_LAYOUT=False, 

484 TILE=TILE, 

485 ) 

486 return out