Coverage for src/flag_gems/runtime/backend/_sunrise/ops/sub.py: 0%

96 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5 

6from flag_gems.utils import pointwise_dynamic 

7from flag_gems.utils.pointwise_dynamic import ComplexMode 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

13@triton.jit 

14def sub_func(x, y, alpha): 

15 return x - y * alpha 

16 

17 

18@pointwise_dynamic( 

19 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")] 

20) 

21@triton.jit 

22def sub_func_tensor_scalar(x, y, alpha): 

23 return x - y * alpha 

24 

25 

26@pointwise_dynamic( 

27 is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")] 

28) 

29@triton.jit 

30def sub_func_scalar_tensor(x, y, alpha): 

31 return x - y * alpha 

32 

33 

34# Register complex support (elementwise) 

35sub_func.register_complex(mode=ComplexMode.ELEMENTWISE) 

36sub_func_tensor_scalar.register_complex( 

37 mode=ComplexMode.ELEMENTWISE, tensorize_scalars=True, fallback_target=sub_func 

38) 

39sub_func_scalar_tensor.register_complex( 

40 mode=ComplexMode.ELEMENTWISE, tensorize_scalars=True, fallback_target=sub_func 

41) 

42 

43 

44def _view_as_real_ptpu_safe(x: torch.Tensor) -> torch.Tensor: 

45 """`torch.view_as_real(x)` with a CPU bounce when x is on PTPU.""" 

46 try: 

47 return torch.view_as_real(x) 

48 except NotImplementedError: 

49 if x.device.type != "ptpu": 

50 raise 

51 return torch.view_as_real(x.cpu()).to(x.device) 

52 

53 

54def _view_as_complex_ptpu_safe(x: torch.Tensor) -> torch.Tensor: 

55 """`torch.view_as_complex(x)` with a CPU bounce when x is on PTPU.""" 

56 try: 

57 return torch.view_as_complex(x) 

58 except NotImplementedError: 

59 if x.device.type != "ptpu": 

60 raise 

61 return torch.view_as_complex(x.cpu()).to(x.device) 

62 

63 

64def _scalar_complex_as_real_ptpu_safe( 

65 scalar, complex_dtype: torch.dtype, target_shape, device: torch.device 

66) -> torch.Tensor: 

67 """Broadcast a python complex scalar to a `view_as_real`-shaped tensor.""" 

68 cpu_scalar = torch.tensor(scalar, dtype=complex_dtype, device="cpu").expand( 

69 target_shape 

70 ) 

71 cpu_real = torch.view_as_real(cpu_scalar).contiguous() 

72 if device.type == "cpu": 

73 return cpu_real 

74 return cpu_real.to(device) 

75 

76 

77def _operand_as_real_ptpu_safe( 

78 value, complex_dtype: torch.dtype, target_shape, device: torch.device 

79) -> torch.Tensor: 

80 if isinstance(value, torch.Tensor): 

81 tensor = value if value.is_complex() else value.to(complex_dtype) 

82 return _view_as_real_ptpu_safe(tensor) 

83 return _scalar_complex_as_real_ptpu_safe(value, complex_dtype, target_shape, device) 

84 

85 

86def _complex_sub(A, B, alpha): 

87 result_dtype = torch.result_type(A, B) 

88 shape_a = A.shape if isinstance(A, torch.Tensor) else torch.Size([]) 

89 shape_b = B.shape if isinstance(B, torch.Tensor) else torch.Size([]) 

90 target_shape = torch.broadcast_shapes(shape_a, shape_b) 

91 device = A.device if isinstance(A, torch.Tensor) else B.device 

92 

93 A_is_complex = (isinstance(A, torch.Tensor) and A.is_complex()) or isinstance( 

94 A, complex 

95 ) 

96 B_is_complex = (isinstance(B, torch.Tensor) and B.is_complex()) or isinstance( 

97 B, complex 

98 ) 

99 

100 if A_is_complex and B_is_complex: 

101 Ar = _operand_as_real_ptpu_safe(A, result_dtype, target_shape, device) 

102 Br = _operand_as_real_ptpu_safe(B, result_dtype, target_shape, device) 

103 common_dtype = torch.promote_types(Ar.dtype, Br.dtype) 

104 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype) 

105 out_real = sub_func(Ar, Br, alpha) 

106 return _view_as_complex_ptpu_safe(out_real.contiguous()).to(result_dtype) 

107 

108 if A_is_complex: 

109 Ar = _operand_as_real_ptpu_safe(A, result_dtype, target_shape, device) 

110 if isinstance(B, torch.Tensor): 

111 Br = _operand_as_real_ptpu_safe(B, result_dtype, target_shape, device) 

112 else: 

113 Br = _scalar_complex_as_real_ptpu_safe( 

114 B, result_dtype, target_shape, device 

115 ) 

116 common_dtype = torch.promote_types(Ar.dtype, Br.dtype) 

117 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype) 

118 out_real = sub_func(Ar, Br, alpha) 

119 return _view_as_complex_ptpu_safe(out_real.contiguous()).to(result_dtype) 

120 

121 Br = _operand_as_real_ptpu_safe(B, result_dtype, target_shape, device) 

122 if isinstance(A, torch.Tensor): 

123 Ar = _operand_as_real_ptpu_safe(A, result_dtype, target_shape, device) 

124 else: 

125 Ar = _scalar_complex_as_real_ptpu_safe(A, result_dtype, target_shape, device) 

126 common_dtype = torch.promote_types(Ar.dtype, Br.dtype) 

127 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype) 

128 out_real = sub_func(Ar, Br, alpha) 

129 return _view_as_complex_ptpu_safe(out_real.contiguous()).to(result_dtype) 

130 

131 

132def sub(A, B, *, alpha=1): 

133 logger.debug("GEMS SUB") 

134 A_is_complex = (isinstance(A, torch.Tensor) and A.is_complex()) or isinstance( 

135 A, complex 

136 ) 

137 B_is_complex = (isinstance(B, torch.Tensor) and B.is_complex()) or isinstance( 

138 B, complex 

139 ) 

140 if A_is_complex or B_is_complex: 

141 return _complex_sub(A, B, alpha) 

142 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

143 return sub_func(A, B, alpha) 

144 elif isinstance(A, torch.Tensor): 

145 return sub_func_tensor_scalar(A, B, alpha) 

146 elif isinstance(B, torch.Tensor): 

147 return sub_func_scalar_tensor(A, B, alpha) 

148 else: 

149 return torch.tensor(A - B * alpha) 

150 

151 

152def sub_(A, B, *, alpha=1): 

153 logger.debug("GEMS SUB_") 

154 if isinstance(B, torch.Tensor): 

155 return sub_func(A, B, alpha, out0=A) 

156 else: 

157 return sub_func_tensor_scalar(A, B, alpha, out0=A)