Coverage for src/flag_gems/runtime/backend/_mthreads/ops/all.py: 0%

141 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2import math 

3from typing import Sequence 

4 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems import runtime 

10from flag_gems.runtime import torch_device_fn 

11from flag_gems.utils import dim_compress, libentry, libtuner 

12from flag_gems.utils import triton_lang_extension as ext 

13 

14shortname = __name__.split(".")[-1] 

15logger = logging.getLogger(f"flag_gems.runtime.backend._mthreads.ops.{shortname}") 

16 

17NAIVE_REDUCTION_CONFIGS = [ 

18 triton.Config({"BLOCK_M": 4, "BLOCK_N": 1024}, num_warps=4), 

19 triton.Config({"BLOCK_M": 8, "BLOCK_N": 1024}, num_warps=4), 

20 triton.Config({"BLOCK_M": 16, "BLOCK_N": 1024}, num_warps=8), 

21 triton.Config({"BLOCK_M": 32, "BLOCK_N": 512}, num_warps=8), 

22 triton.Config({"BLOCK_M": 32, "BLOCK_N": 256}, num_warps=4), 

23 triton.Config({"BLOCK_M": 32, "BLOCK_N": 128}, num_warps=4), 

24 triton.Config({"BLOCK_M": 32, "BLOCK_N": 64}, num_warps=4), 

25 triton.Config({"BLOCK_M": 16, "BLOCK_N": 256}, num_warps=4), 

26 triton.Config({"BLOCK_M": 16, "BLOCK_N": 128}, num_warps=4), 

27 triton.Config({"BLOCK_M": 64, "BLOCK_N": 256}, num_warps=8), 

28 triton.Config({"BLOCK_M": 64, "BLOCK_N": 128}, num_warps=4), 

29 triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4), 

30 triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=8), 

31 triton.Config({"BLOCK_M": 64, "BLOCK_N": 32}, num_warps=4), 

32 triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=8), 

33 triton.Config({"BLOCK_M": 256, "BLOCK_N": 256}, num_warps=8), 

34] 

35 

36 

37@triton.jit 

38def reduce_all(a, b): 

39 return a and b 

40 

41 

42@triton.autotune(configs=NAIVE_REDUCTION_CONFIGS, key=["M", "N"]) 

43@triton.jit 

44def all_kernel_dim_strided( 

45 inp, 

46 out, 

47 M, 

48 N, 

49 INNER, 

50 STRIDE_OUTER, 

51 STRIDE_REDUCE, 

52 BLOCK_M: tl.constexpr, 

53 BLOCK_N: tl.constexpr, 

54): 

55 pid = tl.program_id(0) 

56 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M) 

57 rows = rows.to(tl.int64) 

58 row_mask = rows < M 

59 

60 outer_idx = rows // INNER 

61 inner_idx = rows % INNER 

62 base_ptr = inp + outer_idx * STRIDE_OUTER + inner_idx 

63 

64 acc = tl.full([BLOCK_M, BLOCK_N], value=1, dtype=tl.int1) 

65 for off in range(0, N, BLOCK_N): 

66 cols = off + tl.arange(0, BLOCK_N) 

67 cols = cols.to(tl.int64) 

68 col_mask = cols < N 

69 mask = row_mask[:, None] and col_mask[None, :] 

70 vals = tl.load( 

71 base_ptr[:, None] + cols[None, :] * STRIDE_REDUCE, mask, other=1.0 

72 ) 

73 acc = acc and (vals != 0) 

74 all_val = tl.reduce(acc, axis=1, combine_fn=reduce_all) 

75 tl.store(out + rows, all_val, mask=row_mask) 

76 

77 

78def _flatten_dim(shape: Sequence[int], dim: int): 

79 dim = dim % len(shape) 

80 n = shape[dim] 

81 inner = math.prod(shape[dim + 1 :]) if dim + 1 < len(shape) else 1 

82 outer = math.prod(shape[:dim]) if dim > 0 else 1 

83 return dim, n, inner, outer 

84 

85 

86def triton_all_dim_strided( 

87 inp: torch.Tensor, dim: int, keepdim: bool = False 

88) -> torch.Tensor: 

89 dim = dim % inp.ndim 

90 shape = list(inp.shape) 

91 dim, n, inner, outer = _flatten_dim(shape, dim) 

92 m = outer * inner 

93 

94 stride = inp.stride() 

95 stride_reduce = stride[dim] 

96 stride_outer = stride_reduce * n 

97 

98 out_flat = torch.empty((m,), dtype=torch.bool, device=inp.device) 

99 grid = lambda meta: (triton.cdiv(m, meta["BLOCK_M"]),) 

100 all_kernel_dim_strided[grid]( 

101 inp, 

102 out_flat, 

103 m, 

104 n, 

105 inner, 

106 stride_outer, 

107 stride_reduce, 

108 ) 

109 

110 shape[dim] = 1 

111 out = out_flat.view(shape) 

112 if not keepdim: 

113 out = out.squeeze(dim=dim) 

114 return out 

115 

116 

117@libentry() 

118@libtuner( 

119 configs=runtime.get_tuned_config("naive_reduction"), 

120 key=["M", "N"], 

121) 

122@triton.jit 

123def all_kernel_dim( 

124 inp, 

125 out, 

126 M, 

127 N, 

128 BLOCK_M: tl.constexpr, 

129 BLOCK_N: tl.constexpr, 

130): 

131 pid = ext.program_id(0) 

132 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

133 inp = inp + rows * N 

134 out = out + rows 

135 row_mask = rows < M 

136 

137 _all = tl.full([BLOCK_M, BLOCK_N], value=1, dtype=tl.int1) 

138 for off in range(0, N, BLOCK_N): 

139 cols = off + tl.arange(0, BLOCK_N)[None, :] 

140 col_mask = cols < N 

141 mask = row_mask and col_mask 

142 

143 a = tl.load(inp + cols, mask, other=1.0) 

144 _all = _all and (a != 0) 

145 all = tl.reduce(_all, axis=1, combine_fn=reduce_all) 

146 tl.store(out, all[:, None], row_mask) 

147 

148 

149@libentry() 

150@triton.jit 

151def all_kernel_1( 

152 inp, 

153 mid, 

154 n_elements, 

155 mid_size, 

156 BLOCK_SIZE: tl.constexpr, 

157): 

158 pid = ext.program_id(0) 

159 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

160 inp_ptrs = inp + offset 

161 mask = offset < n_elements 

162 inp_val = tl.load(inp_ptrs, mask=mask, other=1.0) 

163 all_val = tl.reduce(inp_val != 0, axis=0, combine_fn=reduce_all) 

164 mid_ptr = mid + pid 

165 tl.store(mid_ptr, all_val) 

166 

167 

168@libentry() 

169@triton.jit 

170def all_kernel_2(mid, out, MID_SIZE, BLOCK_MID: tl.constexpr): 

171 offset = tl.arange(0, BLOCK_MID) 

172 mid_ptrs = mid + offset 

173 mask = offset < MID_SIZE 

174 mid_val = tl.load(mid_ptrs, mask=mask, other=1).to(tl.int1) 

175 all_val = tl.reduce(mid_val, axis=0, combine_fn=reduce_all) 

176 tl.store(out, all_val) 

177 

178 

179def all(inp): 

180 logger.debug("GEMS_MTHREADS ALL") 

181 n_elements = inp.numel() 

182 block_size = triton.next_power_of_2(math.ceil(math.sqrt(n_elements))) 

183 block_size = min(block_size * 2, 4096, triton.next_power_of_2(n_elements)) 

184 mid_size = triton.cdiv(n_elements, block_size) 

185 block_mid = triton.next_power_of_2(mid_size) 

186 

187 mid = torch.empty((mid_size,), dtype=torch.bool, device=inp.device) 

188 out = torch.empty([], dtype=torch.bool, device=inp.device) 

189 

190 with torch_device_fn.device(inp.device): 

191 all_kernel_1[(mid_size, 1)](inp, mid, n_elements, mid_size, block_size) 

192 all_kernel_2[(1, 1)](mid, out, mid_size, block_mid) 

193 

194 return out 

195 

196 

197def all_dim(inp, dim=None, keepdim=False): 

198 logger.debug("GEMS_MTHREADS ALL DIM") 

199 if dim is None: 

200 out = all(inp) 

201 if keepdim: 

202 out = torch.reshape(out, [1] * inp.ndim) 

203 return out 

204 

205 assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" 

206 dim = dim % inp.ndim 

207 

208 with torch_device_fn.device(inp.device): 

209 return triton_all_dim_strided(inp, dim=dim, keepdim=keepdim) 

210 

211 

212def all_dims(inp, dim=None, keepdim=False): 

213 logger.debug("GEMS_MTHREADS ALL DIMS") 

214 

215 if dim is None or isinstance(dim, int): 

216 return all_dim(inp, dim=dim, keepdim=keepdim) 

217 assert ((i >= -inp.ndim and i < inp.ndim) for i in dim), "Invalid dim" 

218 

219 shape = list(inp.shape) 

220 dim = [d % inp.ndim for d in dim] 

221 inp = dim_compress(inp, dim) 

222 N = 1 

223 for i in dim: 

224 N *= shape[i] 

225 shape[i] = 1 

226 M = inp.numel() // N 

227 

228 out = torch.empty(shape, dtype=torch.bool, device=inp.device) 

229 

230 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

231 with torch_device_fn.device(inp.device): 

232 all_kernel_dim[grid](inp, out, M, N) 

233 if not keepdim: 

234 out = out.squeeze(dim=dim) 

235 return out 

236 

237 

238__all__ = ["all", "all_dim", "all_dims"]