Coverage for src/flag_gems/ops/round.py: 57%
115 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
9logger = logging.getLogger(__name__)
12@triton.jit
13def round_half_to_even_impl(x):
14 """Round to nearest with ties to even (round half to even).
15 x must be fp32."""
16 r = tl.floor(x)
17 d = x - r # fractional part, in [0, 1) for positive, in (-1, 0] for negative
19 # is_odd = (r % 2 == 1), i.e., r is odd
20 # In Triton: r - 2 * floor(r/2) for odd r in [-2.5, 2.5] range is close to 1
21 is_odd = tl.abs(r - 2.0 * tl.floor(r / 2.0)) > 0.5
23 # For d > 0.5: always round up
24 # For d == 0.5 and r is odd: round up (to make result even)
25 # For d == 0.5 and r is even: stay at r (already even)
26 # For d < 0.5: stay at r
27 return tl.where((d > 0.5) | ((tl.abs(d - 0.5) < 1e-10) & is_odd), r + 1.0, r)
30@triton.jit
31def round_kernel(
32 x_ptr,
33 out_ptr,
34 n_elements,
35 decimals: tl.constexpr,
36 BLOCK_SIZE: tl.constexpr,
37 IS_FP32: tl.constexpr,
38 IS_FP16: tl.constexpr,
39 IS_BF16: tl.constexpr,
40):
41 pid = tl.program_id(axis=0)
42 block_start = pid * BLOCK_SIZE
43 offsets = block_start + tl.arange(0, BLOCK_SIZE)
44 mask = offsets < n_elements
46 x = tl.load(x_ptr + offsets, mask=mask)
48 # Apply round with "round half to even" rule
49 if decimals == 0:
50 out = x
51 if IS_FP32:
52 out = round_half_to_even_impl(x)
53 elif IS_FP16:
54 x_fp32 = tl.cast(x, tl.float32)
55 out = tl.cast(round_half_to_even_impl(x_fp32), tl.float16)
56 elif IS_BF16:
57 x_fp32 = tl.cast(x, tl.float32)
58 out = tl.cast(round_half_to_even_impl(x_fp32), tl.bfloat16)
59 else:
60 # For non-zero decimals, use scaling approach
61 scale = 10.0**decimals
62 if IS_FP32:
63 x_scaled = x * scale
64 out = round_half_to_even_impl(x_scaled) / scale
65 elif IS_FP16:
66 x_fp32 = tl.cast(x, tl.float32)
67 x_scaled = x_fp32 * scale
68 out = tl.cast(round_half_to_even_impl(x_scaled) / scale, tl.float16)
69 elif IS_BF16:
70 x_fp32 = tl.cast(x, tl.float32)
71 x_scaled = x_fp32 * scale
72 out = tl.cast(round_half_to_even_impl(x_scaled) / scale, tl.bfloat16)
73 else:
74 out = x
76 tl.store(out_ptr + offsets, out, mask=mask)
79def round_func(input, decimals=0):
80 if not isinstance(input, torch.Tensor):
81 raise TypeError("round expects a torch.Tensor.")
83 if input.is_complex():
84 raise TypeError("round is not supported for complex tensors.")
86 # For integer types, return a copy (array-api convention)
87 if input.dtype in [torch.int32, torch.int64, torch.int16, torch.int8]:
88 return input.clone()
90 if not input.is_contiguous():
91 raise ValueError(
92 "round Triton kernel currently supports only contiguous tensors."
93 )
95 n_elements = input.numel()
96 if n_elements == 0:
97 return input
99 dtype = input.dtype
100 IS_FP32 = dtype == torch.float32
101 IS_FP16 = dtype == torch.float16
102 IS_BF16 = dtype == torch.bfloat16
104 output = torch.empty_like(input)
106 BLOCK_SIZE = 1024
107 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
109 with torch_device_fn.device(input.device):
110 round_kernel[grid](
111 input,
112 output,
113 n_elements,
114 decimals,
115 BLOCK_SIZE=BLOCK_SIZE,
116 IS_FP32=IS_FP32,
117 IS_FP16=IS_FP16,
118 IS_BF16=IS_BF16,
119 )
120 return output
123def round(input, decimals=0):
124 logger.debug("GEMS ROUND")
125 return round_func(input, decimals=decimals)
128def round_out(input, *, decimals=0, out=None):
129 logger.debug("GEMS ROUND_OUT")
130 if out is None:
131 return round_func(input, decimals=decimals)
133 if not isinstance(input, torch.Tensor):
134 raise TypeError("round expects a torch.Tensor.")
136 if input.is_complex():
137 raise TypeError("round is not supported for complex tensors.")
139 # For integer types, return a copy
140 if input.dtype in [torch.int32, torch.int64, torch.int16, torch.int8]:
141 out.copy_(input)
142 return out
144 if not input.is_contiguous():
145 raise ValueError(
146 "round Triton kernel currently supports only contiguous tensors."
147 )
149 n_elements = input.numel()
150 if n_elements == 0:
151 return out
153 dtype = input.dtype
154 IS_FP32 = dtype == torch.float32
155 IS_FP16 = dtype == torch.float16
156 IS_BF16 = dtype == torch.bfloat16
158 BLOCK_SIZE = 1024
159 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
161 with torch_device_fn.device(input.device):
162 round_kernel[grid](
163 input,
164 out,
165 n_elements,
166 decimals,
167 BLOCK_SIZE=BLOCK_SIZE,
168 IS_FP32=IS_FP32,
169 IS_FP16=IS_FP16,
170 IS_BF16=IS_BF16,
171 )
172 return out
175def round_(input, *, decimals=0):
176 logger.debug("GEMS ROUND_")
177 if not isinstance(input, torch.Tensor):
178 raise TypeError("round expects a torch.Tensor.")
180 if input.is_complex():
181 raise TypeError("round is not supported for complex tensors.")
183 # For integer types, return input unchanged (array-api convention for integer round)
184 if input.dtype in [torch.int32, torch.int64, torch.int16, torch.int8]:
185 return input
187 if not input.is_contiguous():
188 raise ValueError(
189 "round Triton kernel currently supports only contiguous tensors."
190 )
192 n_elements = input.numel()
193 if n_elements == 0:
194 return input
196 dtype = input.dtype
197 IS_FP32 = dtype == torch.float32
198 IS_FP16 = dtype == torch.float16
199 IS_BF16 = dtype == torch.bfloat16
201 BLOCK_SIZE = 1024
202 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
204 with torch_device_fn.device(input.device):
205 round_kernel[grid](
206 input,
207 input,
208 n_elements,
209 decimals,
210 BLOCK_SIZE=BLOCK_SIZE,
211 IS_FP32=IS_FP32,
212 IS_FP16=IS_FP16,
213 IS_BF16=IS_BF16,
214 )
215 return input