Coverage for src/flag_gems/ops/clamp_max.py: 93%

15 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1# Generated by KernelGen: https://github.com/flagos-ai/KernelGen 

2import logging 

3 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import pointwise_dynamic 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

13@triton.jit 

14def clamp_max_func(x, maxi): 

15 return tl.minimum(maxi, x.to(tl.float32)) 

16 

17 

18def clamp_max(A, max_value): 

19 logger.debug("GEMS CLAMP_MAX") 

20 return clamp_max_func(A, max_value) 

21 

22 

23def clamp_max_(A, max_value): 

24 logger.debug("GEMS CLAMP_MAX_") 

25 return clamp_max_func(A, max_value, out0=A)