Coverage for src/flag_gems/runtime/backend/_cambricon/ops/add.py: 0%

64 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2 

3import torch 

4import triton 

5 

6from ..utils.pointwise_dynamic import pointwise_dynamic 

7 

8logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

9 

10 

11@pointwise_dynamic( 

12 is_tensor=[True, True, False, False], promotion_methods=[(0, 1, "DEFAULT")] 

13) 

14@triton.jit 

15def add_func(x, y, alpha, inplace): 

16 return x + y * alpha 

17 

18 

19@pointwise_dynamic( 

20 is_tensor=[True, False, False, False], promotion_methods=[(0, 1, "DEFAULT")] 

21) 

22@triton.jit 

23def add_func_tensor_scalar(x, y, alpha, inplace): 

24 return x + y * alpha 

25 

26 

27@pointwise_dynamic( 

28 is_tensor=[False, True, False, False], promotion_methods=[(0, 1, "DEFAULT")] 

29) 

30@triton.jit 

31def add_func_scalar_tensor(x, y, alpha, inplace): 

32 return x + y * alpha 

33 

34 

35def add(A, B, *, alpha=1): 

36 logger.debug("GEMS_CAMBRICON ADD") 

37 A_is_complex = (isinstance(A, torch.Tensor) and A.is_complex()) or isinstance( 

38 A, complex 

39 ) 

40 B_is_complex = (isinstance(B, torch.Tensor) and B.is_complex()) or isinstance( 

41 B, complex 

42 ) 

43 if A_is_complex or B_is_complex: 

44 if A_is_complex and B_is_complex: 

45 Ar = torch.view_as_real(A) 

46 Br = torch.view_as_real(B) 

47 common_dtype = torch.promote_types(Ar.dtype, Br.dtype) 

48 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype) 

49 out_real = add_func(Ar, Br, alpha, False) 

50 return torch.view_as_complex(out_real).to(torch.result_type(A, B)) 

51 elif A_is_complex and not B_is_complex: 

52 Ar = torch.view_as_real(A) 

53 if isinstance(B, torch.Tensor): 

54 Br = torch.view_as_real(B.to(A.dtype)) 

55 else: 

56 Br = torch.view_as_real( 

57 torch.tensor(B, dtype=A.dtype, device=A.device).expand_as(A) 

58 ) 

59 common_dtype = torch.promote_types(Ar.dtype, Br.dtype) 

60 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype) 

61 out_real = add_func(Ar, Br, alpha, False) 

62 return torch.view_as_complex(out_real).to(torch.result_type(A, B)) 

63 else: 

64 Br = torch.view_as_real(B) 

65 if isinstance(A, torch.Tensor): 

66 Ar = torch.view_as_real(A.to(B.dtype)) 

67 else: 

68 Ar = torch.view_as_real( 

69 torch.tensor(A, dtype=B.dtype, device=B.device).expand_as(B) 

70 ) 

71 common_dtype = torch.promote_types(Ar.dtype, Br.dtype) 

72 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype) 

73 out_real = add_func(Ar, Br, alpha, False) 

74 return torch.view_as_complex(out_real).to(torch.result_type(A, B)) 

75 elif isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

76 if B.device != A.device: 

77 B = B.to(A.device) 

78 return add_func(A, B, alpha, False) 

79 elif isinstance(A, torch.Tensor): 

80 return add_func_tensor_scalar(A, B, alpha, False) 

81 elif isinstance(B, torch.Tensor): 

82 return add_func_scalar_tensor(A, B, alpha, False) 

83 else: 

84 return torch.tensor(A + B * alpha) 

85 

86 

87def add_(A, B, *, alpha=1): 

88 logger.debug("GEMS_CAMBRICON ADD_") 

89 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

90 if B.device != A.device: 

91 B = B.to(A.device) 

92 return add_func(A, B, alpha, True, out0=A) 

93 elif isinstance(A, torch.Tensor): 

94 return add_func_tensor_scalar(A, B, alpha, True, out0=A) 

95 # elif isinstance(B, torch.Tensor): 

96 # return add_func_scalar_tensor(A, B, alpha, out0=A) 

97 else: 

98 raise ValueError("Unreachable.")