Coverage for src/flag_gems/runtime/backend/_sunrise/ops/where.py: 0%
51 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import pointwise_dynamic
8from flag_gems.utils.pointwise_dynamic import CodeGenConfig
10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
12MAX_GRID_SIZES = (65535, 65535, 65535)
13config = CodeGenConfig(
14 max_tile_size=256,
15 max_grid_size=MAX_GRID_SIZES,
16 max_num_warps_per_cta=16,
17 prefer_block_pointer=True,
18 prefer_1d_tile=True,
19)
22@pointwise_dynamic(
23 is_tensor=[True, True, True], promotion_methods=[(1, 2, "NO_OPMATH")], config=config
24)
25@triton.jit
26def where_inner(condition, self, other):
27 return tl.where(condition, self, other)
30def where_self_out(condition, self, other, out=None):
31 logger.debug("GEMS WHERE_SELF_OUT")
32 result_type = torch.result_type(self, other)
33 if out is not None:
34 assert (
35 out.dtype == result_type
36 ), f"Expected out type to be {result_type}, but got {out.dtype}."
38 c, a, b = list(
39 map(
40 lambda x: x if isinstance(x, torch.Tensor) else torch.tensor(x),
41 (condition, self, other),
42 )
43 )
45 if a.dtype != result_type:
46 a = a.to(result_type)
47 if b.dtype != result_type:
48 b = b.to(result_type)
50 devices = map(lambda x: x.device, (c, a, b))
51 devices = list(filter(lambda k: k.type != "cpu", devices))
53 assert len(devices), "CPU only. There seems a mistake to dispatch to here."
55 device = devices[0]
56 if c.device != device and c.ndim == 0:
57 c = c.to(device)
58 if a.device != device and a.ndim == 0:
59 a = a.to(device)
60 if b.device != device and b.ndim == 0:
61 b = b.to(device)
63 assert (
64 len(set(devices)) == 1
65 ), f"Expected all tensors to be on the same device, but found at least two devices, {devices}"
66 assert (
67 c.dtype == torch.bool
68 ), f"where expected condition to be a boolean tensor, but got a tensor with dtype {condition.dtype}"
70 if out is None:
71 out_shape = torch.broadcast_shapes(c.shape, a.shape, b.shape)
72 out = torch.empty(out_shape, dtype=result_type, device=device)
74 ndim = max(c.ndim, a.ndim, b.ndim)
75 where_inner.instantiate(ndim)
76 where_inner(c, a, b, out0=out)
77 return out
80def where_self(condition, self, other):
81 logger.debug("GEMS WHERE_SELF")
82 return where_self_out(condition, self, other)
85def where_scalar_self(condition, self, other):
86 logger.debug("GEMS WHERE_SCALAR_SELF")
87 return where_self_out(condition, self, other)
90def where_scalar_other(condition, self, other):
91 logger.debug("GEMS WHERE_SCALAR_OTHER")
92 return where_self_out(condition, self, other)