Coverage for src/flag_gems/ops/assert_async.py: 82%

11 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6@triton.jit 

7def _assert_async_kernel(x_ptr, MSG: tl.constexpr): 

8 val = tl.load(x_ptr) 

9 tl.device_assert(val != 0, MSG) 

10 

11 

12def _assert_async(tensor: torch.Tensor, msg: str = "Assertion failed"): 

13 if tensor.numel() != 1: 

14 raise RuntimeError( 

15 f"Boolean value of Tensor with shape {list(tensor.shape)} is ambiguous" 

16 ) 

17 _assert_async_kernel[(1,)](tensor, MSG=msg)