Coverage for src/flag_gems/ops/sub.py: 79%

62 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import logging 

2 

3import torch 

4import triton 

5 

6from flag_gems.utils import pointwise_dynamic 

7from flag_gems.utils.pointwise_dynamic import ComplexMode 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

13@triton.jit 

14def sub_func(x, y, alpha): 

15 return x - y * alpha 

16 

17 

18@pointwise_dynamic( 

19 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")] 

20) 

21@triton.jit 

22def sub_func_tensor_scalar(x, y, alpha): 

23 return x - y * alpha 

24 

25 

26@pointwise_dynamic( 

27 is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")] 

28) 

29@triton.jit 

30def sub_func_scalar_tensor(x, y, alpha): 

31 return x - y * alpha 

32 

33 

34# Register complex support (elementwise) 

35sub_func.register_complex(mode=ComplexMode.ELEMENTWISE) 

36sub_func_tensor_scalar.register_complex( 

37 mode=ComplexMode.ELEMENTWISE, tensorize_scalars=True, fallback_target=sub_func 

38) 

39sub_func_scalar_tensor.register_complex( 

40 mode=ComplexMode.ELEMENTWISE, tensorize_scalars=True, fallback_target=sub_func 

41) 

42 

43 

44def sub(A, B, *, alpha=1): 

45 logger.debug("GEMS SUB") 

46 A_is_complex = (isinstance(A, torch.Tensor) and A.is_complex()) or isinstance( 

47 A, complex 

48 ) 

49 B_is_complex = (isinstance(B, torch.Tensor) and B.is_complex()) or isinstance( 

50 B, complex 

51 ) 

52 if A_is_complex or B_is_complex: 

53 if A_is_complex and B_is_complex: 

54 Ar = torch.view_as_real(A) 

55 Br = torch.view_as_real(B) 

56 common_dtype = torch.promote_types(Ar.dtype, Br.dtype) 

57 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype) 

58 out_real = sub_func(Ar, Br, alpha) 

59 return torch.view_as_complex(out_real).to(torch.result_type(A, B)) 

60 elif A_is_complex and not B_is_complex: 

61 Ar = torch.view_as_real(A) 

62 if isinstance(B, torch.Tensor): 

63 Br = torch.view_as_real(B.to(A.dtype)) 

64 else: 

65 Br = torch.view_as_real( 

66 torch.tensor(B, dtype=A.dtype, device=A.device).expand_as(A) 

67 ) 

68 common_dtype = torch.promote_types(Ar.dtype, Br.dtype) 

69 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype) 

70 out_real = sub_func(Ar, Br, alpha) 

71 return torch.view_as_complex(out_real).to(torch.result_type(A, B)) 

72 else: 

73 Br = torch.view_as_real(B) 

74 if isinstance(A, torch.Tensor): 

75 Ar = torch.view_as_real(A.to(B.dtype)) 

76 else: 

77 Ar = torch.view_as_real( 

78 torch.tensor(A, dtype=B.dtype, device=B.device).expand_as(B) 

79 ) 

80 common_dtype = torch.promote_types(Ar.dtype, Br.dtype) 

81 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype) 

82 out_real = sub_func(Ar, Br, alpha) 

83 return torch.view_as_complex(out_real).to(torch.result_type(A, B)) 

84 elif isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

85 return sub_func(A, B, alpha) 

86 elif isinstance(A, torch.Tensor): 

87 return sub_func_tensor_scalar(A, B, alpha) 

88 elif isinstance(B, torch.Tensor): 

89 return sub_func_scalar_tensor(A, B, alpha) 

90 else: 

91 # Both scalar 

92 return torch.tensor(A - B * alpha) 

93 

94 

95def sub_(A, B, *, alpha=1): 

96 logger.debug("GEMS SUB_") 

97 if isinstance(B, torch.Tensor): 

98 return sub_func(A, B, alpha, out0=A) 

99 else: 

100 return sub_func_tensor_scalar(A, B, alpha, out0=A)