Coverage for src/flag_gems/ops/tril.py: 49%

424 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8 

9logger = logging.getLogger(__name__) 

10 

11 

12@triton.jit 

13def _tril_tile_kernel( 

14 in_ptr, 

15 out_ptr, 

16 diag: tl.constexpr, 

17 M: tl.constexpr, 

18 N: tl.constexpr, 

19 BLOCK_M: tl.constexpr, 

20 BLOCK_N: tl.constexpr, 

21): 

22 pid_m = tl.program_id(0) 

23 pid_n = tl.program_id(1) 

24 pid_b = tl.program_id(2) 

25 

26 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

27 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] 

28 mask = (offs_m < M) & (offs_n < N) 

29 

30 base = pid_b * (M * N) 

31 idxs = base + offs_m * N + offs_n 

32 

33 row_start = pid_m * BLOCK_M 

34 row_end = row_start + BLOCK_M - 1 

35 col_start = pid_n * BLOCK_N 

36 col_end = col_start + BLOCK_N - 1 

37 

38 if col_start > row_end + diag: 

39 tl.store(out_ptr + idxs, 0.0, mask=mask) 

40 return 

41 

42 if col_end <= row_start + diag: 

43 x = tl.load(in_ptr + idxs, mask=mask, other=0.0) 

44 tl.store(out_ptr + idxs, x, mask=mask) 

45 return 

46 

47 keep = offs_n <= (offs_m + diag) 

48 x = tl.load(in_ptr + idxs, mask=mask & keep, other=0.0) 

49 tl.store(out_ptr + idxs, x, mask=mask) 

50 

51 

52@triton.jit 

53def _tril_rows_kernel( 

54 in_ptr, 

55 out_ptr, 

56 diag, 

57 M: tl.constexpr, 

58 N: tl.constexpr, 

59 BLOCK_M: tl.constexpr, 

60 BLOCK_N: tl.constexpr, 

61): 

62 pid_m = tl.program_id(0) 

63 pid_b = tl.program_id(1) 

64 

65 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

66 row_mask = offs_m < M 

67 base = pid_b * (M * N) 

68 row_base = base + offs_m * N 

69 row_start = pid_m * BLOCK_M 

70 row_end = row_start + BLOCK_M - 1 

71 

72 for col_start in range(0, N, BLOCK_N): 

73 offs_n = col_start + tl.arange(0, BLOCK_N)[None, :] 

74 mask = row_mask & (offs_n < N) 

75 idxs = row_base + offs_n 

76 

77 col_end = col_start + BLOCK_N - 1 

78 if col_start > row_end + diag: 

79 tl.store(out_ptr + idxs, 0.0, mask=mask) 

80 elif col_end <= row_start + diag: 

81 x = tl.load(in_ptr + idxs, mask=mask, other=0.0) 

82 tl.store(out_ptr + idxs, x, mask=mask) 

83 else: 

84 keep = offs_n <= (offs_m + diag) 

85 x = tl.load(in_ptr + idxs, mask=mask & keep, other=0.0) 

86 tl.store(out_ptr + idxs, x, mask=mask) 

87 

88 

89@triton.jit 

90def _tril_flat_kernel( 

91 in_ptr, 

92 out_ptr, 

93 total, 

94 diag, 

95 M: tl.constexpr, 

96 N: tl.constexpr, 

97 BLOCK_SIZE: tl.constexpr, 

98): 

99 pid = tl.program_id(0) 

100 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

101 mask = offsets < total 

102 

103 matrix_offsets = offsets % (M * N) 

104 rows = matrix_offsets // N 

105 cols = matrix_offsets - rows * N 

106 keep = cols <= rows + diag 

107 

108 x = tl.load(in_ptr + offsets, mask=mask & keep, other=0.0) 

109 tl.store(out_ptr + offsets, x, mask=mask) 

110 

111 

112@triton.jit 

113def _tril_exact_row_kernel( 

114 in_ptr, 

115 out_ptr, 

116 diag, 

117 M: tl.constexpr, 

118 N: tl.constexpr, 

119 BLOCK_N: tl.constexpr, 

120): 

121 pid_m = tl.program_id(0) 

122 pid_b = tl.program_id(1) 

123 

124 offs_n = tl.arange(0, BLOCK_N) 

125 idxs = pid_b * (M * N) + pid_m * N + offs_n 

126 keep = offs_n <= pid_m + diag 

127 x = tl.load(in_ptr + idxs, mask=keep, other=0.0) 

128 tl.store(out_ptr + idxs, x) 

129 

130 

131@triton.jit 

132def _tril_exact_diag0_tile_kernel( 

133 in_ptr, 

134 out_ptr, 

135 M: tl.constexpr, 

136 N: tl.constexpr, 

137 BLOCK_M: tl.constexpr, 

138 BLOCK_N: tl.constexpr, 

139): 

140 pid_m = tl.program_id(0) 

141 pid_n = tl.program_id(1) 

142 pid_b = tl.program_id(2) 

143 

144 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

145 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] 

146 idxs = pid_b * (M * N) + offs_m * N + offs_n 

147 

148 row_start = pid_m * BLOCK_M 

149 row_end = row_start + BLOCK_M - 1 

150 col_start = pid_n * BLOCK_N 

151 col_end = col_start + BLOCK_N - 1 

152 

153 if col_start > row_end: 

154 tl.store(out_ptr + idxs, 0.0) 

155 return 

156 

157 if col_end <= row_start: 

158 x = tl.load(in_ptr + idxs) 

159 tl.store(out_ptr + idxs, x) 

160 return 

161 

162 keep = offs_n <= offs_m 

163 x = tl.load(in_ptr + idxs, mask=keep, other=0.0) 

164 tl.store(out_ptr + idxs, x) 

165 

166 

167@triton.jit 

168def _tril_inplace_zero_tile_kernel( 

169 ptr, 

170 diag: tl.constexpr, 

171 M: tl.constexpr, 

172 N: tl.constexpr, 

173 BLOCK_M: tl.constexpr, 

174 BLOCK_N: tl.constexpr, 

175): 

176 pid_m = tl.program_id(0) 

177 pid_n = tl.program_id(1) 

178 pid_b = tl.program_id(2) 

179 

180 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

181 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] 

182 mask = (offs_m < M) & (offs_n < N) 

183 idxs = pid_b * (M * N) + offs_m * N + offs_n 

184 

185 row_start = pid_m * BLOCK_M 

186 col_end = pid_n * BLOCK_N + BLOCK_N - 1 

187 if col_end <= row_start + diag: 

188 return 

189 

190 row_end = row_start + BLOCK_M - 1 

191 col_start = pid_n * BLOCK_N 

192 if col_start > row_end + diag: 

193 tl.store(ptr + idxs, 0.0, mask=mask) 

194 return 

195 

196 zero = offs_n > offs_m + diag 

197 tl.store(ptr + idxs, 0.0, mask=mask & zero) 

198 

199 

200@triton.jit 

201def _tril_inplace_zero_strided_tile_kernel( 

202 ptr, 

203 diag: tl.constexpr, 

204 M: tl.constexpr, 

205 N: tl.constexpr, 

206 B0: tl.constexpr, 

207 B1: tl.constexpr, 

208 B2: tl.constexpr, 

209 B3: tl.constexpr, 

210 B4: tl.constexpr, 

211 B5: tl.constexpr, 

212 S0: tl.constexpr, 

213 S1: tl.constexpr, 

214 S2: tl.constexpr, 

215 S3: tl.constexpr, 

216 S4: tl.constexpr, 

217 S5: tl.constexpr, 

218 STRIDE_M: tl.constexpr, 

219 STRIDE_N: tl.constexpr, 

220 BLOCK_M: tl.constexpr, 

221 BLOCK_N: tl.constexpr, 

222): 

223 pid_m = tl.program_id(0) 

224 pid_n = tl.program_id(1) 

225 pid_b = tl.program_id(2) 

226 

227 b = pid_b 

228 i5 = b % B5 

229 b = b // B5 

230 i4 = b % B4 

231 b = b // B4 

232 i3 = b % B3 

233 b = b // B3 

234 i2 = b % B2 

235 b = b // B2 

236 i1 = b % B1 

237 i0 = b // B1 

238 batch_offset = i0 * S0 + i1 * S1 + i2 * S2 + i3 * S3 + i4 * S4 + i5 * S5 

239 

240 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

241 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] 

242 mask = (offs_m < M) & (offs_n < N) 

243 idxs = batch_offset + offs_m * STRIDE_M + offs_n * STRIDE_N 

244 

245 row_start = pid_m * BLOCK_M 

246 col_end = pid_n * BLOCK_N + BLOCK_N - 1 

247 if col_end <= row_start + diag: 

248 return 

249 

250 row_end = row_start + BLOCK_M - 1 

251 col_start = pid_n * BLOCK_N 

252 if col_start > row_end + diag: 

253 tl.store(ptr + idxs, 0.0, mask=mask) 

254 return 

255 

256 zero = offs_n > offs_m + diag 

257 tl.store(ptr + idxs, 0.0, mask=mask & zero) 

258 

259 

260@triton.jit 

261def _tril_strided_out_tile_kernel( 

262 in_ptr, 

263 out_ptr, 

264 diag, 

265 M: tl.constexpr, 

266 N: tl.constexpr, 

267 B0: tl.constexpr, 

268 B1: tl.constexpr, 

269 B2: tl.constexpr, 

270 B3: tl.constexpr, 

271 B4: tl.constexpr, 

272 B5: tl.constexpr, 

273 S0: tl.constexpr, 

274 S1: tl.constexpr, 

275 S2: tl.constexpr, 

276 S3: tl.constexpr, 

277 S4: tl.constexpr, 

278 S5: tl.constexpr, 

279 STRIDE_M: tl.constexpr, 

280 STRIDE_N: tl.constexpr, 

281 BLOCK_M: tl.constexpr, 

282 BLOCK_N: tl.constexpr, 

283): 

284 pid_m = tl.program_id(0) 

285 pid_n = tl.program_id(1) 

286 pid_b = tl.program_id(2) 

287 

288 b = pid_b 

289 i5 = b % B5 

290 b = b // B5 

291 i4 = b % B4 

292 b = b // B4 

293 i3 = b % B3 

294 b = b // B3 

295 i2 = b % B2 

296 b = b // B2 

297 i1 = b % B1 

298 i0 = b // B1 

299 out_batch_offset = i0 * S0 + i1 * S1 + i2 * S2 + i3 * S3 + i4 * S4 + i5 * S5 

300 

301 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

302 offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)[None, :] 

303 mask = (offs_m < M) & (offs_n < N) 

304 in_idxs = pid_b * (M * N) + offs_m * N + offs_n 

305 out_idxs = out_batch_offset + offs_m * STRIDE_M + offs_n * STRIDE_N 

306 

307 row_start = pid_m * BLOCK_M 

308 row_end = row_start + BLOCK_M - 1 

309 col_start = pid_n * BLOCK_N 

310 col_end = col_start + BLOCK_N - 1 

311 

312 if col_start > row_end + diag: 

313 tl.store(out_ptr + out_idxs, 0.0, mask=mask) 

314 return 

315 

316 if col_end <= row_start + diag: 

317 x = tl.load(in_ptr + in_idxs, mask=mask, other=0.0) 

318 tl.store(out_ptr + out_idxs, x, mask=mask) 

319 return 

320 

321 keep = offs_n <= (offs_m + diag) 

322 x = tl.load(in_ptr + in_idxs, mask=mask & keep, other=0.0) 

323 tl.store(out_ptr + out_idxs, x, mask=mask) 

324 

325 

326def _check_input(input: torch.Tensor): 

327 if input.dim() < 2: 

328 raise RuntimeError("tril: input tensor must have at least 2 dimensions") 

329 

330 

331def _empty_contiguous_like(input: torch.Tensor): 

332 if input.is_contiguous(): 

333 return torch.empty_like(input) 

334 return torch.empty_like(input, memory_format=torch.contiguous_format) 

335 

336 

337def _zero_out(out: torch.Tensor): 

338 if out.numel() == 0: 

339 return out 

340 if out.is_contiguous(): 

341 return out.zero_() 

342 return out.fill_(0) 

343 

344 

345def _is_power_of_2(value: int): 

346 return value > 0 and (value & (value - 1)) == 0 

347 

348 

349def _has_internal_overlap_from_strides(tensor: torch.Tensor): 

350 span = 1 

351 strides_and_sizes = sorted( 

352 (stride, size) 

353 for size, stride in zip(tensor.shape, tensor.stride()) 

354 if size > 1 

355 ) 

356 for stride, size in strides_and_sizes: 

357 if stride < span: 

358 return True 

359 span += stride * (size - 1) 

360 return False 

361 

362 

363def _tensors_overlap(left: torch.Tensor, right: torch.Tensor): 

364 try: 

365 return torch._C._overlaps(left, right) 

366 except AttributeError: 

367 return True 

368 

369 

370def _can_use_strided_out_kernel(input: torch.Tensor, out: torch.Tensor): 

371 if out.is_contiguous() or out.numel() == 0: 

372 return False 

373 if out.dim() - 2 > 6: 

374 return False 

375 if _has_internal_overlap_from_strides(out): 

376 return False 

377 if input.is_contiguous() and _tensors_overlap(input, out): 

378 return False 

379 return True 

380 

381 

382_WIDE_EXACT_ROW_MIN_N = 2048 

383_WIDE_EXACT_ROW_MAX_N = 8192 

384_WIDE_EXACT_ROW_MIN_ROWS = 256 

385_WIDE_EXACT_ROW_ALWAYS_ROW_M = 512 

386_TINY_BATCHED_TILE_MIN_BATCH = 128 

387 

388 

389def _use_wide_exact_row(M: int, N: int, batch: int): 

390 # One exact-row program covers one matrix row with BLOCK_N == N. Use it for 

391 # wide power-of-two rows where it avoids the flat kernel's div/mod indexing, 

392 # but require enough row programs to keep occupancy reasonable. 

393 if N < _WIDE_EXACT_ROW_MIN_N or N > _WIDE_EXACT_ROW_MAX_N or not _is_power_of_2(N): 

394 return False 

395 

396 rows = M * batch 

397 if M >= _WIDE_EXACT_ROW_ALWAYS_ROW_M: 

398 return True 

399 return N <= 4096 and rows >= _WIDE_EXACT_ROW_MIN_ROWS 

400 

401 

402def _use_tiny_batched_tile(M: int, N: int, batch: int): 

403 return batch >= _TINY_BATCHED_TILE_MIN_BATCH and M <= 32 and N <= 32 

404 

405 

406def _wide_exact_row_warps(N: int): 

407 if N <= 4096: 

408 return 2 

409 return 4 

410 

411 

412def _launch_tile( 

413 input: torch.Tensor, 

414 out: torch.Tensor, 

415 diagonal: int, 

416 block_m: int = 32, 

417 block_n: int = 32, 

418 num_warps: int = 4, 

419 num_stages: int = 2, 

420): 

421 M, N = input.shape[-2:] 

422 total = input.numel() 

423 if total == 0: 

424 return out 

425 

426 batch = total // (M * N) 

427 grid = (triton.cdiv(M, block_m), triton.cdiv(N, block_n), batch) 

428 with torch_device_fn.device(input.device): 

429 _tril_tile_kernel[grid]( 

430 input, 

431 out, 

432 int(diagonal), 

433 M, 

434 N, 

435 BLOCK_M=block_m, 

436 BLOCK_N=block_n, 

437 num_warps=num_warps, 

438 num_stages=num_stages, 

439 ) 

440 return out 

441 

442 

443def _launch_flat( 

444 input: torch.Tensor, 

445 out: torch.Tensor, 

446 diagonal: int, 

447 block_size: int = 1024, 

448 num_warps: int = 4, 

449 num_stages: int = 2, 

450): 

451 M, N = input.shape[-2:] 

452 total = input.numel() 

453 if total == 0: 

454 return out 

455 

456 grid = (triton.cdiv(total, block_size),) 

457 with torch_device_fn.device(input.device): 

458 _tril_flat_kernel[grid]( 

459 input, 

460 out, 

461 total, 

462 int(diagonal), 

463 M, 

464 N, 

465 BLOCK_SIZE=block_size, 

466 num_warps=num_warps, 

467 num_stages=num_stages, 

468 ) 

469 return out 

470 

471 

472def _launch_rows( 

473 input: torch.Tensor, 

474 out: torch.Tensor, 

475 diagonal: int, 

476 block_m: int = 32, 

477 block_n: int = 64, 

478 num_warps: int = 4, 

479 num_stages: int = 2, 

480): 

481 M, N = input.shape[-2:] 

482 total = input.numel() 

483 if total == 0: 

484 return out 

485 

486 batch = total // (M * N) 

487 grid = (triton.cdiv(M, block_m), batch) 

488 with torch_device_fn.device(input.device): 

489 _tril_rows_kernel[grid]( 

490 input, 

491 out, 

492 int(diagonal), 

493 M, 

494 N, 

495 BLOCK_M=block_m, 

496 BLOCK_N=block_n, 

497 num_warps=num_warps, 

498 num_stages=num_stages, 

499 ) 

500 return out 

501 

502 

503def _launch_exact_row( 

504 input: torch.Tensor, 

505 out: torch.Tensor, 

506 diagonal: int, 

507 num_warps: int = 4, 

508 num_stages: int = 2, 

509): 

510 M, N = input.shape[-2:] 

511 total = input.numel() 

512 if total == 0: 

513 return out 

514 

515 batch = total // (M * N) 

516 grid = (M, batch) 

517 with torch_device_fn.device(input.device): 

518 _tril_exact_row_kernel[grid]( 

519 input, 

520 out, 

521 int(diagonal), 

522 M, 

523 N, 

524 BLOCK_N=N, 

525 num_warps=num_warps, 

526 num_stages=num_stages, 

527 ) 

528 return out 

529 

530 

531def _launch_exact_diag0_tile( 

532 input: torch.Tensor, 

533 out: torch.Tensor, 

534 block_m: int, 

535 block_n: int, 

536 num_warps: int = 4, 

537 num_stages: int = 2, 

538): 

539 M, N = input.shape[-2:] 

540 total = input.numel() 

541 if total == 0: 

542 return out 

543 

544 batch = total // (M * N) 

545 grid = (triton.cdiv(M, block_m), triton.cdiv(N, block_n), batch) 

546 with torch_device_fn.device(input.device): 

547 _tril_exact_diag0_tile_kernel[grid]( 

548 input, 

549 out, 

550 M, 

551 N, 

552 BLOCK_M=block_m, 

553 BLOCK_N=block_n, 

554 num_warps=num_warps, 

555 num_stages=num_stages, 

556 ) 

557 return out 

558 

559 

560def _launch_tril_inplace_contiguous( 

561 input: torch.Tensor, 

562 diagonal: int, 

563 block_m: int = 16, 

564 block_n: int = 64, 

565 num_warps: int = 4, 

566 num_stages: int = 2, 

567): 

568 M, N = input.shape[-2:] 

569 if input.numel() == 0: 

570 return input 

571 

572 active_rows = min(M, max(0, N - 1 - diagonal)) 

573 if active_rows == 0: 

574 return input 

575 

576 batch = input.numel() // (M * N) 

577 grid = (triton.cdiv(active_rows, block_m), triton.cdiv(N, block_n), batch) 

578 with torch_device_fn.device(input.device): 

579 _tril_inplace_zero_tile_kernel[grid]( 

580 input, 

581 int(diagonal), 

582 M, 

583 N, 

584 BLOCK_M=block_m, 

585 BLOCK_N=block_n, 

586 num_warps=num_warps, 

587 num_stages=num_stages, 

588 ) 

589 return input 

590 

591 

592def _launch_tril_inplace_strided( 

593 input: torch.Tensor, 

594 diagonal: int, 

595 block_m: int = 16, 

596 block_n: int = 64, 

597 num_warps: int = 4, 

598 num_stages: int = 2, 

599): 

600 M, N = input.shape[-2:] 

601 if input.numel() == 0: 

602 return input 

603 

604 active_rows = min(M, max(0, N - 1 - diagonal)) 

605 if active_rows == 0: 

606 return input 

607 

608 batch_shape = list(input.shape[:-2]) 

609 batch_strides = list(input.stride()[:-2]) 

610 batch = 1 

611 for size in batch_shape: 

612 batch *= size 

613 

614 if len(batch_shape) > 6: 

615 tmp = _empty_contiguous_like(input) 

616 _launch_tril(input, tmp, diagonal) 

617 input.copy_(tmp) 

618 return input 

619 

620 batch_shape.extend([1] * (6 - len(batch_shape))) 

621 batch_strides.extend([0] * (6 - len(batch_strides))) 

622 stride_m, stride_n = input.stride()[-2:] 

623 

624 grid = (triton.cdiv(active_rows, block_m), triton.cdiv(N, block_n), batch) 

625 with torch_device_fn.device(input.device): 

626 _tril_inplace_zero_strided_tile_kernel[grid]( 

627 input, 

628 int(diagonal), 

629 M, 

630 N, 

631 B0=batch_shape[0], 

632 B1=batch_shape[1], 

633 B2=batch_shape[2], 

634 B3=batch_shape[3], 

635 B4=batch_shape[4], 

636 B5=batch_shape[5], 

637 S0=batch_strides[0], 

638 S1=batch_strides[1], 

639 S2=batch_strides[2], 

640 S3=batch_strides[3], 

641 S4=batch_strides[4], 

642 S5=batch_strides[5], 

643 STRIDE_M=stride_m, 

644 STRIDE_N=stride_n, 

645 BLOCK_M=block_m, 

646 BLOCK_N=block_n, 

647 num_warps=num_warps, 

648 num_stages=num_stages, 

649 ) 

650 return input 

651 

652 

653def _launch_tril_strided_out( 

654 input: torch.Tensor, 

655 out: torch.Tensor, 

656 diagonal: int, 

657 block_m: int = 32, 

658 block_n: int = 64, 

659 num_warps: int = 4, 

660 num_stages: int = 2, 

661): 

662 M, N = input.shape[-2:] 

663 if input.numel() == 0: 

664 return out 

665 

666 input_to_use = input if input.is_contiguous() else input.contiguous() 

667 batch_shape = list(out.shape[:-2]) 

668 batch_strides = list(out.stride()[:-2]) 

669 batch = 1 

670 for size in batch_shape: 

671 batch *= size 

672 

673 batch_shape.extend([1] * (6 - len(batch_shape))) 

674 batch_strides.extend([0] * (6 - len(batch_strides))) 

675 stride_m, stride_n = out.stride()[-2:] 

676 

677 grid = (triton.cdiv(M, block_m), triton.cdiv(N, block_n), batch) 

678 with torch_device_fn.device(input.device): 

679 _tril_strided_out_tile_kernel[grid]( 

680 input_to_use, 

681 out, 

682 int(diagonal), 

683 M, 

684 N, 

685 B0=batch_shape[0], 

686 B1=batch_shape[1], 

687 B2=batch_shape[2], 

688 B3=batch_shape[3], 

689 B4=batch_shape[4], 

690 B5=batch_shape[5], 

691 S0=batch_strides[0], 

692 S1=batch_strides[1], 

693 S2=batch_strides[2], 

694 S3=batch_strides[3], 

695 S4=batch_strides[4], 

696 S5=batch_strides[5], 

697 STRIDE_M=stride_m, 

698 STRIDE_N=stride_n, 

699 BLOCK_M=block_m, 

700 BLOCK_N=block_n, 

701 num_warps=num_warps, 

702 num_stages=num_stages, 

703 ) 

704 return out 

705 

706 

707def _launch_tril(input: torch.Tensor, out: torch.Tensor, diagonal: int): 

708 M, N = input.shape[-2:] 

709 if input.numel() == 0: 

710 return out 

711 

712 if diagonal <= -M: 

713 return _zero_out(out) 

714 if diagonal >= N - 1: 

715 out.copy_(input) 

716 return out 

717 

718 input_to_use = input if input.is_contiguous() else input.contiguous() 

719 batch = input_to_use.numel() // (M * N) 

720 if _use_wide_exact_row(M, N, batch): 

721 return _launch_exact_row( 

722 input_to_use, 

723 out, 

724 diagonal, 

725 num_warps=_wide_exact_row_warps(N), 

726 ) 

727 if batch == 1 and M == 1024 and N == 1024 and diagonal == 0: 

728 return _launch_exact_diag0_tile( 

729 input_to_use, 

730 out, 

731 block_m=32, 

732 block_n=64, 

733 num_warps=4, 

734 ) 

735 if batch >= 1 and M == 512 and N == 512 and diagonal == 0: 

736 return _launch_exact_diag0_tile( 

737 input_to_use, 

738 out, 

739 block_m=16, 

740 block_n=128, 

741 num_warps=4, 

742 ) 

743 if _use_tiny_batched_tile(M, N, batch): 

744 return _launch_tile( 

745 input_to_use, 

746 out, 

747 diagonal, 

748 block_m=16, 

749 block_n=64, 

750 num_warps=2, 

751 ) 

752 if M <= 64 and N <= 64: 

753 return _launch_rows( 

754 input_to_use, 

755 out, 

756 diagonal, 

757 block_m=2, 

758 block_n=64, 

759 num_warps=1, 

760 ) 

761 if N >= 2048: 

762 return _launch_flat( 

763 input_to_use, 

764 out, 

765 diagonal, 

766 block_size=4096, 

767 num_warps=4, 

768 ) 

769 if batch > 1: 

770 if M >= 256 and N >= 256: 

771 return _launch_tile( 

772 input_to_use, 

773 out, 

774 diagonal, 

775 block_m=16, 

776 block_n=64, 

777 num_warps=4, 

778 ) 

779 return _launch_rows( 

780 input_to_use, 

781 out, 

782 diagonal, 

783 block_m=8, 

784 block_n=512, 

785 num_warps=1, 

786 ) 

787 if N >= 512: 

788 return _launch_tile( 

789 input_to_use, 

790 out, 

791 diagonal, 

792 block_m=64, 

793 block_n=64, 

794 num_warps=4, 

795 ) 

796 if M == 256 and N == 256: 

797 return _launch_rows( 

798 input_to_use, 

799 out, 

800 diagonal, 

801 block_m=8, 

802 block_n=256, 

803 num_warps=2, 

804 ) 

805 return _launch_rows( 

806 input_to_use, 

807 out, 

808 diagonal, 

809 block_m=8, 

810 block_n=512, 

811 num_warps=1, 

812 ) 

813 

814 

815def tril(input: torch.Tensor, diagonal: int = 0): 

816 logger.debug("GEMS TRIL") 

817 _check_input(input) 

818 

819 out = _empty_contiguous_like(input) 

820 return _launch_tril(input, out, int(diagonal)) 

821 

822 

823def tril_(input: torch.Tensor, diagonal: int = 0): 

824 logger.debug("GEMS TRIL_") 

825 _check_input(input) 

826 

827 diagonal = int(diagonal) 

828 if input.numel() == 0: 

829 return input 

830 

831 M, N = input.shape[-2:] 

832 if diagonal >= N - 1: 

833 return input 

834 if diagonal <= -M: 

835 return _zero_out(input) 

836 

837 if input.is_contiguous(): 

838 return _launch_tril_inplace_contiguous(input, diagonal) 

839 

840 return _launch_tril_inplace_strided(input, diagonal) 

841 

842 

843def tril_out(input: torch.Tensor, diagonal: int = 0, *, out: torch.Tensor = None): 

844 logger.debug("GEMS TRIL.OUT") 

845 

846 if out is None: 

847 return tril(input, diagonal) 

848 

849 _check_input(input) 

850 if out.dtype != input.dtype: 

851 raise RuntimeError( 

852 f"Expected out tensor to have dtype {input.dtype}, but got {out.dtype} instead" 

853 ) 

854 if out.device != input.device: 

855 raise RuntimeError( 

856 f"Expected out tensor to be on device {input.device}, but got {out.device} instead" 

857 ) 

858 if out.shape != input.shape: 

859 out.resize_(input.shape) 

860 

861 if out.is_contiguous(): 

862 return _launch_tril(input, out, int(diagonal)) 

863 

864 if input.numel() == 0: 

865 return out 

866 M, N = input.shape[-2:] 

867 if diagonal <= -M: 

868 return _zero_out(out) 

869 if diagonal >= N - 1: 

870 out.copy_(input) 

871 return out 

872 

873 if _can_use_strided_out_kernel(input, out): 

874 batch = input.numel() // (M * N) 

875 if M <= 64 and N <= 64: 

876 return _launch_tril_strided_out( 

877 input, 

878 out, 

879 int(diagonal), 

880 block_m=16, 

881 block_n=64, 

882 num_warps=2, 

883 ) 

884 if batch > 1 and M >= 256 and N >= 256: 

885 return _launch_tril_strided_out( 

886 input, 

887 out, 

888 int(diagonal), 

889 block_m=16, 

890 block_n=64, 

891 num_warps=4, 

892 ) 

893 return _launch_tril_strided_out( 

894 input, 

895 out, 

896 int(diagonal), 

897 block_m=32, 

898 block_n=64, 

899 num_warps=4, 

900 ) 

901 

902 tmp = _empty_contiguous_like(input) 

903 _launch_tril(input, tmp, int(diagonal)) 

904 out.copy_(tmp) 

905 return out