Coverage for src/flag_gems/runtime/backend/_sunrise/ops/eq.py: 0%
36 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7import flag_gems
8from flag_gems.runtime import device
9from flag_gems.utils import pointwise_dynamic
10from flag_gems.utils.pointwise_dynamic import CodeGenConfig
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13device = device.name
15MAX_GRID_SIZES = (65535, 65535, 65535)
16config = CodeGenConfig(
17 max_tile_size=1024,
18 max_grid_size=MAX_GRID_SIZES,
19 max_num_warps_per_cta=32,
20 prefer_block_pointer=True,
21 prefer_1d_tile=True,
22)
25@pointwise_dynamic(promotion_methods=[(0, 1, "ALWAYS_BOOL")], config=config)
26@triton.jit
27def eq_func(x, y):
28 return x.to(tl.float32) == y.to(tl.float32)
31def eq(A, B):
32 if A.device != B.device:
33 if A.device.type == device:
34 B = B.to(A.device)
35 else:
36 A = A.to(B.device)
37 logger.debug("GEMS EQ")
38 return eq_func(A, B)
41@pointwise_dynamic(
42 is_tensor=[True, False], promotion_methods=[(0, 1, "ALWAYS_BOOL")], config=config
43)
44@triton.jit
45def eq_func_scalar(x, y):
46 return x.to(tl.float32) == y.to(tl.float32)
49def eq_scalar(A, B):
50 logger.debug("GEMS EQ SCALAR")
51 return eq_func_scalar(A, B)
54def equal(x: torch.Tensor, y: torch.Tensor) -> bool:
55 logger.debug("GEMS EQUAL")
56 if x.shape != y.shape:
57 return False
58 eq_tensor = eq(x, y)
59 return bool(flag_gems.all(eq_tensor).item())