Coverage for src/flag_gems/runtime/backend/_sunrise/ops/sub.py: 0%
96 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
6from flag_gems.utils import pointwise_dynamic
7from flag_gems.utils.pointwise_dynamic import ComplexMode
9logger = logging.getLogger(__name__)
12@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")])
13@triton.jit
14def sub_func(x, y, alpha):
15 return x - y * alpha
18@pointwise_dynamic(
19 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")]
20)
21@triton.jit
22def sub_func_tensor_scalar(x, y, alpha):
23 return x - y * alpha
26@pointwise_dynamic(
27 is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")]
28)
29@triton.jit
30def sub_func_scalar_tensor(x, y, alpha):
31 return x - y * alpha
34# Register complex support (elementwise)
35sub_func.register_complex(mode=ComplexMode.ELEMENTWISE)
36sub_func_tensor_scalar.register_complex(
37 mode=ComplexMode.ELEMENTWISE, tensorize_scalars=True, fallback_target=sub_func
38)
39sub_func_scalar_tensor.register_complex(
40 mode=ComplexMode.ELEMENTWISE, tensorize_scalars=True, fallback_target=sub_func
41)
44def _view_as_real_ptpu_safe(x: torch.Tensor) -> torch.Tensor:
45 """`torch.view_as_real(x)` with a CPU bounce when x is on PTPU."""
46 try:
47 return torch.view_as_real(x)
48 except NotImplementedError:
49 if x.device.type != "ptpu":
50 raise
51 return torch.view_as_real(x.cpu()).to(x.device)
54def _view_as_complex_ptpu_safe(x: torch.Tensor) -> torch.Tensor:
55 """`torch.view_as_complex(x)` with a CPU bounce when x is on PTPU."""
56 try:
57 return torch.view_as_complex(x)
58 except NotImplementedError:
59 if x.device.type != "ptpu":
60 raise
61 return torch.view_as_complex(x.cpu()).to(x.device)
64def _scalar_complex_as_real_ptpu_safe(
65 scalar, complex_dtype: torch.dtype, target_shape, device: torch.device
66) -> torch.Tensor:
67 """Broadcast a python complex scalar to a `view_as_real`-shaped tensor."""
68 cpu_scalar = torch.tensor(scalar, dtype=complex_dtype, device="cpu").expand(
69 target_shape
70 )
71 cpu_real = torch.view_as_real(cpu_scalar).contiguous()
72 if device.type == "cpu":
73 return cpu_real
74 return cpu_real.to(device)
77def _operand_as_real_ptpu_safe(
78 value, complex_dtype: torch.dtype, target_shape, device: torch.device
79) -> torch.Tensor:
80 if isinstance(value, torch.Tensor):
81 tensor = value if value.is_complex() else value.to(complex_dtype)
82 return _view_as_real_ptpu_safe(tensor)
83 return _scalar_complex_as_real_ptpu_safe(value, complex_dtype, target_shape, device)
86def _complex_sub(A, B, alpha):
87 result_dtype = torch.result_type(A, B)
88 shape_a = A.shape if isinstance(A, torch.Tensor) else torch.Size([])
89 shape_b = B.shape if isinstance(B, torch.Tensor) else torch.Size([])
90 target_shape = torch.broadcast_shapes(shape_a, shape_b)
91 device = A.device if isinstance(A, torch.Tensor) else B.device
93 A_is_complex = (isinstance(A, torch.Tensor) and A.is_complex()) or isinstance(
94 A, complex
95 )
96 B_is_complex = (isinstance(B, torch.Tensor) and B.is_complex()) or isinstance(
97 B, complex
98 )
100 if A_is_complex and B_is_complex:
101 Ar = _operand_as_real_ptpu_safe(A, result_dtype, target_shape, device)
102 Br = _operand_as_real_ptpu_safe(B, result_dtype, target_shape, device)
103 common_dtype = torch.promote_types(Ar.dtype, Br.dtype)
104 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype)
105 out_real = sub_func(Ar, Br, alpha)
106 return _view_as_complex_ptpu_safe(out_real.contiguous()).to(result_dtype)
108 if A_is_complex:
109 Ar = _operand_as_real_ptpu_safe(A, result_dtype, target_shape, device)
110 if isinstance(B, torch.Tensor):
111 Br = _operand_as_real_ptpu_safe(B, result_dtype, target_shape, device)
112 else:
113 Br = _scalar_complex_as_real_ptpu_safe(
114 B, result_dtype, target_shape, device
115 )
116 common_dtype = torch.promote_types(Ar.dtype, Br.dtype)
117 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype)
118 out_real = sub_func(Ar, Br, alpha)
119 return _view_as_complex_ptpu_safe(out_real.contiguous()).to(result_dtype)
121 Br = _operand_as_real_ptpu_safe(B, result_dtype, target_shape, device)
122 if isinstance(A, torch.Tensor):
123 Ar = _operand_as_real_ptpu_safe(A, result_dtype, target_shape, device)
124 else:
125 Ar = _scalar_complex_as_real_ptpu_safe(A, result_dtype, target_shape, device)
126 common_dtype = torch.promote_types(Ar.dtype, Br.dtype)
127 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype)
128 out_real = sub_func(Ar, Br, alpha)
129 return _view_as_complex_ptpu_safe(out_real.contiguous()).to(result_dtype)
132def sub(A, B, *, alpha=1):
133 logger.debug("GEMS SUB")
134 A_is_complex = (isinstance(A, torch.Tensor) and A.is_complex()) or isinstance(
135 A, complex
136 )
137 B_is_complex = (isinstance(B, torch.Tensor) and B.is_complex()) or isinstance(
138 B, complex
139 )
140 if A_is_complex or B_is_complex:
141 return _complex_sub(A, B, alpha)
142 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
143 return sub_func(A, B, alpha)
144 elif isinstance(A, torch.Tensor):
145 return sub_func_tensor_scalar(A, B, alpha)
146 elif isinstance(B, torch.Tensor):
147 return sub_func_scalar_tensor(A, B, alpha)
148 else:
149 return torch.tensor(A - B * alpha)
152def sub_(A, B, *, alpha=1):
153 logger.debug("GEMS SUB_")
154 if isinstance(B, torch.Tensor):
155 return sub_func(A, B, alpha, out0=A)
156 else:
157 return sub_func_tensor_scalar(A, B, alpha, out0=A)