Coverage for src/flag_gems/runtime/backend/_sunrise/ops/add.py: 0%

68 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import torch 

4import triton 

5 

6from flag_gems.utils import pointwise_dynamic 

7from flag_gems.utils.codegen_config_utils import CodeGenConfig 

8 

9logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

10 

11 

12config_for_general = CodeGenConfig( 

13 1024, 

14 (65536, 65536, 65536), 

15 32, 

16 True, 

17 prefer_1d_tile=False, 

18 # num_warps=2 

19) 

20 

21 

22@pointwise_dynamic( 

23 is_tensor=[True, True, False], 

24 promotion_methods=[(0, 1, "DEFAULT")], 

25 config=config_for_general, 

26) 

27@triton.jit 

28def add_func(x, y, alpha): 

29 return x + y * alpha 

30 

31 

32config_for_broadcast = CodeGenConfig( 

33 128, 

34 (65536, 65536, 65536), 

35 32, 

36 True, 

37 prefer_1d_tile=True, 

38 # num_warps=4 

39) 

40 

41 

42@pointwise_dynamic( 

43 is_tensor=[True, True, False], 

44 promotion_methods=[(0, 1, "DEFAULT")], 

45 config=config_for_broadcast, 

46) 

47@triton.jit 

48def add_func_broadcast(x, y, alpha): 

49 return x + y * alpha 

50 

51 

52@pointwise_dynamic( 

53 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")] 

54) 

55@triton.jit 

56def add_func_tensor_scalar(x, y, alpha): 

57 return x + y * alpha 

58 

59 

60@pointwise_dynamic( 

61 is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")] 

62) 

63@triton.jit 

64def add_func_scalar_tensor(x, y, alpha): 

65 return x + y * alpha 

66 

67 

68def get_best_strided_output_tensor(A, B): 

69 def get_best_strides(A, B, broadcast_shape): 

70 if A.shape == broadcast_shape: 

71 return A.stride() 

72 elif B.shape == broadcast_shape: 

73 return B.stride() 

74 return None 

75 

76 broadcast_shape = torch.broadcast_shapes(A.shape, B.shape) 

77 dtype = torch.float32 

78 out = torch.empty(broadcast_shape, device=A.device, dtype=dtype) 

79 best_stride = get_best_strides(A, B, broadcast_shape) 

80 if best_stride is not None: 

81 out = out.as_strided(broadcast_shape, best_stride) 

82 return out 

83 

84 

85def is_power_of_two(n): 

86 return n > 0 and (n & (n - 1)) == 0 

87 

88 

89def should_use_broadcast_configs(A, B): 

90 # In scenarios where broadcasting is involved and the last two dimensions 

91 # of the two input tensors are the same, we use 1D tiling with a smaller 

92 # max_tile_size config for better performance. 

93 need_broadcast = A.shape != B.shape 

94 has_equal_last_dimentions = ( 

95 len(A.shape) >= 2 and len(B.shape) >= 2 and A.shape[-2:] == B.shape[-2:] 

96 ) 

97 return ( 

98 need_broadcast 

99 and has_equal_last_dimentions 

100 and not is_power_of_two(A.shape[-1]) 

101 and torch.result_type(A, B) in [torch.float16, torch.float32] 

102 ) 

103 

104 

105def add(A, B, *, alpha=1): 

106 logger.debug("GEMS ADD") 

107 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

108 if B.device != A.device: 

109 B = B.to(A.device) 

110 if should_use_broadcast_configs(A, B): 

111 out = get_best_strided_output_tensor(A, B) 

112 add_func_broadcast(A, B, alpha, out0=out) 

113 return out.to(torch.result_type(A, B)) 

114 else: 

115 return add_func(A, B, alpha) 

116 elif isinstance(A, torch.Tensor): 

117 return add_func_tensor_scalar(A, B, alpha) 

118 elif isinstance(B, torch.Tensor): 

119 return add_func_scalar_tensor(A, B, alpha) 

120 else: 

121 return torch.tensor(A + B * alpha) 

122 

123 

124def add_(A, B, *, alpha=1): 

125 logger.debug("GEMS ADD_") 

126 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

127 if B.device != A.device: 

128 B = B.to(A.device) 

129 return add_func(A, B, alpha, out0=A) 

130 elif isinstance(A, torch.Tensor): 

131 return add_func_tensor_scalar(A, B, alpha, out0=A) 

132 # elif isinstance(B, torch.Tensor): 

133 # return add_func_scalar_tensor(A, B, alpha, out0=A) 

134 else: 

135 raise ValueError("Unreachable.")