Coverage for src/flag_gems/ops/clip.py: 86%
35 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import logging
3import triton
4import triton.language as tl
6from flag_gems.utils import pointwise_dynamic
8logger = logging.getLogger(__name__)
11@pointwise_dynamic(
12 is_tensor=[True, False, False], promotion_methods=[(0, 1, 2, "DEFAULT")]
13)
14@triton.jit
15def clip_func(x, mini, maxi):
16 return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32)))
19@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
20@triton.jit
21def clip_func_min(x, mini):
22 return tl.maximum(mini, x.to(tl.float32))
25@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
26@triton.jit
27def clip_func_max(x, maxi):
28 return tl.minimum(maxi, x.to(tl.float32))
31def clip(A, mini=None, maxi=None):
32 logger.debug("GEMS CLIP")
33 if mini is None and maxi is None:
34 raise ValueError("At least one of mini or maxi must not be None")
35 elif mini is None:
36 return clip_func_max(A, maxi)
37 elif maxi is None:
38 return clip_func_min(A, mini)
39 else:
40 return clip_func(A, mini, maxi)
43def clip_(A, mini=None, maxi=None):
44 logger.debug("GEMS CLIP_")
45 if mini is None and maxi is None:
46 raise ValueError("At least one of mini or maxi must not be None")
47 elif mini is None:
48 return clip_func_max(A, maxi, out0=A)
49 elif maxi is None:
50 return clip_func_min(A, mini, out0=A)
51 else:
52 return clip_func(A, mini, maxi, out0=A)