Coverage for src/flag_gems/runtime/backend/_sunrise/ops/mean.py: 0%

192 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2import math 

3from functools import reduce 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems import runtime 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils import dim_compress, libentry, libtuner 

12from flag_gems.utils import triton_lang_extension as ext 

13 

14logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

15 

16 

17@libentry() 

18@triton.jit 

19def mean_kernel_1( 

20 inp, 

21 mid, 

22 M, 

23 BLOCK_SIZE: tl.constexpr, 

24): 

25 # accumulation dtype 

26 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr( 

27 inp.dtype.element_ty == tl.bfloat16 

28 ): 

29 cdtype = tl.float32 

30 else: 

31 cdtype = inp.dtype.element_ty 

32 

33 pid = ext.program_id(0) 

34 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

35 inp_ptrs = inp + offset 

36 mask = offset < M 

37 

38 inp_val = tl.load(inp_ptrs, mask=mask, other=0).to(cdtype) 

39 sum_val = tl.sum(inp_val) 

40 mid_ptr = mid + pid 

41 tl.store(mid_ptr, sum_val) 

42 

43 

44@libentry() 

45@triton.jit 

46def mean_kernel_2(mid, out, M, MID_SIZE, BLOCK_MID: tl.constexpr): 

47 if tl.constexpr(mid.dtype.element_ty == tl.float16) or tl.constexpr( 

48 mid.dtype.element_ty == tl.bfloat16 

49 ): 

50 cdtype = tl.float32 

51 else: 

52 cdtype = mid.dtype.element_ty 

53 

54 offset = tl.arange(0, BLOCK_MID) 

55 mid_ptrs = mid + offset 

56 mask = offset < MID_SIZE 

57 mid_val = tl.load(mid_ptrs, mask=mask, other=0).to(cdtype) 

58 sum_val = tl.sum(mid_val) 

59 # divide by total element count M to get mean 

60 mean_val = sum_val / M 

61 tl.store(out, mean_val) 

62 

63 

64def mean(inp, *, dtype=None): 

65 logger.debug("GEMS MEAN") 

66 M = inp.numel() 

67 if dtype is None: 

68 dtype = inp.dtype 

69 block_size = triton.next_power_of_2(math.ceil(math.sqrt(M))) 

70 mid_size = triton.cdiv(M, block_size) 

71 block_mid = triton.next_power_of_2(mid_size) 

72 

73 mid = torch.empty((mid_size,), dtype=dtype, device=inp.device) 

74 out = torch.empty([], dtype=dtype, device=inp.device) 

75 

76 with torch_device_fn.device(inp.device): 

77 mean_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size) 

78 mean_kernel_2[(1, 1, 1)](mid, out, M, mid_size, block_mid) 

79 return out 

80 

81 

82@libentry() 

83@triton.heuristics(runtime.get_heuristic_config("mean_non_inner")) 

84@triton.jit 

85def mean_dim_kernel_non_inner( 

86 output_ptr, 

87 input_ptr, 

88 M, 

89 N, 

90 K, 

91 TILE_N: tl.constexpr, 

92 TILE_K: tl.constexpr, 

93 ONE_TILE_PER_CTA: tl.constexpr, 

94): 

95 # accumulation dtype 

96 if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr( 

97 input_ptr.dtype.element_ty == tl.bfloat16 

98 ): 

99 cdtype = tl.float32 

100 else: 

101 cdtype = input_ptr.dtype.element_ty 

102 

103 pid_m = ext.program_id(0) 

104 pid_k = ext.program_id(1) 

105 

106 k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K)[None, :] 

107 

108 if ONE_TILE_PER_CTA: 

109 n_offsets = tl.arange(0, TILE_N)[:, None] 

110 inp_offset = pid_m * N * K + n_offsets * K + k_offsets 

111 mask = (n_offsets < N) & (k_offsets < K) 

112 input_ptrs = input_ptr + inp_offset 

113 inp = tl.load(input_ptrs, mask=mask, other=0).to(cdtype) 

114 # sum along reduction axis (N) -> keep dims so axis 0 corresponds to TILE_K 

115 summed = tl.sum(inp, axis=0, keep_dims=True) 

116 # divide by N to get mean 

117 out = summed / N 

118 out_offset = pid_m * K + k_offsets 

119 output_ptrs = output_ptr + out_offset 

120 tl.store(output_ptrs, out, mask=k_offsets < K) 

121 else: 

122 sum_tile = tl.zeros([TILE_N, TILE_K], dtype=cdtype) 

123 for start_n in range(0, N, TILE_N): 

124 n_offsets = start_n + tl.arange(0, TILE_N)[:, None] 

125 inp_offsets = pid_m * N * K + n_offsets * K + k_offsets 

126 mask = (n_offsets < N) & (k_offsets < K) 

127 inp = tl.load(input_ptr + inp_offsets, mask=mask, other=0).to(cdtype) 

128 sum_tile += inp 

129 summed = tl.sum(sum_tile, axis=0, keep_dims=True) 

130 out = summed / N 

131 out_offset = pid_m * K + k_offsets 

132 output_ptrs = output_ptr + out_offset 

133 tl.store(output_ptrs, out, mask=k_offsets < K) 

134 

135 

136@libentry() 

137@triton.heuristics(runtime.get_heuristic_config("softmax_inner")) 

138@triton.jit 

139def mean_dim_kernel_inner( 

140 output_ptr, 

141 input_ptr, 

142 M, 

143 N, 

144 TILE_N: tl.constexpr, 

145 ONE_TILE_PER_CTA: tl.constexpr, 

146): 

147 if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr( 

148 input_ptr.dtype.element_ty == tl.bfloat16 

149 ): 

150 cdtype = tl.float32 

151 else: 

152 cdtype = input_ptr.dtype.element_ty 

153 

154 pid_m = ext.program_id(0) 

155 if ONE_TILE_PER_CTA: 

156 n_offsets = tl.arange(0, TILE_N) 

157 inp_offset = pid_m * N + n_offsets 

158 input_ptrs = input_ptr + inp_offset 

159 mask = n_offsets < N 

160 inp = tl.load(input_ptrs, mask=mask, other=0).to(cdtype) 

161 summed = tl.sum(inp, axis=0) 

162 out = summed / N 

163 out_offset = pid_m 

164 output_ptrs = output_ptr + out_offset 

165 tl.store(output_ptrs, out) 

166 else: 

167 sum_vec = tl.zeros( 

168 [ 

169 TILE_N, 

170 ], 

171 dtype=cdtype, 

172 ) 

173 for start_n in range(0, N, TILE_N): 

174 n_offsets = start_n + tl.arange(0, TILE_N) 

175 inp_offsets = pid_m * N + n_offsets 

176 mask = n_offsets < N 

177 inp = tl.load(input_ptr + inp_offsets, mask=mask, other=0).to(cdtype) 

178 sum_vec += inp 

179 summed = tl.sum(sum_vec, axis=0) 

180 out = summed / N 

181 out_offset = pid_m 

182 output_ptrs = output_ptr + out_offset 

183 tl.store(output_ptrs, out) 

184 

185 

186@libentry() 

187@libtuner( 

188 configs=runtime.get_tuned_config("naive_reduction"), 

189 key=["M", "N"], 

190) 

191@triton.jit 

192def mean_dim_kernel( 

193 inp, 

194 out, 

195 M, 

196 N, 

197 BLOCK_M: tl.constexpr, 

198 BLOCK_N: tl.constexpr, 

199): 

200 if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr( 

201 inp.dtype.element_ty == tl.bfloat16 

202 ): 

203 cdtype = tl.float32 

204 else: 

205 cdtype = inp.dtype.element_ty 

206 

207 # Map the program id to the row of inp it should compute. 

208 pid = ext.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

209 inp = inp + pid * N 

210 out = out + pid 

211 row_mask = pid < M 

212 

213 _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=cdtype) 

214 for off in range(0, N, BLOCK_N): 

215 cols = off + tl.arange(0, BLOCK_N)[None, :] 

216 col_mask = cols < N 

217 mask = row_mask and col_mask 

218 

219 a = tl.load(inp + cols, mask, other=0).to(cdtype) 

220 _sum += a 

221 summed = tl.sum(_sum, axis=1)[:, None] 

222 mean = summed / N 

223 tl.store(out, mean, row_mask) 

224 

225 

226def mean_dim_comm(inp, dim=None, keepdim=False, *, dtype=None, out=None): 

227 logger.debug("GEMS MEAN_DIM") 

228 if dtype is None: 

229 dtype = inp.dtype 

230 if dtype is torch.bool: 

231 inp = inp.to(torch.int64) 

232 dtype = torch.int64 

233 

234 if dim == []: 

235 # mean over all elements 

236 if not keepdim: 

237 return mean(inp, dtype=dtype) 

238 else: 

239 dim_num = inp.ndim 

240 return torch.reshape(mean(inp, dtype=dtype), [1] * dim_num) 

241 

242 shape = list(inp.shape) 

243 

244 # -------- normalize dim to a list of ints -------- 

245 if isinstance(dim, int): 

246 dim = [dim] 

247 else: 

248 try: 

249 dim = list(dim) 

250 except TypeError: 

251 raise TypeError( 

252 f"dim must be an int, iterable of ints, or [], got {type(dim)}" 

253 ) 

254 

255 dim = [d % inp.ndim for d in dim] 

256 # ------------------------------------------------- 

257 

258 if len(dim) == 1: 

259 dim0 = dim[0] 

260 N = inp.shape[dim0] # reduction length 

261 # product of dims before dim0; use initializer 1 for empty slice 

262 M = reduce(lambda x, y: x * y, shape[:dim0], 1) 

263 inp = inp.contiguous() 

264 K = inp.numel() // M // N 

265 shape[dim0] = 1 

266 if out is None: 

267 out = torch.empty(shape, dtype=dtype, device=inp.device) 

268 

269 with torch_device_fn.device(inp.device): 

270 if K > 1: 

271 grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1) 

272 mean_dim_kernel_non_inner[grid]( 

273 out, 

274 inp, 

275 M, 

276 N, 

277 K, 

278 ) 

279 else: 

280 grid = (M, 1, 1) 

281 mean_dim_kernel_inner[grid]( 

282 out, 

283 inp, 

284 M, 

285 N, 

286 ) 

287 if not keepdim: 

288 out = out.squeeze(dim=dim0) 

289 return out 

290 else: 

291 inp = dim_compress(inp, dim) 

292 N = 1 

293 for i in dim: 

294 N *= shape[i] 

295 shape[i] = 1 

296 M = inp.numel() // N 

297 if out is None: 

298 out = torch.empty(shape, dtype=dtype, device=inp.device) 

299 

300 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

301 with torch_device_fn.device(inp.device): 

302 mean_dim_kernel[grid](inp, out, M, N) 

303 if not keepdim: 

304 out = out.squeeze(dim=dim) 

305 return out 

306 

307 

308def mean_dim(inp, dim=None, keepdim=False, *, dtype=None): 

309 logger.debug("GEMS MEAN_DIM (wrapper)") 

310 

311 return mean_dim_comm(inp, dim, keepdim, dtype=dtype)