Coverage for src/flag_gems/ops/add.py: 78%
68 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
3import torch
4import triton
6from flag_gems.utils import pointwise_dynamic
7from flag_gems.utils.pointwise_dynamic import ComplexMode
9logger = logging.getLogger(__name__)
12@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")])
13@triton.jit
14def add_func(x, y, alpha):
15 return x + y * alpha
18@pointwise_dynamic(
19 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")]
20)
21@triton.jit
22def add_func_tensor_scalar(x, y, alpha):
23 return x + y * alpha
26@pointwise_dynamic(
27 is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")]
28)
29@triton.jit
30def add_func_scalar_tensor(x, y, alpha):
31 return x + y * alpha
34# Register complex support (elementwise)
35add_func.register_complex(mode=ComplexMode.ELEMENTWISE)
36add_func_tensor_scalar.register_complex(
37 mode=ComplexMode.ELEMENTWISE, tensorize_scalars=True, fallback_target=add_func
38)
39add_func_scalar_tensor.register_complex(
40 mode=ComplexMode.ELEMENTWISE, tensorize_scalars=True, fallback_target=add_func
41)
44def add(A, B, *, alpha=1):
45 logger.debug("GEMS ADD")
46 A_is_complex = (isinstance(A, torch.Tensor) and A.is_complex()) or isinstance(
47 A, complex
48 )
49 B_is_complex = (isinstance(B, torch.Tensor) and B.is_complex()) or isinstance(
50 B, complex
51 )
52 if A_is_complex or B_is_complex:
53 if A_is_complex and B_is_complex:
54 Ar = torch.view_as_real(A)
55 Br = torch.view_as_real(B)
56 common_dtype = torch.promote_types(Ar.dtype, Br.dtype)
57 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype)
58 out_real = add_func(Ar, Br, alpha)
59 return torch.view_as_complex(out_real).to(torch.result_type(A, B))
60 elif A_is_complex and not B_is_complex:
61 Ar = torch.view_as_real(A)
62 if isinstance(B, torch.Tensor):
63 Br = torch.view_as_real(B.to(A.dtype))
64 else:
65 Br = torch.view_as_real(
66 torch.tensor(B, dtype=A.dtype, device=A.device).expand_as(A)
67 )
68 common_dtype = torch.promote_types(Ar.dtype, Br.dtype)
69 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype)
70 out_real = add_func(Ar, Br, alpha)
71 return torch.view_as_complex(out_real).to(torch.result_type(A, B))
72 else:
73 Br = torch.view_as_real(B)
74 if isinstance(A, torch.Tensor):
75 Ar = torch.view_as_real(A.to(B.dtype))
76 else:
77 Ar = torch.view_as_real(
78 torch.tensor(A, dtype=B.dtype, device=B.device).expand_as(B)
79 )
80 common_dtype = torch.promote_types(Ar.dtype, Br.dtype)
81 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype)
82 out_real = add_func(Ar, Br, alpha)
83 return torch.view_as_complex(out_real).to(torch.result_type(A, B))
84 elif isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
85 if B.device != A.device:
86 B = B.to(A.device)
87 return add_func(A, B, alpha)
88 elif isinstance(A, torch.Tensor):
89 return add_func_tensor_scalar(A, B, alpha)
90 elif isinstance(B, torch.Tensor):
91 return add_func_scalar_tensor(A, B, alpha)
92 else:
93 return torch.tensor(A + B * alpha)
96def add_(A, B, *, alpha=1):
97 logger.debug("GEMS ADD_")
98 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
99 if B.device != A.device:
100 B = B.to(A.device)
101 return add_func(A, B, alpha, out0=A)
102 elif isinstance(A, torch.Tensor):
103 return add_func_tensor_scalar(A, B, alpha, out0=A)
104 # elif isinstance(B, torch.Tensor):
105 # return add_func_scalar_tensor(A, B, alpha, out0=A)
106 else:
107 raise ValueError("Unreachable.")