Coverage for src/flag_gems/runtime/backend/_cambricon/ops/sub.py: 0%
58 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
3import torch
4import triton
6from ..utils.pointwise_dynamic import pointwise_dynamic
8logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
11@pointwise_dynamic(
12 is_tensor=[True, True, False, False], promotion_methods=[(0, 1, "DEFAULT")]
13)
14@triton.jit
15def sub_func(x, y, alpha, inplace):
16 return x - y * alpha
19@pointwise_dynamic(
20 is_tensor=[True, False, False, False], promotion_methods=[(0, 1, "DEFAULT")]
21)
22@triton.jit
23def sub_func_tensor_scalar(x, y, alpha, inplace):
24 return x - y * alpha
27@pointwise_dynamic(
28 is_tensor=[False, True, False, False], promotion_methods=[(0, 1, "DEFAULT")]
29)
30@triton.jit
31def sub_func_scalar_tensor(x, y, alpha, inplace):
32 return x - y * alpha
35def sub(A, B, *, alpha=1):
36 logger.debug("GEMS_CAMBRICON SUB")
37 A_is_complex = (isinstance(A, torch.Tensor) and A.is_complex()) or isinstance(
38 A, complex
39 )
40 B_is_complex = (isinstance(B, torch.Tensor) and B.is_complex()) or isinstance(
41 B, complex
42 )
43 if A_is_complex or B_is_complex:
44 if A_is_complex and B_is_complex:
45 Ar = torch.view_as_real(A)
46 Br = torch.view_as_real(B)
47 common_dtype = torch.promote_types(Ar.dtype, Br.dtype)
48 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype)
49 out_real = sub_func(Ar, Br, alpha, False)
50 return torch.view_as_complex(out_real).to(torch.result_type(A, B))
51 elif A_is_complex and not B_is_complex:
52 Ar = torch.view_as_real(A)
53 if isinstance(B, torch.Tensor):
54 Br = torch.view_as_real(B.to(A.dtype))
55 else:
56 Br = torch.view_as_real(
57 torch.tensor(B, dtype=A.dtype, device=A.device).expand_as(A)
58 )
59 common_dtype = torch.promote_types(Ar.dtype, Br.dtype)
60 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype)
61 out_real = sub_func(Ar, Br, alpha, False)
62 return torch.view_as_complex(out_real).to(torch.result_type(A, B))
63 else:
64 Br = torch.view_as_real(B)
65 if isinstance(A, torch.Tensor):
66 Ar = torch.view_as_real(A.to(B.dtype))
67 else:
68 Ar = torch.view_as_real(
69 torch.tensor(A, dtype=B.dtype, device=B.device).expand_as(B)
70 )
71 common_dtype = torch.promote_types(Ar.dtype, Br.dtype)
72 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype)
73 out_real = sub_func(Ar, Br, alpha, False)
74 return torch.view_as_complex(out_real).to(torch.result_type(A, B))
75 elif isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
76 return sub_func(A, B, alpha, False)
77 elif isinstance(A, torch.Tensor):
78 return sub_func_tensor_scalar(A, B, alpha, False)
79 elif isinstance(B, torch.Tensor):
80 return sub_func_scalar_tensor(A, B, alpha, False)
81 else:
82 # Both scalar
83 return torch.tensor(A - B * alpha)
86def sub_(A, B, *, alpha=1):
87 logger.debug("GEMS_CAMBRICON SUB_")
88 if isinstance(B, torch.Tensor):
89 return sub_func(A, B, alpha, True, out0=A)
90 else:
91 return sub_func_tensor_scalar(A, B, alpha, True, out0=A)