Coverage for src/flag_gems/runtime/backend/_arm/ops/cumsum.py: 0%

242 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import device, torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12device = device.name 

13 

14 

15# @libentry() 

16@triton.jit(do_not_specialize=["n_elements", "part_num"]) 

17def scan_part_sum_kernel( 

18 inp, 

19 out, 

20 partial_sum, 

21 n_elements, 

22 part_num, 

23 BLOCK_SIZE: tl.constexpr, 

24): 

25 pid = tle.program_id(0) 

26 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

27 mask = offset < n_elements 

28 

29 inp_ptrs = inp + offset 

30 inp_vals = tl.load(inp_ptrs, mask=mask) 

31 if ( 

32 tl.constexpr(inp_vals.dtype.is_int64()) 

33 or tl.constexpr(inp_vals.dtype.is_uint64()) 

34 ) or tl.constexpr(inp_vals.dtype.is_fp64()): 

35 inp_vals = inp_vals 

36 elif tl.constexpr(inp_vals.dtype.is_int()): 

37 inp_vals = inp_vals.to(tl.int32) 

38 else: 

39 inp_vals = inp_vals.to(tl.float32) 

40 result = tl.cumsum(inp_vals, axis=0) 

41 

42 part_sum_via_sum = tl.sum(inp_vals) 

43 

44 out_ptrs = out + offset 

45 tl.store(out_ptrs, result, mask=mask) 

46 

47 partial_sum_ptrs = partial_sum + pid 

48 tl.store(partial_sum_ptrs, part_sum_via_sum) 

49 

50 

51# @libentry() 

52@triton.jit(do_not_specialize=["n_elements", "part_num"]) 

53def add_base_sum_kernel( 

54 out, 

55 partial_sum, 

56 n_elements, 

57 part_num, 

58 BLOCK_SIZE: tl.constexpr, 

59): 

60 pid = tle.program_id(0) 

61 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

62 mask = offset < n_elements 

63 

64 out_ptrs = out + offset 

65 out_vals = tl.load(out_ptrs, mask=mask) 

66 

67 if pid > 0: 

68 partial_sum_ptrs = partial_sum + pid - 1 

69 last_part_sum_via_sum = tl.load(partial_sum_ptrs) 

70 

71 final_vals = out_vals + last_part_sum_via_sum 

72 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask) 

73 

74 

75# @libentry() 

76@triton.jit(do_not_specialize=["part_num"]) 

77def scan_part_sum_abc_kernel( 

78 inp, 

79 out, 

80 partial_sum, 

81 B, 

82 C, 

83 part_num, 

84 BLOCK_SIZE: tl.constexpr, 

85): 

86 pid_a = tle.program_id(0) 

87 pid_b = tle.program_id(1) 

88 pid_c = tle.program_id(2) 

89 

90 a_idx = pid_a 

91 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

92 c_idx = pid_c 

93 

94 offset = a_idx * B * C + b_idx * C + c_idx 

95 base_part_offset = a_idx * part_num * C + c_idx 

96 part_offset = base_part_offset + pid_b * C 

97 

98 mask = b_idx < B 

99 inp_ptrs = inp + offset 

100 inp_vals = tl.load(inp_ptrs, mask=mask) 

101 if ( 

102 tl.constexpr(inp_vals.dtype.is_int64()) 

103 or tl.constexpr(inp_vals.dtype.is_uint64()) 

104 ) or tl.constexpr(inp_vals.dtype.is_fp64()): 

105 inp_vals = inp_vals 

106 elif tl.constexpr(inp_vals.dtype.is_int()): 

107 inp_vals = inp_vals.to(tl.int32) 

108 else: 

109 inp_vals = inp_vals.to(tl.float32) 

110 result = tl.cumsum(inp_vals, axis=0) 

111 

112 part_sum_via_sum = tl.sum(inp_vals) 

113 

114 out_ptrs = out + offset 

115 tl.store(out_ptrs, result, mask=mask) 

116 

117 partial_sum_ptrs = partial_sum + part_offset 

118 tl.store(partial_sum_ptrs, part_sum_via_sum) 

119 

120 

121# @libentry() 

122@triton.jit(do_not_specialize=["part_num"]) 

123def add_base_sum_abc_kernel( 

124 out, 

125 partial_sum, 

126 B, 

127 C, 

128 part_num, 

129 BLOCK_SIZE: tl.constexpr, 

130): 

131 pid_a = tle.program_id(0) 

132 pid_b = tle.program_id(1) 

133 pid_c = tle.program_id(2) 

134 

135 a_idx = pid_a 

136 b_idx = pid_b * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

137 c_idx = pid_c 

138 

139 base_offset = a_idx * B * C + c_idx 

140 offset = base_offset + b_idx * C 

141 base_part_offset = a_idx * part_num * C + c_idx 

142 last_part_offset = base_part_offset + (pid_b - 1) * C 

143 

144 mask = b_idx < B 

145 out_ptrs = out + offset 

146 out_vals = tl.load(out_ptrs, mask=mask) 

147 

148 if pid_b > 0: 

149 partial_sum_ptrs = partial_sum + last_part_offset 

150 last_part_sum_via_sum = tl.load(partial_sum_ptrs) 

151 

152 final_vals = out_vals + last_part_sum_via_sum 

153 tl.store(out_ptrs, final_vals.to(out_vals.dtype), mask=mask) 

154 

155 

156def scan_then_fan_col(inp, out, n_ele, dtype): 

157 # TODO(all): tune on target board 

158 BLOCK_SIZE = 64 

159 # if n_ele <= 1024 * 4: 

160 # BLOCK_SIZE = triton.next_power_of_2(n_ele) 

161 part_num = math.ceil(n_ele / BLOCK_SIZE) 

162 partial_sum = torch.empty(part_num, dtype=dtype, device=inp.device) 

163 

164 grid = (part_num,) 

165 # with torch_device_fn.device(inp.device): 

166 scan_part_sum_kernel[grid](inp, out, partial_sum, n_ele, part_num, BLOCK_SIZE) 

167 

168 if part_num >= 2: 

169 scan_then_fan_col(partial_sum, partial_sum, part_num, dtype) 

170 # with torch_device_fn.device(inp.device): 

171 add_base_sum_kernel[grid](out, partial_sum, n_ele, part_num, BLOCK_SIZE) 

172 

173 

174def scan_then_fan(inp, out, A, B, C, dtype): 

175 # TODO(all): tune on target board 

176 BLOCK_SIZE = 64 

177 # if B <= 1024 * 4: 

178 # BLOCK_SIZE = triton.next_power_of_2(B) 

179 part_num = math.ceil(B / BLOCK_SIZE) 

180 partial_sum = torch.empty(A, part_num, C, dtype=dtype, device=inp.device) 

181 

182 grid = (A, part_num, C) 

183 # with torch_device_fn.device(inp.device): 

184 scan_part_sum_abc_kernel[grid](inp, out, partial_sum, B, C, part_num, BLOCK_SIZE) 

185 

186 if part_num >= 2: 

187 scan_then_fan(partial_sum, partial_sum, A, part_num, C, dtype) 

188 # with torch_device_fn.device(inp.device): 

189 add_base_sum_abc_kernel[grid](out, partial_sum, B, C, part_num, BLOCK_SIZE) 

190 

191 

192def cumsum(inp, dim=1, *, dtype=None): 

193 logging.debug("GEMS CUMSUM") 

194 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

195 shape = inp.shape 

196 dim = dim % inp.ndim 

197 M = 1 

198 N = shape[dim] 

199 for i in range(dim): 

200 M *= shape[i] 

201 inp = inp.contiguous() 

202 K = inp.numel() // M // N 

203 

204 if dtype is None: 

205 dtype = inp.dtype 

206 if dtype in ( 

207 torch.bool, 

208 torch.int8, 

209 torch.uint8, 

210 torch.int16, 

211 torch.int32, 

212 torch.int64, 

213 ): 

214 dtype = torch.int64 

215 out = torch.empty_like(inp, dtype=dtype) 

216 

217 compute_dtype = out.dtype 

218 if inp.dtype == torch.float16 or inp.dtype == torch.bfloat16: 

219 compute_dtype = torch.float32 

220 

221 if M == 1 and K == 1: 

222 scan_then_fan_col(inp, out, N, compute_dtype) 

223 else: 

224 scan_then_fan(inp, out, M, N, K, compute_dtype) 

225 return out 

226 

227 

228@libentry() 

229@triton.jit(do_not_specialize=["K"]) 

230def normed_cumsum_kernel(inp, out, K, BLOCK: tl.constexpr): 

231 row_start = tle.program_id(0) * K 

232 row_off = tl.arange(0, BLOCK) 

233 x = tl.load(inp + row_start + row_off, mask=row_off < K, other=0) 

234 if x.dtype.is_fp16(): 

235 x = x.to(tl.float32) 

236 y_sum = tl.sum(x, 0) 

237 y = tl.cumsum(x, 0) 

238 y = y / y_sum 

239 tl.store(out + row_start + row_off, y, mask=row_off < K) 

240 

241 

242@libentry() 

243@triton.jit( 

244 do_not_specialize=[ 

245 "r", 

246 "t", 

247 "R", 

248 "K", 

249 "r_stride", 

250 "out_r_stride", 

251 ] 

252) 

253def block_cumsum_kernel( 

254 inp, 

255 out, 

256 sums, 

257 r, 

258 t, 

259 R, 

260 K, 

261 r_stride, 

262 k_stride, 

263 out_r_stride, 

264 out_k_stride, 

265 OUTPUT_SUMS: tl.constexpr, 

266 NORMALIZE: tl.constexpr, 

267 HAS_OUT_LAYOUT: tl.constexpr, 

268 TILE: tl.constexpr, 

269): 

270 # One CTA processes a (r, t*tile) chunk 

271 # rows = [ grid.y, grid.y + r ) 

272 # cols = [ grid.x * t * tile, (grid.x + 1) * t * tile ) 

273 gridx = tle.program_id(0).to(tl.int64) 

274 gridy = tle.program_id(1).to(tl.int64) 

275 n_chunks = tle.num_programs(0) 

276 

277 for row in range(gridy * r, min((gridy + 1) * r, R)): 

278 curr_cumsum = tl.zeros((1,), tl.float32) 

279 row_offset = row * r_stride 

280 cols = gridx * t * TILE + tl.arange(0, TILE) 

281 for ti in range(0, t): 

282 cols_offset = cols * k_stride 

283 x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0) 

284 if x.dtype.is_fp16() | x.dtype.is_bf16(): 

285 x = x.to(tl.float32) 

286 tile_sum = tl.sum(x, 0)[None] 

287 tile_cumsum = tl.cumsum(x, 0) + curr_cumsum 

288 curr_cumsum += tile_sum 

289 if HAS_OUT_LAYOUT: 

290 cols_offset = cols * out_k_stride 

291 row_offset = row * out_r_stride 

292 tl.store(out + row_offset + cols_offset, tile_cumsum, mask=cols < K) 

293 if OUTPUT_SUMS: 

294 tl.store(sums + row * n_chunks + gridx[None], curr_cumsum) 

295 cols += TILE 

296 if NORMALIZE: 

297 cols = gridx * t * TILE + tl.arange(0, TILE) 

298 for _ in range(0, t): 

299 cols_offset = cols * k_stride 

300 if HAS_OUT_LAYOUT: 

301 cols_offset = cols * out_k_stride 

302 row_offset = row * out_r_stride 

303 x = tl.load(out + row_offset + cols_offset, mask=cols < K, other=0) 

304 if x.dtype.is_fp16() | x.dtype.is_bf16(): 

305 x = x.to(tl.float32) 

306 x = x / curr_cumsum 

307 tl.store(out + row_offset + cols_offset, x, mask=cols < K) 

308 cols += TILE 

309 

310 

311@libentry() 

312@triton.jit( 

313 do_not_specialize=[ 

314 "r", 

315 "t", 

316 "R", 

317 "K", 

318 "r_stride", 

319 "out_r_stride", 

320 ] 

321) 

322def block_update_kernel( 

323 inp, 

324 base, 

325 rscale_ptr, 

326 out, 

327 r, 

328 t, 

329 R, 

330 K, 

331 r_stride, 

332 k_stride, 

333 out_r_stride, 

334 out_k_stride, 

335 rscale_stride, 

336 HAS_OUT_LAYOUT: tl.constexpr, 

337 TILE: tl.constexpr, 

338): 

339 # One CTA processes a (r, t*tile) chunk 

340 # rows = [ grid.y, grid.y + r ) 

341 # cols = [ grid.x * t * tile, (grid.x + 1) * t * tile ) 

342 gridx = tle.program_id(0).to(tl.int64) 

343 gridy = tle.program_id(1).to(tl.int64) 

344 n_gridx = tle.num_programs(1) 

345 

346 base += gridy * n_gridx + gridx 

347 rscale_ptr += gridy * rscale_stride 

348 

349 for row in range(gridy, min(gridy + r, R)): 

350 d = tl.load(base) 

351 rscale = tl.load(rscale_ptr) 

352 base += gridx 

353 rscale_ptr += rscale_stride 

354 row_offset = row * r_stride 

355 cols = gridx * t * TILE + tl.arange(0, TILE) 

356 for _ in range(0, t): 

357 cols_offset = cols * k_stride 

358 x = tl.load(inp + row_offset + cols_offset, mask=cols < K, other=0) 

359 x += d 

360 x /= rscale 

361 if HAS_OUT_LAYOUT: 

362 cols_offset = cols * out_k_stride 

363 row_offset = row * out_r_stride 

364 tl.store(out + row_offset + cols_offset, x, mask=cols < K) 

365 cols += TILE 

366 

367 

368GRID_Y_LIMIT = 65535 

369 

370 

371def normed_cumsum(inp, dim=-1): 

372 logging.debug("GEMS NORMED_CUMSUM") 

373 assert inp.dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64) 

374 dim = dim % inp.ndim 

375 N = inp.numel() 

376 K = inp.size(dim) 

377 # inp = inp.contiguous() 

378 # First and last dims are easier to handle, but transpose the middle dim to the last 

379 ranked_dims = sorted(range(inp.ndim), key=lambda i: inp.stride(i), reverse=True) 

380 is_mid_dim = dim not in (ranked_dims[0], ranked_dims[-1]) 

381 if is_mid_dim: 

382 inp = inp.transpose(dim, -1).contiguous() 

383 dim = -1 

384 out = torch.empty_like(inp) 

385 # with torch_device_fn.device(inp.device.index): 

386 # Pass one, scan a (batch, n_tiles * TILE) sized block within each cta 

387 device_props = torch_device_fn.get_device_properties(device) 

388 if isinstance(device_props, dict): 

389 num_sms = int(device_props.get("multi_processor_count", 1)) 

390 else: 

391 num_sms = device_props.multi_processor_count 

392 TILE = 2048 

393 # Each row is split into n_chunks of chunks where each chunk is compised of 

394 # n_tiles of tiles. Different chunks are assigned to different ctas. 

395 n_rows = N // K 

396 n_chunks = min(triton.cdiv(num_sms, n_rows), triton.cdiv(K, TILE)) 

397 n_tiles = triton.cdiv(triton.cdiv(K, TILE), n_chunks) 

398 k_stride = inp.stride(dim) 

399 r_stride = inp.size(dim) if k_stride == 1 else 1 

400 if n_rows > GRID_Y_LIMIT: 

401 batch = triton.cdiv(n_rows, GRID_Y_LIMIT) 

402 n_batch = triton.cdiv(n_rows, batch) 

403 else: 

404 batch = 1 

405 n_batch = n_rows 

406 

407 grid = (n_chunks, n_batch) 

408 if n_chunks == 1: 

409 block_cumsum_kernel[grid]( 

410 inp, 

411 out, 

412 0, 

413 batch, 

414 n_tiles, 

415 n_rows, 

416 K, 

417 r_stride, 

418 k_stride, 

419 r_stride, 

420 k_stride, 

421 OUTPUT_SUMS=False, 

422 NORMALIZE=True, 

423 HAS_OUT_LAYOUT=False, 

424 TILE=TILE, 

425 ) 

426 return out 

427 

428 if inp.dtype != torch.float64: 

429 acc_dtype = torch.float32 

430 else: 

431 acc_dtype = torch.float64 

432 sums = torch.empty((n_rows, n_chunks), dtype=acc_dtype, device=inp.device) 

433 cumsums = torch.empty_like(sums) 

434 block_cumsum_kernel[grid]( 

435 inp, 

436 out, 

437 sums, 

438 batch, 

439 n_tiles, 

440 n_rows, 

441 K, 

442 r_stride, 

443 k_stride, 

444 r_stride, 

445 k_stride, 

446 OUTPUT_SUMS=True, 

447 NORMALIZE=False, 

448 HAS_OUT_LAYOUT=False, 

449 TILE=TILE, 

450 ) 

451 # Pass two, scan partial cumsums 

452 block_cumsum_kernel[(1, n_batch)]( 

453 sums, 

454 cumsums, 

455 0, 

456 batch, 

457 1, 

458 n_rows, 

459 n_chunks, 

460 n_chunks, 

461 1, 

462 n_chunks, 

463 1, 

464 OUTPUT_SUMS=False, 

465 NORMALIZE=False, 

466 HAS_OUT_LAYOUT=True, 

467 TILE=TILE, 

468 ) 

469 rscale = cumsums[..., -1] 

470 block_update_kernel[grid]( 

471 out, 

472 cumsums - sums, 

473 rscale, 

474 out, 

475 batch, 

476 n_tiles, 

477 n_rows, 

478 K, 

479 r_stride, 

480 k_stride, 

481 r_stride, 

482 k_stride, 

483 n_chunks, 

484 HAS_OUT_LAYOUT=False, 

485 TILE=TILE, 

486 ) 

487 return out