Coverage for src/flag_gems/ops/add.py: 78%

68 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2 

3import torch 

4import triton 

5 

6from flag_gems.utils import pointwise_dynamic 

7from flag_gems.utils.pointwise_dynamic import ComplexMode 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

13@triton.jit 

14def add_func(x, y, alpha): 

15 return x + y * alpha 

16 

17 

18@pointwise_dynamic( 

19 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")] 

20) 

21@triton.jit 

22def add_func_tensor_scalar(x, y, alpha): 

23 return x + y * alpha 

24 

25 

26@pointwise_dynamic( 

27 is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")] 

28) 

29@triton.jit 

30def add_func_scalar_tensor(x, y, alpha): 

31 return x + y * alpha 

32 

33 

34# Register complex support (elementwise) 

35add_func.register_complex(mode=ComplexMode.ELEMENTWISE) 

36add_func_tensor_scalar.register_complex( 

37 mode=ComplexMode.ELEMENTWISE, tensorize_scalars=True, fallback_target=add_func 

38) 

39add_func_scalar_tensor.register_complex( 

40 mode=ComplexMode.ELEMENTWISE, tensorize_scalars=True, fallback_target=add_func 

41) 

42 

43 

44def add(A, B, *, alpha=1): 

45 logger.debug("GEMS ADD") 

46 A_is_complex = (isinstance(A, torch.Tensor) and A.is_complex()) or isinstance( 

47 A, complex 

48 ) 

49 B_is_complex = (isinstance(B, torch.Tensor) and B.is_complex()) or isinstance( 

50 B, complex 

51 ) 

52 if A_is_complex or B_is_complex: 

53 if A_is_complex and B_is_complex: 

54 Ar = torch.view_as_real(A) 

55 Br = torch.view_as_real(B) 

56 common_dtype = torch.promote_types(Ar.dtype, Br.dtype) 

57 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype) 

58 out_real = add_func(Ar, Br, alpha) 

59 return torch.view_as_complex(out_real).to(torch.result_type(A, B)) 

60 elif A_is_complex and not B_is_complex: 

61 Ar = torch.view_as_real(A) 

62 if isinstance(B, torch.Tensor): 

63 Br = torch.view_as_real(B.to(A.dtype)) 

64 else: 

65 Br = torch.view_as_real( 

66 torch.tensor(B, dtype=A.dtype, device=A.device).expand_as(A) 

67 ) 

68 common_dtype = torch.promote_types(Ar.dtype, Br.dtype) 

69 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype) 

70 out_real = add_func(Ar, Br, alpha) 

71 return torch.view_as_complex(out_real).to(torch.result_type(A, B)) 

72 else: 

73 Br = torch.view_as_real(B) 

74 if isinstance(A, torch.Tensor): 

75 Ar = torch.view_as_real(A.to(B.dtype)) 

76 else: 

77 Ar = torch.view_as_real( 

78 torch.tensor(A, dtype=B.dtype, device=B.device).expand_as(B) 

79 ) 

80 common_dtype = torch.promote_types(Ar.dtype, Br.dtype) 

81 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype) 

82 out_real = add_func(Ar, Br, alpha) 

83 return torch.view_as_complex(out_real).to(torch.result_type(A, B)) 

84 elif isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

85 if B.device != A.device: 

86 B = B.to(A.device) 

87 return add_func(A, B, alpha) 

88 elif isinstance(A, torch.Tensor): 

89 return add_func_tensor_scalar(A, B, alpha) 

90 elif isinstance(B, torch.Tensor): 

91 return add_func_scalar_tensor(A, B, alpha) 

92 else: 

93 return torch.tensor(A + B * alpha) 

94 

95 

96def add_(A, B, *, alpha=1): 

97 logger.debug("GEMS ADD_") 

98 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

99 if B.device != A.device: 

100 B = B.to(A.device) 

101 return add_func(A, B, alpha, out0=A) 

102 elif isinstance(A, torch.Tensor): 

103 return add_func_tensor_scalar(A, B, alpha, out0=A) 

104 # elif isinstance(B, torch.Tensor): 

105 # return add_func_scalar_tensor(A, B, alpha, out0=A) 

106 else: 

107 raise ValueError("Unreachable.")