Coverage for src/flag_gems/runtime/backend/_sunrise/ops/exponential_.py: 0%

200 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import libentry, libtuner 

9from flag_gems.utils.random_utils import ( 

10 philox_backend_seed_offset, 

11 uint_to_uniform_float, 

12) 

13 

14logger = logging.getLogger(__name__) 

15 

16 

17@triton.jit 

18def safe_fast_log_f32(x): 

19 min_normal = (x * 0.0 + 1.17549435e-38).to(tl.float32) 

20 max_u = x * 0.0 + 0.99999994 

21 x = tl.minimum(tl.maximum(x, min_normal), max_u) 

22 bits = x.to(tl.int32, bitcast=True) 

23 exponent = (bits >> 23) - 127 

24 mantissa = (bits & 0x7FFFFF).to(tl.float32) * (1.0 / 8388608.0) + 1.0 

25 m1 = mantissa - 1.0 

26 return ( 

27 m1 * (1.0 + m1 * (-0.5 + m1 * (0.3333333333 - m1 * 0.25))) 

28 + exponent.to(tl.float32) * 0.6931471805599453 

29 ) 

30 

31 

32@triton.jit 

33def safe_fast_log_f64(x): 

34 min_normal = x * 0.0 + 2.2250738585072014e-308 

35 max_u = x * 0.0 + (1.0 - 2.220446049250313e-16) 

36 x = tl.minimum(tl.maximum(x, min_normal), max_u) 

37 bits = x.to(tl.int64, bitcast=True) 

38 exponent = (bits >> 52) - 1023 

39 mantissa = (bits & 0x000FFFFFFFFFFFFF).to(tl.float64) * ( 

40 1.0 / 4503599627370496.0 

41 ) + 1.0 

42 m1 = mantissa - 1.0 

43 return ( 

44 m1 * (1.0 + m1 * (-0.5 + m1 * (0.3333333333333333 - m1 * 0.25))) 

45 + exponent.to(tl.float64) * 0.6931471805599453 

46 ) 

47 

48 

49@triton.jit 

50def paste_u64(hi: tl.uint32, lo: tl.uint32): 

51 return (hi.to(tl.uint64) << 32) | lo.to(tl.uint64) 

52 

53 

54@triton.jit 

55def transform_exponential_f32_precise(u, inv_lambd, eps_minus): 

56 log = tl.where(u >= 1.0 + eps_minus, eps_minus, tl.math.log(u)) 

57 # log = tl.log(tl.maximum(u, 1e-38)) 

58 return -inv_lambd * log 

59 

60 

61@triton.jit 

62def transform_exponential_f32_fast(u, inv_lambd, eps_minus): 

63 log = tl.where(u >= 1.0 + eps_minus, eps_minus, safe_fast_log_f32(u)) 

64 # log = tl.log(tl.maximum(u, 1e-38)) 

65 return -inv_lambd * log 

66 

67 

68# Sunrise/PTPU shows a large statistical drift on the fast float32 log path. 

69# Keep the narrower fast approximation for other backends, but use the precise 

70# log transform here so the generated exponential distribution matches CPU. 

71transform_exponential_f32 = transform_exponential_f32_precise 

72 

73 

74@triton.jit 

75def transform_exponential_f64(u, inv_lambd, eps_minus): 

76 log = tl.where(u >= 1.0 + eps_minus, eps_minus, safe_fast_log_f64(u)) 

77 return -inv_lambd * log 

78 

79 

80@libentry() 

81@libtuner( 

82 configs=[ 

83 triton.Config({"BLOCK": 64}, num_warps=4, num_stages=2), 

84 triton.Config({"BLOCK": 128}, num_warps=4, num_stages=3), 

85 triton.Config({"BLOCK": 256}, num_warps=4, num_stages=3), 

86 triton.Config({"BLOCK": 512}, num_warps=8, num_stages=3), 

87 ], 

88 key=["N"], 

89) 

90@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"]) 

91def fused_exponential_kernel_f32_unroll8( 

92 out_ptr, N, inv_lambd, eps_minus, philox_seed, philox_offset, BLOCK: tl.constexpr 

93): 

94 philox_seed = philox_seed.to(tl.int64) 

95 philox_offset = philox_offset.to(tl.int64) 

96 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

97 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

98 

99 pid = tl.program_id(0) 

100 block_start = pid * BLOCK 

101 offsets = block_start + tl.arange(0, BLOCK) 

102 

103 c0_first = c0 + offsets * 4 

104 c0_second = c0_first + BLOCK * 4 

105 z = c0_first * 0 

106 

107 r0_0, r1_0, r2_0, r3_0 = tl.philox(philox_seed, c0_first, c1, z, z) 

108 r0_1, r1_1, r2_1, r3_1 = tl.philox(philox_seed, c0_second, c1, z, z) 

109 

110 y0_0 = transform_exponential_f32(uint_to_uniform_float(r0_0), inv_lambd, eps_minus) 

111 y1_0 = transform_exponential_f32(uint_to_uniform_float(r1_0), inv_lambd, eps_minus) 

112 y2_0 = transform_exponential_f32(uint_to_uniform_float(r2_0), inv_lambd, eps_minus) 

113 y3_0 = transform_exponential_f32(uint_to_uniform_float(r3_0), inv_lambd, eps_minus) 

114 

115 y0_1 = transform_exponential_f32(uint_to_uniform_float(r0_1), inv_lambd, eps_minus) 

116 y1_1 = transform_exponential_f32(uint_to_uniform_float(r1_1), inv_lambd, eps_minus) 

117 y2_1 = transform_exponential_f32(uint_to_uniform_float(r2_1), inv_lambd, eps_minus) 

118 y3_1 = transform_exponential_f32(uint_to_uniform_float(r3_1), inv_lambd, eps_minus) 

119 

120 base_off = pid.to(tl.uint64) * BLOCK * 8 

121 off0 = base_off + tl.arange(0, BLOCK) 

122 off1 = off0 + BLOCK 

123 off2 = off1 + BLOCK 

124 off3 = off2 + BLOCK 

125 off4 = off3 + BLOCK 

126 off5 = off4 + BLOCK 

127 off6 = off5 + BLOCK 

128 off7 = off6 + BLOCK 

129 

130 tl.store(out_ptr + off0, y0_0, mask=off0 < N) 

131 tl.store(out_ptr + off1, y1_0, mask=off1 < N) 

132 tl.store(out_ptr + off2, y2_0, mask=off2 < N) 

133 tl.store(out_ptr + off3, y3_0, mask=off3 < N) 

134 tl.store(out_ptr + off4, y0_1, mask=off4 < N) 

135 tl.store(out_ptr + off5, y1_1, mask=off5 < N) 

136 tl.store(out_ptr + off6, y2_1, mask=off6 < N) 

137 tl.store(out_ptr + off7, y3_1, mask=off7 < N) 

138 

139 

140@libentry() 

141@libtuner( 

142 configs=[ 

143 triton.Config({"BLOCK": 64}, num_warps=4, num_stages=2), 

144 triton.Config({"BLOCK": 128}, num_warps=4, num_stages=3), 

145 triton.Config({"BLOCK": 256}, num_warps=4, num_stages=3), 

146 triton.Config({"BLOCK": 512}, num_warps=8, num_stages=3), 

147 ], 

148 key=["N"], 

149) 

150@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"]) 

151def fused_exponential_kernel_f32( 

152 out_ptr, N, inv_lambd, eps_minus, philox_seed, philox_offset, BLOCK: tl.constexpr 

153): 

154 philox_seed = philox_seed.to(tl.int64) 

155 philox_offset = philox_offset.to(tl.int64) 

156 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

157 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

158 

159 pid = tl.program_id(0) 

160 i = pid * BLOCK + tl.arange(0, BLOCK) 

161 c0 += i 

162 z = c0 * 0 

163 r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, z, z) 

164 

165 y0 = transform_exponential_f32(uint_to_uniform_float(r0), inv_lambd, eps_minus) 

166 y1 = transform_exponential_f32(uint_to_uniform_float(r1), inv_lambd, eps_minus) 

167 y2 = transform_exponential_f32(uint_to_uniform_float(r2), inv_lambd, eps_minus) 

168 y3 = transform_exponential_f32(uint_to_uniform_float(r3), inv_lambd, eps_minus) 

169 

170 start = pid.to(tl.uint64) * BLOCK * 4 

171 off0 = start + tl.arange(0, BLOCK) 

172 off1 = off0 + BLOCK 

173 off2 = off1 + BLOCK 

174 off3 = off2 + BLOCK 

175 

176 tl.store(out_ptr + off0, y0, mask=off0 < N) 

177 tl.store(out_ptr + off1, y1, mask=off1 < N) 

178 tl.store(out_ptr + off2, y2, mask=off2 < N) 

179 tl.store(out_ptr + off3, y3, mask=off3 < N) 

180 

181 

182@libentry() 

183@libtuner( 

184 configs=[ 

185 triton.Config({"BLOCK": 64}, num_warps=4, num_stages=2), 

186 triton.Config({"BLOCK": 128}, num_warps=4, num_stages=2), 

187 triton.Config({"BLOCK": 256}, num_warps=4, num_stages=3), 

188 triton.Config({"BLOCK": 512}, num_warps=8, num_stages=3), 

189 ], 

190 key=["N"], 

191) 

192@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"]) 

193def fused_exponential_kernel_f32_small( 

194 out_ptr, N, inv_lambd, eps_minus, philox_seed, philox_offset, BLOCK: tl.constexpr 

195): 

196 philox_seed = philox_seed.to(tl.int64) 

197 philox_offset = philox_offset.to(tl.int64) 

198 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

199 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

200 

201 pid = tl.program_id(0) 

202 base_idx = pid * BLOCK * 4 

203 c0_i = c0 + tl.arange(0, BLOCK) 

204 z = c0_i * 0 

205 

206 r0, r1, r2, r3 = tl.philox(philox_seed, c0_i, c1, z, z) 

207 

208 y0 = transform_exponential_f32(uint_to_uniform_float(r0), inv_lambd, eps_minus) 

209 y1 = transform_exponential_f32(uint_to_uniform_float(r1), inv_lambd, eps_minus) 

210 y2 = transform_exponential_f32(uint_to_uniform_float(r2), inv_lambd, eps_minus) 

211 y3 = transform_exponential_f32(uint_to_uniform_float(r3), inv_lambd, eps_minus) 

212 

213 off0 = base_idx + tl.arange(0, BLOCK) 

214 off1 = off0 + BLOCK 

215 off2 = off1 + BLOCK 

216 off3 = off2 + BLOCK 

217 

218 tl.store(out_ptr + off0, y0, mask=off0 < N) 

219 tl.store(out_ptr + off1, y1, mask=off1 < N) 

220 tl.store(out_ptr + off2, y2, mask=off2 < N) 

221 tl.store(out_ptr + off3, y3, mask=off3 < N) 

222 

223 

224@libentry() 

225@libtuner( 

226 configs=[ 

227 triton.Config({"BLOCK": 64}, num_warps=4, num_stages=2), 

228 triton.Config({"BLOCK": 128}, num_warps=4, num_stages=3), 

229 triton.Config({"BLOCK": 256}, num_warps=4, num_stages=3), 

230 triton.Config({"BLOCK": 512}, num_warps=8, num_stages=3), 

231 ], 

232 key=["N"], 

233) 

234@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"]) 

235def fused_exponential_kernel_f64( 

236 out_ptr, N, inv_lambd, eps_minus, philox_seed, philox_offset, BLOCK: tl.constexpr 

237): 

238 philox_seed = philox_seed.to(tl.int64) 

239 philox_offset = philox_offset.to(tl.int64) 

240 c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32) 

241 c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32) 

242 

243 pid = tl.program_id(0) 

244 base_idx = pid * BLOCK * 4 

245 block_offset = tl.arange(0, BLOCK) 

246 c0_base = c0 + block_offset 

247 z = c0_base * 0 

248 

249 r0_0, r1_0, r2_0, r3_0 = tl.philox(philox_seed, c0_base, c1, z, z) 

250 r0_1, r1_1, r2_1, r3_1 = tl.philox(philox_seed, c0_base + BLOCK, c1, z, z) 

251 

252 u0_0 = uint_to_uniform_float(paste_u64(r0_0, r2_0)) 

253 u1_0 = uint_to_uniform_float(paste_u64(r1_0, r3_0)) 

254 u0_1 = uint_to_uniform_float(paste_u64(r0_1, r2_1)) 

255 u1_1 = uint_to_uniform_float(paste_u64(r1_1, r3_1)) 

256 

257 y0_0 = transform_exponential_f64(u0_0, inv_lambd, eps_minus) 

258 y1_0 = transform_exponential_f64(u1_0, inv_lambd, eps_minus) 

259 y0_1 = transform_exponential_f64(u0_1, inv_lambd, eps_minus) 

260 y1_1 = transform_exponential_f64(u1_1, inv_lambd, eps_minus) 

261 

262 off0 = base_idx + tl.arange(0, BLOCK) 

263 off1 = off0 + BLOCK 

264 off2 = off1 + BLOCK 

265 off3 = off2 + BLOCK 

266 

267 tl.store(out_ptr + off0, y0_0, mask=off0 < N) 

268 tl.store(out_ptr + off1, y1_0, mask=off1 < N) 

269 tl.store(out_ptr + off2, y0_1, mask=off2 < N) 

270 tl.store(out_ptr + off3, y1_1, mask=off3 < N) 

271 

272 

273def exponential_(x, lambd: float = 1.0, *, generator=None): 

274 logger.debug("GEMS EXPONENTIAL_") 

275 

276 dtype = x.dtype 

277 device = x.device 

278 inplace = x.is_contiguous() 

279 assert dtype in (torch.float16, torch.bfloat16, torch.float32, torch.float64) 

280 

281 N = x.numel() 

282 inv_lambd = 1.0 / lambd 

283 eps_minus = -0.5 * torch.finfo(dtype).eps 

284 

285 out = x if inplace else torch.empty_like(x) 

286 

287 if dtype is torch.float64: 

288 UNROLL = 2 

289 grid = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) 

290 increment = triton.cdiv(N, UNROLL) 

291 philox_seed, philox_offset = philox_backend_seed_offset( 

292 increment, generator=generator 

293 ) 

294 with torch_device_fn.device(device): 

295 fused_exponential_kernel_f64[grid]( 

296 out, N, inv_lambd, eps_minus, philox_seed, philox_offset 

297 ) 

298 elif dtype in (torch.float16, torch.bfloat16) and N < 65536: 

299 UNROLL = 4 

300 grid = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) 

301 increment = triton.cdiv(N, UNROLL) 

302 philox_seed, philox_offset = philox_backend_seed_offset( 

303 increment, generator=generator 

304 ) 

305 with torch_device_fn.device(device): 

306 fused_exponential_kernel_f32_small[grid]( 

307 out, N, inv_lambd, eps_minus, philox_seed, philox_offset 

308 ) 

309 else: 

310 UNROLL = 4 

311 grid = lambda meta: (triton.cdiv(N, meta["BLOCK"] * UNROLL),) 

312 increment = triton.cdiv(N, UNROLL) 

313 philox_seed, philox_offset = philox_backend_seed_offset( 

314 increment, generator=generator 

315 ) 

316 with torch_device_fn.device(device): 

317 fused_exponential_kernel_f32[grid]( 

318 out, N, inv_lambd, eps_minus, philox_seed, philox_offset 

319 ) 

320 

321 if not inplace: 

322 x.copy_(out) 

323 return x