Coverage for src/flag_gems/ops/gcd.py: 42%
172 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
6import triton.language.extra.libdevice as libdevice
8logger = logging.getLogger(__name__)
9_I16_MIN_LUT_CACHE = {}
12@triton.jit
13def _ctz(x):
14 return libdevice.ffs(x) - 1
17@triton.jit
18def _abs_u32(x):
19 ux = x.to(tl.uint32)
20 return tl.where(x < 0, 0 - ux, ux)
23@triton.jit
24def _abs_u64(x):
25 ux = x.to(tl.uint64)
26 return tl.where(x < 0, 0 - ux, ux)
29@triton.jit
30def _c_rem_i32(a, b):
31 mag = _abs_u32(a) % _abs_u32(b)
32 rem = mag.to(tl.int32)
33 return tl.where((a < 0) & (mag != 0), -rem, rem)
36@triton.jit
37def _c_rem_i64(a, b):
38 mag = _abs_u64(a) % _abs_u64(b)
39 rem = mag.to(tl.int64)
40 return tl.where((a < 0) & (mag != 0), -rem, rem)
43@triton.jit
44def _binary_gcd(ax, ay, normal):
45 zero_ax = ax == 0
46 zero_ay = ay == 0
47 res = tl.where(zero_ax, ay, ax)
48 both_nonzero = normal & (~zero_ax) & (~zero_ay)
49 common = _ctz(tl.where(both_nonzero, ax | ay, 1))
50 u = tl.where(both_nonzero, ax >> _ctz(tl.where(both_nonzero, ax, 1)), ax)
51 v = ay
52 active = both_nonzero
54 while tl.sum(active.to(tl.int32), axis=0) > 0:
55 v_shifted = tl.where(active, v >> _ctz(tl.where(active, v, 1)), v)
56 swap = active & (u > v_shifted)
57 small = tl.where(swap, v_shifted, u)
58 large = tl.where(swap, u, v_shifted)
59 u = tl.where(active, small, u)
60 v = tl.where(active, large - small, v)
61 active = active & (v != 0)
63 return tl.where(both_nonzero, u << common, res)
66@triton.jit
67def gcd_kernel_i16(x_ptr, y_ptr, lut_ptr, out_ptr, n_elements, BLOCK: tl.constexpr):
68 pid = tl.program_id(0)
69 offsets = pid * BLOCK + tl.arange(0, BLOCK)
70 mask = offsets < n_elements
72 x = tl.load(x_ptr + offsets, mask=mask, other=0)
73 y = tl.load(y_ptr + offsets, mask=mask, other=0)
74 x_i32 = x.to(tl.int32)
75 y_i32 = y.to(tl.int32)
76 min_value: tl.constexpr = -32768
77 min_x = x_i32 == min_value
78 min_y = y_i32 == min_value
79 special_mask = mask & (min_x | min_y)
80 normal = mask & (~special_mask)
81 ax = tl.abs(x_i32)
82 ay = tl.abs(y_i32)
83 normal_res = _binary_gcd(ax, ay, normal)
85 both_min = special_mask & min_x & min_y
86 one_min = special_mask & (~both_min)
87 other_abs = tl.where(min_x, tl.abs(y_i32), tl.abs(x_i32))
88 special_res = tl.load(lut_ptr + other_abs, mask=one_min, other=0).to(tl.int32)
89 special_res = tl.where(both_min, min_value, special_res)
91 out = tl.where(special_mask, special_res, normal_res)
92 tl.store(out_ptr + offsets, out.to(out_ptr.type.element_ty), mask=mask)
95@triton.jit
96def gcd_kernel_i32(x_ptr, y_ptr, out_ptr, n_elements, BLOCK: tl.constexpr):
97 pid = tl.program_id(0)
98 offsets = pid * BLOCK + tl.arange(0, BLOCK)
99 mask = offsets < n_elements
101 x = tl.load(x_ptr + offsets, mask=mask, other=0)
102 y = tl.load(y_ptr + offsets, mask=mask, other=0)
103 min_value: tl.constexpr = -(1 << 31)
104 min_x = x == min_value
105 min_y = y == min_value
106 ax_native = tl.where(min_x, x, tl.abs(x))
107 ay_native = tl.where(min_y, y, tl.abs(y))
108 ax = ax_native.to(tl.int32)
109 ay = ay_native.to(tl.int32)
111 special_mask = mask & (min_x | min_y)
112 normal = mask & (~special_mask)
113 normal_res = _binary_gcd(ax, ay, normal)
115 sa = ax_native.to(tl.int32)
116 sb = ay_native.to(tl.int32)
117 special = special_mask & (sa != 0)
118 while tl.sum(special.to(tl.int32), axis=0) > 0:
119 next_sa = tl.where(special, _c_rem_i32(sb, tl.where(special, sa, 1)), sa)
120 sb = tl.where(special, sa, sb)
121 sa = next_sa
122 special = special & (sa != 0)
124 out = tl.where(mask & (~normal), sb, normal_res)
125 tl.store(out_ptr + offsets, out.to(out_ptr.type.element_ty), mask=mask)
128@triton.jit
129def gcd_kernel_i64(x_ptr, y_ptr, out_ptr, n_elements, BLOCK: tl.constexpr):
130 pid = tl.program_id(0)
131 offsets = pid * BLOCK + tl.arange(0, BLOCK)
132 mask = offsets < n_elements
134 x = tl.load(x_ptr + offsets, mask=mask, other=0)
135 y = tl.load(y_ptr + offsets, mask=mask, other=0)
136 min_x = x == -(1 << 63)
137 min_y = y == -(1 << 63)
138 ax = tl.where(min_x, x, tl.abs(x)).to(tl.int64)
139 ay = tl.where(min_y, y, tl.abs(y)).to(tl.int64)
141 special_mask = mask & (min_x | min_y)
142 normal = mask & (~special_mask)
143 normal_res = _binary_gcd(ax, ay, normal)
145 sa = ax
146 sb = ay
147 special = special_mask & (sa != 0)
148 while tl.sum(special.to(tl.int32), axis=0) > 0:
149 next_sa = tl.where(special, _c_rem_i64(sb, tl.where(special, sa, 1)), sa)
150 sb = tl.where(special, sa, sb)
151 sa = next_sa
152 special = special & (sa != 0)
154 out = tl.where(mask & (~normal), sb, normal_res)
155 tl.store(out_ptr + offsets, out.to(out_ptr.type.element_ty), mask=mask)
158def _kernel_meta(dtype):
159 if dtype == torch.int16:
160 return gcd_kernel_i16, 512, 4
161 if dtype == torch.int32:
162 return gcd_kernel_i32, 512, 4
163 if dtype == torch.int64:
164 return gcd_kernel_i64, 256, 4
165 raise TypeError(f"unsupported dtype for gcd: {dtype}")
168def _get_i16_min_lut(device):
169 key = (device.type, device.index)
170 lut = _I16_MIN_LUT_CACHE.get(key)
171 if lut is None:
172 info = torch.iinfo(torch.int16)
173 lhs = torch.full((info.max + 1,), info.min, dtype=torch.int16)
174 rhs = torch.arange(info.max + 1, dtype=torch.int16)
175 lut = torch.gcd(lhs, rhs).to(device=device)
176 _I16_MIN_LUT_CACHE[key] = lut
177 return lut
180def _materialize_inputs(self, other):
181 promoted_dtype = torch.promote_types(self.dtype, other.dtype)
182 lhs = self if self.dtype == promoted_dtype else self.to(promoted_dtype)
183 rhs = other if other.dtype == promoted_dtype else other.to(promoted_dtype)
184 lhs, rhs = torch.broadcast_tensors(lhs, rhs)
185 return lhs.contiguous(), rhs.contiguous(), promoted_dtype
188def _launch_gcd(lhs, rhs, out):
189 numel = out.numel()
190 if numel == 0:
191 return out
193 kernel, block, num_warps = _kernel_meta(out.dtype)
194 grid = (triton.cdiv(numel, block),)
195 if out.dtype == torch.int16:
196 lut = _get_i16_min_lut(out.device)
197 kernel[grid](lhs, rhs, lut, out, numel, BLOCK=block, num_warps=num_warps)
198 else:
199 kernel[grid](lhs, rhs, out, numel, BLOCK=block, num_warps=num_warps)
200 return out
203def gcd(self, other, *, out=None):
204 logger.debug("GEMS GCD")
205 lhs, rhs, promoted_dtype = _materialize_inputs(self, other)
206 result = torch.empty_like(lhs, dtype=promoted_dtype)
207 _launch_gcd(lhs.reshape(-1), rhs.reshape(-1), result.reshape(-1))
208 result = result.view(lhs.shape)
209 if out is None:
210 return result
212 out.copy_(result)
213 return out
216def gcd_out(self, other, *, out=None):
217 logger.debug("GEMS GCD_OUT")
218 if out is None:
219 return gcd(self, other)
220 return gcd(self, other, out=out)