Coverage for src/flag_gems/runtime/backend/_sunrise/ops/mul.py: 0%

99 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5 

6from flag_gems.utils import pointwise_dynamic 

7from flag_gems.utils.codegen_config_utils import CodeGenConfig 

8from flag_gems.utils.pointwise_dynamic import ComplexMode 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13config_for_broadcast = CodeGenConfig( 

14 8192, 

15 (65536, 65536, 65536), 

16 32, 

17 True, 

18 prefer_1d_tile=False, 

19 # num_warps=16 

20) 

21 

22 

23@pointwise_dynamic( 

24 is_tensor=[True, True], 

25 promotion_methods=[(0, 1, "DEFAULT")], 

26 config=config_for_broadcast, 

27) 

28@triton.jit 

29def mul_func(x, y): 

30 return x * y 

31 

32 

33@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

34@triton.jit 

35def mul_func_scalar(x, y): 

36 return x * y 

37 

38 

39@pointwise_dynamic( 

40 is_tensor=[True, True, True, True], # ar, ai, br, bi 

41 num_outputs=2, 

42 promotion_methods=[(0, 1, 2, 3, "DEFAULT"), (0, 1, 2, 3, "DEFAULT")], 

43) 

44@triton.jit 

45def mul_complex_kernel(ar, ai, br, bi): 

46 real = ar * br - ai * bi 

47 imag = ar * bi + ai * br 

48 return real, imag 

49 

50 

51# Register complex support 

52mul_func.register_complex(mode=ComplexMode.CROSS, cross_kernel=mul_complex_kernel) 

53mul_func_scalar.register_complex( 

54 mode=ComplexMode.CROSS, tensorize_scalars=True, fallback_target=mul_func 

55) 

56 

57 

58def _view_as_real_ptpu_safe(x: torch.Tensor) -> torch.Tensor: 

59 """`torch.view_as_real(x)` with a CPU bounce when x is on PTPU.""" 

60 try: 

61 return torch.view_as_real(x) 

62 except NotImplementedError: 

63 if x.device.type != "ptpu": 

64 raise 

65 return torch.view_as_real(x.cpu()).to(x.device) 

66 

67 

68def _view_as_complex_ptpu_safe(x: torch.Tensor) -> torch.Tensor: 

69 """`torch.view_as_complex(x)` with a CPU bounce when x is on PTPU.""" 

70 try: 

71 return torch.view_as_complex(x) 

72 except NotImplementedError: 

73 if x.device.type != "ptpu": 

74 raise 

75 return torch.view_as_complex(x.cpu()).to(x.device) 

76 

77 

78def _scalar_complex_as_real_ptpu_safe( 

79 scalar, complex_dtype: torch.dtype, target_shape, device: torch.device 

80) -> torch.Tensor: 

81 """Broadcast a python complex scalar to a `view_as_real`-shaped tensor.""" 

82 cpu_scalar = torch.tensor(scalar, dtype=complex_dtype, device="cpu").expand( 

83 target_shape 

84 ) 

85 cpu_real = torch.view_as_real(cpu_scalar).contiguous() 

86 if device.type == "cpu": 

87 return cpu_real 

88 return cpu_real.to(device) 

89 

90 

91def _operand_as_real_ptpu_safe( 

92 value, complex_dtype: torch.dtype, target_shape, device: torch.device 

93) -> torch.Tensor: 

94 if isinstance(value, torch.Tensor): 

95 tensor = value if value.is_complex() else value.to(complex_dtype) 

96 return _view_as_real_ptpu_safe(tensor) 

97 return _scalar_complex_as_real_ptpu_safe(value, complex_dtype, target_shape, device) 

98 

99 

100def _complex_mul(A, B): 

101 result_dtype = torch.result_type(A, B) 

102 shape_a = A.shape if isinstance(A, torch.Tensor) else torch.Size([]) 

103 shape_b = B.shape if isinstance(B, torch.Tensor) else torch.Size([]) 

104 target_shape = torch.broadcast_shapes(shape_a, shape_b) 

105 device = A.device if isinstance(A, torch.Tensor) else B.device 

106 

107 A_is_complex = (isinstance(A, torch.Tensor) and A.is_complex()) or isinstance( 

108 A, complex 

109 ) 

110 B_is_complex = (isinstance(B, torch.Tensor) and B.is_complex()) or isinstance( 

111 B, complex 

112 ) 

113 

114 if A_is_complex and B_is_complex: 

115 Ar = _operand_as_real_ptpu_safe(A, result_dtype, target_shape, device) 

116 Br = _operand_as_real_ptpu_safe(B, result_dtype, target_shape, device) 

117 ar, ai = Ar[..., 0], Ar[..., 1] 

118 br, bi = Br[..., 0], Br[..., 1] 

119 common_dtype = torch.promote_types(ar.dtype, br.dtype) 

120 ar, ai = ar.to(common_dtype), ai.to(common_dtype) 

121 br, bi = br.to(common_dtype), bi.to(common_dtype) 

122 real_out, imag_out = mul_complex_kernel(ar, ai, br, bi) 

123 out = torch.stack((real_out, imag_out), dim=-1) 

124 return _view_as_complex_ptpu_safe(out.contiguous()).to(result_dtype) 

125 

126 if A_is_complex: 

127 Ar = _operand_as_real_ptpu_safe(A, result_dtype, target_shape, device) 

128 if isinstance(B, torch.Tensor): 

129 Br = B.unsqueeze(-1) 

130 out_real = mul_func(Ar, Br) 

131 else: 

132 out_real = mul_func_scalar(Ar, B) 

133 return _view_as_complex_ptpu_safe(out_real.contiguous()).to(result_dtype) 

134 

135 Br = _operand_as_real_ptpu_safe(B, result_dtype, target_shape, device) 

136 if isinstance(A, torch.Tensor): 

137 Ar = A.unsqueeze(-1) 

138 out_real = mul_func(Ar, Br) 

139 else: 

140 out_real = mul_func_scalar(Br, A) 

141 return _view_as_complex_ptpu_safe(out_real.contiguous()).to(result_dtype) 

142 

143 

144def mul(A, B): 

145 logger.debug("GEMS MUL") 

146 A_is_complex = (isinstance(A, torch.Tensor) and A.is_complex()) or isinstance( 

147 A, complex 

148 ) 

149 B_is_complex = (isinstance(B, torch.Tensor) and B.is_complex()) or isinstance( 

150 B, complex 

151 ) 

152 if A_is_complex or B_is_complex: 

153 return _complex_mul(A, B) 

154 elif isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

155 return mul_func(A, B) 

156 elif isinstance(A, torch.Tensor): 

157 return mul_func_scalar(A, B) 

158 elif isinstance(B, torch.Tensor): 

159 return mul_func_scalar(B, A) 

160 else: 

161 # Both scalar 

162 return torch.tensor(A * B) 

163 

164 

165def mul_(A, B): 

166 logger.debug("GEMS MUL_") 

167 if isinstance(B, torch.Tensor): 

168 return mul_func(A, B, out0=A) 

169 else: 

170 return mul_func_scalar(A, B, out0=A)