Coverage for src/flag_gems/runtime/backend/_sunrise/ops/add.py: 0%

119 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5 

6from flag_gems.utils import pointwise_dynamic 

7from flag_gems.utils.codegen_config_utils import CodeGenConfig 

8from flag_gems.utils.pointwise_dynamic import ComplexMode 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13config_for_general = CodeGenConfig( 

14 1024, 

15 (65536, 65536, 65536), 

16 32, 

17 True, 

18 prefer_1d_tile=False, 

19 # num_warps=2 

20) 

21 

22 

23@pointwise_dynamic( 

24 is_tensor=[True, True, False], 

25 promotion_methods=[(0, 1, "DEFAULT")], 

26 config=config_for_general, 

27) 

28@triton.jit 

29def add_func(x, y, alpha): 

30 return x + y * alpha 

31 

32 

33config_for_broadcast = CodeGenConfig( 

34 128, 

35 (65536, 65536, 65536), 

36 32, 

37 True, 

38 prefer_1d_tile=True, 

39 # num_warps=4 

40) 

41 

42 

43@pointwise_dynamic( 

44 is_tensor=[True, True, False], 

45 promotion_methods=[(0, 1, "DEFAULT")], 

46 config=config_for_broadcast, 

47) 

48@triton.jit 

49def add_func_broadcast(x, y, alpha): 

50 return x + y * alpha 

51 

52 

53@pointwise_dynamic( 

54 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")] 

55) 

56@triton.jit 

57def add_func_tensor_scalar(x, y, alpha): 

58 return x + y * alpha 

59 

60 

61@pointwise_dynamic( 

62 is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")] 

63) 

64@triton.jit 

65def add_func_scalar_tensor(x, y, alpha): 

66 return x + y * alpha 

67 

68 

69def get_best_strided_output_tensor(A, B): 

70 def get_best_strides(A, B, broadcast_shape): 

71 if A.shape == broadcast_shape: 

72 return A.stride() 

73 elif B.shape == broadcast_shape: 

74 return B.stride() 

75 return None 

76 

77 broadcast_shape = torch.broadcast_shapes(A.shape, B.shape) 

78 dtype = torch.float32 

79 out = torch.empty(broadcast_shape, device=A.device, dtype=dtype) 

80 best_stride = get_best_strides(A, B, broadcast_shape) 

81 if best_stride is not None: 

82 out = out.as_strided(broadcast_shape, best_stride) 

83 return out 

84 

85 

86def is_power_of_two(n): 

87 return n > 0 and (n & (n - 1)) == 0 

88 

89 

90def should_use_broadcast_configs(A, B): 

91 # In scenarios where broadcasting is involved and the last two dimensions 

92 # of the two input tensors are the same, we use 1D tiling with a smaller 

93 # max_tile_size config for better performance. 

94 need_broadcast = A.shape != B.shape 

95 has_equal_last_dimentions = ( 

96 len(A.shape) >= 2 and len(B.shape) >= 2 and A.shape[-2:] == B.shape[-2:] 

97 ) 

98 return ( 

99 need_broadcast 

100 and has_equal_last_dimentions 

101 and not is_power_of_two(A.shape[-1]) 

102 and torch.result_type(A, B) in [torch.float16, torch.float32] 

103 ) 

104 

105 

106# Register complex support (elementwise) 

107add_func.register_complex(mode=ComplexMode.ELEMENTWISE) 

108add_func_tensor_scalar.register_complex( 

109 mode=ComplexMode.ELEMENTWISE, tensorize_scalars=True, fallback_target=add_func 

110) 

111add_func_scalar_tensor.register_complex( 

112 mode=ComplexMode.ELEMENTWISE, tensorize_scalars=True, fallback_target=add_func 

113) 

114 

115 

116def _view_as_real_ptpu_safe(x: torch.Tensor) -> torch.Tensor: 

117 """`torch.view_as_real(x)` with a CPU bounce when x is on PTPU. 

118 

119 [sunrise fix] PTPU lacks `aten::view_as_real`. The surrounding complex 

120 branch uses the result only as a read-only input to the triton `add_func` 

121 kernel (which IS PTPU-native), and the subsequent `.to(common_dtype)` would 

122 materialize a non-aliasing copy anyway — so it is safe to break alias 

123 semantics here. Per the FlagGems Sunrise skill, do not generically 

124 monkey-patch view_as_real (alias/view primitive). Compute stays on PTPU. 

125 """ 

126 try: 

127 return torch.view_as_real(x) 

128 except NotImplementedError: 

129 if x.device.type != "ptpu": 

130 raise 

131 return torch.view_as_real(x.cpu()).to(x.device) 

132 

133 

134def _view_as_complex_ptpu_safe(x: torch.Tensor) -> torch.Tensor: 

135 """`torch.view_as_complex(x)` with a CPU bounce when x is on PTPU. 

136 

137 See `_view_as_real_ptpu_safe` above. Used here to recompose the complex 

138 output after the PTPU-native real-domain `add_func(Ar, Br, alpha)` finishes. 

139 """ 

140 try: 

141 return torch.view_as_complex(x) 

142 except NotImplementedError: 

143 if x.device.type != "ptpu": 

144 raise 

145 return torch.view_as_complex(x.cpu()).to(x.device) 

146 

147 

148def _scalar_complex_as_real_ptpu_safe( 

149 scalar, complex_dtype: torch.dtype, target_shape, device: torch.device 

150) -> torch.Tensor: 

151 """Broadcast a python scalar to `view_as_real`-shaped tensor on `device`. 

152 

153 [sunrise fix] The natural code path is 

154 

155 torch.view_as_real( 

156 torch.tensor(scalar, dtype=complex_dtype, device=device).expand_as(ref) 

157 ) 

158 

159 On PTPU this dies at the `view_as_real` step (no kernel) and the obvious 

160 CPU fallback (`.cpu()`) also dies because PTPU's `direct_copy_kernel_ptpu` 

161 has no entry for `ComplexHalf` / `ComplexFloat`. So instead we build the 

162 complex scalar AND take its real view ENTIRELY on CPU, then only move the 

163 final real-dtype tensor onto PTPU (which the device's copy_ DOES support). 

164 """ 

165 cpu_scalar = torch.tensor(scalar, dtype=complex_dtype, device="cpu").expand( 

166 target_shape 

167 ) 

168 cpu_real = torch.view_as_real(cpu_scalar).contiguous() 

169 if device.type == "cpu": 

170 return cpu_real 

171 return cpu_real.to(device) 

172 

173 

174def add(A, B, *, alpha=1): 

175 logger.debug("GEMS ADD") 

176 A_is_complex = (isinstance(A, torch.Tensor) and A.is_complex()) or isinstance( 

177 A, complex 

178 ) 

179 B_is_complex = (isinstance(B, torch.Tensor) and B.is_complex()) or isinstance( 

180 B, complex 

181 ) 

182 if A_is_complex or B_is_complex: 

183 if A_is_complex and B_is_complex: 

184 Ar = _view_as_real_ptpu_safe(A) 

185 Br = _view_as_real_ptpu_safe(B) 

186 common_dtype = torch.promote_types(Ar.dtype, Br.dtype) 

187 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype) 

188 out_real = add_func(Ar, Br, alpha) 

189 return _view_as_complex_ptpu_safe(out_real).to(torch.result_type(A, B)) 

190 elif A_is_complex and not B_is_complex: 

191 Ar = _view_as_real_ptpu_safe(A) 

192 if isinstance(B, torch.Tensor): 

193 Br = _view_as_real_ptpu_safe(B.to(A.dtype)) 

194 else: 

195 Br = _scalar_complex_as_real_ptpu_safe(B, A.dtype, A.shape, A.device) 

196 common_dtype = torch.promote_types(Ar.dtype, Br.dtype) 

197 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype) 

198 out_real = add_func(Ar, Br, alpha) 

199 return _view_as_complex_ptpu_safe(out_real).to(torch.result_type(A, B)) 

200 else: 

201 Br = _view_as_real_ptpu_safe(B) 

202 if isinstance(A, torch.Tensor): 

203 Ar = _view_as_real_ptpu_safe(A.to(B.dtype)) 

204 else: 

205 Ar = _scalar_complex_as_real_ptpu_safe(A, B.dtype, B.shape, B.device) 

206 common_dtype = torch.promote_types(Ar.dtype, Br.dtype) 

207 Ar, Br = Ar.to(common_dtype), Br.to(common_dtype) 

208 out_real = add_func(Ar, Br, alpha) 

209 return _view_as_complex_ptpu_safe(out_real).to(torch.result_type(A, B)) 

210 elif isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

211 if B.device != A.device: 

212 B = B.to(A.device) 

213 if should_use_broadcast_configs(A, B): 

214 out = get_best_strided_output_tensor(A, B) 

215 add_func_broadcast(A, B, alpha, out0=out) 

216 return out.to(torch.result_type(A, B)) 

217 else: 

218 return add_func(A, B, alpha) 

219 elif isinstance(A, torch.Tensor): 

220 return add_func_tensor_scalar(A, B, alpha) 

221 elif isinstance(B, torch.Tensor): 

222 return add_func_scalar_tensor(A, B, alpha) 

223 else: 

224 return torch.tensor(A + B * alpha) 

225 

226 

227def add_(A, B, *, alpha=1): 

228 logger.debug("GEMS ADD_") 

229 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

230 if B.device != A.device: 

231 B = B.to(A.device) 

232 return add_func(A, B, alpha, out0=A) 

233 elif isinstance(A, torch.Tensor): 

234 return add_func_tensor_scalar(A, B, alpha, out0=A) 

235 # elif isinstance(B, torch.Tensor): 

236 # return add_func_scalar_tensor(A, B, alpha, out0=A) 

237 else: 

238 raise ValueError("Unreachable.")