Coverage for src/flag_gems/runtime/backend/_sunrise/ops/mul.py: 0%

78 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import torch 

4import triton 

5 

6from flag_gems.utils import pointwise_dynamic 

7from flag_gems.utils.codegen_config_utils import CodeGenConfig 

8 

9logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

10 

11 

12config_for_broadcast = CodeGenConfig( 

13 8192, 

14 (65536, 65536, 65536), 

15 32, 

16 True, 

17 prefer_1d_tile=False, 

18 # num_warps=16 

19) 

20 

21 

22@pointwise_dynamic( 

23 is_tensor=[True, True], 

24 promotion_methods=[(0, 1, "DEFAULT")], 

25 config=config_for_broadcast, 

26) 

27@triton.jit 

28def mul_func(x, y): 

29 return x * y 

30 

31 

32@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

33@triton.jit 

34def mul_func_scalar(x, y): 

35 return x * y 

36 

37 

38@pointwise_dynamic( 

39 is_tensor=[True, True, True, True], # ar, ai, br, bi 

40 num_outputs=2, 

41 promotion_methods=[(0, 1, 2, 3, "DEFAULT"), (0, 1, 2, 3, "DEFAULT")], 

42) 

43@triton.jit 

44def mul_complex_kernel(ar, ai, br, bi): 

45 real = ar * br - ai * bi 

46 imag = ar * bi + ai * br 

47 return real, imag 

48 

49 

50def mul(A, B): 

51 logger.debug("GEMS MUL") 

52 A_is_complex = (isinstance(A, torch.Tensor) and A.is_complex()) or isinstance( 

53 A, complex 

54 ) 

55 B_is_complex = (isinstance(B, torch.Tensor) and B.is_complex()) or isinstance( 

56 B, complex 

57 ) 

58 if A_is_complex or B_is_complex: 

59 # 1) A、B both are complex 

60 if A_is_complex and B_is_complex: 

61 a_device = A.device 

62 b_device = B.device 

63 A = A.to(device="cpu") 

64 B = B.to(device="cpu") 

65 Ar = torch.view_as_real(A) 

66 Br = torch.view_as_real(B) 

67 ar, ai = Ar[..., 0], Ar[..., 1] 

68 br, bi = Br[..., 0], Br[..., 1] 

69 ar = ar.to(a_device) 

70 ai = ai.to(a_device) 

71 br = br.to(b_device) 

72 bi = bi.to(b_device) 

73 common_dtype = torch.promote_types(ar.dtype, br.dtype) 

74 ar, ai = ar.to(common_dtype), ai.to(common_dtype) 

75 br, bi = br.to(common_dtype), bi.to(common_dtype) 

76 

77 real_out = torch.empty_like(ar, dtype=common_dtype) 

78 imag_out = torch.empty_like(ar, dtype=common_dtype) 

79 mul_complex_kernel(ar, ai, br, bi, out0=real_out, out1=imag_out) 

80 

81 out = torch.view_as_complex(torch.stack((real_out, imag_out), dim=-1).cpu()) 

82 return out.to(torch.result_type(A, B)).to(a_device) 

83 # 2) A complex, B real 

84 elif A_is_complex and not B_is_complex: 

85 a_device = A.device 

86 A = A.to(device="cpu") 

87 Ar = torch.view_as_real(A) 

88 Ar = Ar.to(a_device) 

89 Br = B.unsqueeze(-1) if isinstance(B, torch.Tensor) else B 

90 if isinstance(Br, torch.Tensor): 

91 out_real = mul_func(Ar, Br) 

92 else: 

93 out_real = mul_func_scalar(Ar, Br) 

94 return ( 

95 torch.view_as_complex(out_real.cpu()) 

96 .to(torch.result_type(A, B)) 

97 .to(a_device) 

98 ) 

99 # 3) A real, B complex 

100 else: # not A_is_complex and B_is_complex 

101 b_device = B.device 

102 B = B.to(device="cpu") 

103 Br = torch.view_as_real(B) 

104 Br = Br.to(b_device) 

105 Ar = A.unsqueeze(-1) if isinstance(A, torch.Tensor) else A 

106 if isinstance(Ar, torch.Tensor): 

107 out_real = mul_func(Ar, Br) # shape broadcasting requires Ar and Br 

108 else: 

109 out_real = mul_func_scalar(Br, Ar) # Br is tensor, Ar is scalar 

110 return ( 

111 torch.view_as_complex(out_real.cpu()) 

112 .to(torch.result_type(A, B)) 

113 .to(b_device) 

114 ) 

115 elif isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

116 return mul_func(A, B) 

117 elif isinstance(A, torch.Tensor): 

118 return mul_func_scalar(A, B) 

119 elif isinstance(B, torch.Tensor): 

120 return mul_func_scalar(B, A) 

121 else: 

122 # Both scalar 

123 return torch.tensor(A * B) 

124 

125 

126def mul_(A, B): 

127 logger.debug("GEMS MUL_") 

128 if isinstance(B, torch.Tensor): 

129 return mul_func(A, B, out0=A) 

130 else: 

131 return mul_func_scalar(A, B, out0=A)