Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/add.py: 0%
66 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
3import torch
4import triton
6from ..utils.pointwise_dynamic import pointwise_dynamic
8logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
11@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")])
12@triton.jit
13def add_func(x, y, alpha):
14 return x + y * alpha
17@pointwise_dynamic(
18 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")]
19)
20@triton.jit
21def add_func_tensor_scalar(x, y, alpha):
22 return x + y * alpha
25@pointwise_dynamic(
26 is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")]
27)
28@triton.jit
29def add_func_scalar_tensor(x, y, alpha):
30 return x + y * alpha
33def add(A, B, *, alpha=1):
34 logger.debug("GEMS_KUNLUNXIN ADD")
35 A_is_complex = (isinstance(A, torch.Tensor) and A.is_complex()) or isinstance(
36 A, complex
37 )
38 B_is_complex = (isinstance(B, torch.Tensor) and B.is_complex()) or isinstance(
39 B, complex
40 )
41 if A_is_complex or B_is_complex:
42 if A_is_complex and B_is_complex:
43 Ar = torch.view_as_real(A)
44 Br = torch.view_as_real(B)
45 common_dtype = torch.promote_types(Ar.dtype, Br.dtype)
46 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype)
47 out_real = add_func(Ar, Br, alpha)
48 return torch.view_as_complex(out_real).to(torch.result_type(A, B))
49 elif A_is_complex and not B_is_complex:
50 Ar = torch.view_as_real(A)
51 if isinstance(B, torch.Tensor):
52 B_casted = B.to(dtype=Ar.dtype)
53 Br = torch.stack([B_casted, torch.zeros_like(B_casted)], dim=-1)
54 else:
55 B_tensor = torch.full_like(Ar[..., 0], fill_value=B, dtype=Ar.dtype)
56 Br = torch.stack([B_tensor, torch.zeros_like(B_tensor)], dim=-1)
57 common_dtype = torch.promote_types(Ar.dtype, Br.dtype)
58 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype)
59 out_real = add_func(Ar, Br, alpha)
60 return torch.view_as_complex(out_real.contiguous()).to(
61 torch.result_type(A, B)
62 )
63 else:
64 Br = torch.view_as_real(B)
65 if isinstance(A, torch.Tensor):
66 A_casted = A.to(dtype=Br.dtype)
67 Ar = torch.stack([A_casted, torch.zeros_like(A_casted)], dim=-1)
68 else:
69 A_tensor = torch.full_like(Br[..., 0], fill_value=A, dtype=Br.dtype)
70 Ar = torch.stack([A_tensor, torch.zeros_like(A_tensor)], dim=-1)
71 common_dtype = torch.promote_types(Ar.dtype, Br.dtype)
72 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype)
73 out_real = add_func(Ar, Br, alpha)
74 return torch.view_as_complex(out_real.contiguous()).to(
75 torch.result_type(A, B)
76 )
77 elif isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
78 if B.device != A.device:
79 B = B.to(A.device)
80 return add_func(A, B, alpha)
81 elif isinstance(A, torch.Tensor):
82 return add_func_tensor_scalar(A, B, alpha)
83 elif isinstance(B, torch.Tensor):
84 return add_func_scalar_tensor(A, B, alpha)
85 else:
86 return torch.tensor(A + B * alpha)
89def add_(A, B, *, alpha=1.0):
90 logger.debug("GEMS_KUNLUNXIN ADD_")
91 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
92 return add_func(A, B, alpha, out0=A)
93 elif isinstance(A, torch.Tensor):
94 return add_func_tensor_scalar(A, B, alpha, out0=A)
95 # elif isinstance(B, torch.Tensor):
96 # return add_func_scalar_tensor(A, B, alpha, out0=A)
97 else:
98 raise ValueError("Unreachable.")