Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/eq.py: 0%
33 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
2import os
4import triton
5import triton.language as tl
6from _kunlunxin.utils.codegen_config_utils import CodeGenConfig
8from flag_gems.runtime import device
10from ..utils.pointwise_dynamic import pointwise_dynamic
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13device = device.name
15config_ = CodeGenConfig(
16 512,
17 (65536, 65536, 65536),
18 32,
19 True,
20 prefer_1d_tile=True,
21 isCloseMemoryAsync=False,
22 kunlunAutoGrid=True,
23 unroll_num=8,
24)
27@pointwise_dynamic(
28 promotion_methods=[(0, 1, "ALWAYS_BOOL")],
29 config=config_,
30)
31@triton.jit
32def eq_func(x, y):
33 return x.to(tl.float32) == y.to(tl.float32)
36def eq(A, B):
37 if A.device != B.device:
38 if A.device.type == device:
39 B = B.to(A.device)
40 else:
41 A = A.to(B.device)
42 logger.debug("GEMS_KUNLUNXIN EQ")
43 os.environ["TRITONXPU_COMPARE_FUSION"] = "1"
44 os.environ["TRITONXPU_FP16_FAST"] = "1"
45 res = eq_func(A, B)
46 del os.environ["TRITONXPU_COMPARE_FUSION"]
47 del os.environ["TRITONXPU_FP16_FAST"]
48 return res
51@pointwise_dynamic(is_tensor=[True, False], promotion_methods=[(0, 1, "ALWAYS_BOOL")])
52@triton.jit
53def eq_func_scalar(x, y):
54 return x.to(tl.float32) == y.to(tl.float32)
57def eq_scalar(A, B):
58 logger.debug("GEMS_KUNLUNXIN EQ SCALAR")
59 return eq_func_scalar(A, B)