Coverage for src/flag_gems/runtime/backend/_arm/ops/sub.py: 0%

279 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2import os 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.utils import pointwise_dynamic 

9 

10logger = logging.getLogger(__name__) 

11_PREWARM_SUB_DONE = False 

12 

13_SUPPORTED_FAST_DTYPES = ( 

14 torch.float16, 

15 torch.bfloat16, 

16 torch.float32, 

17 torch.float64, 

18) 

19_SUPPORTED_INT_FAST_DTYPES = ( 

20 torch.int8, 

21 torch.int16, 

22 torch.int32, 

23 torch.int64, 

24) 

25 

26 

27@triton.jit(do_not_specialize=["alpha", "n_elements"]) 

28def _sub_contiguous_kernel( 

29 x_ptr, 

30 y_ptr, 

31 out_ptr, 

32 alpha, 

33 n_elements, 

34 BLOCK_SIZE: tl.constexpr, 

35): 

36 pid = tl.program_id(0) 

37 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

38 mask = offsets < n_elements 

39 x = tl.load(x_ptr + offsets, mask=mask, other=0.0) 

40 y = tl.load(y_ptr + offsets, mask=mask, other=0.0) 

41 tl.store(out_ptr + offsets, x - y * alpha, mask=mask) 

42 

43 

44@triton.jit(do_not_specialize=["alpha", "n_elements"]) 

45def _sub_contiguous_single_program_kernel( 

46 x_ptr, 

47 y_ptr, 

48 out_ptr, 

49 alpha, 

50 n_elements, 

51 BLOCK_SIZE: tl.constexpr, 

52): 

53 offs = tl.arange(0, BLOCK_SIZE) 

54 for base in range(0, n_elements, BLOCK_SIZE): 

55 idx = base + offs 

56 mask = idx < n_elements 

57 x = tl.load(x_ptr + idx, mask=mask, other=0.0) 

58 y = tl.load(y_ptr + idx, mask=mask, other=0.0) 

59 tl.store(out_ptr + idx, x - y * alpha, mask=mask) 

60 

61 

62@triton.jit(do_not_specialize=["alpha", "rows", "cols"]) 

63def _sub_broadcast_lastdim1_kernel( 

64 x_ptr, 

65 y_ptr, 

66 out_ptr, 

67 alpha, 

68 rows, 

69 cols, 

70 BLOCK_SIZE: tl.constexpr, 

71): 

72 row = tl.program_id(0) 

73 if row >= rows: 

74 return 

75 

76 y = tl.load(y_ptr + row) 

77 offs = tl.arange(0, BLOCK_SIZE) 

78 row_start = row * cols 

79 for base in range(0, cols, BLOCK_SIZE): 

80 col = base + offs 

81 mask = col < cols 

82 x = tl.load(x_ptr + row_start + col, mask=mask, other=0.0) 

83 tl.store(out_ptr + row_start + col, x - y * alpha, mask=mask) 

84 

85 

86@triton.jit(do_not_specialize=["scalar", "alpha", "n_elements"]) 

87def _sub_tensor_scalar_kernel( 

88 x_ptr, 

89 scalar, 

90 out_ptr, 

91 alpha, 

92 n_elements, 

93 BLOCK_SIZE: tl.constexpr, 

94): 

95 pid = tl.program_id(0) 

96 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

97 mask = offsets < n_elements 

98 x = tl.load(x_ptr + offsets, mask=mask, other=0.0) 

99 tl.store(out_ptr + offsets, x - scalar * alpha, mask=mask) 

100 

101 

102@triton.jit(do_not_specialize=["scalar", "alpha", "n_elements"]) 

103def _sub_tensor_scalar_single_program_kernel( 

104 x_ptr, 

105 scalar, 

106 out_ptr, 

107 alpha, 

108 n_elements, 

109 BLOCK_SIZE: tl.constexpr, 

110): 

111 offs = tl.arange(0, BLOCK_SIZE) 

112 for base in range(0, n_elements, BLOCK_SIZE): 

113 idx = base + offs 

114 mask = idx < n_elements 

115 x = tl.load(x_ptr + idx, mask=mask, other=0.0) 

116 tl.store(out_ptr + idx, x - scalar * alpha, mask=mask) 

117 

118 

119@triton.jit(do_not_specialize=["scalar", "n_elements"]) 

120def _sub_tensor_scalar_int_kernel( 

121 x_ptr, 

122 scalar, 

123 out_ptr, 

124 n_elements, 

125 BLOCK_SIZE: tl.constexpr, 

126): 

127 pid = tl.program_id(0) 

128 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

129 mask = offsets < n_elements 

130 x = tl.load(x_ptr + offsets, mask=mask, other=0) 

131 tl.store(out_ptr + offsets, x - scalar, mask=mask) 

132 

133 

134@triton.jit(do_not_specialize=["scalar", "n_elements"]) 

135def _sub_tensor_scalar_int_single_program_kernel( 

136 x_ptr, 

137 scalar, 

138 out_ptr, 

139 n_elements, 

140 BLOCK_SIZE: tl.constexpr, 

141): 

142 offs = tl.arange(0, BLOCK_SIZE) 

143 for base in range(0, n_elements, BLOCK_SIZE): 

144 idx = base + offs 

145 mask = idx < n_elements 

146 x = tl.load(x_ptr + idx, mask=mask, other=0) 

147 tl.store(out_ptr + idx, x - scalar, mask=mask) 

148 

149 

150@triton.jit(do_not_specialize=["scalar", "alpha", "n_elements"]) 

151def _sub_scalar_tensor_kernel( 

152 scalar, 

153 y_ptr, 

154 out_ptr, 

155 alpha, 

156 n_elements, 

157 BLOCK_SIZE: tl.constexpr, 

158): 

159 pid = tl.program_id(0) 

160 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

161 mask = offsets < n_elements 

162 y = tl.load(y_ptr + offsets, mask=mask, other=0.0) 

163 tl.store(out_ptr + offsets, scalar - y * alpha, mask=mask) 

164 

165 

166@triton.jit(do_not_specialize=["scalar", "alpha", "n_elements"]) 

167def _sub_scalar_tensor_single_program_kernel( 

168 scalar, 

169 y_ptr, 

170 out_ptr, 

171 alpha, 

172 n_elements, 

173 BLOCK_SIZE: tl.constexpr, 

174): 

175 offs = tl.arange(0, BLOCK_SIZE) 

176 for base in range(0, n_elements, BLOCK_SIZE): 

177 idx = base + offs 

178 mask = idx < n_elements 

179 y = tl.load(y_ptr + idx, mask=mask, other=0.0) 

180 tl.store(out_ptr + idx, scalar - y * alpha, mask=mask) 

181 

182 

183@triton.jit(do_not_specialize=["scalar", "n_elements"]) 

184def _sub_scalar_tensor_int_kernel( 

185 scalar, 

186 y_ptr, 

187 out_ptr, 

188 n_elements, 

189 BLOCK_SIZE: tl.constexpr, 

190): 

191 pid = tl.program_id(0) 

192 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

193 mask = offsets < n_elements 

194 y = tl.load(y_ptr + offsets, mask=mask, other=0) 

195 tl.store(out_ptr + offsets, scalar - y, mask=mask) 

196 

197 

198@triton.jit(do_not_specialize=["scalar", "n_elements"]) 

199def _sub_scalar_tensor_int_single_program_kernel( 

200 scalar, 

201 y_ptr, 

202 out_ptr, 

203 n_elements, 

204 BLOCK_SIZE: tl.constexpr, 

205): 

206 offs = tl.arange(0, BLOCK_SIZE) 

207 for base in range(0, n_elements, BLOCK_SIZE): 

208 idx = base + offs 

209 mask = idx < n_elements 

210 y = tl.load(y_ptr + idx, mask=mask, other=0) 

211 tl.store(out_ptr + idx, scalar - y, mask=mask) 

212 

213 

214@pointwise_dynamic(is_tensor=[True, True, False], promotion_methods=[(0, 1, "DEFAULT")]) 

215@triton.jit 

216def sub_func(x, y, alpha): 

217 return x - y * alpha 

218 

219 

220@pointwise_dynamic( 

221 is_tensor=[True, False, False], promotion_methods=[(0, 1, "DEFAULT")] 

222) 

223@triton.jit 

224def sub_func_tensor_scalar(x, y, alpha): 

225 return x - y * alpha 

226 

227 

228@pointwise_dynamic( 

229 is_tensor=[False, True, False], promotion_methods=[(0, 1, "DEFAULT")] 

230) 

231@triton.jit 

232def sub_func_scalar_tensor(x, y, alpha): 

233 return x - y * alpha 

234 

235 

236def _select_block_size(n_elements, dtype): 

237 if n_elements <= 32: 

238 return 32 

239 if n_elements <= 1024: 

240 return 32 

241 if n_elements <= 8192: 

242 return 64 

243 return 256 if dtype in (torch.float16, torch.bfloat16) else 128 

244 

245 

246def _single_program_block(n_elements): 

247 if n_elements <= 256: 

248 return 32 

249 if n_elements <= 2048: 

250 return 128 

251 return 256 

252 

253 

254def _launch_sub_tensor_tensor(x, y, out, alpha, n_elements, block_size): 

255 if 1 < n_elements <= 8192: 

256 single_block = _single_program_block(n_elements) 

257 _sub_contiguous_single_program_kernel[(1,)]( 

258 x, 

259 y, 

260 out, 

261 alpha, 

262 n_elements, 

263 BLOCK_SIZE=single_block, 

264 num_warps=1, 

265 num_stages=1, 

266 ) 

267 return 

268 

269 grid = (triton.cdiv(n_elements, block_size),) 

270 _sub_contiguous_kernel[grid]( 

271 x, 

272 y, 

273 out, 

274 alpha, 

275 n_elements, 

276 BLOCK_SIZE=block_size, 

277 num_warps=1, 

278 num_stages=1, 

279 ) 

280 

281 

282def _launch_sub_tensor_scalar(x, scalar, out, alpha, n_elements, block_size): 

283 if 1 < n_elements <= 8192: 

284 single_block = _single_program_block(n_elements) 

285 _sub_tensor_scalar_single_program_kernel[(1,)]( 

286 x, 

287 scalar, 

288 out, 

289 alpha, 

290 n_elements, 

291 BLOCK_SIZE=single_block, 

292 num_warps=1, 

293 num_stages=1, 

294 ) 

295 return 

296 

297 grid = (triton.cdiv(n_elements, block_size),) 

298 _sub_tensor_scalar_kernel[grid]( 

299 x, 

300 scalar, 

301 out, 

302 alpha, 

303 n_elements, 

304 BLOCK_SIZE=block_size, 

305 num_warps=1, 

306 num_stages=1, 

307 ) 

308 

309 

310def _launch_sub_tensor_scalar_int(x, scalar, out, n_elements, block_size): 

311 if 1 < n_elements <= 8192: 

312 single_block = _single_program_block(n_elements) 

313 _sub_tensor_scalar_int_single_program_kernel[(1,)]( 

314 x, 

315 scalar, 

316 out, 

317 n_elements, 

318 BLOCK_SIZE=single_block, 

319 num_warps=1, 

320 num_stages=1, 

321 ) 

322 return 

323 

324 grid = (triton.cdiv(n_elements, block_size),) 

325 _sub_tensor_scalar_int_kernel[grid]( 

326 x, 

327 scalar, 

328 out, 

329 n_elements, 

330 BLOCK_SIZE=block_size, 

331 num_warps=1, 

332 num_stages=1, 

333 ) 

334 

335 

336def _launch_sub_broadcast_lastdim1(x, y, out, alpha): 

337 rows = x.numel() // x.shape[-1] 

338 cols = x.shape[-1] 

339 if rows == 0 or cols == 0: 

340 return 

341 if cols <= 1024: 

342 block_size = 64 

343 elif cols <= 4096: 

344 block_size = 128 

345 else: 

346 block_size = 256 

347 grid = (rows,) 

348 _sub_broadcast_lastdim1_kernel[grid]( 

349 x, 

350 y, 

351 out, 

352 alpha, 

353 rows, 

354 cols, 

355 BLOCK_SIZE=block_size, 

356 num_warps=1, 

357 num_stages=1, 

358 ) 

359 

360 

361def _launch_sub_scalar_tensor(scalar, y, out, alpha, n_elements, block_size): 

362 if 1 < n_elements <= 8192: 

363 single_block = _single_program_block(n_elements) 

364 _sub_scalar_tensor_single_program_kernel[(1,)]( 

365 scalar, 

366 y, 

367 out, 

368 alpha, 

369 n_elements, 

370 BLOCK_SIZE=single_block, 

371 num_warps=1, 

372 num_stages=1, 

373 ) 

374 return 

375 

376 grid = (triton.cdiv(n_elements, block_size),) 

377 _sub_scalar_tensor_kernel[grid]( 

378 scalar, 

379 y, 

380 out, 

381 alpha, 

382 n_elements, 

383 BLOCK_SIZE=block_size, 

384 num_warps=1, 

385 num_stages=1, 

386 ) 

387 

388 

389def _launch_sub_scalar_tensor_int(scalar, y, out, n_elements, block_size): 

390 if 1 < n_elements <= 8192: 

391 single_block = _single_program_block(n_elements) 

392 _sub_scalar_tensor_int_single_program_kernel[(1,)]( 

393 scalar, 

394 y, 

395 out, 

396 n_elements, 

397 BLOCK_SIZE=single_block, 

398 num_warps=1, 

399 num_stages=1, 

400 ) 

401 return 

402 

403 grid = (triton.cdiv(n_elements, block_size),) 

404 _sub_scalar_tensor_int_kernel[grid]( 

405 scalar, 

406 y, 

407 out, 

408 n_elements, 

409 BLOCK_SIZE=block_size, 

410 num_warps=1, 

411 num_stages=1, 

412 ) 

413 

414 

415def _can_use_contiguous_fastpath(a, b): 

416 return ( 

417 isinstance(a, torch.Tensor) 

418 and isinstance(b, torch.Tensor) 

419 and a.device.type == "cpu" 

420 and b.device == a.device 

421 and a.is_contiguous() 

422 and b.is_contiguous() 

423 and a.shape == b.shape 

424 and a.dtype == b.dtype 

425 and a.dtype in _SUPPORTED_FAST_DTYPES 

426 ) 

427 

428 

429def _can_use_broadcast_lastdim1_fastpath(a, b): 

430 return ( 

431 isinstance(a, torch.Tensor) 

432 and isinstance(b, torch.Tensor) 

433 and a.device.type == "cpu" 

434 and b.device == a.device 

435 and a.is_contiguous() 

436 and b.is_contiguous() 

437 and a.ndim >= 1 

438 and b.ndim == a.ndim 

439 and a.shape[:-1] == b.shape[:-1] 

440 and b.shape[-1] == 1 

441 and a.dtype == b.dtype 

442 and a.dtype in _SUPPORTED_FAST_DTYPES 

443 ) 

444 

445 

446def _can_use_tensor_scalar_int_fastpath(a, scalar, alpha): 

447 return ( 

448 isinstance(a, torch.Tensor) 

449 and a.device.type == "cpu" 

450 and a.is_contiguous() 

451 and a.dtype in _SUPPORTED_INT_FAST_DTYPES 

452 and isinstance(scalar, int) 

453 and int(alpha) == 1 

454 and float(alpha) == 1.0 

455 ) 

456 

457 

458def _can_use_scalar_tensor_fastpath(b, scalar): 

459 return ( 

460 isinstance(b, torch.Tensor) 

461 and b.device.type == "cpu" 

462 and b.is_contiguous() 

463 and b.dtype in _SUPPORTED_FAST_DTYPES 

464 and isinstance(scalar, (int, float)) 

465 ) 

466 

467 

468def _can_use_scalar_tensor_int_fastpath(b, scalar, alpha): 

469 return ( 

470 isinstance(b, torch.Tensor) 

471 and b.device.type == "cpu" 

472 and b.is_contiguous() 

473 and b.dtype in _SUPPORTED_INT_FAST_DTYPES 

474 and isinstance(scalar, int) 

475 and int(alpha) == 1 

476 and float(alpha) == 1.0 

477 ) 

478 

479 

480def _maybe_scalar(v): 

481 if isinstance(v, torch.Tensor) and v.numel() == 1: 

482 return v.item() 

483 if isinstance(v, (int, float)): 

484 return v 

485 return None 

486 

487 

488def _maybe_prewarm_sub_kernels(): 

489 global _PREWARM_SUB_DONE 

490 if _PREWARM_SUB_DONE: 

491 return 

492 if os.environ.get("GEMS_ARM_SUB_PREWARM", "1") != "1": 

493 _PREWARM_SUB_DONE = True 

494 return 

495 try: 

496 x = torch.zeros(8, dtype=torch.float32, device="cpu") 

497 y = torch.ones(8, dtype=torch.float32, device="cpu") 

498 out = torch.empty_like(x) 

499 _launch_sub_tensor_tensor(x, y, out, 1.0, x.numel(), 32) 

500 _launch_sub_tensor_scalar(x, 1.0, out, 1.0, x.numel(), 32) 

501 _launch_sub_scalar_tensor(1.0, x, out, 1.0, x.numel(), 32) 

502 

503 xi = torch.arange(8, dtype=torch.int64, device="cpu") 

504 oi = torch.empty_like(xi) 

505 _launch_sub_tensor_scalar_int(xi, 1, oi, xi.numel(), 32) 

506 _launch_sub_scalar_tensor_int(1, xi, oi, xi.numel(), 32) 

507 

508 xb = torch.zeros((1, 5, 32), dtype=torch.float32, device="cpu") 

509 yb = torch.zeros((1, 5, 1), dtype=torch.float32, device="cpu") 

510 ob = torch.empty_like(xb) 

511 _launch_sub_broadcast_lastdim1(xb, yb.view(-1), ob, 1.0) 

512 except Exception: 

513 logger.debug("GEMS ARM sub prewarm failed", exc_info=True) 

514 _PREWARM_SUB_DONE = True 

515 

516 

517def sub(A, B, *, alpha=1): 

518 logger.debug("GEMS SUB") 

519 _maybe_prewarm_sub_kernels() 

520 

521 if isinstance(A, torch.Tensor) and isinstance(B, torch.Tensor): 

522 if _can_use_contiguous_fastpath(A, B): 

523 out = torch.empty_like(A) 

524 block_size = _select_block_size(A.numel(), A.dtype) 

525 _launch_sub_tensor_tensor(A, B, out, float(alpha), A.numel(), block_size) 

526 return out 

527 if _can_use_broadcast_lastdim1_fastpath(A, B): 

528 out = torch.empty_like(A) 

529 _launch_sub_broadcast_lastdim1(A, B.view(-1), out, float(alpha)) 

530 return out 

531 return sub_func(A, B, alpha) 

532 

533 if isinstance(A, torch.Tensor): 

534 scalar = _maybe_scalar(B) 

535 if ( 

536 scalar is not None 

537 and A.device.type == "cpu" 

538 and A.is_contiguous() 

539 and A.dtype in _SUPPORTED_FAST_DTYPES 

540 ): 

541 out = torch.empty_like(A) 

542 block_size = _select_block_size(A.numel(), A.dtype) 

543 _launch_sub_tensor_scalar( 

544 A, float(scalar), out, float(alpha), A.numel(), block_size 

545 ) 

546 return out 

547 if _can_use_tensor_scalar_int_fastpath(A, scalar, alpha): 

548 out = torch.empty_like(A) 

549 block_size = _select_block_size(A.numel(), A.dtype) 

550 _launch_sub_tensor_scalar_int(A, int(scalar), out, A.numel(), block_size) 

551 return out 

552 return sub_func_tensor_scalar(A, B, alpha) 

553 

554 if isinstance(B, torch.Tensor): 

555 scalar = _maybe_scalar(A) 

556 if _can_use_scalar_tensor_fastpath(B, scalar): 

557 out = torch.empty_like(B) 

558 block_size = _select_block_size(B.numel(), B.dtype) 

559 _launch_sub_scalar_tensor( 

560 float(scalar), B, out, float(alpha), B.numel(), block_size 

561 ) 

562 return out 

563 if _can_use_scalar_tensor_int_fastpath(B, scalar, alpha): 

564 out = torch.empty_like(B) 

565 block_size = _select_block_size(B.numel(), B.dtype) 

566 _launch_sub_scalar_tensor_int(int(scalar), B, out, B.numel(), block_size) 

567 return out 

568 return sub_func_scalar_tensor(A, B, alpha) 

569 

570 return torch.tensor(A - B * alpha) 

571 

572 

573def sub_(A, B, *, alpha=1): 

574 logger.debug("GEMS SUB_") 

575 _maybe_prewarm_sub_kernels() 

576 

577 if isinstance(B, torch.Tensor): 

578 if _can_use_contiguous_fastpath(A, B): 

579 if A.untyped_storage().data_ptr() == B.untyped_storage().data_ptr(): 

580 return sub_func(A, B, alpha, out0=A) 

581 block_size = _select_block_size(A.numel(), A.dtype) 

582 _launch_sub_tensor_tensor(A, B, A, float(alpha), A.numel(), block_size) 

583 return A 

584 if _can_use_broadcast_lastdim1_fastpath(A, B): 

585 _launch_sub_broadcast_lastdim1(A, B.view(-1), A, float(alpha)) 

586 return A 

587 return sub_func(A, B, alpha, out0=A) 

588 

589 scalar = _maybe_scalar(B) 

590 if ( 

591 scalar is not None 

592 and isinstance(A, torch.Tensor) 

593 and A.device.type == "cpu" 

594 and A.is_contiguous() 

595 and A.dtype in _SUPPORTED_FAST_DTYPES 

596 ): 

597 block_size = _select_block_size(A.numel(), A.dtype) 

598 _launch_sub_tensor_scalar( 

599 A, float(scalar), A, float(alpha), A.numel(), block_size 

600 ) 

601 return A 

602 if _can_use_tensor_scalar_int_fastpath(A, scalar, alpha): 

603 block_size = _select_block_size(A.numel(), A.dtype) 

604 _launch_sub_tensor_scalar_int(A, int(scalar), A, A.numel(), block_size) 

605 return A 

606 

607 return sub_func_tensor_scalar(A, B, alpha, out0=A) 

608 

609 

610_maybe_prewarm_sub_kernels()