Coverage for src/flag_gems/ops/fmod.py: 71%

34 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-27 08:02 +0800

1import logging 

2 

3import triton 

4import triton.language as tl 

5 

6from flag_gems.utils import pointwise_dynamic 

7from flag_gems.utils.triton_lang_extension import fmod as _fmod 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")]) 

13@triton.jit 

14def fmod_func(x, y): 

15 # Convert to float32 for computation to avoid libdevice float16/bfloat16 issues 

16 dtype = x.dtype 

17 x_fp32 = x.to(tl.float32) 

18 y_fp32 = y.to(tl.float32) 

19 result = _fmod(x_fp32, y_fp32) 

20 return result.to(dtype) 

21 

22 

23@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

24@triton.jit 

25def fmod_func_tensor_scalar(x, y): 

26 # Convert to float32 for computation to avoid libdevice float16/bfloat16 issues 

27 dtype = x.dtype 

28 x_fp32 = x.to(tl.float32) 

29 y_fp32 = y.to(tl.float32) 

30 result = _fmod(x_fp32, y_fp32) 

31 return result.to(dtype) 

32 

33 

34def fmod_tensor(A, B): 

35 logger.debug("GEMS FMOD_TENSOR") 

36 return fmod_func(A, B) 

37 

38 

39def fmod_scalar(A, B): 

40 logger.debug("GEMS FMOD_SCALAR") 

41 return fmod_func_tensor_scalar(A, B) 

42 

43 

44def fmod_tensor_(A, B): 

45 logger.debug("GEMS FMOD_TENSOR_") 

46 return fmod_func(A, B, out0=A) 

47 

48 

49def fmod_scalar_(A, B): 

50 logger.debug("GEMS FMOD_SCALAR_") 

51 return fmod_func_tensor_scalar(A, B, out0=A)