Coverage for src/flag_gems/runtime/backend/_sunrise/ops/clamp.py: 0%
152 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import pointwise_dynamic
8from flag_gems.utils.pointwise_dynamic import CodeGenConfig
10logger = logging.getLogger(__name__)
12MAX_GRID_SIZES = (65535, 65535, 65535)
13config_f16 = CodeGenConfig(
14 max_tile_size=1024,
15 max_grid_size=MAX_GRID_SIZES,
16 max_num_warps_per_cta=32,
17 prefer_block_pointer=True,
18 prefer_1d_tile=True,
19)
22@pointwise_dynamic(promotion_methods=[(0, 1, 2, "DEFAULT")])
23@triton.jit
24def clamp_func_tensor(x, mini, maxi):
25 return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32)))
28@pointwise_dynamic(promotion_methods=[(0, 1, 2, "DEFAULT")], config=config_f16)
29@triton.jit
30def clamp_func_tensor_f16(x, mini, maxi):
31 return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32)))
34@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
35@triton.jit
36def clamp_func_min_tensor(x, mini):
37 return tl.maximum(mini, x.to(tl.float32))
40@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")], config=config_f16)
41@triton.jit
42def clamp_func_min_tensor_f16(x, mini):
43 return tl.maximum(mini, x.to(tl.float32))
46@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
47@triton.jit
48def clamp_func_max_tensor(x, maxi):
49 return tl.minimum(maxi, x.to(tl.float32))
52@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")], config=config_f16)
53@triton.jit
54def clamp_func_max_tensor_f16(x, maxi):
55 return tl.minimum(maxi, x.to(tl.float32))
58def clamp_tensor(A, mini=None, maxi=None):
59 logging.debug("GEMS CLAMP TENSOR")
60 if A.dtype == torch.half:
61 if mini is None and maxi is None:
62 raise ValueError("At least one of mini or maxi must not be None")
63 elif mini is None:
64 return clamp_func_max_tensor_f16(A, maxi)
65 elif maxi is None:
66 return clamp_func_min_tensor_f16(A, mini)
67 else:
68 return clamp_func_tensor_f16(A, mini, maxi)
69 else:
70 if mini is None and maxi is None:
71 raise ValueError("At least one of mini or maxi must not be None")
72 elif mini is None:
73 return clamp_func_max_tensor(A, maxi)
74 elif maxi is None:
75 return clamp_func_min_tensor(A, mini)
76 else:
77 return clamp_func_tensor(A, mini, maxi)
80def clamp_tensor_(A, mini=None, maxi=None):
81 logger.debug("GEMS CLAMP_ TENSOR")
82 if A.dtype == torch.half:
83 if mini is None and maxi is None:
84 raise ValueError("At least one of mini or maxi must not be None")
85 elif mini is None:
86 return clamp_func_max_tensor_f16(A, maxi, out0=A)
87 elif maxi is None:
88 return clamp_func_min_tensor_f16(A, mini, out0=A)
89 else:
90 return clamp_func_tensor_f16(A, mini, maxi, out0=A)
91 else:
92 if mini is None and maxi is None:
93 raise ValueError("At least one of mini or maxi must not be None")
94 elif mini is None:
95 return clamp_func_max_tensor(A, maxi, out0=A)
96 elif maxi is None:
97 return clamp_func_min_tensor(A, mini, out0=A)
98 else:
99 return clamp_func_tensor(A, mini, maxi, out0=A)
102@pointwise_dynamic(
103 is_tensor=[True, False, False], promotion_methods=[(0, 1, 2, "DEFAULT")]
104)
105@triton.jit
106def clamp_func(x, mini, maxi):
107 return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32)))
110@pointwise_dynamic(
111 is_tensor=[True, False, False],
112 promotion_methods=[(0, 1, 2, "DEFAULT")],
113 config=config_f16,
114)
115@triton.jit
116def clamp_func_f16(x, mini, maxi):
117 return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32)))
120@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
121@triton.jit
122def clamp_func_min(x, mini):
123 return tl.maximum(mini, x.to(tl.float32))
126@pointwise_dynamic(
127 is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")], config=config_f16
128)
129@triton.jit
130def clamp_func_min_f16(x, mini):
131 return tl.maximum(mini, x.to(tl.float32))
134@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
135@triton.jit
136def clamp_func_max(x, maxi):
137 return tl.minimum(maxi, x.to(tl.float32))
140@pointwise_dynamic(
141 is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")], config=config_f16
142)
143@triton.jit
144def clamp_func_max_f16(x, maxi):
145 return tl.minimum(maxi, x.to(tl.float32))
148def clamp_min(A, mini):
149 logger.debug("GEMS CLAMP MIN")
150 if mini is None:
151 raise ValueError("Mini must not be None")
152 if isinstance(mini, torch.Tensor):
153 if A.dtype == torch.half:
154 return clamp_func_min_tensor_f16(A, mini)
155 return clamp_func_min_tensor(A, mini)
156 return clamp_func_min(A, mini)
159def clamp_min_(A, mini):
160 logger.debug("GEMS CLAMP_ MIN")
161 if mini is None:
162 raise ValueError("Mini must not be None")
163 if isinstance(mini, torch.Tensor):
164 if A.dtype == torch.half:
165 return clamp_func_min_tensor_f16(A, mini, out0=A)
166 return clamp_func_min_tensor(A, mini, out0=A)
167 return clamp_func_min(A, mini, out0=A)
170def clamp_min_out(A, mini, *, out=None):
171 logger.debug("GEMS CLAMP MIN OUT")
172 if mini is None:
173 raise ValueError("Mini must not be None")
174 if isinstance(mini, torch.Tensor):
175 if A.dtype == torch.half:
176 return clamp_func_min_tensor_f16(A, mini, out0=out)
177 return clamp_func_min_tensor(A, mini, out0=out)
178 return clamp_func_min(A, mini, out0=out)
181def clamp(A, mini=None, maxi=None):
182 logger.debug("GEMS CLAMP")
183 if A.dtype == torch.half:
184 if mini is None and maxi is None:
185 raise ValueError("At least one of mini or maxi must not be None")
186 elif mini is None:
187 return clamp_func_max_f16(A, maxi)
188 elif maxi is None:
189 return clamp_func_min_f16(A, mini)
190 else:
191 return clamp_func_f16(A, mini, maxi)
192 else:
193 if mini is None and maxi is None:
194 raise ValueError("At least one of mini or maxi must not be None")
195 elif mini is None:
196 return clamp_func_max(A, maxi)
197 elif maxi is None:
198 return clamp_func_min(A, mini)
199 else:
200 return clamp_func(A, mini, maxi)
203def clamp_(A, mini=None, maxi=None):
204 logger.debug("GEMS CLAMP")
205 if A.dtype == torch.half:
206 if mini is None and maxi is None:
207 raise ValueError("At least one of mini or maxi must not be None")
208 elif mini is None:
209 return clamp_func_max_f16(A, maxi, out0=A)
210 elif maxi is None:
211 return clamp_func_min_f16(A, mini, out0=A)
212 else:
213 return clamp_func_f16(A, mini, maxi, out0=A)
214 else:
215 if mini is None and maxi is None:
216 raise ValueError("At least one of mini or maxi must not be None")
217 elif mini is None:
218 return clamp_func_max(A, maxi, out0=A)
219 elif maxi is None:
220 return clamp_func_min(A, mini, out0=A)
221 else:
222 return clamp_func(A, mini, maxi, out0=A)