Coverage for src/flag_gems/runtime/backend/_cambricon/ops/mul.py: 0%

72 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5 

6from ..utils.pointwise_dynamic import pointwise_dynamic 

7 

8logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

9 

10 

11@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

12@triton.jit 

13def mul_func(x, y, inplace): 

14 return x * y 

15 

16 

17@pointwise_dynamic( 

18 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")] 

19) 

20@triton.jit 

21def mul_func_scalar(x, y, inplace): 

22 return x * y 

23 

24 

25@pointwise_dynamic( 

26 is_tensor=[True, True, True, True], # ar, ai, br, bi 

27 num_outputs=2, 

28 promotion_methods=[(0, 1, 2, 3, "DEFAULT"), (0, 1, 2, 3, "DEFAULT")], 

29) 

30@triton.jit 

31def mul_complex_kernel(ar, ai, br, bi): 

32 real = ar * br - ai * bi 

33 imag = ar * bi + ai * br 

34 return real, imag 

35 

36 

37def mul(A, B): 

38 logger.debug("GEMS_CAMBRICON MUL") 

39 A_is_complex = (isinstance(A, torch.Tensor) and A.is_complex()) or isinstance( 

40 A, complex 

41 ) 

42 B_is_complex = (isinstance(B, torch.Tensor) and B.is_complex()) or isinstance( 

43 B, complex 

44 ) 

45 if A_is_complex or B_is_complex: 

46 # 1) A、B both are complex 

47 if A_is_complex and B_is_complex: 

48 Ar = torch.view_as_real(A) 

49 Br = torch.view_as_real(B) 

50 ar, ai = Ar[..., 0], Ar[..., 1] 

51 br, bi = Br[..., 0], Br[..., 1] 

52 common_dtype = torch.promote_types(ar.dtype, br.dtype) 

53 ar, ai = ar.to(common_dtype), ai.to(common_dtype) 

54 br, bi = br.to(common_dtype), bi.to(common_dtype) 

55 

56 real_out = torch.empty_like(ar, dtype=common_dtype) 

57 imag_out = torch.empty_like(ar, dtype=common_dtype) 

58 mul_complex_kernel(ar, ai, br, bi, out0=real_out, out1=imag_out) 

59 

60 out = torch.view_as_complex(torch.stack((real_out, imag_out), dim=-1)) 

61 return out.to(torch.result_type(A, B)) 

62 # 2) A complex, B real 

63 elif A_is_complex and not B_is_complex: 

64 Ar = torch.view_as_real(A) 

65 Br = B.unsqueeze(-1) if isinstance(B, torch.Tensor) else B 

66 if isinstance(Br, torch.Tensor): 

67 out_real = mul_func(Ar, Br, False) 

68 else: 

69 out_real = mul_func_scalar(Ar, Br, False) 

70 return torch.view_as_complex(out_real).to(torch.result_type(A, B)) 

71 # 3) A real, B complex 

72 else: # not A_is_complex and B_is_complex 

73 Br = torch.view_as_real(B) 

74 Ar = A.unsqueeze(-1) if isinstance(A, torch.Tensor) else A 

75 if isinstance(Ar, torch.Tensor): 

76 out_real = mul_func( 

77 Ar, Br, False 

78 ) # shape broadcasting requires Ar and Br 

79 else: 

80 out_real = mul_func_scalar(Br, Ar, False) # Br is tensor, Ar is scalar 

81 return torch.view_as_complex(out_real).to(torch.result_type(A, B)) 

82 elif isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

83 if A.device != B.device: 

84 if A.dim() == 0: 

85 assert A.device == torch.device("cpu"), "expect scalar tensor on cpu" 

86 A = A.to(B.device) 

87 elif B.dim() == 0: 

88 assert B.device == torch.device("cpu"), "expect scalar tensor on cpu" 

89 B = B.to(A.device) 

90 return mul_func(A, B, False) 

91 elif isinstance(A, torch.Tensor): 

92 return mul_func_scalar(A, B, False) 

93 elif isinstance(B, torch.Tensor): 

94 return mul_func_scalar(B, A, False) 

95 else: 

96 # Both scalar 

97 return torch.tensor(A * B) 

98 

99 

100def mul_(A, B): 

101 logger.debug("GEMS_CAMBRICON MUL_") 

102 if isinstance(B, torch.Tensor): 

103 if B.device != A.device and B.dim() == 0: 

104 assert B.device == torch.device("cpu"), "expect scalar tensor on cpu" 

105 B = B.to(A.device) 

106 return mul_func(A, B, True, out0=A) 

107 else: 

108 return mul_func_scalar(A, B, True, out0=A)