Coverage for src/flag_gems/runtime/backend/_sunrise/ops/add.py: 0%
68 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
6from flag_gems.utils import pointwise_dynamic
7from flag_gems.utils.codegen_config_utils import CodeGenConfig
9logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
12config_for_general = CodeGenConfig(
13 1024,
14 (65536, 65536, 65536),
15 32,
16 True,
17 prefer_1d_tile=False,
18 # num_warps=2
19)
22@pointwise_dynamic(
23 is_tensor=[True, True, False],
24 promotion_methods=[(0, 1, "DEFAULT")],
25 config=config_for_general,
26)
27@triton.jit
28def add_func(x, y, alpha):
29 return x + y * alpha
32config_for_broadcast = CodeGenConfig(
33 128,
34 (65536, 65536, 65536),
35 32,
36 True,
37 prefer_1d_tile=True,
38 # num_warps=4
39)
42@pointwise_dynamic(
43 is_tensor=[True, True, False],
44 promotion_methods=[(0, 1, "DEFAULT")],
45 config=config_for_broadcast,
46)
47@triton.jit
48def add_func_broadcast(x, y, alpha):
49 return x + y * alpha
52@pointwise_dynamic(
53 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")]
54)
55@triton.jit
56def add_func_tensor_scalar(x, y, alpha):
57 return x + y * alpha
60@pointwise_dynamic(
61 is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")]
62)
63@triton.jit
64def add_func_scalar_tensor(x, y, alpha):
65 return x + y * alpha
68def get_best_strided_output_tensor(A, B):
69 def get_best_strides(A, B, broadcast_shape):
70 if A.shape == broadcast_shape:
71 return A.stride()
72 elif B.shape == broadcast_shape:
73 return B.stride()
74 return None
76 broadcast_shape = torch.broadcast_shapes(A.shape, B.shape)
77 dtype = torch.float32
78 out = torch.empty(broadcast_shape, device=A.device, dtype=dtype)
79 best_stride = get_best_strides(A, B, broadcast_shape)
80 if best_stride is not None:
81 out = out.as_strided(broadcast_shape, best_stride)
82 return out
85def is_power_of_two(n):
86 return n > 0 and (n & (n - 1)) == 0
89def should_use_broadcast_configs(A, B):
90 # In scenarios where broadcasting is involved and the last two dimensions
91 # of the two input tensors are the same, we use 1D tiling with a smaller
92 # max_tile_size config for better performance.
93 need_broadcast = A.shape != B.shape
94 has_equal_last_dimentions = (
95 len(A.shape) >= 2 and len(B.shape) >= 2 and A.shape[-2:] == B.shape[-2:]
96 )
97 return (
98 need_broadcast
99 and has_equal_last_dimentions
100 and not is_power_of_two(A.shape[-1])
101 and torch.result_type(A, B) in [torch.float16, torch.float32]
102 )
105def add(A, B, *, alpha=1):
106 logger.debug("GEMS ADD")
107 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
108 if B.device != A.device:
109 B = B.to(A.device)
110 if should_use_broadcast_configs(A, B):
111 out = get_best_strided_output_tensor(A, B)
112 add_func_broadcast(A, B, alpha, out0=out)
113 return out.to(torch.result_type(A, B))
114 else:
115 return add_func(A, B, alpha)
116 elif isinstance(A, torch.Tensor):
117 return add_func_tensor_scalar(A, B, alpha)
118 elif isinstance(B, torch.Tensor):
119 return add_func_scalar_tensor(A, B, alpha)
120 else:
121 return torch.tensor(A + B * alpha)
124def add_(A, B, *, alpha=1):
125 logger.debug("GEMS ADD_")
126 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor):
127 if B.device != A.device:
128 B = B.to(A.device)
129 return add_func(A, B, alpha, out0=A)
130 elif isinstance(A, torch.Tensor):
131 return add_func_tensor_scalar(A, B, alpha, out0=A)
132 # elif isinstance(B, torch.Tensor):
133 # return add_func_scalar_tensor(A, B, alpha, out0=A)
134 else:
135 raise ValueError("Unreachable.")