Coverage for src/flag_gems/runtime/backend/_cambricon/ops/add.py: 0%
64 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
3import torch
4import triton
6from ..utils.pointwise_dynamic import pointwise_dynamic
8logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
11@pointwise_dynamic(
12 is_tensor=[True, True, False, False], promotion_methods=[(0, 1, "DEFAULT")]
13)
14@triton.jit
15def add_func(x, y, alpha, inplace):
16 return x + y * alpha
19@pointwise_dynamic(
20 is_tensor=[True, False, False, False], promotion_methods=[(0, 1, "DEFAULT")]
21)
22@triton.jit
23def add_func_tensor_scalar(x, y, alpha, inplace):
24 return x + y * alpha
27@pointwise_dynamic(
28 is_tensor=[False, True, False, False], promotion_methods=[(0, 1, "DEFAULT")]
29)
30@triton.jit
31def add_func_scalar_tensor(x, y, alpha, inplace):
32 return x + y * alpha
35def add(A, B, *, alpha=1):
36 logger.debug("GEMS_CAMBRICON ADD")
37 A_is_complex = (isinstance(A, torch.Tensor) and A.is_complex()) or isinstance(
38 A, complex
39 )
40 B_is_complex = (isinstance(B, torch.Tensor) and B.is_complex()) or isinstance(
41 B, complex
42 )
43 if A_is_complex or B_is_complex:
44 if A_is_complex and B_is_complex:
45 Ar = torch.view_as_real(A)
46 Br = torch.view_as_real(B)
47 common_dtype = torch.promote_types(Ar.dtype, Br.dtype)
48 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype)
49 out_real = add_func(Ar, Br, alpha, False)
50 return torch.view_as_complex(out_real).to(torch.result_type(A, B))
51 elif A_is_complex and not B_is_complex:
52 Ar = torch.view_as_real(A)
53 if isinstance(B, torch.Tensor):
54 Br = torch.view_as_real(B.to(A.dtype))
55 else:
56 Br = torch.view_as_real(
57 torch.tensor(B, dtype=A.dtype, device=A.device).expand_as(A)
58 )
59 common_dtype = torch.promote_types(Ar.dtype, Br.dtype)
60 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype)
61 out_real = add_func(Ar, Br, alpha, False)
62 return torch.view_as_complex(out_real).to(torch.result_type(A, B))
63 else:
64 Br = torch.view_as_real(B)
65 if isinstance(A, torch.Tensor):
66 Ar = torch.view_as_real(A.to(B.dtype))
67 else:
68 Ar = torch.view_as_real(
69 torch.tensor(A, dtype=B.dtype, device=B.device).expand_as(B)
70 )
71 common_dtype = torch.promote_types(Ar.dtype, Br.dtype)
72 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype)
73 out_real = add_func(Ar, Br, alpha, False)
74 return torch.view_as_complex(out_real).to(torch.result_type(A, B))
75 elif isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
76 if B.device != A.device:
77 B = B.to(A.device)
78 return add_func(A, B, alpha, False)
79 elif isinstance(A, torch.Tensor):
80 return add_func_tensor_scalar(A, B, alpha, False)
81 elif isinstance(B, torch.Tensor):
82 return add_func_scalar_tensor(A, B, alpha, False)
83 else:
84 return torch.tensor(A + B * alpha)
87def add_(A, B, *, alpha=1):
88 logger.debug("GEMS_CAMBRICON ADD_")
89 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
90 if B.device != A.device:
91 B = B.to(A.device)
92 return add_func(A, B, alpha, True, out0=A)
93 elif isinstance(A, torch.Tensor):
94 return add_func_tensor_scalar(A, B, alpha, True, out0=A)
95 # elif isinstance(B, torch.Tensor):
96 # return add_func_scalar_tensor(A, B, alpha, out0=A)
97 else:
98 raise ValueError("Unreachable.")