Coverage for src/flag_gems/ops/mean.py: 46%

223 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2import math 

3from functools import reduce 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems import runtime 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils import dim_compress, libentry, libtuner 

12from flag_gems.utils import triton_lang_extension as ext 

13 

14logger = logging.getLogger(__name__) 

15 

16 

17@libentry() 

18@triton.jit 

19def mean_kernel_1( 

20 inp, 

21 mid, 

22 M, 

23 BLOCK_SIZE: tl.constexpr, 

24): 

25 # accumulation dtype 

26 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr( 

27 inp.dtype.element_ty == tl.bfloat16 

28 ): 

29 cdtype = tl.float32 

30 else: 

31 cdtype = inp.dtype.element_ty 

32 

33 pid = ext.program_id(0) 

34 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

35 inp_ptrs = inp + offset 

36 mask = offset < M 

37 

38 inp_val = tl.load(inp_ptrs, mask=mask, other=0).to(cdtype) 

39 sum_val = tl.sum(inp_val) 

40 mid_ptr = mid + pid 

41 tl.store(mid_ptr, sum_val) 

42 

43 

44@libentry() 

45@triton.jit 

46def mean_kernel_2(mid, out, M, MID_SIZE, BLOCK_MID: tl.constexpr): 

47 if tl.constexpr(mid.dtype.element_ty == tl.float16) or tl.constexpr( 

48 mid.dtype.element_ty == tl.bfloat16 

49 ): 

50 cdtype = tl.float32 

51 else: 

52 cdtype = mid.dtype.element_ty 

53 

54 offset = tl.arange(0, BLOCK_MID) 

55 mid_ptrs = mid + offset 

56 mask = offset < MID_SIZE 

57 mid_val = tl.load(mid_ptrs, mask=mask, other=0).to(cdtype) 

58 sum_val = tl.sum(mid_val) 

59 # divide by total element count M to get mean 

60 mean_val = sum_val / M 

61 tl.store(out, mean_val) 

62 

63 

64def mean(inp, *, dtype=None): 

65 logger.debug("GEMS MEAN") 

66 inp = inp.contiguous() 

67 M = inp.numel() 

68 if dtype is None: 

69 dtype = inp.dtype 

70 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

71 mid_size = triton.cdiv(M, block_size) 

72 block_mid = triton.next_power_of_2(mid_size) 

73 

74 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

75 out = torch.empty([], dtype=dtype, device=inp.device) 

76 

77 with torch_device_fn.device(inp.device): 

78 mean_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size) 

79 mean_kernel_2[(1, 1, 1)](mid, out, M, mid_size, block_mid) 

80 return out 

81 

82 

83@libentry() 

84@triton.jit 

85def mean_dim_kernel_non_inner_vec( 

86 output_ptr, 

87 input_ptr, 

88 M, 

89 N, 

90 K, 

91 BLOCK_SIZE_K: tl.constexpr, # number of threads per block along K 

92 VEC_SIZE: tl.constexpr, # elements per thread (1 for FP32, 8 for FP16/BF16) 

93): 

94 # Determine accumulation and load behavior 

95 input_dtype = input_ptr.dtype.element_ty 

96 if tl.constexpr(input_dtype == tl.float16) or tl.constexpr( 

97 input_dtype == tl.bfloat16 

98 ): 

99 ACC_DTYPE = tl.float32 

100 # VEC_SIZE should be 4 or 8 for vectorization 

101 else: 

102 ACC_DTYPE = input_dtype 

103 # VEC_SIZE = 1 for FP32 

104 

105 pid_m = ext.program_id(0) 

106 pid_k = ext.program_id(1) 

107 

108 # Each thread handles VEC_SIZE consecutive elements 

109 k_base = pid_k * BLOCK_SIZE_K * VEC_SIZE 

110 k_offsets = ( 

111 k_base 

112 + tl.arange(0, BLOCK_SIZE_K)[:, None] * VEC_SIZE 

113 + tl.arange(0, VEC_SIZE)[None, :] 

114 ) 

115 # Shape: [BLOCK_SIZE_K, VEC_SIZE] 

116 k_mask = k_offsets < K 

117 

118 # Accumulator: [BLOCK_SIZE_K, VEC_SIZE] 

119 acc = tl.zeros((BLOCK_SIZE_K, VEC_SIZE), dtype=ACC_DTYPE) 

120 

121 base = pid_m * N * K 

122 

123 for n in range(N): 

124 offsets = base + n * K + k_offsets 

125 # This will trigger vectorized load if VEC_SIZE >= 4 and aligned 

126 val = tl.load(input_ptr + offsets, mask=k_mask, other=0.0) 

127 acc += val.to(ACC_DTYPE) 

128 

129 mean_val = acc / N 

130 

131 # Store back 

132 out_offsets = pid_m * K + k_offsets 

133 tl.store(output_ptr + out_offsets, mean_val, mask=k_mask) 

134 

135 

136@libentry() 

137@triton.heuristics(runtime.get_heuristic_config("mean_non_inner")) 

138@triton.jit 

139def mean_dim_kernel_non_inner( 

140 output_ptr, 

141 input_ptr, 

142 M, 

143 N, 

144 K, 

145 TILE_N: tl.constexpr, 

146 TILE_K: tl.constexpr, 

147 ONE_TILE_PER_CTA: tl.constexpr, 

148): 

149 # accumulation dtype 

150 if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr( 

151 input_ptr.dtype.element_ty == tl.bfloat16 

152 ): 

153 cdtype = tl.float32 

154 else: 

155 cdtype = input_ptr.dtype.element_ty 

156 

157 pid_m = ext.program_id(0) 

158 pid_k = ext.program_id(1) 

159 

160 k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K)[None, :] 

161 

162 if ONE_TILE_PER_CTA: 

163 n_offsets = tl.arange(0, TILE_N)[:, None] 

164 inp_offset = pid_m * N * K + n_offsets * K + k_offsets 

165 mask = (n_offsets < N) & (k_offsets < K) 

166 input_ptrs = input_ptr + inp_offset 

167 inp = tl.load(input_ptrs, mask=mask, other=0).to(cdtype) 

168 # sum along reduction axis (N) -> keep dims so axis 0 corresponds to TILE_K 

169 summed = tl.sum(inp, axis=0, keep_dims=True) 

170 # divide by N to get mean 

171 out = summed / N 

172 out_offset = pid_m * K + k_offsets 

173 output_ptrs = output_ptr + out_offset 

174 tl.store(output_ptrs, out, mask=k_offsets < K) 

175 else: 

176 sum_tile = tl.zeros([TILE_N, TILE_K], dtype=cdtype) 

177 for start_n in range(0, N, TILE_N): 

178 n_offsets = start_n + tl.arange(0, TILE_N)[:, None] 

179 inp_offsets = pid_m * N * K + n_offsets * K + k_offsets 

180 mask = (n_offsets < N) & (k_offsets < K) 

181 inp = tl.load(input_ptr + inp_offsets, mask=mask, other=0).to(cdtype) 

182 sum_tile += inp 

183 summed = tl.sum(sum_tile, axis=0, keep_dims=True) 

184 out = summed / N 

185 out_offset = pid_m * K + k_offsets 

186 output_ptrs = output_ptr + out_offset 

187 tl.store(output_ptrs, out, mask=k_offsets < K) 

188 

189 

190@libentry() 

191@triton.heuristics(runtime.get_heuristic_config("softmax_inner")) 

192@triton.jit 

193def mean_dim_kernel_inner( 

194 output_ptr, 

195 input_ptr, 

196 M, 

197 N, 

198 TILE_N: tl.constexpr, 

199 ONE_TILE_PER_CTA: tl.constexpr, 

200): 

201 if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr( 

202 input_ptr.dtype.element_ty == tl.bfloat16 

203 ): 

204 cdtype = tl.float32 

205 else: 

206 cdtype = input_ptr.dtype.element_ty 

207 

208 pid_m = ext.program_id(0) 

209 if ONE_TILE_PER_CTA: 

210 n_offsets = tl.arange(0, TILE_N) 

211 inp_offset = pid_m * N + n_offsets 

212 input_ptrs = input_ptr + inp_offset 

213 mask = n_offsets < N 

214 inp = tl.load(input_ptrs, mask=mask, other=0).to(cdtype) 

215 summed = tl.sum(inp, axis=0) 

216 out = summed / N 

217 out_offset = pid_m 

218 output_ptrs = output_ptr + out_offset 

219 tl.store(output_ptrs, out) 

220 else: 

221 sum_vec = tl.zeros( 

222 [ 

223 TILE_N, 

224 ], 

225 dtype=cdtype, 

226 ) 

227 for start_n in range(0, N, TILE_N): 

228 n_offsets = start_n + tl.arange(0, TILE_N) 

229 inp_offsets = pid_m * N + n_offsets 

230 mask = n_offsets < N 

231 inp = tl.load(input_ptr + inp_offsets, mask=mask, other=0).to(cdtype) 

232 sum_vec += inp 

233 summed = tl.sum(sum_vec, axis=0) 

234 out = summed / N 

235 out_offset = pid_m 

236 output_ptrs = output_ptr + out_offset 

237 tl.store(output_ptrs, out) 

238 

239 

240@libentry() 

241@libtuner( 

242 configs=runtime.get_tuned_config("naive_reduction"), 

243 key=["M", "N"], 

244) 

245@triton.jit 

246def mean_dim_kernel( 

247 inp, 

248 out, 

249 M, 

250 N, 

251 BLOCK_M: tl.constexpr, 

252 BLOCK_N: tl.constexpr, 

253): 

254 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr( 

255 inp.dtype.element_ty == tl.bfloat16 

256 ): 

257 cdtype = tl.float32 

258 else: 

259 cdtype = inp.dtype.element_ty 

260 

261 # Map the program id to the row of inp it should compute. 

262 pid = ext.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

263 inp = inp + pid * N 

264 out = out + pid 

265 row_mask = pid < M 

266 

267 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=cdtype) 

268 for off in range(0, N, BLOCK_N): 

269 cols = off + tl.arange(0, BLOCK_N)[None, :] 

270 col_mask = cols < N 

271 mask = row_mask and col_mask 

272 

273 a = tl.load(inp + cols, mask, other=0).to(cdtype) 

274 _sum += a 

275 summed = tl.sum(_sum, axis=1)[:, None] 

276 mean = summed / N 

277 tl.store(out, mean, row_mask) 

278 

279 

280def mean_dim_comm(inp, dim=None, keepdim=False, *, dtype=None, out=None): 

281 logger.debug("GEMS MEAN_DIM") 

282 if dtype is None: 

283 dtype = inp.dtype 

284 if dtype is torch.bool: 

285 inp = inp.to(torch.int64) 

286 dtype = torch.int64 

287 

288 if dim == []: 

289 # mean over all elements 

290 if not keepdim: 

291 return mean(inp, dtype=dtype) 

292 else: 

293 dim_num = inp.ndim 

294 return torch.reshape(mean(inp, dtype=dtype), [1] * dim_num) 

295 

296 shape = list(inp.shape) 

297 

298 # -------- normalize dim to a list of ints -------- 

299 if isinstance(dim, int): 

300 dim = [dim] 

301 else: 

302 try: 

303 dim = list(dim) 

304 except TypeError: 

305 raise TypeError( 

306 f"dim must be an int, iterable of ints, or [], got {type(dim)}" 

307 ) 

308 

309 dim = [d % inp.ndim for d in dim] 

310 # ------------------------------------------------- 

311 

312 if len(dim) == 1: 

313 dim0 = dim[0] 

314 N = inp.shape[dim0] # reduction length 

315 # product of dims before dim0; use initializer 1 for empty slice 

316 M = reduce(lambda x, y: x * y, shape[:dim0], 1) 

317 inp = inp.contiguous() 

318 K = inp.numel() // M // N 

319 shape[dim0] = 1 

320 if out is None: 

321 out = torch.empty(shape, dtype=dtype, device=inp.device) 

322 

323 with torch_device_fn.device(inp.device): 

324 if K >= 1024: 

325 input_dtype = inp.dtype 

326 if input_dtype in (torch.float16, torch.bfloat16): 

327 VEC_SIZE = 8 

328 BLOCK_SIZE_K = 128 

329 else: 

330 VEC_SIZE = 1 

331 BLOCK_SIZE_K = min(triton.next_power_of_2(K), 512) 

332 grid = (M, triton.cdiv(K, BLOCK_SIZE_K * VEC_SIZE)) 

333 mean_dim_kernel_non_inner_vec[grid]( 

334 out, 

335 inp, 

336 M, 

337 N, 

338 K, 

339 BLOCK_SIZE_K=BLOCK_SIZE_K, 

340 VEC_SIZE=VEC_SIZE, 

341 num_warps=8 if BLOCK_SIZE_K <= 128 else 16, 

342 ) 

343 elif K > 1: 

344 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1) 

345 mean_dim_kernel_non_inner[grid]( 

346 out, 

347 inp, 

348 M, 

349 N, 

350 K, 

351 ) 

352 else: 

353 grid = (M, 1, 1) 

354 mean_dim_kernel_inner[grid]( 

355 out, 

356 inp, 

357 M, 

358 N, 

359 ) 

360 if not keepdim: 

361 out = out.squeeze(dim=dim0) 

362 return out 

363 else: 

364 inp = dim_compress(inp, dim) 

365 N = 1 

366 for i in dim: 

367 N *= shape[i] 

368 shape[i] = 1 

369 M = inp.numel() // N 

370 if out is None: 

371 out = torch.empty(shape, dtype=dtype, device=inp.device) 

372 

373 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

374 with torch_device_fn.device(inp.device): 

375 mean_dim_kernel[grid](inp, out, M, N) 

376 if not keepdim: 

377 out = out.squeeze(dim=dim) 

378 return out 

379 

380 

381def mean_dim(inp, dim=None, keepdim=False, *, dtype=None): 

382 logger.debug("GEMS MEAN_DIM (wrapper)") 

383 

384 return mean_dim_comm(inp, dim, keepdim, dtype=dtype)