Coverage for src/flag_gems/runtime/backend/_sunrise/ops/clamp.py: 0%
135 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import pointwise_dynamic
8from flag_gems.utils.pointwise_dynamic import CodeGenConfig
10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
12MAX_GRID_SIZES = (65535, 65535, 65535)
13config_f16 = CodeGenConfig(
14 max_tile_size=1024,
15 max_grid_size=MAX_GRID_SIZES,
16 max_num_warps_per_cta=32,
17 prefer_block_pointer=True,
18 prefer_1d_tile=True,
19)
22@pointwise_dynamic(promotion_methods=[(0, 1, 2, "DEFAULT")])
23@triton.jit
24def clamp_func_tensor(x, mini, maxi):
25 return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32)))
28@pointwise_dynamic(promotion_methods=[(0, 1, 2, "DEFAULT")], config=config_f16)
29@triton.jit
30def clamp_func_tensor_f16(x, mini, maxi):
31 return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32)))
34@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
35@triton.jit
36def clamp_func_min_tensor(x, mini):
37 return tl.maximum(mini, x.to(tl.float32))
40@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")], config=config_f16)
41@triton.jit
42def clamp_func_min_tensor_f16(x, mini):
43 return tl.maximum(mini, x.to(tl.float32))
46@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
47@triton.jit
48def clamp_func_max_tensor(x, maxi):
49 return tl.minimum(maxi, x.to(tl.float32))
52@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")], config=config_f16)
53@triton.jit
54def clamp_func_max_tensor_f16(x, maxi):
55 return tl.minimum(maxi, x.to(tl.float32))
58def clamp_tensor(A, mini=None, maxi=None):
59 logging.debug("GEMS CLAMP TENSOR")
60 if A.dtype == torch.half:
61 if mini is None and maxi is None:
62 raise ValueError("At least one of mini or maxi must not be None")
63 elif mini is None:
64 return clamp_func_max_tensor_f16(A, maxi)
65 elif maxi is None:
66 return clamp_func_min_tensor_f16(A, mini)
67 else:
68 return clamp_func_tensor_f16(A, mini, maxi)
69 else:
70 if mini is None and maxi is None:
71 raise ValueError("At least one of mini or maxi must not be None")
72 elif mini is None:
73 return clamp_func_max_tensor(A, maxi)
74 elif maxi is None:
75 return clamp_func_min_tensor(A, mini)
76 else:
77 return clamp_func_tensor(A, mini, maxi)
80def clamp_tensor_(A, mini=None, maxi=None):
81 logger.debug("GEMS CLAMP_ TENSOR")
82 if A.dtype == torch.half:
83 if mini is None and maxi is None:
84 raise ValueError("At least one of mini or maxi must not be None")
85 elif mini is None:
86 return clamp_func_max_tensor_f16(A, maxi, out0=A)
87 elif maxi is None:
88 return clamp_func_min_tensor_f16(A, mini, out0=A)
89 else:
90 return clamp_func_tensor_f16(A, mini, maxi, out0=A)
91 else:
92 if mini is None and maxi is None:
93 raise ValueError("At least one of mini or maxi must not be None")
94 elif mini is None:
95 return clamp_func_max_tensor(A, maxi, out0=A)
96 elif maxi is None:
97 return clamp_func_min_tensor(A, mini, out0=A)
98 else:
99 return clamp_func_tensor(A, mini, maxi, out0=A)
102@pointwise_dynamic(
103 is_tensor=[True, False, False], promotion_methods=[(0, 1, 2, "DEFAULT")]
104)
105@triton.jit
106def clamp_func(x, mini, maxi):
107 return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32)))
110@pointwise_dynamic(
111 is_tensor=[True, False, False],
112 promotion_methods=[(0, 1, 2, "DEFAULT")],
113 config=config_f16,
114)
115@triton.jit
116def clamp_func_f16(x, mini, maxi):
117 return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32)))
120@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
121@triton.jit
122def clamp_func_min(x, mini):
123 return tl.maximum(mini, x.to(tl.float32))
126@pointwise_dynamic(
127 is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")], config=config_f16
128)
129@triton.jit
130def clamp_func_min_f16(x, mini):
131 return tl.maximum(mini, x.to(tl.float32))
134@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
135@triton.jit
136def clamp_func_max(x, maxi):
137 return tl.minimum(maxi, x.to(tl.float32))
140@pointwise_dynamic(
141 is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")], config=config_f16
142)
143@triton.jit
144def clamp_func_max_f16(x, maxi):
145 return tl.minimum(maxi, x.to(tl.float32))
148def clamp_min(A, mini):
149 logger.debug("GEMS CLAMP MIN")
150 if mini is None:
151 raise ValueError("Mini must not be None")
152 return clamp_func_min(A, mini)
155def clamp_min_(A, mini):
156 logger.debug("GEMS CLAMP_ MIN")
157 if mini is None:
158 raise ValueError("Mini must not be None")
159 return clamp_func_min(A, mini, out0=A)
162def clamp(A, mini=None, maxi=None):
163 logger.debug("GEMS CLAMP")
164 if A.dtype == torch.half:
165 if mini is None and maxi is None:
166 raise ValueError("At least one of mini or maxi must not be None")
167 elif mini is None:
168 return clamp_func_max_f16(A, maxi)
169 elif maxi is None:
170 return clamp_func_min_f16(A, mini)
171 else:
172 return clamp_func_f16(A, mini, maxi)
173 else:
174 if mini is None and maxi is None:
175 raise ValueError("At least one of mini or maxi must not be None")
176 elif mini is None:
177 return clamp_func_max(A, maxi)
178 elif maxi is None:
179 return clamp_func_min(A, mini)
180 else:
181 return clamp_func(A, mini, maxi)
184def clamp_(A, mini=None, maxi=None):
185 logger.debug("GEMS CLAMP")
186 if A.dtype == torch.half:
187 if mini is None and maxi is None:
188 raise ValueError("At least one of mini or maxi must not be None")
189 elif mini is None:
190 return clamp_func_max_f16(A, maxi, out0=A)
191 elif maxi is None:
192 return clamp_func_min_f16(A, mini, out0=A)
193 else:
194 return clamp_func_f16(A, mini, maxi, out0=A)
195 else:
196 if mini is None and maxi is None:
197 raise ValueError("At least one of mini or maxi must not be None")
198 elif mini is None:
199 return clamp_func_max(A, maxi, out0=A)
200 elif maxi is None:
201 return clamp_func_min(A, mini, out0=A)
202 else:
203 return clamp_func(A, mini, maxi, out0=A)