Coverage for src/flag_gems/runtime/backend/_cambricon/ops/sub.py: 0%

58 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2 

3import torch 

4import triton 

5 

6from ..utils.pointwise_dynamic import pointwise_dynamic 

7 

8logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

9 

10 

11@pointwise_dynamic( 

12 is_tensor=[True, True, False, False], promotion_methods=[(0, 1, "DEFAULT")] 

13) 

14@triton.jit 

15def sub_func(x, y, alpha, inplace): 

16 return x - y * alpha 

17 

18 

19@pointwise_dynamic( 

20 is_tensor=[True, False, False, False], promotion_methods=[(0, 1, "DEFAULT")] 

21) 

22@triton.jit 

23def sub_func_tensor_scalar(x, y, alpha, inplace): 

24 return x - y * alpha 

25 

26 

27@pointwise_dynamic( 

28 is_tensor=[False, True, False, False], promotion_methods=[(0, 1, "DEFAULT")] 

29) 

30@triton.jit 

31def sub_func_scalar_tensor(x, y, alpha, inplace): 

32 return x - y * alpha 

33 

34 

35def sub(A, B, *, alpha=1): 

36 logger.debug("GEMS_CAMBRICON SUB") 

37 A_is_complex = (isinstance(A, torch.Tensor) and A.is_complex()) or isinstance( 

38 A, complex 

39 ) 

40 B_is_complex = (isinstance(B, torch.Tensor) and B.is_complex()) or isinstance( 

41 B, complex 

42 ) 

43 if A_is_complex or B_is_complex: 

44 if A_is_complex and B_is_complex: 

45 Ar = torch.view_as_real(A) 

46 Br = torch.view_as_real(B) 

47 common_dtype = torch.promote_types(Ar.dtype, Br.dtype) 

48 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype) 

49 out_real = sub_func(Ar, Br, alpha, False) 

50 return torch.view_as_complex(out_real).to(torch.result_type(A, B)) 

51 elif A_is_complex and not B_is_complex: 

52 Ar = torch.view_as_real(A) 

53 if isinstance(B, torch.Tensor): 

54 Br = torch.view_as_real(B.to(A.dtype)) 

55 else: 

56 Br = torch.view_as_real( 

57 torch.tensor(B, dtype=A.dtype, device=A.device).expand_as(A) 

58 ) 

59 common_dtype = torch.promote_types(Ar.dtype, Br.dtype) 

60 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype) 

61 out_real = sub_func(Ar, Br, alpha, False) 

62 return torch.view_as_complex(out_real).to(torch.result_type(A, B)) 

63 else: 

64 Br = torch.view_as_real(B) 

65 if isinstance(A, torch.Tensor): 

66 Ar = torch.view_as_real(A.to(B.dtype)) 

67 else: 

68 Ar = torch.view_as_real( 

69 torch.tensor(A, dtype=B.dtype, device=B.device).expand_as(B) 

70 ) 

71 common_dtype = torch.promote_types(Ar.dtype, Br.dtype) 

72 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype) 

73 out_real = sub_func(Ar, Br, alpha, False) 

74 return torch.view_as_complex(out_real).to(torch.result_type(A, B)) 

75 elif isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

76 return sub_func(A, B, alpha, False) 

77 elif isinstance(A, torch.Tensor): 

78 return sub_func_tensor_scalar(A, B, alpha, False) 

79 elif isinstance(B, torch.Tensor): 

80 return sub_func_scalar_tensor(A, B, alpha, False) 

81 else: 

82 # Both scalar 

83 return torch.tensor(A - B * alpha) 

84 

85 

86def sub_(A, B, *, alpha=1): 

87 logger.debug("GEMS_CAMBRICON SUB_") 

88 if isinstance(B, torch.Tensor): 

89 return sub_func(A, B, alpha, True, out0=A) 

90 else: 

91 return sub_func_tensor_scalar(A, B, alpha, True, out0=A)