Coverage for src/flag_gems/ops/fmod.py: 71%
34 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
3import triton
4import triton.language as tl
6from flag_gems.utils import pointwise_dynamic
7from flag_gems.utils.triton_lang_extension import fmod as _fmod
9logger = logging.getLogger(__name__)
12@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")])
13@triton.jit
14def fmod_func(x, y):
15 # Convert to float32 for computation to avoid libdevice float16/bfloat16 issues
16 dtype = x.dtype
17 x_fp32 = x.to(tl.float32)
18 y_fp32 = y.to(tl.float32)
19 result = _fmod(x_fp32, y_fp32)
20 return result.to(dtype)
23@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")])
24@triton.jit
25def fmod_func_tensor_scalar(x, y):
26 # Convert to float32 for computation to avoid libdevice float16/bfloat16 issues
27 dtype = x.dtype
28 x_fp32 = x.to(tl.float32)
29 y_fp32 = y.to(tl.float32)
30 result = _fmod(x_fp32, y_fp32)
31 return result.to(dtype)
34def fmod_tensor(A, B):
35 logger.debug("GEMS FMOD_TENSOR")
36 return fmod_func(A, B)
39def fmod_scalar(A, B):
40 logger.debug("GEMS FMOD_SCALAR")
41 return fmod_func_tensor_scalar(A, B)
44def fmod_tensor_(A, B):
45 logger.debug("GEMS FMOD_TENSOR_")
46 return fmod_func(A, B, out0=A)
49def fmod_scalar_(A, B):
50 logger.debug("GEMS FMOD_SCALAR_")
51 return fmod_func_tensor_scalar(A, B, out0=A)