Coverage for src/flag_gems/runtime/backend/_spacemit/ops/where.py: 0%
47 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils.pointwise_dynamic import pointwise_dynamic
9logger = logging.getLogger(__name__)
12@pointwise_dynamic(
13 is_tensor=[True, True, True],
14 promotion_methods=[(1, 2, "NO_OPMATH")],
15)
16@triton.jit
17def where_inner(condition, self, other):
18 return tl.where(condition, self, other)
21def where_self_out(condition, self, other, out=None):
22 logger.debug("GEMS_SPACEMIT WHERE_SELF_OUT")
23 result_type = torch.result_type(self, other)
24 if out is not None:
25 assert (
26 out.dtype == result_type
27 ), f"Expected out type to be {result_type}, but got {out.dtype}."
29 c, a, b = list(
30 map(
31 lambda x: x if isinstance(x, torch.Tensor) else torch.tensor(x),
32 (condition, self, other),
33 )
34 )
36 if a.dtype != result_type:
37 a = a.to(result_type)
38 if b.dtype != result_type:
39 b = b.to(result_type)
41 devices = [x.device for x in (c, a, b)]
43 assert all(device.type == "cpu" for device in devices), (
44 "CPU only. Expected all tensors to be on CPU, " f"but found devices {devices}"
45 )
47 device = devices[0]
48 if c.device != device and c.ndim == 0:
49 c = c.to(device)
50 if a.device != device and a.ndim == 0:
51 a = a.to(device)
52 if b.device != device and b.ndim == 0:
53 b = b.to(device)
55 assert (
56 len(set(devices)) == 1
57 ), f"Expected all tensors to be on the same device, but found at least two devices, {devices}"
58 assert (
59 c.dtype == torch.bool
60 ), f"where expected condition to be a boolean tensor, but got a tensor with dtype {condition.dtype}"
62 if out is None:
63 out_shape = torch.broadcast_shapes(c.shape, a.shape, b.shape)
64 out = torch.empty(out_shape, dtype=result_type, device=device)
66 ndim = max(c.ndim, a.ndim, b.ndim)
67 where_inner.instantiate(ndim)
68 where_inner(c, a, b, out0=out)
69 return out
72def where_self(condition, self, other):
73 logger.debug("GEMS_SPACEMIT WHERE_SELF")
74 return where_self_out(condition, self, other)
77def where_scalar_self(condition, self, other):
78 logger.debug("GEMS_SPACEMIT WHERE_SCALAR_SELF")
79 return where_self_out(condition, self, other)
82def where_scalar_other(condition, self, other):
83 logger.debug("GEMS_SPACEMIT WHERE_SCALAR_OTHER")
84 return where_self_out(condition, self, other)