Coverage for src/flag_gems/ops/gcd.py: 42%

172 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6import triton.language.extra.libdevice as libdevice 

7 

8logger = logging.getLogger(__name__) 

9_I16_MIN_LUT_CACHE = {} 

10 

11 

12@triton.jit 

13def _ctz(x): 

14 return libdevice.ffs(x) - 1 

15 

16 

17@triton.jit 

18def _abs_u32(x): 

19 ux = x.to(tl.uint32) 

20 return tl.where(x < 0, 0 - ux, ux) 

21 

22 

23@triton.jit 

24def _abs_u64(x): 

25 ux = x.to(tl.uint64) 

26 return tl.where(x < 0, 0 - ux, ux) 

27 

28 

29@triton.jit 

30def _c_rem_i32(a, b): 

31 mag = _abs_u32(a) % _abs_u32(b) 

32 rem = mag.to(tl.int32) 

33 return tl.where((a < 0) & (mag != 0), -rem, rem) 

34 

35 

36@triton.jit 

37def _c_rem_i64(a, b): 

38 mag = _abs_u64(a) % _abs_u64(b) 

39 rem = mag.to(tl.int64) 

40 return tl.where((a < 0) & (mag != 0), -rem, rem) 

41 

42 

43@triton.jit 

44def _binary_gcd(ax, ay, normal): 

45 zero_ax = ax == 0 

46 zero_ay = ay == 0 

47 res = tl.where(zero_ax, ay, ax) 

48 both_nonzero = normal & (~zero_ax) & (~zero_ay) 

49 common = _ctz(tl.where(both_nonzero, ax | ay, 1)) 

50 u = tl.where(both_nonzero, ax >> _ctz(tl.where(both_nonzero, ax, 1)), ax) 

51 v = ay 

52 active = both_nonzero 

53 

54 while tl.sum(active.to(tl.int32), axis=0) > 0: 

55 v_shifted = tl.where(active, v >> _ctz(tl.where(active, v, 1)), v) 

56 swap = active & (u > v_shifted) 

57 small = tl.where(swap, v_shifted, u) 

58 large = tl.where(swap, u, v_shifted) 

59 u = tl.where(active, small, u) 

60 v = tl.where(active, large - small, v) 

61 active = active & (v != 0) 

62 

63 return tl.where(both_nonzero, u << common, res) 

64 

65 

66@triton.jit 

67def gcd_kernel_i16(x_ptr, y_ptr, lut_ptr, out_ptr, n_elements, BLOCK: tl.constexpr): 

68 pid = tl.program_id(0) 

69 offsets = pid * BLOCK + tl.arange(0, BLOCK) 

70 mask = offsets < n_elements 

71 

72 x = tl.load(x_ptr + offsets, mask=mask, other=0) 

73 y = tl.load(y_ptr + offsets, mask=mask, other=0) 

74 x_i32 = x.to(tl.int32) 

75 y_i32 = y.to(tl.int32) 

76 min_value: tl.constexpr = -32768 

77 min_x = x_i32 == min_value 

78 min_y = y_i32 == min_value 

79 special_mask = mask & (min_x | min_y) 

80 normal = mask & (~special_mask) 

81 ax = tl.abs(x_i32) 

82 ay = tl.abs(y_i32) 

83 normal_res = _binary_gcd(ax, ay, normal) 

84 

85 both_min = special_mask & min_x & min_y 

86 one_min = special_mask & (~both_min) 

87 other_abs = tl.where(min_x, tl.abs(y_i32), tl.abs(x_i32)) 

88 special_res = tl.load(lut_ptr + other_abs, mask=one_min, other=0).to(tl.int32) 

89 special_res = tl.where(both_min, min_value, special_res) 

90 

91 out = tl.where(special_mask, special_res, normal_res) 

92 tl.store(out_ptr + offsets, out.to(out_ptr.type.element_ty), mask=mask) 

93 

94 

95@triton.jit 

96def gcd_kernel_i32(x_ptr, y_ptr, out_ptr, n_elements, BLOCK: tl.constexpr): 

97 pid = tl.program_id(0) 

98 offsets = pid * BLOCK + tl.arange(0, BLOCK) 

99 mask = offsets < n_elements 

100 

101 x = tl.load(x_ptr + offsets, mask=mask, other=0) 

102 y = tl.load(y_ptr + offsets, mask=mask, other=0) 

103 min_value: tl.constexpr = -(1 << 31) 

104 min_x = x == min_value 

105 min_y = y == min_value 

106 ax_native = tl.where(min_x, x, tl.abs(x)) 

107 ay_native = tl.where(min_y, y, tl.abs(y)) 

108 ax = ax_native.to(tl.int32) 

109 ay = ay_native.to(tl.int32) 

110 

111 special_mask = mask & (min_x | min_y) 

112 normal = mask & (~special_mask) 

113 normal_res = _binary_gcd(ax, ay, normal) 

114 

115 sa = ax_native.to(tl.int32) 

116 sb = ay_native.to(tl.int32) 

117 special = special_mask & (sa != 0) 

118 while tl.sum(special.to(tl.int32), axis=0) > 0: 

119 next_sa = tl.where(special, _c_rem_i32(sb, tl.where(special, sa, 1)), sa) 

120 sb = tl.where(special, sa, sb) 

121 sa = next_sa 

122 special = special & (sa != 0) 

123 

124 out = tl.where(mask & (~normal), sb, normal_res) 

125 tl.store(out_ptr + offsets, out.to(out_ptr.type.element_ty), mask=mask) 

126 

127 

128@triton.jit 

129def gcd_kernel_i64(x_ptr, y_ptr, out_ptr, n_elements, BLOCK: tl.constexpr): 

130 pid = tl.program_id(0) 

131 offsets = pid * BLOCK + tl.arange(0, BLOCK) 

132 mask = offsets < n_elements 

133 

134 x = tl.load(x_ptr + offsets, mask=mask, other=0) 

135 y = tl.load(y_ptr + offsets, mask=mask, other=0) 

136 min_x = x == -(1 << 63) 

137 min_y = y == -(1 << 63) 

138 ax = tl.where(min_x, x, tl.abs(x)).to(tl.int64) 

139 ay = tl.where(min_y, y, tl.abs(y)).to(tl.int64) 

140 

141 special_mask = mask & (min_x | min_y) 

142 normal = mask & (~special_mask) 

143 normal_res = _binary_gcd(ax, ay, normal) 

144 

145 sa = ax 

146 sb = ay 

147 special = special_mask & (sa != 0) 

148 while tl.sum(special.to(tl.int32), axis=0) > 0: 

149 next_sa = tl.where(special, _c_rem_i64(sb, tl.where(special, sa, 1)), sa) 

150 sb = tl.where(special, sa, sb) 

151 sa = next_sa 

152 special = special & (sa != 0) 

153 

154 out = tl.where(mask & (~normal), sb, normal_res) 

155 tl.store(out_ptr + offsets, out.to(out_ptr.type.element_ty), mask=mask) 

156 

157 

158def _kernel_meta(dtype): 

159 if dtype == torch.int16: 

160 return gcd_kernel_i16, 512, 4 

161 if dtype == torch.int32: 

162 return gcd_kernel_i32, 512, 4 

163 if dtype == torch.int64: 

164 return gcd_kernel_i64, 256, 4 

165 raise TypeError(f"unsupported dtype for gcd: {dtype}") 

166 

167 

168def _get_i16_min_lut(device): 

169 key = (device.type, device.index) 

170 lut = _I16_MIN_LUT_CACHE.get(key) 

171 if lut is None: 

172 info = torch.iinfo(torch.int16) 

173 lhs = torch.full((info.max + 1,), info.min, dtype=torch.int16) 

174 rhs = torch.arange(info.max + 1, dtype=torch.int16) 

175 lut = torch.gcd(lhs, rhs).to(device=device) 

176 _I16_MIN_LUT_CACHE[key] = lut 

177 return lut 

178 

179 

180def _materialize_inputs(self, other): 

181 promoted_dtype = torch.promote_types(self.dtype, other.dtype) 

182 lhs = self if self.dtype == promoted_dtype else self.to(promoted_dtype) 

183 rhs = other if other.dtype == promoted_dtype else other.to(promoted_dtype) 

184 lhs, rhs = torch.broadcast_tensors(lhs, rhs) 

185 return lhs.contiguous(), rhs.contiguous(), promoted_dtype 

186 

187 

188def _launch_gcd(lhs, rhs, out): 

189 numel = out.numel() 

190 if numel == 0: 

191 return out 

192 

193 kernel, block, num_warps = _kernel_meta(out.dtype) 

194 grid = (triton.cdiv(numel, block),) 

195 if out.dtype == torch.int16: 

196 lut = _get_i16_min_lut(out.device) 

197 kernel[grid](lhs, rhs, lut, out, numel, BLOCK=block, num_warps=num_warps) 

198 else: 

199 kernel[grid](lhs, rhs, out, numel, BLOCK=block, num_warps=num_warps) 

200 return out 

201 

202 

203def gcd(self, other, *, out=None): 

204 logger.debug("GEMS GCD") 

205 lhs, rhs, promoted_dtype = _materialize_inputs(self, other) 

206 result = torch.empty_like(lhs, dtype=promoted_dtype) 

207 _launch_gcd(lhs.reshape(-1), rhs.reshape(-1), result.reshape(-1)) 

208 result = result.view(lhs.shape) 

209 if out is None: 

210 return result 

211 

212 out.copy_(result) 

213 return out 

214 

215 

216def gcd_out(self, other, *, out=None): 

217 logger.debug("GEMS GCD_OUT") 

218 if out is None: 

219 return gcd(self, other) 

220 return gcd(self, other, out=out)