Coverage for src/flag_gems/runtime/backend/_arm/ops/where.py: 0%
126 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import pointwise_dynamic
10@pointwise_dynamic(
11 is_tensor=[True, True, True],
12 promotion_methods=[(1, 2, "NO_OPMATH")],
13)
14@triton.jit
15def where_inner(condition, self, other):
16 return tl.where(condition, self, other)
19@triton.jit(do_not_specialize=["scalar", "n_elements"])
20def _where_scalar_self_kernel(
21 condition_ptr,
22 other_ptr,
23 out_ptr,
24 scalar,
25 n_elements,
26 BLOCK_SIZE: tl.constexpr,
27):
28 pid = tl.program_id(0)
29 offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
30 mask = offs < n_elements
31 cond = tl.load(condition_ptr + offs, mask=mask, other=0).to(tl.int1)
32 other = tl.load(other_ptr + offs, mask=mask, other=0.0)
33 out = tl.where(cond, scalar, other)
34 tl.store(out_ptr + offs, out, mask=mask)
37@triton.jit(do_not_specialize=["scalar", "n_elements"])
38def _where_scalar_other_kernel(
39 condition_ptr,
40 self_ptr,
41 out_ptr,
42 scalar,
43 n_elements,
44 BLOCK_SIZE: tl.constexpr,
45):
46 pid = tl.program_id(0)
47 offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
48 mask = offs < n_elements
49 cond = tl.load(condition_ptr + offs, mask=mask, other=0).to(tl.int1)
50 self_tensor = tl.load(self_ptr + offs, mask=mask, other=0.0)
51 out = tl.where(cond, self_tensor, scalar)
52 tl.store(out_ptr + offs, out, mask=mask)
55@triton.jit(do_not_specialize=["scalar", "n_elements"])
56def _where_scalar_self_single_program_kernel(
57 condition_ptr,
58 other_ptr,
59 out_ptr,
60 scalar,
61 n_elements,
62 BLOCK_SIZE: tl.constexpr,
63):
64 offs = tl.arange(0, BLOCK_SIZE)
65 for base in range(0, n_elements, BLOCK_SIZE):
66 idx = base + offs
67 mask = idx < n_elements
68 cond = tl.load(condition_ptr + idx, mask=mask, other=0).to(tl.int1)
69 other = tl.load(other_ptr + idx, mask=mask, other=0.0)
70 out = tl.where(cond, scalar, other)
71 tl.store(out_ptr + idx, out, mask=mask)
74@triton.jit(do_not_specialize=["scalar", "n_elements"])
75def _where_scalar_other_single_program_kernel(
76 condition_ptr,
77 self_ptr,
78 out_ptr,
79 scalar,
80 n_elements,
81 BLOCK_SIZE: tl.constexpr,
82):
83 offs = tl.arange(0, BLOCK_SIZE)
84 for base in range(0, n_elements, BLOCK_SIZE):
85 idx = base + offs
86 mask = idx < n_elements
87 cond = tl.load(condition_ptr + idx, mask=mask, other=0).to(tl.int1)
88 self_tensor = tl.load(self_ptr + idx, mask=mask, other=0.0)
89 out = tl.where(cond, self_tensor, scalar)
90 tl.store(out_ptr + idx, out, mask=mask)
93def _as_scalar(v):
94 if isinstance(v, torch.Tensor):
95 if v.numel() != 1:
96 return None
97 return v.item()
98 if isinstance(v, (int, float, bool)):
99 return v
100 return None
103def _where_scalar_tensor_fastpath(condition, self, other, out):
104 if not isinstance(condition, torch.Tensor) or condition.dtype is not torch.bool:
105 return False
106 if condition.device.type != "cpu":
107 return False
108 if not condition.is_contiguous() or not out.is_contiguous():
109 return False
111 self_scalar = _as_scalar(self)
112 other_scalar = _as_scalar(other)
113 self_tensor = self if isinstance(self, torch.Tensor) else None
114 other_tensor = other if isinstance(other, torch.Tensor) else None
116 # Only specialize one-scalar + one-tensor, contiguous, same flattened size.
117 if (
118 self_scalar is not None
119 and other_tensor is not None
120 and other_tensor.is_contiguous()
121 ):
122 if other_tensor.numel() != condition.numel():
123 return False
124 if other_tensor.dtype != out.dtype:
125 return False
126 cond_flat = condition.view(-1)
127 other_flat = other_tensor.view(-1)
128 out_flat = out.view(-1)
129 n = cond_flat.numel()
130 if n <= 262144:
131 _where_scalar_self_single_program_kernel[(1,)](
132 cond_flat,
133 other_flat,
134 out_flat,
135 float(self_scalar),
136 n,
137 BLOCK_SIZE=256,
138 num_warps=1,
139 num_stages=1,
140 )
141 else:
142 grid = (triton.cdiv(n, 256),)
143 _where_scalar_self_kernel[grid](
144 cond_flat,
145 other_flat,
146 out_flat,
147 float(self_scalar),
148 n,
149 BLOCK_SIZE=256,
150 num_warps=1,
151 num_stages=1,
152 )
153 return True
155 if (
156 other_scalar is not None
157 and self_tensor is not None
158 and self_tensor.is_contiguous()
159 ):
160 if self_tensor.numel() != condition.numel():
161 return False
162 if self_tensor.dtype != out.dtype:
163 return False
164 cond_flat = condition.view(-1)
165 self_flat = self_tensor.view(-1)
166 out_flat = out.view(-1)
167 n = cond_flat.numel()
168 if n <= 262144:
169 _where_scalar_other_single_program_kernel[(1,)](
170 cond_flat,
171 self_flat,
172 out_flat,
173 float(other_scalar),
174 n,
175 BLOCK_SIZE=256,
176 num_warps=1,
177 num_stages=1,
178 )
179 else:
180 grid = (triton.cdiv(n, 256),)
181 _where_scalar_other_kernel[grid](
182 cond_flat,
183 self_flat,
184 out_flat,
185 float(other_scalar),
186 n,
187 BLOCK_SIZE=256,
188 num_warps=1,
189 num_stages=1,
190 )
191 return True
193 return False
196def where_self_out(condition, self, other, out=None):
197 logging.debug("GEMS WHERE_SELF_OUT")
198 result_type = torch.result_type(self, other)
199 if out is not None:
200 assert (
201 out.dtype == result_type
202 ), f"Expected out type to be {result_type}, but got {out.dtype}."
204 c, a, b = list(
205 map(
206 lambda x: x if isinstance(x, torch.Tensor) else torch.tensor(x),
207 (condition, self, other),
208 )
209 )
211 if a.dtype != result_type:
212 a = a.to(result_type)
213 if b.dtype != result_type:
214 b = b.to(result_type)
216 devices = map(lambda x: x.device, (c, a, b))
217 devices = list(filter(lambda k: k.type != "cpu", devices))
219 # assert len(devices), "CPU only. There seems a mistake to dispatch to here."
221 # device = devices[0]
222 # if c.device != device and c.ndim == 0:
223 # c = c.to(device)
224 # if a.device != device and a.ndim == 0:
225 # a = a.to(device)
226 # if b.device != device and b.ndim == 0:
227 # b = b.to(device)
229 # assert (
230 # len(set(devices)) == 1
231 # ), f"Expected all tensors to be on the same device, but found at least two devices, {devices}"
232 assert (
233 c.dtype == torch.bool
234 ), f"where expected condition to be a boolean tensor, but got a tensor with dtype {condition.dtype}"
236 if out is None:
237 out_shape = torch.broadcast_shapes(c.shape, a.shape, b.shape)
238 out = torch.empty(out_shape, dtype=result_type, device=c.device)
240 if _where_scalar_tensor_fastpath(c, a, b, out):
241 return out
243 ndim = max(c.ndim, a.ndim, b.ndim)
244 where_inner.instantiate(ndim)
245 where_inner(c, a, b, out0=out)
246 return out
249def where_self(condition, self, other):
250 logging.debug("GEMS WHERE_SELF")
251 return where_self_out(condition, self, other)
254def where_scalar_self(condition, self, other):
255 logging.debug("GEMS WHERE_SCALAR_SELF")
256 return where_self_out(condition, self, other)
259def where_scalar_other(condition, self, other):
260 logging.debug("GEMS WHERE_SCALAR_OTHER")
261 return where_self_out(condition, self, other)