Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/round.py: 0%
58 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6from triton.language.extra.xpu.libdevice import rint as _rint
8from ..utils.pointwise_dynamic import pointwise_dynamic
10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13# rint(fp32) implements round-half-to-even, matching torch.round semantics.
14# XPU libdevice rint only supports fp32, so always cast to fp32 for computation.
15# The scale trick handles non-zero decimals: round(x, d) = rint(x * 10^d) / 10^d.
16@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, "DEFAULT")])
17@triton.jit
18def round_func(x, scale):
19 x_fp32 = x.to(tl.float32)
20 return _rint(x_fp32 * scale) / scale
23def _scale(decimals):
24 return 10.0**decimals
27def round(input, decimals=0):
28 logger.debug("GEMS_KUNLUNXIN ROUND")
29 if not isinstance(input, torch.Tensor):
30 raise TypeError("round expects a torch.Tensor.")
31 if input.is_complex():
32 raise TypeError("round is not supported for complex tensors.")
33 if input.dtype in [torch.int32, torch.int64, torch.int16, torch.int8]:
34 return input.clone()
35 if input.numel() == 0:
36 return torch.empty_like(input)
37 if not input.is_contiguous():
38 input = input.contiguous()
39 return round_func(input, _scale(decimals))
42def round_out(input, *, decimals=0, out=None):
43 logger.debug("GEMS_KUNLUNXIN ROUND_OUT")
44 if out is None:
45 return round(input, decimals=decimals)
46 if not isinstance(input, torch.Tensor):
47 raise TypeError("round expects a torch.Tensor.")
48 if input.is_complex():
49 raise TypeError("round is not supported for complex tensors.")
50 if input.dtype in [torch.int32, torch.int64, torch.int16, torch.int8]:
51 out.copy_(input)
52 return out
53 if input.numel() == 0:
54 return out
55 if not input.is_contiguous():
56 input = input.contiguous()
57 round_func(input, _scale(decimals), out0=out)
58 return out
61def round_(input, *, decimals=0):
62 logger.debug("GEMS_KUNLUNXIN ROUND_")
63 if not isinstance(input, torch.Tensor):
64 raise TypeError("round expects a torch.Tensor.")
65 if input.is_complex():
66 raise TypeError("round is not supported for complex tensors.")
67 if input.dtype in [torch.int32, torch.int64, torch.int16, torch.int8]:
68 return input
69 if input.numel() == 0:
70 return input
71 if not input.is_contiguous():
72 raise ValueError(
73 "round Triton kernel currently supports only contiguous tensors."
74 )
75 round_func(input, _scale(decimals), out0=input)
76 return input