Coverage for src/flag_gems/ops/clip.py: 86%

35 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5 

6from flag_gems.utils import pointwise_dynamic 

7 

8logger = logging.getLogger(__name__) 

9 

10 

11@pointwise_dynamic( 

12 is_tensor=[True, False, False], promotion_methods=[(0, 1, 2, "DEFAULT")] 

13) 

14@triton.jit 

15def clip_func(x, mini, maxi): 

16 return tl.minimum(maxi, tl.maximum(mini, x.to(tl.float32))) 

17 

18 

19@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

20@triton.jit 

21def clip_func_min(x, mini): 

22 return tl.maximum(mini, x.to(tl.float32)) 

23 

24 

25@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

26@triton.jit 

27def clip_func_max(x, maxi): 

28 return tl.minimum(maxi, x.to(tl.float32)) 

29 

30 

31def clip(A, mini=None, maxi=None): 

32 logger.debug("GEMS CLIP") 

33 if mini is None and maxi is None: 

34 raise ValueError("At least one of mini or maxi must not be None") 

35 elif mini is None: 

36 return clip_func_max(A, maxi) 

37 elif maxi is None: 

38 return clip_func_min(A, mini) 

39 else: 

40 return clip_func(A, mini, maxi) 

41 

42 

43def clip_(A, mini=None, maxi=None): 

44 logger.debug("GEMS CLIP_") 

45 if mini is None and maxi is None: 

46 raise ValueError("At least one of mini or maxi must not be None") 

47 elif mini is None: 

48 return clip_func_max(A, maxi, out0=A) 

49 elif maxi is None: 

50 return clip_func_min(A, mini, out0=A) 

51 else: 

52 return clip_func(A, mini, maxi, out0=A)