Coverage for src/flag_gems/ops/assert_async.py: 82%
11 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import torch
2import triton
3import triton.language as tl
6@triton.jit
7def _assert_async_kernel(x_ptr, MSG: tl.constexpr):
8 val = tl.load(x_ptr)
9 tl.device_assert(val != 0, MSG)
12def _assert_async(tensor: torch.Tensor, msg: str = "Assertion failed"):
13 if tensor.numel() != 1:
14 raise RuntimeError(
15 f"Boolean value of Tensor with shape {list(tensor.shape)} is ambiguous"
16 )
17 _assert_async_kernel[(1,)](tensor, MSG=msg)