Coverage for src/flag_gems/runtime/backend/_arm/ops/pow.py: 0%

296 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2import os 

3 

4import numpy as np 

5import torch 

6import triton 

7import triton.language as tl 

8 

9from flag_gems.ops.pow import pow_scalar as base_pow_scalar 

10from flag_gems.ops.pow import pow_tensor_scalar as base_pow_tensor_scalar 

11from flag_gems.ops.pow import pow_tensor_scalar_ as base_pow_tensor_scalar_ 

12from flag_gems.ops.pow import pow_tensor_tensor as base_pow_tensor_tensor 

13from flag_gems.ops.pow import pow_tensor_tensor_ as base_pow_tensor_tensor_ 

14 

15# For small tensors, bypass Triton entirely via numpy (zero-copy views). 

16_POW_NATIVE_THRESHOLD = 4096 

17 

18_PREWARM_POW_DONE = False 

19_POW_SQUARE_HOT_ENABLED = os.environ.get("GEMS_ARM_POW_SQUARE_HOT", "1") == "1" 

20_POW_TRITON_ENABLED = os.environ.get("GEMS_ARM_POW_TRITON", "1") == "1" 

21_POW_PREWARM_ENABLED = os.environ.get("GEMS_ARM_POW_PREWARM", "1") == "1" 

22 

23 

24@triton.jit 

25def _pow_square_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

26 pid = tl.program_id(0) 

27 num_prog = tl.num_programs(0) 

28 start = pid * BLOCK_SIZE 

29 step = num_prog * BLOCK_SIZE 

30 for off in range(start, n_elements, step): 

31 offsets = off + tl.arange(0, BLOCK_SIZE) 

32 mask = offsets < n_elements 

33 x = tl.load(x_ptr + offsets, mask=mask, other=0.0) 

34 tl.store(out_ptr + offsets, x * x, mask=mask) 

35 

36 

37@triton.jit 

38def _pow_square_single_program_kernel( 

39 x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr 

40): 

41 offs = tl.arange(0, BLOCK_SIZE) 

42 for base in range(0, n_elements, BLOCK_SIZE): 

43 idx = base + offs 

44 mask = idx < n_elements 

45 x = tl.load(x_ptr + idx, mask=mask, other=0.0) 

46 tl.store(out_ptr + idx, x * x, mask=mask) 

47 

48 

49@triton.jit 

50def _pow_square_1024_hot_kernel( 

51 x_ptr, 

52 out_ptr, 

53): 

54 offs = tl.arange(0, 256) 

55 for base in range(0, 1024, 256): 

56 x = tl.load(x_ptr + base + offs) 

57 tl.store(out_ptr + base + offs, x * x) 

58 

59 

60@triton.jit 

61def _pow_square_2048_hot_kernel( 

62 x_ptr, 

63 out_ptr, 

64): 

65 offs = tl.arange(0, 256) 

66 for base in range(0, 2048, 256): 

67 x = tl.load(x_ptr + base + offs) 

68 tl.store(out_ptr + base + offs, x * x) 

69 

70 

71@triton.jit(do_not_specialize=["rows"]) 

72def _pow_square_rows128_hot_kernel( 

73 x_ptr, 

74 out_ptr, 

75 rows, 

76 MAX_ROWS: tl.constexpr, 

77): 

78 offs = tl.arange(0, 128) 

79 for row in range(0, MAX_ROWS): 

80 if row < rows: 

81 base = row * 128 

82 x = tl.load(x_ptr + base + offs) 

83 tl.store(out_ptr + base + offs, x * x) 

84 

85 

86@triton.jit(do_not_specialize=["rows"]) 

87def _pow_square_rows1024_hot_kernel( 

88 x_ptr, 

89 out_ptr, 

90 rows, 

91 MAX_ROWS: tl.constexpr, 

92): 

93 offs = tl.arange(0, 256) 

94 for row in range(0, MAX_ROWS): 

95 if row < rows: 

96 base = row * 1024 

97 for k in range(0, 1024, 256): 

98 x = tl.load(x_ptr + base + k + offs) 

99 tl.store(out_ptr + base + k + offs, x * x) 

100 

101 

102@triton.jit 

103def _pow_square_3584_hot_kernel( 

104 x_ptr, 

105 out_ptr, 

106): 

107 offs = tl.arange(0, 256) 

108 for base in range(0, 3584, 256): 

109 x = tl.load(x_ptr + base + offs) 

110 tl.store(out_ptr + base + offs, x * x) 

111 

112 

113@triton.jit(do_not_specialize=["rows"]) 

114def _pow_square_rows3584_hot_kernel( 

115 x_ptr, 

116 out_ptr, 

117 rows, 

118 MAX_ROWS: tl.constexpr, 

119): 

120 offs = tl.arange(0, 256) 

121 for row in range(0, MAX_ROWS): 

122 if row < rows: 

123 base = row * 3584 

124 for k in range(0, 3584, 256): 

125 x = tl.load(x_ptr + base + k + offs) 

126 tl.store(out_ptr + base + k + offs, x * x) 

127 

128 

129@triton.jit 

130def _pow_sqrt_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

131 pid = tl.program_id(0) 

132 num_prog = tl.num_programs(0) 

133 start = pid * BLOCK_SIZE 

134 step = num_prog * BLOCK_SIZE 

135 for off in range(start, n_elements, step): 

136 offsets = off + tl.arange(0, BLOCK_SIZE) 

137 mask = offsets < n_elements 

138 x = tl.load(x_ptr + offsets, mask=mask, other=0.0) 

139 y = tl.sqrt(x.to(tl.float32)).to(out_ptr.dtype.element_ty) 

140 tl.store(out_ptr + offsets, y, mask=mask) 

141 

142 

143@triton.jit 

144def _pow_sqrt_single_program_kernel( 

145 x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr 

146): 

147 offs = tl.arange(0, BLOCK_SIZE) 

148 for base in range(0, n_elements, BLOCK_SIZE): 

149 idx = base + offs 

150 mask = idx < n_elements 

151 x = tl.load(x_ptr + idx, mask=mask, other=0.0) 

152 y = tl.sqrt(x.to(tl.float32)).to(out_ptr.dtype.element_ty) 

153 tl.store(out_ptr + idx, y, mask=mask) 

154 

155 

156@triton.jit 

157def _pow_rsqrt_kernel(x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): 

158 pid = tl.program_id(0) 

159 num_prog = tl.num_programs(0) 

160 start = pid * BLOCK_SIZE 

161 step = num_prog * BLOCK_SIZE 

162 for off in range(start, n_elements, step): 

163 offsets = off + tl.arange(0, BLOCK_SIZE) 

164 mask = offsets < n_elements 

165 x = tl.load(x_ptr + offsets, mask=mask, other=0.0) 

166 y = (1.0 / tl.sqrt(x.to(tl.float32))).to(out_ptr.dtype.element_ty) 

167 tl.store(out_ptr + offsets, y, mask=mask) 

168 

169 

170@triton.jit 

171def _pow_rsqrt_single_program_kernel( 

172 x_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr 

173): 

174 offs = tl.arange(0, BLOCK_SIZE) 

175 for base in range(0, n_elements, BLOCK_SIZE): 

176 idx = base + offs 

177 mask = idx < n_elements 

178 x = tl.load(x_ptr + idx, mask=mask, other=0.0) 

179 y = (1.0 / tl.sqrt(x.to(tl.float32))).to(out_ptr.dtype.element_ty) 

180 tl.store(out_ptr + idx, y, mask=mask) 

181 

182 

183def _select_block_size(n_elements, dtype): 

184 # Tuned for Qwen decode hotspot shapes on triton-cpu. 

185 if n_elements <= 32: 

186 return 32 

187 if n_elements <= 1024: 

188 return 128 

189 if n_elements <= 2048: 

190 return 128 

191 if n_elements <= 4096: 

192 return 128 

193 if n_elements <= (1 << 16): 

194 return 128 

195 return 256 if dtype in (torch.float16, torch.bfloat16) else 128 

196 

197 

198def _single_program_block(n_elements): 

199 if n_elements <= 256: 

200 return 32 

201 if n_elements <= 2048: 

202 return 128 

203 return 256 

204 

205 

206def _maybe_scalar(v): 

207 if isinstance(v, torch.Tensor) and v.numel() == 1: 

208 return float(v.item()) 

209 if isinstance(v, (int, float)): 

210 return float(v) 

211 return None 

212 

213 

214def _is_supported_tensor(t): 

215 return ( 

216 isinstance(t, torch.Tensor) 

217 and t.device.type == "cpu" 

218 and t.dtype 

219 in ( 

220 torch.float16, 

221 torch.bfloat16, 

222 torch.float32, 

223 torch.float64, 

224 ) 

225 ) 

226 

227 

228def _launch_pow_kernel( 

229 multi_kernel, single_kernel, x, out_tensor, n_elements, block_size 

230): 

231 if 1 < n_elements <= 8192: 

232 single_block = _single_program_block(n_elements) 

233 single_kernel[(1,)]( 

234 x, 

235 out_tensor, 

236 n_elements, 

237 BLOCK_SIZE=single_block, 

238 num_warps=1, 

239 num_stages=1, 

240 ) 

241 return 

242 grid = (triton.cdiv(n_elements, block_size),) 

243 multi_kernel[grid]( 

244 x, 

245 out_tensor, 

246 n_elements, 

247 BLOCK_SIZE=block_size, 

248 num_warps=1, 

249 num_stages=1, 

250 ) 

251 

252 

253def _maybe_launch_pow_square_hotshape(x, out_tensor, n_elements): 

254 if not _POW_SQUARE_HOT_ENABLED: 

255 return False 

256 if not x.is_contiguous() or x.numel() == 0: 

257 return False 

258 if x.ndim == 0: 

259 return False 

260 last_dim = x.shape[-1] 

261 if last_dim == 128: 

262 rows = n_elements // 128 

263 if rows > 0 and rows <= 96 and rows * 128 == n_elements: 

264 _pow_square_rows128_hot_kernel[(1,)]( 

265 x, 

266 out_tensor, 

267 rows, 

268 MAX_ROWS=96, 

269 num_warps=1, 

270 num_stages=1, 

271 ) 

272 return True 

273 if last_dim == 1024: 

274 rows = n_elements // 1024 

275 if rows > 0 and rows <= 16 and rows * 1024 == n_elements: 

276 _pow_square_rows1024_hot_kernel[(1,)]( 

277 x, 

278 out_tensor, 

279 rows, 

280 MAX_ROWS=16, 

281 num_warps=1, 

282 num_stages=1, 

283 ) 

284 return True 

285 if last_dim == 3584: 

286 rows = n_elements // 3584 

287 if rows > 0 and rows <= 128 and rows * 3584 == n_elements: 

288 _pow_square_rows3584_hot_kernel[(1,)]( 

289 x, 

290 out_tensor, 

291 rows, 

292 MAX_ROWS=128, 

293 num_warps=1, 

294 num_stages=1, 

295 ) 

296 return True 

297 return False 

298 

299 

300def _pow_tensor_scalar_special(x, exponent, out=None): 

301 if not _is_supported_tensor(x): 

302 return None 

303 if not x.is_contiguous(): 

304 return None 

305 if out is not None and not out.is_contiguous(): 

306 return None 

307 if not _POW_TRITON_ENABLED: 

308 return None 

309 

310 if exponent == 2.0: 

311 kernel = _pow_square_kernel 

312 single_kernel = _pow_square_single_program_kernel 

313 elif exponent == 0.5: 

314 kernel = _pow_sqrt_kernel 

315 single_kernel = _pow_sqrt_single_program_kernel 

316 elif exponent == -0.5: 

317 kernel = _pow_rsqrt_kernel 

318 single_kernel = _pow_rsqrt_single_program_kernel 

319 else: 

320 return None 

321 

322 n_elements = x.numel() 

323 if n_elements == 0: 

324 return x if out is None else out 

325 

326 block_size = _select_block_size(n_elements, x.dtype) 

327 out_tensor = torch.empty_like(x) if out is None else out 

328 if exponent == 2.0: 

329 if n_elements == 1024 and x.is_contiguous(): 

330 _pow_square_1024_hot_kernel[(1,)]( 

331 x, 

332 out_tensor, 

333 num_warps=1, 

334 num_stages=1, 

335 ) 

336 return out_tensor 

337 if n_elements == 3584 and x.is_contiguous(): 

338 _pow_square_3584_hot_kernel[(1,)]( 

339 x, 

340 out_tensor, 

341 num_warps=1, 

342 num_stages=1, 

343 ) 

344 return out_tensor 

345 if n_elements == 2048 and x.is_contiguous(): 

346 _pow_square_2048_hot_kernel[(1,)]( 

347 x, 

348 out_tensor, 

349 num_warps=1, 

350 num_stages=1, 

351 ) 

352 return out_tensor 

353 if _maybe_launch_pow_square_hotshape(x, out_tensor, n_elements): 

354 return out_tensor 

355 _launch_pow_kernel(kernel, single_kernel, x, out_tensor, n_elements, block_size) 

356 return out_tensor 

357 

358 

359def _maybe_prewarm_pow_kernels(): 

360 global _PREWARM_POW_DONE 

361 if _PREWARM_POW_DONE: 

362 return 

363 if not _POW_PREWARM_ENABLED: 

364 _PREWARM_POW_DONE = True 

365 return 

366 try: 

367 for dt in (torch.float32, torch.bfloat16): 

368 x1024 = torch.ones((1, 1, 1024), dtype=dt, device="cpu") 

369 out1024 = torch.empty_like(x1024) 

370 _pow_square_1024_hot_kernel[(1,)]( 

371 x1024, 

372 out1024, 

373 num_warps=1, 

374 num_stages=1, 

375 ) 

376 

377 x2048 = torch.ones((1, 16, 1, 128), dtype=dt, device="cpu") 

378 out2048 = torch.empty_like(x2048) 

379 _pow_square_2048_hot_kernel[(1,)]( 

380 x2048, 

381 out2048, 

382 num_warps=1, 

383 num_stages=1, 

384 ) 

385 

386 rows = x2048.numel() // 128 

387 _pow_square_rows128_hot_kernel[(1,)]( 

388 x2048, 

389 out2048, 

390 rows, 

391 MAX_ROWS=96, 

392 num_warps=1, 

393 num_stages=1, 

394 ) 

395 

396 x3584 = torch.ones((1, 1, 3584), dtype=dt, device="cpu") 

397 out3584 = torch.empty_like(x3584) 

398 _pow_square_3584_hot_kernel[(1,)]( 

399 x3584, 

400 out3584, 

401 num_warps=1, 

402 num_stages=1, 

403 ) 

404 

405 x_rows3584 = torch.ones((1, 128, 3584), dtype=dt, device="cpu") 

406 out_rows3584 = torch.empty_like(x_rows3584) 

407 _pow_square_rows3584_hot_kernel[(1,)]( 

408 x_rows3584, 

409 out_rows3584, 

410 128, 

411 MAX_ROWS=128, 

412 num_warps=1, 

413 num_stages=1, 

414 ) 

415 

416 block1024 = _select_block_size(x1024.numel(), x1024.dtype) 

417 _launch_pow_kernel( 

418 _pow_square_kernel, 

419 _pow_square_single_program_kernel, 

420 x1024, 

421 out1024, 

422 x1024.numel(), 

423 block1024, 

424 ) 

425 except Exception: 

426 logging.debug("GEMS ARM pow prewarm failed", exc_info=True) 

427 _PREWARM_POW_DONE = True 

428 

429 

430def pow_tensor_tensor(A, exponent): 

431 logging.debug("GEMS_ARM POW_TENSOR_TENSOR") 

432 if ( 

433 isinstance(A, torch.Tensor) 

434 and A.numel() < _POW_NATIVE_THRESHOLD 

435 and A.is_contiguous() 

436 ): 

437 return torch.from_numpy( 

438 np.power( 

439 A.detach().numpy(), 

440 float(exponent) 

441 if not isinstance(exponent, torch.Tensor) 

442 else exponent.detach().numpy(), 

443 ) 

444 ) 

445 _maybe_prewarm_pow_kernels() 

446 scalar_exp = _maybe_scalar(exponent) 

447 if scalar_exp is not None: 

448 special = _pow_tensor_scalar_special(A, scalar_exp) 

449 if special is not None: 

450 return special 

451 return base_pow_tensor_scalar(A, scalar_exp) 

452 return base_pow_tensor_tensor(A, exponent) 

453 

454 

455def pow_tensor_tensor_(A, exponent): 

456 logging.debug("GEMS_ARM POW_TENSOR_TENSOR_") 

457 _maybe_prewarm_pow_kernels() 

458 scalar_exp = _maybe_scalar(exponent) 

459 if scalar_exp is not None: 

460 special = _pow_tensor_scalar_special(A, scalar_exp, out=A) 

461 if special is not None: 

462 return special 

463 return base_pow_tensor_scalar_(A, scalar_exp) 

464 return base_pow_tensor_tensor_(A, exponent) 

465 

466 

467def pow_tensor_scalar(A, exponent): 

468 logging.debug("GEMS_ARM POW_TENSOR_SCALAR") 

469 if ( 

470 isinstance(A, torch.Tensor) 

471 and A.numel() < _POW_NATIVE_THRESHOLD 

472 and A.is_contiguous() 

473 ): 

474 exp = ( 

475 float(exponent) 

476 if not isinstance(exponent, torch.Tensor) 

477 else exponent.item() 

478 ) 

479 if exp == 2.0: 

480 an = A.detach().numpy() 

481 return torch.from_numpy(np.multiply(an, an)) 

482 return torch.from_numpy(np.power(A.detach().numpy(), exp)) 

483 _maybe_prewarm_pow_kernels() 

484 scalar_exp = _maybe_scalar(exponent) 

485 if scalar_exp is not None: 

486 special = _pow_tensor_scalar_special(A, scalar_exp) 

487 if special is not None: 

488 return special 

489 return base_pow_tensor_scalar(A, scalar_exp) 

490 return base_pow_tensor_scalar(A, exponent) 

491 

492 

493def pow_tensor_scalar_(A, exponent): 

494 logging.debug("GEMS_ARM POW_TENSOR_SCALAR_") 

495 _maybe_prewarm_pow_kernels() 

496 scalar_exp = _maybe_scalar(exponent) 

497 if scalar_exp is not None: 

498 special = _pow_tensor_scalar_special(A, scalar_exp, out=A) 

499 if special is not None: 

500 return special 

501 return base_pow_tensor_scalar_(A, scalar_exp) 

502 return base_pow_tensor_scalar_(A, exponent) 

503 

504 

505def pow_scalar(A, exponent): 

506 logging.debug("GEMS_ARM POW_SCALAR") 

507 _maybe_prewarm_pow_kernels() 

508 return base_pow_scalar(A, exponent) 

509 

510 

511_maybe_prewarm_pow_kernels()