Coverage for src/flag_gems/runtime/backend/_sunrise/ops/bitwise_and.py: 0%
30 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
3import triton
5from flag_gems.utils import pointwise_dynamic
6from flag_gems.utils.pointwise_dynamic import CodeGenConfig
8logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
11MAX_GRID_SIZES = (65535, 65535, 65535)
12config = CodeGenConfig(
13 max_tile_size=2048,
14 max_grid_size=MAX_GRID_SIZES,
15 max_num_warps_per_cta=32,
16 prefer_block_pointer=True,
17 prefer_1d_tile=True,
18)
21@pointwise_dynamic(promotion_methods=[(0, 1, "DEFAULT")], config=config)
22@triton.jit
23def bitwise_and_func(x, y):
24 return x & y
27def bitwise_and_tensor(A, B):
28 logger.debug("GEMS BITWISE AND")
29 return bitwise_and_func(A, B)
32def bitwise_and_tensor_(A, B):
33 logger.debug("GEMS BITWISE AND_")
34 return bitwise_and_func(A, B, out0=A)
37@pointwise_dynamic(
38 is_tensor=[True, False], promotion_methods=[(0, 1, "DEFAULT")], config=config
39)
40@triton.jit
41def bitwise_and_func_scalar(x, y):
42 return x & y
45def bitwise_and_scalar(A, B):
46 logger.debug("GEMS BITWISE AND SCALAR")
47 return bitwise_and_func_scalar(A, B)
50def bitwise_and_scalar_(A, B):
51 logger.debug("GEMS BITWISE AND_ SCALAR")
52 return bitwise_and_func_scalar(A, B, out0=A)
55def bitwise_and_scalar_tensor(A, B):
56 logger.debug("GEMS BITWISE AND SCALAR TENSOR")
57 return bitwise_and_func_scalar(B, A)