Coverage for src/flag_gems/runtime/backend/_sunrise/ops/sum.py: 0%

280 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2import math 

3from functools import reduce 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems import runtime 

10from flag_gems.ops.zeros import zero_ 

11from flag_gems.runtime import torch_device_fn 

12from flag_gems.utils import dim_compress, libentry, libtuner 

13from flag_gems.utils import triton_lang_extension as ext 

14 

15logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

16 

17 

18@libentry() 

19@triton.jit 

20def sum_kernel_1( 

21 inp, 

22 mid, 

23 M, 

24 BLOCK_SIZE: tl.constexpr, 

25): 

26 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr( 

27 inp.dtype.element_ty == tl.bfloat16 

28 ): 

29 cdtype = tl.float32 

30 else: 

31 cdtype = inp.dtype.element_ty 

32 

33 pid = ext.program_id(0) 

34 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

35 inp_ptrs = inp + offset 

36 mask = offset < M 

37 

38 inp_val = tl.load(inp_ptrs, mask=mask, other=0).to(cdtype) 

39 sum_val = tl.sum(inp_val) 

40 mid_ptr = mid + pid 

41 tl.store(mid_ptr, sum_val) 

42 

43 

44@libentry() 

45@triton.jit 

46def sum_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr): 

47 if tl.constexpr(mid.dtype.element_ty == tl.float16) or tl.constexpr( 

48 mid.dtype.element_ty == tl.bfloat16 

49 ): 

50 cdtype = tl.float32 

51 else: 

52 cdtype = mid.dtype.element_ty 

53 

54 offset = tl.arange(0, BLOCK_MID) 

55 mid_ptrs = mid + offset 

56 mask = offset < mid_size 

57 mid_val = tl.load(mid_ptrs, mask=mask, other=0).to(cdtype) 

58 sum_val = tl.sum(mid_val) 

59 tl.store(out, sum_val) 

60 

61 

62@libentry() 

63@triton.autotune(configs=runtime.get_tuned_config("sum"), key=["M", "N"]) 

64@triton.jit 

65def sum_kernel_dim0( 

66 inp, 

67 out, 

68 M, 

69 N, 

70 BLOCK_M: tl.constexpr, 

71 BLOCK_N: tl.constexpr, 

72): 

73 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr( 

74 inp.dtype.element_ty == tl.bfloat16 

75 ): 

76 cdtype = tl.float32 

77 else: 

78 cdtype = inp.dtype.element_ty 

79 

80 # Map the program id to the row of inp it should compute. 

81 pid = ext.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[None, :] 

82 inp = inp + pid 

83 out = out + pid 

84 row_mask = pid < M 

85 

86 _sum = tl.zeros([BLOCK_N, BLOCK_M], dtype=cdtype) 

87 for off in range(0, N, BLOCK_N): 

88 cols = off + tl.arange(0, BLOCK_N)[:, None] 

89 col_mask = cols < N 

90 mask = row_mask & col_mask 

91 

92 a = tl.load(inp + cols * M, mask, other=0).to(cdtype) 

93 _sum += a 

94 sum = tl.sum(_sum, axis=0)[None, :] 

95 tl.store(out, sum, row_mask) 

96 

97 

98def sum(inp, *, dtype=None): 

99 logger.debug("GEMS SUM") 

100 inp = inp.contiguous() 

101 M = inp.numel() 

102 if dtype is None: 

103 dtype = inp.dtype 

104 if dtype is torch.bool: 

105 inp = inp.to(torch.int64) 

106 dtype = torch.int64 

107 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

108 mid_size = triton.cdiv(M, block_size) 

109 block_mid = triton.next_power_of_2(mid_size) 

110 

111 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

112 out = torch.empty([], dtype=dtype, device=inp.device) 

113 

114 with torch_device_fn.device(inp.device): 

115 sum_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size) 

116 sum_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid) 

117 return out 

118 

119 

120def sum_out(inp, *, dtype=None, out): 

121 logger.debug("GEMS SUM_OUT") 

122 M = inp.numel() 

123 if dtype is None: 

124 dtype = inp.dtype 

125 if dtype is torch.bool: 

126 inp = inp.to(torch.int64) 

127 dtype = torch.int64 

128 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

129 mid_size = triton.cdiv(M, block_size) 

130 block_mid = triton.next_power_of_2(mid_size) 

131 

132 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

133 with torch_device_fn.device(inp.device): 

134 sum_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size) 

135 sum_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid) 

136 return out 

137 

138 

139@libentry() 

140@triton.heuristics(runtime.get_heuristic_config("sum_non_inner")) 

141@triton.jit 

142def sum_dim_kernel_non_inner( 

143 output_ptr, 

144 input_ptr, 

145 M, 

146 N, 

147 K, 

148 TILE_N: tl.constexpr, 

149 TILE_K: tl.constexpr, 

150 ONE_TILE_PER_CTA: tl.constexpr, 

151): 

152 if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr( 

153 input_ptr.dtype.element_ty == tl.bfloat16 

154 ): 

155 cdtype = tl.float32 

156 else: 

157 cdtype = input_ptr.dtype.element_ty 

158 

159 pid_m = ext.program_id(0) 

160 pid_k = ext.program_id(1) 

161 

162 k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K)[None, :] 

163 

164 if ONE_TILE_PER_CTA: 

165 n_offsets = tl.arange(0, TILE_N)[:, None] 

166 inp_offset = pid_m * N * K + n_offsets * K + k_offsets 

167 mask = (n_offsets < N) & (k_offsets < K) 

168 input_ptrs = input_ptr + inp_offset 

169 inp = tl.load(input_ptrs, mask=mask, other=0).to(cdtype) 

170 out = tl.sum(inp, axis=0, keep_dims=True) 

171 out_offset = pid_m * K + k_offsets 

172 output_ptrs = output_ptr + out_offset 

173 tl.store(output_ptrs, out, mask=k_offsets < K) 

174 else: 

175 sum = tl.zeros([TILE_N, TILE_K], dtype=cdtype) 

176 

177 # specialization does not improve performance inn this example, as tested 

178 for start_n in range(0, N, TILE_N): 

179 n_offsets = start_n + tl.arange(0, TILE_N)[:, None] 

180 inp_offsets = pid_m * N * K + n_offsets * K + k_offsets 

181 mask = (n_offsets < N) & (k_offsets < K) 

182 inp = tl.load(input_ptr + inp_offsets, mask=mask, other=0).to(cdtype) 

183 sum += inp 

184 out = tl.sum(sum, axis=0, keep_dims=True) 

185 out_offset = pid_m * K + k_offsets 

186 output_ptrs = output_ptr + out_offset 

187 tl.store(output_ptrs, out, mask=k_offsets < K) 

188 

189 

190@libentry() 

191@triton.heuristics(runtime.get_heuristic_config("sum_inner")) 

192@triton.jit 

193def sum_dim_kernel_inner( 

194 output_ptr, 

195 input_ptr, 

196 M, 

197 N, 

198 TILE_N: tl.constexpr, 

199 ONE_TILE_PER_CTA: tl.constexpr, 

200): 

201 if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr( 

202 input_ptr.dtype.element_ty == tl.bfloat16 

203 ): 

204 cdtype = tl.float32 

205 else: 

206 cdtype = input_ptr.dtype.element_ty 

207 

208 pid_m = ext.program_id(0) 

209 if ONE_TILE_PER_CTA: 

210 n_offsets = tl.arange(0, TILE_N) 

211 inp_offset = pid_m * N + n_offsets 

212 input_ptrs = input_ptr + inp_offset 

213 mask = n_offsets < N 

214 inp = tl.load(input_ptrs, mask=mask, other=0).to(cdtype) 

215 out = tl.sum(inp, axis=0) 

216 out_offset = pid_m 

217 output_ptrs = output_ptr + out_offset 

218 tl.store(output_ptrs, out) 

219 else: 

220 sum = tl.zeros( 

221 [ 

222 TILE_N, 

223 ], 

224 dtype=cdtype, 

225 ) 

226 for start_n in range(0, N, TILE_N): 

227 n_offsets = start_n + tl.arange(0, TILE_N) 

228 inp_offsets = pid_m * N + n_offsets 

229 mask = n_offsets < N 

230 inp = tl.load(input_ptr + inp_offsets, mask=mask, other=0).to(cdtype) 

231 sum += inp 

232 out = tl.sum(sum, axis=0) 

233 out_offset = pid_m 

234 output_ptrs = output_ptr + out_offset 

235 tl.store(output_ptrs, out) 

236 

237 

238@libentry() 

239@libtuner( 

240 configs=runtime.get_tuned_config("naive_reduction"), 

241 key=["M", "N"], 

242) 

243@triton.jit 

244def sum_dim_kernel( 

245 inp, 

246 out, 

247 M, 

248 N, 

249 BLOCK_M: tl.constexpr, 

250 BLOCK_N: tl.constexpr, 

251): 

252 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr( 

253 inp.dtype.element_ty == tl.bfloat16 

254 ): 

255 cdtype = tl.float32 

256 else: 

257 cdtype = inp.dtype.element_ty 

258 

259 # Map the program id to the row of inp it should compute. 

260 pid = ext.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

261 inp = inp + pid * N 

262 out = out + pid 

263 row_mask = pid < M 

264 

265 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=cdtype) 

266 for off in range(0, N, BLOCK_N): 

267 cols = off + tl.arange(0, BLOCK_N)[None, :] 

268 col_mask = cols < N 

269 mask = row_mask and col_mask 

270 

271 a = tl.load(inp + cols, mask, other=0).to(cdtype) 

272 _sum += a 

273 sum = tl.sum(_sum, axis=1)[:, None] 

274 tl.store(out, sum, row_mask) 

275 

276 

277def sum_dim_comm(inp, dim=None, keepdim=False, *, dtype=None, out=None): 

278 if dtype is None: 

279 dtype = inp.dtype 

280 if dtype is torch.bool: 

281 dtype = torch.int64 

282 

283 if dim is None: 

284 result = torch.sum(inp, dtype=dtype) 

285 if keepdim: 

286 result = result.reshape([1] * inp.ndim) 

287 return result 

288 

289 if dim == []: 

290 if not keepdim: 

291 return sum(inp, dtype=dtype) 

292 else: 

293 dim_num = inp.ndim 

294 return torch.reshape(sum(inp, dtype=dtype), [1] * dim_num) 

295 

296 shape = list(inp.shape) 

297 dim = [d % inp.ndim for d in dim] 

298 

299 if check_dim0(inp, dim): 

300 return sum_dim0(inp, dim, keepdim, dtype) 

301 

302 if len(dim) == 1: 

303 dim = dim[0] 

304 N = inp.shape[dim] 

305 M = reduce(lambda x, y: x * y, shape[:dim], 1) 

306 inp = inp.contiguous() 

307 K = inp.numel() // M // N 

308 shape[dim] = 1 

309 if out is None: 

310 out = torch.empty(shape, dtype=dtype, device=inp.device) 

311 

312 with torch_device_fn.device(inp.device): 

313 if K > 1: 

314 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1) 

315 sum_dim_kernel_non_inner[grid]( 

316 out, 

317 inp, 

318 M, 

319 N, 

320 K, 

321 ) 

322 else: 

323 grid = (M, 1, 1) 

324 sum_dim_kernel_inner[grid]( 

325 out, 

326 inp, 

327 M, 

328 N, 

329 ) 

330 if not keepdim: 

331 out = out.squeeze(dim=dim) 

332 return out 

333 else: 

334 inp = dim_compress(inp, dim) 

335 N = 1 

336 for i in dim: 

337 N *= shape[i] 

338 shape[i] = 1 

339 M = inp.numel() // N 

340 if out is None: 

341 out = torch.empty(shape, dtype=dtype, device=inp.device) 

342 

343 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

344 with torch_device_fn.device(inp.device): 

345 sum_dim_kernel[grid](inp, out, M, N) 

346 if not keepdim: 

347 out = out.squeeze(dim=dim) 

348 return out 

349 

350 

351def check_dim0(inp, dim): 

352 shape = list(inp.shape) 

353 if len(shape) == len(dim): 

354 return False 

355 for i in dim: 

356 shape[i] = 1 

357 if shape == [1] * len(shape): 

358 return False 

359 

360 for i in range(max(dim)): 

361 if shape[i] > 1: 

362 return False 

363 return True 

364 

365 

366def sum_dim0(inp, dim, keepdim, dtype): 

367 shape = list(inp.shape) 

368 N = 1 

369 for i in dim: 

370 N *= shape[i] 

371 shape[i] = 1 

372 M = inp.numel() // N 

373 out = torch.empty(shape, dtype=dtype, device=inp.device) 

374 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

375 with torch_device_fn.device(inp.device): 

376 sum_kernel_dim0[grid](inp, out, M, N) 

377 if not keepdim: 

378 out = out.squeeze(dim=dim) 

379 return out 

380 

381 

382def sum_dim(inp, dim=None, keepdim=False, *, dtype=None): 

383 logger.debug("GEMS SUM_DIM") 

384 # support dim = 0, which are consistent with PyTorch 

385 if inp.numel() == 0: 

386 if dtype is None: 

387 dtype = inp.dtype 

388 if dtype is torch.bool: 

389 dtype = torch.int64 

390 

391 out_shape = list(inp.shape) 

392 if dim is None: 

393 if keepdim: 

394 out_shape = [1] * len(out_shape) 

395 else: 

396 out_shape = [] 

397 elif isinstance(dim, (list, tuple)) and len(dim) == 0: 

398 if keepdim: 

399 out_shape = [1] * len(out_shape) 

400 else: 

401 out_shape = [] 

402 else: 

403 dims_to_reduce = dim if isinstance(dim, (list, tuple)) else [dim] 

404 if keepdim: 

405 for d in dims_to_reduce: 

406 out_shape[d % inp.ndim] = 1 

407 else: 

408 sorted_dims_to_remove = sorted( 

409 dims_to_reduce, key=lambda x: x % inp.ndim, reverse=True 

410 ) 

411 for d in sorted_dims_to_remove: 

412 index_to_remove = d % inp.ndim 

413 out_shape.pop(index_to_remove) 

414 out = torch.empty(out_shape, dtype=dtype, device=inp.device) 

415 zero_(out) 

416 return out 

417 return sum_dim_comm(inp, dim, keepdim, dtype=dtype) 

418 

419 

420def sum_dim_out(inp, dim=None, keepdim=False, *, dtype=None, out): 

421 logger.debug("GEMS SUM_DIM_OUT") 

422 return sum_dim_comm(inp, dim, keepdim, dtype=dtype, out=out)