Coverage for src/flag_gems/runtime/backend/_cambricon/ops/logical_and.py: 0%
49 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import libentry, libtuner
10from ..utils import TOTAL_CORE_NUM
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15@libentry()
16@libtuner(
17 configs=[
18 triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_stages=3, num_warps=1),
19 triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_stages=3, num_warps=1),
20 triton.Config(kwargs={"BLOCK_SIZE": 65536}, num_stages=3, num_warps=1),
21 triton.Config(kwargs={"BLOCK_SIZE": 131072}, num_stages=3, num_warps=1),
22 ],
23 key=["n_elements"],
24)
25@triton.jit
26def logical_and_kernel(
27 X_ptr,
28 Y_ptr,
29 OUT_ptr,
30 n_elements,
31 BLOCK_SIZE: tl.constexpr,
32):
33 pid = tl.program_id(0)
34 num_jobs = tl.num_programs(0)
35 block_start = pid * BLOCK_SIZE
36 step = num_jobs * BLOCK_SIZE
37 block_start = block_start.to(tl.int64)
38 for off in range(block_start, n_elements, step):
39 offsets = off + tl.arange(0, BLOCK_SIZE)
40 mask = offsets < n_elements
41 x = tl.load(X_ptr + offsets, mask=mask)
42 y = tl.load(Y_ptr + offsets, mask=mask)
43 result = (x != 0) & (y != 0)
44 tl.store(OUT_ptr + offsets, result, mask=mask)
47def logical_and(A, B):
48 logger.debug("GEMS_CAMBRICON LOGICAL_AND")
49 A = A.contiguous()
50 B = B.contiguous()
51 out = torch.empty(A.shape, dtype=torch.bool, device=A.device)
52 N = A.numel()
53 if N == 0:
54 return out
55 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
56 with torch_device_fn.device(A.device):
57 logical_and_kernel[grid_fn](A, B, out, N)
58 return out
61def logical_and_(A, B):
62 logger.debug("GEMS_CAMBRICON LOGICAL_AND_")
63 A_contig = A.contiguous()
64 B = B.contiguous()
65 N = A_contig.numel()
66 if N == 0:
67 return A
68 grid_fn = lambda meta: (min(triton.cdiv(N, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM),)
69 with torch_device_fn.device(A.device):
70 logical_and_kernel[grid_fn](A_contig, B, A_contig, N)
71 if not A.is_contiguous():
72 A.copy_(A_contig)
73 return A