Coverage for src/flag_gems/runtime/backend/_sunrise/ops/where.py: 0%

51 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import pointwise_dynamic 

8from flag_gems.utils.pointwise_dynamic import CodeGenConfig 

9 

10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

11 

12MAX_GRID_SIZES = (65535, 65535, 65535) 

13config = CodeGenConfig( 

14 max_tile_size=256, 

15 max_grid_size=MAX_GRID_SIZES, 

16 max_num_warps_per_cta=16, 

17 prefer_block_pointer=True, 

18 prefer_1d_tile=True, 

19) 

20 

21 

22@pointwise_dynamic( 

23 is_tensor=[True, True, True], promotion_methods=[(1, 2, "NO_OPMATH")], config=config 

24) 

25@triton.jit 

26def where_inner(condition, self, other): 

27 return tl.where(condition, self, other) 

28 

29 

30def where_self_out(condition, self, other, out=None): 

31 logger.debug("GEMS WHERE_SELF_OUT") 

32 result_type = torch.result_type(self, other) 

33 if out is not None: 

34 assert ( 

35 out.dtype == result_type 

36 ), f"Expected out type to be {result_type}, but got {out.dtype}." 

37 

38 c, a, b = list( 

39 map( 

40 lambda x: x if isinstance(x, torch.Tensor) else torch.tensor(x), 

41 (condition, self, other), 

42 ) 

43 ) 

44 

45 if a.dtype != result_type: 

46 a = a.to(result_type) 

47 if b.dtype != result_type: 

48 b = b.to(result_type) 

49 

50 devices = map(lambda x: x.device, (c, a, b)) 

51 devices = list(filter(lambda k: k.type != "cpu", devices)) 

52 

53 assert len(devices), "CPU only. There seems a mistake to dispatch to here." 

54 

55 device = devices[0] 

56 if c.device != device and c.ndim == 0: 

57 c = c.to(device) 

58 if a.device != device and a.ndim == 0: 

59 a = a.to(device) 

60 if b.device != device and b.ndim == 0: 

61 b = b.to(device) 

62 

63 assert ( 

64 len(set(devices)) == 1 

65 ), f"Expected all tensors to be on the same device, but found at least two devices, {devices}" 

66 assert ( 

67 c.dtype == torch.bool 

68 ), f"where expected condition to be a boolean tensor, but got a tensor with dtype {condition.dtype}" 

69 

70 if out is None: 

71 out_shape = torch.broadcast_shapes(c.shape, a.shape, b.shape) 

72 out = torch.empty(out_shape, dtype=result_type, device=device) 

73 

74 ndim = max(c.ndim, a.ndim, b.ndim) 

75 where_inner.instantiate(ndim) 

76 where_inner(c, a, b, out0=out) 

77 return out 

78 

79 

80def where_self(condition, self, other): 

81 logger.debug("GEMS WHERE_SELF") 

82 return where_self_out(condition, self, other) 

83 

84 

85def where_scalar_self(condition, self, other): 

86 logger.debug("GEMS WHERE_SCALAR_SELF") 

87 return where_self_out(condition, self, other) 

88 

89 

90def where_scalar_other(condition, self, other): 

91 logger.debug("GEMS WHERE_SCALAR_OTHER") 

92 return where_self_out(condition, self, other)