Coverage for src/flag_gems/runtime/backend/_arm/ops/where.py: 0%

126 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import pointwise_dynamic 

8 

9 

10@pointwise_dynamic( 

11 is_tensor=[True, True, True], 

12 promotion_methods=[(1, 2, "NO_OPMATH")], 

13) 

14@triton.jit 

15def where_inner(condition, self, other): 

16 return tl.where(condition, self, other) 

17 

18 

19@triton.jit(do_not_specialize=["scalar", "n_elements"]) 

20def _where_scalar_self_kernel( 

21 condition_ptr, 

22 other_ptr, 

23 out_ptr, 

24 scalar, 

25 n_elements, 

26 BLOCK_SIZE: tl.constexpr, 

27): 

28 pid = tl.program_id(0) 

29 offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

30 mask = offs < n_elements 

31 cond = tl.load(condition_ptr + offs, mask=mask, other=0).to(tl.int1) 

32 other = tl.load(other_ptr + offs, mask=mask, other=0.0) 

33 out = tl.where(cond, scalar, other) 

34 tl.store(out_ptr + offs, out, mask=mask) 

35 

36 

37@triton.jit(do_not_specialize=["scalar", "n_elements"]) 

38def _where_scalar_other_kernel( 

39 condition_ptr, 

40 self_ptr, 

41 out_ptr, 

42 scalar, 

43 n_elements, 

44 BLOCK_SIZE: tl.constexpr, 

45): 

46 pid = tl.program_id(0) 

47 offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

48 mask = offs < n_elements 

49 cond = tl.load(condition_ptr + offs, mask=mask, other=0).to(tl.int1) 

50 self_tensor = tl.load(self_ptr + offs, mask=mask, other=0.0) 

51 out = tl.where(cond, self_tensor, scalar) 

52 tl.store(out_ptr + offs, out, mask=mask) 

53 

54 

55@triton.jit(do_not_specialize=["scalar", "n_elements"]) 

56def _where_scalar_self_single_program_kernel( 

57 condition_ptr, 

58 other_ptr, 

59 out_ptr, 

60 scalar, 

61 n_elements, 

62 BLOCK_SIZE: tl.constexpr, 

63): 

64 offs = tl.arange(0, BLOCK_SIZE) 

65 for base in range(0, n_elements, BLOCK_SIZE): 

66 idx = base + offs 

67 mask = idx < n_elements 

68 cond = tl.load(condition_ptr + idx, mask=mask, other=0).to(tl.int1) 

69 other = tl.load(other_ptr + idx, mask=mask, other=0.0) 

70 out = tl.where(cond, scalar, other) 

71 tl.store(out_ptr + idx, out, mask=mask) 

72 

73 

74@triton.jit(do_not_specialize=["scalar", "n_elements"]) 

75def _where_scalar_other_single_program_kernel( 

76 condition_ptr, 

77 self_ptr, 

78 out_ptr, 

79 scalar, 

80 n_elements, 

81 BLOCK_SIZE: tl.constexpr, 

82): 

83 offs = tl.arange(0, BLOCK_SIZE) 

84 for base in range(0, n_elements, BLOCK_SIZE): 

85 idx = base + offs 

86 mask = idx < n_elements 

87 cond = tl.load(condition_ptr + idx, mask=mask, other=0).to(tl.int1) 

88 self_tensor = tl.load(self_ptr + idx, mask=mask, other=0.0) 

89 out = tl.where(cond, self_tensor, scalar) 

90 tl.store(out_ptr + idx, out, mask=mask) 

91 

92 

93def _as_scalar(v): 

94 if isinstance(v, torch.Tensor): 

95 if v.numel() != 1: 

96 return None 

97 return v.item() 

98 if isinstance(v, (int, float, bool)): 

99 return v 

100 return None 

101 

102 

103def _where_scalar_tensor_fastpath(condition, self, other, out): 

104 if not isinstance(condition, torch.Tensor) or condition.dtype is not torch.bool: 

105 return False 

106 if condition.device.type != "cpu": 

107 return False 

108 if not condition.is_contiguous() or not out.is_contiguous(): 

109 return False 

110 

111 self_scalar = _as_scalar(self) 

112 other_scalar = _as_scalar(other) 

113 self_tensor = self if isinstance(self, torch.Tensor) else None 

114 other_tensor = other if isinstance(other, torch.Tensor) else None 

115 

116 # Only specialize one-scalar + one-tensor, contiguous, same flattened size. 

117 if ( 

118 self_scalar is not None 

119 and other_tensor is not None 

120 and other_tensor.is_contiguous() 

121 ): 

122 if other_tensor.numel() != condition.numel(): 

123 return False 

124 if other_tensor.dtype != out.dtype: 

125 return False 

126 cond_flat = condition.view(-1) 

127 other_flat = other_tensor.view(-1) 

128 out_flat = out.view(-1) 

129 n = cond_flat.numel() 

130 if n <= 262144: 

131 _where_scalar_self_single_program_kernel[(1,)]( 

132 cond_flat, 

133 other_flat, 

134 out_flat, 

135 float(self_scalar), 

136 n, 

137 BLOCK_SIZE=256, 

138 num_warps=1, 

139 num_stages=1, 

140 ) 

141 else: 

142 grid = (triton.cdiv(n, 256),) 

143 _where_scalar_self_kernel[grid]( 

144 cond_flat, 

145 other_flat, 

146 out_flat, 

147 float(self_scalar), 

148 n, 

149 BLOCK_SIZE=256, 

150 num_warps=1, 

151 num_stages=1, 

152 ) 

153 return True 

154 

155 if ( 

156 other_scalar is not None 

157 and self_tensor is not None 

158 and self_tensor.is_contiguous() 

159 ): 

160 if self_tensor.numel() != condition.numel(): 

161 return False 

162 if self_tensor.dtype != out.dtype: 

163 return False 

164 cond_flat = condition.view(-1) 

165 self_flat = self_tensor.view(-1) 

166 out_flat = out.view(-1) 

167 n = cond_flat.numel() 

168 if n <= 262144: 

169 _where_scalar_other_single_program_kernel[(1,)]( 

170 cond_flat, 

171 self_flat, 

172 out_flat, 

173 float(other_scalar), 

174 n, 

175 BLOCK_SIZE=256, 

176 num_warps=1, 

177 num_stages=1, 

178 ) 

179 else: 

180 grid = (triton.cdiv(n, 256),) 

181 _where_scalar_other_kernel[grid]( 

182 cond_flat, 

183 self_flat, 

184 out_flat, 

185 float(other_scalar), 

186 n, 

187 BLOCK_SIZE=256, 

188 num_warps=1, 

189 num_stages=1, 

190 ) 

191 return True 

192 

193 return False 

194 

195 

196def where_self_out(condition, self, other, out=None): 

197 logging.debug("GEMS WHERE_SELF_OUT") 

198 result_type = torch.result_type(self, other) 

199 if out is not None: 

200 assert ( 

201 out.dtype == result_type 

202 ), f"Expected out type to be {result_type}, but got {out.dtype}." 

203 

204 c, a, b = list( 

205 map( 

206 lambda x: x if isinstance(x, torch.Tensor) else torch.tensor(x), 

207 (condition, self, other), 

208 ) 

209 ) 

210 

211 if a.dtype != result_type: 

212 a = a.to(result_type) 

213 if b.dtype != result_type: 

214 b = b.to(result_type) 

215 

216 devices = map(lambda x: x.device, (c, a, b)) 

217 devices = list(filter(lambda k: k.type != "cpu", devices)) 

218 

219 # assert len(devices), "CPU only. There seems a mistake to dispatch to here." 

220 

221 # device = devices[0] 

222 # if c.device != device and c.ndim == 0: 

223 # c = c.to(device) 

224 # if a.device != device and a.ndim == 0: 

225 # a = a.to(device) 

226 # if b.device != device and b.ndim == 0: 

227 # b = b.to(device) 

228 

229 # assert ( 

230 # len(set(devices)) == 1 

231 # ), f"Expected all tensors to be on the same device, but found at least two devices, {devices}" 

232 assert ( 

233 c.dtype == torch.bool 

234 ), f"where expected condition to be a boolean tensor, but got a tensor with dtype {condition.dtype}" 

235 

236 if out is None: 

237 out_shape = torch.broadcast_shapes(c.shape, a.shape, b.shape) 

238 out = torch.empty(out_shape, dtype=result_type, device=c.device) 

239 

240 if _where_scalar_tensor_fastpath(c, a, b, out): 

241 return out 

242 

243 ndim = max(c.ndim, a.ndim, b.ndim) 

244 where_inner.instantiate(ndim) 

245 where_inner(c, a, b, out0=out) 

246 return out 

247 

248 

249def where_self(condition, self, other): 

250 logging.debug("GEMS WHERE_SELF") 

251 return where_self_out(condition, self, other) 

252 

253 

254def where_scalar_self(condition, self, other): 

255 logging.debug("GEMS WHERE_SCALAR_SELF") 

256 return where_self_out(condition, self, other) 

257 

258 

259def where_scalar_other(condition, self, other): 

260 logging.debug("GEMS WHERE_SCALAR_OTHER") 

261 return where_self_out(condition, self, other)