Coverage for src/flag_gems/runtime/backend/_spacemit/ops/where.py: 0%

47 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils.pointwise_dynamic import pointwise_dynamic 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12@pointwise_dynamic( 

13 is_tensor=[True, True, True], 

14 promotion_methods=[(1, 2, "NO_OPMATH")], 

15) 

16@triton.jit 

17def where_inner(condition, self, other): 

18 return tl.where(condition, self, other) 

19 

20 

21def where_self_out(condition, self, other, out=None): 

22 logger.debug("GEMS_SPACEMIT WHERE_SELF_OUT") 

23 result_type = torch.result_type(self, other) 

24 if out is not None: 

25 assert ( 

26 out.dtype == result_type 

27 ), f"Expected out type to be {result_type}, but got {out.dtype}." 

28 

29 c, a, b = list( 

30 map( 

31 lambda x: x if isinstance(x, torch.Tensor) else torch.tensor(x), 

32 (condition, self, other), 

33 ) 

34 ) 

35 

36 if a.dtype != result_type: 

37 a = a.to(result_type) 

38 if b.dtype != result_type: 

39 b = b.to(result_type) 

40 

41 devices = [x.device for x in (c, a, b)] 

42 

43 assert all(device.type == "cpu" for device in devices), ( 

44 "CPU only. Expected all tensors to be on CPU, " f"but found devices {devices}" 

45 ) 

46 

47 device = devices[0] 

48 if c.device != device and c.ndim == 0: 

49 c = c.to(device) 

50 if a.device != device and a.ndim == 0: 

51 a = a.to(device) 

52 if b.device != device and b.ndim == 0: 

53 b = b.to(device) 

54 

55 assert ( 

56 len(set(devices)) == 1 

57 ), f"Expected all tensors to be on the same device, but found at least two devices, {devices}" 

58 assert ( 

59 c.dtype == torch.bool 

60 ), f"where expected condition to be a boolean tensor, but got a tensor with dtype {condition.dtype}" 

61 

62 if out is None: 

63 out_shape = torch.broadcast_shapes(c.shape, a.shape, b.shape) 

64 out = torch.empty(out_shape, dtype=result_type, device=device) 

65 

66 ndim = max(c.ndim, a.ndim, b.ndim) 

67 where_inner.instantiate(ndim) 

68 where_inner(c, a, b, out0=out) 

69 return out 

70 

71 

72def where_self(condition, self, other): 

73 logger.debug("GEMS_SPACEMIT WHERE_SELF") 

74 return where_self_out(condition, self, other) 

75 

76 

77def where_scalar_self(condition, self, other): 

78 logger.debug("GEMS_SPACEMIT WHERE_SCALAR_SELF") 

79 return where_self_out(condition, self, other) 

80 

81 

82def where_scalar_other(condition, self, other): 

83 logger.debug("GEMS_SPACEMIT WHERE_SCALAR_OTHER") 

84 return where_self_out(condition, self, other)