Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/round.py: 0%

58 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6from triton.language.extra.xpu.libdevice import rint as _rint 

7 

8from ..utils.pointwise_dynamic import pointwise_dynamic 

9 

10logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

11 

12 

13# rint(fp32) implements round-half-to-even, matching torch.round semantics. 

14# XPU libdevice rint only supports fp32, so always cast to fp32 for computation. 

15# The scale trick handles non-zero decimals: round(x, d) = rint(x * 10^d) / 10^d. 

16@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, "DEFAULT")]) 

17@triton.jit 

18def round_func(x, scale): 

19 x_fp32 = x.to(tl.float32) 

20 return _rint(x_fp32 * scale) / scale 

21 

22 

23def _scale(decimals): 

24 return 10.0**decimals 

25 

26 

27def round(input, decimals=0): 

28 logger.debug("GEMS_KUNLUNXIN ROUND") 

29 if not isinstance(input, torch.Tensor): 

30 raise TypeError("round expects a torch.Tensor.") 

31 if input.is_complex(): 

32 raise TypeError("round is not supported for complex tensors.") 

33 if input.dtype in [torch.int32, torch.int64, torch.int16, torch.int8]: 

34 return input.clone() 

35 if input.numel() == 0: 

36 return torch.empty_like(input) 

37 if not input.is_contiguous(): 

38 input = input.contiguous() 

39 return round_func(input, _scale(decimals)) 

40 

41 

42def round_out(input, *, decimals=0, out=None): 

43 logger.debug("GEMS_KUNLUNXIN ROUND_OUT") 

44 if out is None: 

45 return round(input, decimals=decimals) 

46 if not isinstance(input, torch.Tensor): 

47 raise TypeError("round expects a torch.Tensor.") 

48 if input.is_complex(): 

49 raise TypeError("round is not supported for complex tensors.") 

50 if input.dtype in [torch.int32, torch.int64, torch.int16, torch.int8]: 

51 out.copy_(input) 

52 return out 

53 if input.numel() == 0: 

54 return out 

55 if not input.is_contiguous(): 

56 input = input.contiguous() 

57 round_func(input, _scale(decimals), out0=out) 

58 return out 

59 

60 

61def round_(input, *, decimals=0): 

62 logger.debug("GEMS_KUNLUNXIN ROUND_") 

63 if not isinstance(input, torch.Tensor): 

64 raise TypeError("round expects a torch.Tensor.") 

65 if input.is_complex(): 

66 raise TypeError("round is not supported for complex tensors.") 

67 if input.dtype in [torch.int32, torch.int64, torch.int16, torch.int8]: 

68 return input 

69 if input.numel() == 0: 

70 return input 

71 if not input.is_contiguous(): 

72 raise ValueError( 

73 "round Triton kernel currently supports only contiguous tensors." 

74 ) 

75 round_func(input, _scale(decimals), out0=input) 

76 return input