Coverage for src/flag_gems/ops/segment_reduce.py: 35%

591 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2import math 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger(__name__) 

13 

14_BLOCK_SIZE = 1024 

15_NPU_BLOCK_SIZE = 256 

16_UNIFORM_FAST_PATH_MIN_NUMEL = 1 << 20 

17_UNIFORM_KERNEL_MAX_SEGMENT_LENGTH = 1024 

18_UNIFORM_LENGTHS_CACHE = {} 

19_SUPPORTED_REDUCES = ("sum", "mean", "max", "min", "prod") 

20_SUPPORTED_DATA_DTYPES = ( 

21 torch.float16, 

22 torch.bfloat16, 

23 torch.float32, 

24 torch.float64, 

25) 

26_SUPPORTED_INDEX_DTYPES = (torch.int32, torch.int64) 

27 

28 

29def _prod(shape): 

30 return math.prod(shape) if shape else 1 

31 

32 

33def _get_block_size(device): 

34 return _NPU_BLOCK_SIZE if device.type == "npu" else _BLOCK_SIZE 

35 

36 

37def _get_uniform_kernel_config(device, inner_size): 

38 if device.type == "npu": 

39 return 4, 16 if inner_size > 1 else 1 

40 if inner_size == 1: 

41 return 16, 1 

42 return 4, 64 

43 

44 

45def _get_uniform_backward_tile_config(device, inner_size, reduce, dtype): 

46 if device.type == "npu": 

47 return 1, 16 if inner_size > 1 else 1 

48 if reduce == "prod" and inner_size > 1 and dtype in (torch.float16, torch.bfloat16): 

49 return 4, 256 

50 return 4, 64 if inner_size > 1 else 1 

51 

52 

53@triton.jit 

54def _mul_combine(a, b): 

55 return a * b 

56 

57 

58def _all_lengths_equal(lengths, value): 

59 cache_key = ( 

60 lengths.device.type, 

61 lengths.data_ptr(), 

62 tuple(lengths.shape), 

63 getattr(lengths, "_version", None), 

64 value, 

65 ) 

66 is_equal = _UNIFORM_LENGTHS_CACHE.get(cache_key) 

67 if is_equal is None: 

68 is_equal = torch.all(lengths.detach().cpu() == value).item() 

69 if len(_UNIFORM_LENGTHS_CACHE) > 128: 

70 _UNIFORM_LENGTHS_CACHE.clear() 

71 _UNIFORM_LENGTHS_CACHE[cache_key] = is_equal 

72 return is_equal 

73 

74 

75def _wrap_axis(axis, ndim): 

76 if ndim == 0: 

77 raise IndexError( 

78 "segment_reduce(): input tensor must have at least one dimension." 

79 ) 

80 if axis < -ndim or axis >= ndim: 

81 raise IndexError( 

82 f"segment_reduce(): axis {axis} is out of bounds for tensor of dimension {ndim}." 

83 ) 

84 return axis % ndim 

85 

86 

87def _check_reduce_and_dtype(data, reduce): 

88 if reduce not in _SUPPORTED_REDUCES: 

89 raise RuntimeError( 

90 "segment_reduce(): reduce must be one of 'sum', 'mean', 'max', 'min', or 'prod'." 

91 ) 

92 if data.dtype not in _SUPPORTED_DATA_DTYPES: 

93 raise NotImplementedError(f'"segment_reduce" not implemented for {data.dtype}.') 

94 

95 

96def _check_index_tensor(data, index_tensor, name, axis): 

97 if index_tensor.dtype not in _SUPPORTED_INDEX_DTYPES: 

98 raise NotImplementedError(f"segment_reduce(): {name} must be int32 or int64.") 

99 if index_tensor.device != data.device: 

100 raise RuntimeError( 

101 f"segment_reduce(): Expected data and {name} on the same device." 

102 ) 

103 if data.dim() < index_tensor.dim(): 

104 raise RuntimeError( 

105 f"segment_reduce(): Expected data.dim() >= {name}.dim(), got " 

106 f"{data.dim()} and {index_tensor.dim()}." 

107 ) 

108 if axis != index_tensor.dim() - 1: 

109 raise RuntimeError( 

110 f"segment_reduce(): Expected axis to be the last dimension of {name} " 

111 f"but got {axis}." 

112 ) 

113 

114 

115def _validate_lengths(data, lengths, axis, unsafe): 

116 _check_index_tensor(data, lengths, "lengths", axis) 

117 if unsafe: 

118 return 

119 lengths_cpu = lengths.detach().cpu() 

120 if torch.any(lengths_cpu < 0).item(): 

121 raise RuntimeError("lengths contains negative value!") 

122 valid_lengths = torch.all(lengths_cpu.sum(dim=-1) == data.size(axis)).item() 

123 if not valid_lengths: 

124 raise RuntimeError( 

125 "segment_reduce(): Expected all rows of lengths along axis to sum to " 

126 "data.size(lengths.dim()-1) when !unsafe." 

127 ) 

128 

129 

130def _make_initial(reduce, initial): 

131 if initial is not None: 

132 return True, initial 

133 if reduce == "max": 

134 return False, float("-inf") 

135 if reduce == "min": 

136 return False, float("inf") 

137 if reduce == "prod": 

138 return False, 1.0 

139 return False, 0.0 

140 

141 

142def _get_uniform_segment_length(data, lengths, axis): 

143 if data.numel() < _UNIFORM_FAST_PATH_MIN_NUMEL: 

144 return None 

145 if tuple(lengths.shape[:-1]) != tuple(data.shape[:axis]): 

146 return None 

147 segment_count = lengths.shape[-1] 

148 if segment_count <= 0: 

149 return None 

150 data_size_axis = data.shape[axis] 

151 if data_size_axis % segment_count != 0: 

152 return None 

153 segment_length = data_size_axis // segment_count 

154 if segment_length <= 0: 

155 return None 

156 

157 if _all_lengths_equal(lengths, segment_length): 

158 return segment_length 

159 return None 

160 

161 

162def _is_unit_lengths(data, lengths, axis): 

163 if tuple(lengths.shape[:-1]) != tuple(data.shape[:axis]): 

164 return False 

165 if lengths.shape[-1] != data.shape[axis]: 

166 return False 

167 return _all_lengths_equal(lengths, 1) 

168 

169 

170@libentry() 

171@triton.jit 

172def _segment_reduce_uniform_other_backward_kernel( 

173 grad, 

174 output, 

175 data, 

176 grad_input, 

177 total_rows, 

178 segment_count, 

179 segment_length, 

180 inner_size, 

181 data_size_axis, 

182 IS_MAX: tl.constexpr, 

183 IS_MIN: tl.constexpr, 

184 IS_PROD: tl.constexpr, 

185 INITIAL_PROD_VALUE: tl.constexpr, 

186 BLOCK_M: tl.constexpr, 

187 BLOCK_N: tl.constexpr, 

188 BLOCK_K: tl.constexpr, 

189): 

190 pid_m = tle.program_id(0) 

191 pid_k = tle.program_id(1) 

192 data_dtype = data.dtype.element_ty 

193 compute_dtype = tl.float64 if data_dtype is tl.float64 else tl.float32 

194 

195 rows = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None, None] 

196 seg_offsets = tl.arange(0, BLOCK_N)[None, :, None] 

197 k_offsets = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)[None, None, :] 

198 output_rows = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

199 output_k_offsets = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)[None, :] 

200 row_mask = rows < total_rows 

201 seg_mask = seg_offsets < segment_length 

202 k_mask = k_offsets < inner_size 

203 mask = row_mask & seg_mask & k_mask 

204 

205 outer_idx = rows // segment_count 

206 dim_idx = rows - outer_idx * segment_count 

207 data_offsets = ( 

208 outer_idx * data_size_axis * inner_size 

209 + (dim_idx * segment_length + seg_offsets) * inner_size 

210 + k_offsets 

211 ) 

212 output_offsets = output_rows * inner_size + output_k_offsets 

213 output_mask = (output_rows < total_rows) & (output_k_offsets < inner_size) 

214 

215 values = tl.load(data + data_offsets, mask=mask, other=0.0).to(compute_dtype) 

216 grad_value = tl.load(grad + output_offsets, mask=output_mask, other=0.0).to( 

217 compute_dtype 

218 ) 

219 output_value = tl.load(output + output_offsets, mask=output_mask, other=0.0).to( 

220 compute_dtype 

221 ) 

222 

223 if IS_MAX or IS_MIN: 

224 match = ((values != values) | (values == output_value[:, None, :])) & mask 

225 counter = tl.sum(match.to(tl.int64), axis=1) 

226 store_value = tl.where( 

227 (counter >= 2) & (grad_value > 0), 

228 grad_value / counter, 

229 grad_value, 

230 ) 

231 tl.store(grad_input + data_offsets, store_value[:, None, :], mask=match) 

232 elif IS_PROD: 

233 nan_mask = (values != values) & mask 

234 zero_mask = (values == 0) & mask & ~nan_mask 

235 zero_count = tl.sum(zero_mask.to(tl.int64), axis=1) 

236 nan_count = tl.sum(nan_mask.to(tl.int64), axis=1) 

237 product_values = tl.where(nan_mask | zero_mask | ~mask, 1.0, values) 

238 product = tl.reduce(product_values, axis=1, combine_fn=_mul_combine) 

239 product *= INITIAL_PROD_VALUE 

240 

241 zero_scalar = tl.full((BLOCK_M, BLOCK_K), 0.0, dtype=compute_dtype) 

242 nan_scalar = zero_scalar / zero_scalar 

243 normal_prefix = grad_value * output_value 

244 normal_grad = normal_prefix[:, None, :] / values 

245 zero_exclusive = tl.where( 

246 nan_count > 0, 

247 nan_scalar, 

248 tl.where(zero_count > 1, zero_scalar, product), 

249 ) 

250 nan_exclusive = tl.where( 

251 nan_count > 1, 

252 nan_scalar, 

253 tl.where(zero_count > 0, zero_scalar, product), 

254 ) 

255 exclusive = tl.where( 

256 nan_mask, nan_exclusive[:, None, :], zero_exclusive[:, None, :] 

257 ) 

258 grad_result = tl.where( 

259 nan_mask | zero_mask, 

260 grad_value[:, None, :] * exclusive, 

261 normal_grad, 

262 ) 

263 tl.store(grad_input + data_offsets, grad_result, mask=mask) 

264 

265 

266@libentry() 

267@triton.jit 

268def _segment_reduce_uniform_sum_mean_backward_kernel( 

269 grad, 

270 grad_input, 

271 total_numel, 

272 segment_count, 

273 segment_length, 

274 inner_size, 

275 data_size_axis, 

276 IS_MEAN: tl.constexpr, 

277 BLOCK_SIZE: tl.constexpr, 

278): 

279 pid = tle.program_id(0) 

280 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

281 mask = offsets < total_numel 

282 

283 inner_idx = offsets % inner_size 

284 axis_idx = (offsets // inner_size) % data_size_axis 

285 outer_idx = offsets // (data_size_axis * inner_size) 

286 segment_idx = axis_idx // segment_length 

287 grad_offsets = (outer_idx * segment_count + segment_idx) * inner_size + inner_idx 

288 

289 grad_value = tl.load(grad + grad_offsets, mask=mask, other=0.0) 

290 if IS_MEAN: 

291 grad_value = grad_value / segment_length 

292 tl.store(grad_input + offsets, grad_value, mask=mask) 

293 

294 

295@libentry() 

296@triton.jit 

297def _segment_reduce_uniform_inner1_forward_kernel( 

298 data, 

299 output, 

300 total_rows, 

301 segment_count, 

302 segment_length, 

303 data_size_axis, 

304 IS_SUM: tl.constexpr, 

305 IS_MEAN: tl.constexpr, 

306 IS_MAX: tl.constexpr, 

307 IS_MIN: tl.constexpr, 

308 IS_PROD: tl.constexpr, 

309 BLOCK_M: tl.constexpr, 

310 BLOCK_N: tl.constexpr, 

311): 

312 pid = tle.program_id(0) 

313 data_dtype = data.dtype.element_ty 

314 compute_dtype = tl.float64 if data_dtype is tl.float64 else tl.float32 

315 

316 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

317 cols = tl.arange(0, BLOCK_N)[None, :] 

318 row_mask = rows < total_rows 

319 col_mask = cols < segment_length 

320 mask = row_mask & col_mask 

321 

322 outer_idx = rows // segment_count 

323 dim_idx = rows - outer_idx * segment_count 

324 data_offsets = outer_idx * data_size_axis + dim_idx * segment_length + cols 

325 

326 if IS_SUM or IS_MEAN: 

327 values = tl.load(data + data_offsets, mask=mask, other=0.0).to(compute_dtype) 

328 result = tl.sum(values, axis=1) 

329 if IS_MEAN: 

330 result = result / segment_length 

331 elif IS_PROD: 

332 values = tl.load(data + data_offsets, mask=mask, other=1.0).to(compute_dtype) 

333 result = tl.reduce(values, axis=1, combine_fn=_mul_combine) 

334 elif IS_MAX: 

335 values = tl.load(data + data_offsets, mask=mask, other=float("-inf")).to( 

336 compute_dtype 

337 ) 

338 nan_mask = (values != values) & mask 

339 has_nan = tl.sum(nan_mask.to(tl.int32), axis=1) > 0 

340 nan_value = tl.sum(tl.where(nan_mask, values, 0.0), axis=1) 

341 result = tl.max(tl.where(mask & ~nan_mask, values, float("-inf")), axis=1) 

342 result = tl.where(has_nan, nan_value, result) 

343 elif IS_MIN: 

344 values = tl.load(data + data_offsets, mask=mask, other=float("inf")).to( 

345 compute_dtype 

346 ) 

347 nan_mask = (values != values) & mask 

348 has_nan = tl.sum(nan_mask.to(tl.int32), axis=1) > 0 

349 nan_value = tl.sum(tl.where(nan_mask, values, 0.0), axis=1) 

350 result = tl.min(tl.where(mask & ~nan_mask, values, float("inf")), axis=1) 

351 result = tl.where(has_nan, nan_value, result) 

352 

353 output_offsets = pid * BLOCK_M + tl.arange(0, BLOCK_M) 

354 tl.store(output + output_offsets, result, mask=output_offsets < total_rows) 

355 

356 

357@libentry() 

358@triton.jit 

359def _segment_reduce_uniform_forward_kernel( 

360 data, 

361 output, 

362 total_rows, 

363 segment_count, 

364 segment_length, 

365 inner_size, 

366 data_size_axis, 

367 IS_SUM: tl.constexpr, 

368 IS_MEAN: tl.constexpr, 

369 IS_MAX: tl.constexpr, 

370 IS_MIN: tl.constexpr, 

371 IS_PROD: tl.constexpr, 

372 BLOCK_M: tl.constexpr, 

373 BLOCK_K: tl.constexpr, 

374): 

375 pid_m = tle.program_id(0) 

376 pid_k = tle.program_id(1) 

377 data_dtype = data.dtype.element_ty 

378 compute_dtype = tl.float64 if data_dtype is tl.float64 else tl.float32 

379 

380 rows = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

381 k_offsets = pid_k * BLOCK_K + tl.arange(0, BLOCK_K)[None, :] 

382 row_mask = rows < total_rows 

383 k_mask = k_offsets < inner_size 

384 mask = row_mask & k_mask 

385 

386 outer_idx = rows // segment_count 

387 dim_idx = rows - outer_idx * segment_count 

388 segment_start = dim_idx * segment_length 

389 base_offsets = ( 

390 outer_idx * data_size_axis * inner_size + segment_start * inner_size + k_offsets 

391 ) 

392 

393 if IS_MAX: 

394 acc = tl.full((BLOCK_M, BLOCK_K), float("-inf"), dtype=compute_dtype) 

395 elif IS_MIN: 

396 acc = tl.full((BLOCK_M, BLOCK_K), float("inf"), dtype=compute_dtype) 

397 elif IS_PROD: 

398 acc = tl.full((BLOCK_M, BLOCK_K), 1.0, dtype=compute_dtype) 

399 else: 

400 acc = tl.zeros((BLOCK_M, BLOCK_K), dtype=compute_dtype) 

401 

402 has_nan = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.int1) 

403 nan_value = tl.zeros((BLOCK_M, BLOCK_K), dtype=compute_dtype) 

404 

405 pos = 0 

406 while pos < segment_length: 

407 data_offsets = base_offsets + pos * inner_size 

408 if IS_SUM or IS_MEAN: 

409 values = tl.load(data + data_offsets, mask=mask, other=0.0).to( 

410 compute_dtype 

411 ) 

412 acc += values 

413 elif IS_PROD: 

414 values = tl.load(data + data_offsets, mask=mask, other=1.0).to( 

415 compute_dtype 

416 ) 

417 acc *= values 

418 elif IS_MAX: 

419 values = tl.load(data + data_offsets, mask=mask, other=float("-inf")).to( 

420 compute_dtype 

421 ) 

422 nan_mask = (values != values) & mask 

423 has_nan |= nan_mask 

424 nan_value = tl.where(nan_mask, values, nan_value) 

425 acc = tl.maximum(acc, tl.where(mask & ~nan_mask, values, float("-inf"))) 

426 elif IS_MIN: 

427 values = tl.load(data + data_offsets, mask=mask, other=float("inf")).to( 

428 compute_dtype 

429 ) 

430 nan_mask = (values != values) & mask 

431 has_nan |= nan_mask 

432 nan_value = tl.where(nan_mask, values, nan_value) 

433 acc = tl.minimum(acc, tl.where(mask & ~nan_mask, values, float("inf"))) 

434 pos += 1 

435 

436 if IS_MEAN: 

437 acc = acc / segment_length 

438 if IS_MAX or IS_MIN: 

439 acc = tl.where(has_nan, nan_value, acc) 

440 

441 output_offsets = rows * inner_size + k_offsets 

442 tl.store(output + output_offsets, acc, mask=mask) 

443 

444 

445def _segment_reduce_uniform_lengths(data, reduce, lengths, axis): 

446 segment_count = lengths.shape[-1] 

447 segment_length = _get_uniform_segment_length(data, lengths, axis) 

448 if segment_length is None: 

449 return None 

450 

451 output_shape = lengths.shape + data.shape[axis + 1 :] 

452 inner_size = _prod(data.shape[axis + 1 :]) 

453 if segment_length <= _UNIFORM_KERNEL_MAX_SEGMENT_LENGTH: 

454 output = torch.empty(output_shape, dtype=data.dtype, device=data.device) 

455 if output.numel() == 0: 

456 return output 

457 total_rows = _prod(lengths.shape) 

458 if inner_size == 1: 

459 block_m = 4 if data.device.type == "npu" else 32 

460 block_n = min( 

461 _get_block_size(data.device), 

462 triton.next_power_of_2(segment_length), 

463 ) 

464 grid = (triton.cdiv(total_rows, block_m),) 

465 with torch_device_fn.device(data.device): 

466 _segment_reduce_uniform_inner1_forward_kernel[grid]( 

467 data, 

468 output, 

469 total_rows, 

470 segment_count, 

471 segment_length, 

472 data.shape[axis], 

473 reduce == "sum", 

474 reduce == "mean", 

475 reduce == "max", 

476 reduce == "min", 

477 reduce == "prod", 

478 BLOCK_M=block_m, 

479 BLOCK_N=block_n, 

480 ) 

481 return output 

482 

483 block_m, block_k = _get_uniform_kernel_config(data.device, inner_size) 

484 grid = (triton.cdiv(total_rows, block_m), triton.cdiv(inner_size, block_k)) 

485 with torch_device_fn.device(data.device): 

486 _segment_reduce_uniform_forward_kernel[grid]( 

487 data, 

488 output, 

489 total_rows, 

490 segment_count, 

491 segment_length, 

492 inner_size, 

493 data.shape[axis], 

494 reduce == "sum", 

495 reduce == "mean", 

496 reduce == "max", 

497 reduce == "min", 

498 reduce == "prod", 

499 BLOCK_M=block_m, 

500 BLOCK_K=block_k, 

501 ) 

502 return output 

503 

504 if data.device.type == "npu": 

505 return None 

506 

507 view_shape = ( 

508 data.shape[:axis] + (segment_count, segment_length) + data.shape[axis + 1 :] 

509 ) 

510 reshaped = data.reshape(view_shape) 

511 reduce_dim = axis + 1 

512 

513 if segment_length == 1: 

514 return torch.squeeze(reshaped, dim=reduce_dim) 

515 if reduce == "sum": 

516 return torch.sum(reshaped, dim=reduce_dim) 

517 if reduce == "mean": 

518 return torch.mean(reshaped, dim=reduce_dim) 

519 if reduce == "max": 

520 return torch.amax(reshaped, dim=reduce_dim) 

521 if reduce == "min": 

522 return torch.amin(reshaped, dim=reduce_dim) 

523 return torch.prod(reshaped, dim=reduce_dim) 

524 

525 

526def _segment_reduce_uniform_sum_mean_backward(data, grad, reduce, lengths, axis): 

527 segment_count = lengths.shape[-1] 

528 segment_length = _get_uniform_segment_length(data, lengths, axis) 

529 if segment_length is None: 

530 return None 

531 

532 grad_input = torch.empty_like(data, dtype=grad.dtype) 

533 if grad_input.numel() == 0: 

534 return grad_input 

535 

536 inner_size = _prod(data.shape[axis + 1 :]) 

537 block_size = _get_block_size(data.device) 

538 grid = (triton.cdiv(data.numel(), block_size),) 

539 with torch_device_fn.device(data.device): 

540 _segment_reduce_uniform_sum_mean_backward_kernel[grid]( 

541 grad, 

542 grad_input, 

543 data.numel(), 

544 segment_count, 

545 segment_length, 

546 inner_size, 

547 data.shape[axis], 

548 reduce == "mean", 

549 BLOCK_SIZE=block_size, 

550 ) 

551 return grad_input 

552 

553 

554def _segment_reduce_uniform_other_backward( 

555 data, output, grad, reduce, lengths, axis, initial 

556): 

557 segment_count = lengths.shape[-1] 

558 segment_length = _get_uniform_segment_length(data, lengths, axis) 

559 if segment_length is None or segment_length > _UNIFORM_KERNEL_MAX_SEGMENT_LENGTH: 

560 return None 

561 

562 if reduce in ("max", "min"): 

563 grad_input = torch.zeros_like(data, dtype=grad.dtype) 

564 else: 

565 grad_input = torch.empty_like(data, dtype=grad.dtype) 

566 if grad_input.numel() == 0: 

567 return grad_input 

568 

569 inner_size = _prod(data.shape[axis + 1 :]) 

570 total_rows = _prod(lengths.shape) 

571 block_m, block_k = _get_uniform_backward_tile_config( 

572 data.device, inner_size, reduce, data.dtype 

573 ) 

574 block_n = min(_get_block_size(data.device), triton.next_power_of_2(segment_length)) 

575 _, initial_prod_value = _make_initial("prod", initial) 

576 grid = (triton.cdiv(total_rows, block_m), triton.cdiv(inner_size, block_k)) 

577 with torch_device_fn.device(data.device): 

578 _segment_reduce_uniform_other_backward_kernel[grid]( 

579 grad, 

580 output, 

581 data, 

582 grad_input, 

583 total_rows, 

584 segment_count, 

585 segment_length, 

586 inner_size, 

587 data.shape[axis], 

588 reduce == "max", 

589 reduce == "min", 

590 reduce == "prod", 

591 initial_prod_value, 

592 BLOCK_M=block_m, 

593 BLOCK_N=block_n, 

594 BLOCK_K=block_k, 

595 ) 

596 return grad_input 

597 

598 

599@libentry() 

600@triton.jit 

601def _lengths_to_offsets_kernel( 

602 lengths, 

603 offsets, 

604 outer_count, 

605 segment_count, 

606): 

607 pid = tle.program_id(0) 

608 acc = tl.full((), 0, dtype=tl.int64) 

609 base_lengths = pid * segment_count 

610 base_offsets = pid * (segment_count + 1) 

611 tl.store(offsets + base_offsets, acc) 

612 

613 idx = 0 

614 while idx < segment_count: 

615 length = tl.load(lengths + base_lengths + idx) 

616 acc += length 

617 tl.store(offsets + base_offsets + idx + 1, acc) 

618 idx += 1 

619 

620 

621@libentry() 

622@triton.jit 

623def _segment_reduce_forward_kernel( 

624 data, 

625 offsets, 

626 output, 

627 segment_count, 

628 inner_size, 

629 data_size_axis, 

630 IS_SUM: tl.constexpr, 

631 IS_MEAN: tl.constexpr, 

632 IS_MAX: tl.constexpr, 

633 IS_MIN: tl.constexpr, 

634 IS_PROD: tl.constexpr, 

635 HAS_INITIAL: tl.constexpr, 

636 INITIAL_VALUE: tl.constexpr, 

637 BLOCK_SIZE: tl.constexpr, 

638): 

639 pid = tle.program_id(0) 

640 data_dtype = data.dtype.element_ty 

641 compute_dtype = tl.float64 if data_dtype is tl.float64 else tl.float32 

642 

643 inner_idx = pid % inner_size 

644 row_idx = pid // inner_size 

645 dim_idx = row_idx % segment_count 

646 outer_idx = row_idx // segment_count 

647 

648 offsets_base = outer_idx * (segment_count + 1) + dim_idx 

649 segment_start = tl.load(offsets + offsets_base) 

650 segment_end = tl.load(offsets + offsets_base + 1) 

651 segment_length = segment_end - segment_start 

652 

653 acc = tl.full((), INITIAL_VALUE, dtype=compute_dtype) 

654 if IS_PROD: 

655 pos = segment_start 

656 while pos < segment_end: 

657 data_offset = ( 

658 outer_idx * data_size_axis * inner_size + pos * inner_size + inner_idx 

659 ) 

660 value = tl.load(data + data_offset).to(compute_dtype) 

661 acc *= value 

662 pos += 1 

663 else: 

664 pos = segment_start 

665 while pos < segment_end: 

666 segment_offsets = pos + tl.arange(0, BLOCK_SIZE) 

667 mask = segment_offsets < segment_end 

668 data_offsets = ( 

669 outer_idx * data_size_axis * inner_size 

670 + segment_offsets * inner_size 

671 + inner_idx 

672 ) 

673 

674 if IS_SUM or IS_MEAN: 

675 values = tl.load(data + data_offsets, mask=mask, other=0.0).to( 

676 compute_dtype 

677 ) 

678 acc += tl.sum(tl.where(mask, values, 0.0), axis=0) 

679 elif IS_MAX: 

680 values = tl.load( 

681 data + data_offsets, mask=mask, other=float("-inf") 

682 ).to(compute_dtype) 

683 nan_mask = (values != values) & mask 

684 has_nan = tl.sum(nan_mask.to(tl.int32), axis=0) > 0 

685 nan_value = tl.sum(tl.where(nan_mask, values, 0.0), axis=0) 

686 chunk = tl.max( 

687 tl.where(mask & ~nan_mask, values, float("-inf")), axis=0 

688 ) 

689 chunk = tl.where(has_nan, nan_value, chunk) 

690 acc = tl.where(has_nan, chunk, tl.maximum(acc, chunk)) 

691 elif IS_MIN: 

692 values = tl.load(data + data_offsets, mask=mask, other=float("inf")).to( 

693 compute_dtype 

694 ) 

695 nan_mask = (values != values) & mask 

696 has_nan = tl.sum(nan_mask.to(tl.int32), axis=0) > 0 

697 nan_value = tl.sum(tl.where(nan_mask, values, 0.0), axis=0) 

698 chunk = tl.min(tl.where(mask & ~nan_mask, values, float("inf")), axis=0) 

699 chunk = tl.where(has_nan, nan_value, chunk) 

700 acc = tl.where(has_nan, chunk, tl.minimum(acc, chunk)) 

701 pos += BLOCK_SIZE 

702 

703 if IS_MEAN: 

704 acc_is_nan = acc != acc 

705 nan_value = acc / acc 

706 if not HAS_INITIAL: 

707 acc = tl.where(segment_length == 0, nan_value, acc) 

708 acc = tl.where((segment_length > 0) & ~acc_is_nan, acc / segment_length, acc) 

709 

710 tl.store(output + pid, acc) 

711 

712 

713@libentry() 

714@triton.jit 

715def _segment_reduce_backward_kernel( 

716 grad, 

717 output, 

718 data, 

719 offsets, 

720 grad_input, 

721 segment_count, 

722 inner_size, 

723 data_size_axis, 

724 IS_SUM: tl.constexpr, 

725 IS_MEAN: tl.constexpr, 

726 IS_MAX: tl.constexpr, 

727 IS_MIN: tl.constexpr, 

728 IS_PROD: tl.constexpr, 

729 INITIAL_PROD_VALUE: tl.constexpr, 

730 BLOCK_SIZE: tl.constexpr, 

731): 

732 pid = tle.program_id(0) 

733 data_dtype = data.dtype.element_ty 

734 compute_dtype = tl.float64 if data_dtype is tl.float64 else tl.float32 

735 

736 inner_idx = pid % inner_size 

737 row_idx = pid // inner_size 

738 dim_idx = row_idx % segment_count 

739 outer_idx = row_idx // segment_count 

740 

741 offsets_base = outer_idx * (segment_count + 1) + dim_idx 

742 segment_start = tl.load(offsets + offsets_base) 

743 segment_end = tl.load(offsets + offsets_base + 1) 

744 segment_length = segment_end - segment_start 

745 

746 if segment_length > 0: 

747 grad_value = tl.load(grad + pid).to(compute_dtype) 

748 output_value = tl.load(output + pid).to(compute_dtype) 

749 

750 if IS_SUM or IS_MEAN: 

751 if IS_MEAN: 

752 grad_value = grad_value / segment_length 

753 pos = segment_start 

754 while pos < segment_end: 

755 segment_offsets = pos + tl.arange(0, BLOCK_SIZE) 

756 mask = segment_offsets < segment_end 

757 data_offsets = ( 

758 outer_idx * data_size_axis * inner_size 

759 + segment_offsets * inner_size 

760 + inner_idx 

761 ) 

762 tl.store(grad_input + data_offsets, grad_value, mask=mask) 

763 pos += BLOCK_SIZE 

764 elif IS_MAX or IS_MIN: 

765 counter = tl.full((), 0, dtype=tl.int64) 

766 pos = segment_start 

767 while pos < segment_end: 

768 segment_offsets = pos + tl.arange(0, BLOCK_SIZE) 

769 mask = segment_offsets < segment_end 

770 data_offsets = ( 

771 outer_idx * data_size_axis * inner_size 

772 + segment_offsets * inner_size 

773 + inner_idx 

774 ) 

775 values = tl.load(data + data_offsets, mask=mask, other=0.0).to( 

776 compute_dtype 

777 ) 

778 match = ((values != values) | (values == output_value)) & mask 

779 counter += tl.sum(match.to(tl.int64), axis=0) 

780 pos += BLOCK_SIZE 

781 

782 store_value = tl.where( 

783 (counter >= 2) & (grad_value > 0), 

784 grad_value / counter, 

785 grad_value, 

786 ) 

787 pos = segment_start 

788 while pos < segment_end: 

789 segment_offsets = pos + tl.arange(0, BLOCK_SIZE) 

790 mask = segment_offsets < segment_end 

791 data_offsets = ( 

792 outer_idx * data_size_axis * inner_size 

793 + segment_offsets * inner_size 

794 + inner_idx 

795 ) 

796 values = tl.load(data + data_offsets, mask=mask, other=0.0).to( 

797 compute_dtype 

798 ) 

799 match = ((values != values) | (values == output_value)) & mask 

800 tl.store(grad_input + data_offsets, store_value, mask=match) 

801 pos += BLOCK_SIZE 

802 elif IS_PROD: 

803 zero_count = tl.full((), 0, dtype=tl.int64) 

804 nan_count = tl.full((), 0, dtype=tl.int64) 

805 product = tl.full((), INITIAL_PROD_VALUE, dtype=compute_dtype) 

806 pos = segment_start 

807 while pos < segment_end: 

808 data_offset = ( 

809 outer_idx * data_size_axis * inner_size 

810 + pos * inner_size 

811 + inner_idx 

812 ) 

813 value = tl.load(data + data_offset).to(compute_dtype) 

814 if value != value: 

815 nan_count += 1 

816 elif value == 0: 

817 zero_count += 1 

818 else: 

819 product *= value 

820 pos += 1 

821 

822 zero_scalar = tl.full((), 0.0, dtype=compute_dtype) 

823 nan_scalar = zero_scalar / zero_scalar 

824 normal_prefix = grad_value * output_value 

825 pos = segment_start 

826 while pos < segment_end: 

827 segment_offsets = pos + tl.arange(0, BLOCK_SIZE) 

828 mask = segment_offsets < segment_end 

829 data_offsets = ( 

830 outer_idx * data_size_axis * inner_size 

831 + segment_offsets * inner_size 

832 + inner_idx 

833 ) 

834 values = tl.load(data + data_offsets, mask=mask, other=1.0).to( 

835 compute_dtype 

836 ) 

837 nan_mask = (values != values) & mask 

838 zero_mask = (values == 0) & mask & ~nan_mask 

839 normal_grad = normal_prefix / values 

840 zero_exclusive = tl.where( 

841 nan_count > 0, 

842 nan_scalar, 

843 tl.where(zero_count > 1, zero_scalar, product), 

844 ) 

845 nan_exclusive = tl.where( 

846 nan_count > 1, 

847 nan_scalar, 

848 tl.where(zero_count > 0, zero_scalar, product), 

849 ) 

850 exclusive = tl.where(nan_mask, nan_exclusive, zero_exclusive) 

851 grad_result = tl.where( 

852 nan_mask | zero_mask, 

853 grad_value * exclusive, 

854 normal_grad, 

855 ) 

856 tl.store(grad_input + data_offsets, grad_result, mask=mask) 

857 pos += BLOCK_SIZE 

858 

859 

860def _lengths_to_offsets(lengths): 

861 segment_count = lengths.shape[-1] 

862 offsets_shape = lengths.shape[:-1] + (segment_count + 1,) 

863 offsets = torch.empty(offsets_shape, dtype=lengths.dtype, device=lengths.device) 

864 outer_count = _prod(lengths.shape[:-1]) 

865 if offsets.numel() > 0: 

866 with torch_device_fn.device(lengths.device): 

867 _lengths_to_offsets_kernel[(outer_count,)]( 

868 lengths, 

869 offsets, 

870 outer_count, 

871 segment_count, 

872 ) 

873 return offsets 

874 

875 

876def _prepare_common(data, reduce, lengths, offsets, indices, axis, unsafe): 

877 _check_reduce_and_dtype(data, reduce) 

878 axis = _wrap_axis(axis, data.dim()) 

879 if indices is not None: 

880 raise RuntimeError( 

881 "segment_reduce(): indices based reduction is not supported yet." 

882 ) 

883 

884 if offsets is not None: 

885 _check_index_tensor(data, offsets, "offsets", axis) 

886 offsets_contig = offsets.contiguous() 

887 segment_count = offsets_contig.shape[-1] - 1 

888 output_shape = ( 

889 offsets_contig.shape[:-1] + (segment_count,) + data.shape[axis + 1 :] 

890 ) 

891 return axis, offsets_contig, output_shape, True 

892 

893 if lengths is None: 

894 raise RuntimeError( 

895 "segment_reduce(): Either lengths or offsets must be defined." 

896 ) 

897 

898 _validate_lengths(data, lengths, axis, unsafe) 

899 lengths_contig = lengths.contiguous() 

900 offsets_contig = _lengths_to_offsets(lengths_contig) 

901 output_shape = lengths_contig.shape + data.shape[axis + 1 :] 

902 return axis, offsets_contig, output_shape, False 

903 

904 

905def segment_reduce( 

906 data, 

907 reduce, 

908 *, 

909 lengths=None, 

910 indices=None, 

911 offsets=None, 

912 axis=0, 

913 unsafe=False, 

914 initial=None, 

915): 

916 logger.debug("GEMS SEGMENT_REDUCE") 

917 _check_reduce_and_dtype(data, reduce) 

918 axis = _wrap_axis(axis, data.dim()) 

919 if indices is not None: 

920 raise RuntimeError( 

921 "segment_reduce(): indices based reduction is not supported yet." 

922 ) 

923 

924 if initial is None and lengths is not None and offsets is None: 

925 _check_index_tensor(data, lengths, "lengths", axis) 

926 if _is_unit_lengths(data, lengths, axis): 

927 return data.contiguous() 

928 

929 data_contig = data.contiguous() 

930 uniform_result = _segment_reduce_uniform_lengths( 

931 data_contig, reduce, lengths, axis 

932 ) 

933 if uniform_result is not None: 

934 return uniform_result 

935 

936 axis, offsets_contig, output_shape, _ = _prepare_common( 

937 data, reduce, lengths, offsets, indices, axis, unsafe 

938 ) 

939 

940 data_contig = data.contiguous() 

941 output = torch.empty(output_shape, dtype=data.dtype, device=data.device) 

942 if output.numel() == 0: 

943 return output 

944 

945 segment_count = output_shape[axis] 

946 inner_size = _prod(data_contig.shape[axis + 1 :]) 

947 data_size_axis = data_contig.shape[axis] 

948 has_initial, initial_value = _make_initial(reduce, initial) 

949 grid = (output.numel(),) 

950 

951 with torch_device_fn.device(data.device): 

952 _segment_reduce_forward_kernel[grid]( 

953 data_contig, 

954 offsets_contig, 

955 output, 

956 segment_count, 

957 inner_size, 

958 data_size_axis, 

959 reduce == "sum", 

960 reduce == "mean", 

961 reduce == "max", 

962 reduce == "min", 

963 reduce == "prod", 

964 has_initial, 

965 initial_value, 

966 BLOCK_SIZE=_get_block_size(data.device), 

967 ) 

968 return output 

969 

970 

971def segment_reduce_out( 

972 data, 

973 reduce, 

974 *, 

975 lengths=None, 

976 indices=None, 

977 offsets=None, 

978 axis=0, 

979 unsafe=False, 

980 initial=None, 

981 out, 

982): 

983 logger.debug("GEMS SEGMENT_REDUCE_OUT") 

984 result = segment_reduce( 

985 data, 

986 reduce, 

987 lengths=lengths, 

988 indices=indices, 

989 offsets=offsets, 

990 axis=axis, 

991 unsafe=unsafe, 

992 initial=initial, 

993 ) 

994 if out.shape != result.shape: 

995 out.resize_(result.shape) 

996 out.copy_(result) 

997 return out 

998 

999 

1000def _segment_reduce_backward( 

1001 grad, 

1002 output, 

1003 data, 

1004 reduce, 

1005 *, 

1006 lengths=None, 

1007 offsets=None, 

1008 axis=0, 

1009 initial=None, 

1010): 

1011 logger.debug("GEMS _SEGMENT_REDUCE_BACKWARD") 

1012 if ( 

1013 initial is None 

1014 and lengths is not None 

1015 and offsets is None 

1016 and reduce in _SUPPORTED_REDUCES 

1017 ): 

1018 _check_reduce_and_dtype(data, reduce) 

1019 axis = _wrap_axis(axis, data.dim()) 

1020 _check_index_tensor(data, lengths, "lengths", axis) 

1021 if _is_unit_lengths(data, lengths, axis): 

1022 return grad.contiguous() 

1023 

1024 if lengths is not None and offsets is None and reduce in ("sum", "mean"): 

1025 _check_reduce_and_dtype(data, reduce) 

1026 axis = _wrap_axis(axis, data.dim()) 

1027 _check_index_tensor(data, lengths, "lengths", axis) 

1028 data_contig = data.contiguous() 

1029 grad_contig = grad.contiguous() 

1030 uniform_result = _segment_reduce_uniform_sum_mean_backward( 

1031 data_contig, grad_contig, reduce, lengths, axis 

1032 ) 

1033 if uniform_result is not None: 

1034 return uniform_result 

1035 if lengths is not None and offsets is None and reduce in ("max", "min", "prod"): 

1036 _check_reduce_and_dtype(data, reduce) 

1037 axis = _wrap_axis(axis, data.dim()) 

1038 _check_index_tensor(data, lengths, "lengths", axis) 

1039 data_contig = data.contiguous() 

1040 grad_contig = grad.contiguous() 

1041 output_contig = output.contiguous() 

1042 uniform_result = _segment_reduce_uniform_other_backward( 

1043 data_contig, output_contig, grad_contig, reduce, lengths, axis, initial 

1044 ) 

1045 if uniform_result is not None: 

1046 return uniform_result 

1047 

1048 axis, offsets_contig, output_shape, _ = _prepare_common( 

1049 data, reduce, lengths, offsets, None, axis, True 

1050 ) 

1051 data_contig = data.contiguous() 

1052 grad_contig = grad.contiguous() 

1053 output_contig = output.contiguous() 

1054 grad_input = torch.zeros(data_contig.shape, dtype=grad.dtype, device=grad.device) 

1055 

1056 if output_contig.numel() == 0: 

1057 return grad_input 

1058 

1059 segment_count = output_shape[axis] 

1060 inner_size = _prod(data_contig.shape[axis + 1 :]) 

1061 data_size_axis = data_contig.shape[axis] 

1062 _, initial_prod_value = _make_initial("prod", initial) 

1063 grid = (output_contig.numel(),) 

1064 

1065 with torch_device_fn.device(data.device): 

1066 _segment_reduce_backward_kernel[grid]( 

1067 grad_contig, 

1068 output_contig, 

1069 data_contig, 

1070 offsets_contig, 

1071 grad_input, 

1072 segment_count, 

1073 inner_size, 

1074 data_size_axis, 

1075 reduce == "sum", 

1076 reduce == "mean", 

1077 reduce == "max", 

1078 reduce == "min", 

1079 reduce == "prod", 

1080 initial_prod_value, 

1081 BLOCK_SIZE=_get_block_size(data.device), 

1082 ) 

1083 return grad_input 

1084 

1085 

1086def _segment_reduce_backward_out( 

1087 grad, 

1088 output, 

1089 data, 

1090 reduce, 

1091 *, 

1092 lengths=None, 

1093 offsets=None, 

1094 axis=0, 

1095 initial=None, 

1096 out, 

1097): 

1098 logger.debug("GEMS _SEGMENT_REDUCE_BACKWARD_OUT") 

1099 result = _segment_reduce_backward( 

1100 grad, 

1101 output, 

1102 data, 

1103 reduce, 

1104 lengths=lengths, 

1105 offsets=offsets, 

1106 axis=axis, 

1107 initial=initial, 

1108 ) 

1109 if out.shape != result.shape: 

1110 out.resize_(result.shape) 

1111 out.copy_(result) 

1112 return out