Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/masked_select.py: 0%
43 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import logging
2import os
4import torch
5import triton
6import triton.language as tl
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import broadcastable, libentry
10from flag_gems.utils import triton_lang_extension as ext
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15@libentry()
16@triton.jit
17def masked_select_kernel(
18 inp_ptr,
19 select_mask_ptr,
20 prefix_sum_ptr,
21 out_ptr,
22 n_elements,
23 BLOCK_SIZE: tl.constexpr,
24):
25 pid = ext.program_id(axis=0)
26 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
27 mask = offsets < n_elements
29 inp = tl.load(inp_ptr + offsets, mask=mask, other=0.0)
30 select_mask = tl.load(select_mask_ptr + offsets, mask=mask, other=0.0).to(tl.int1)
31 out_offset = (
32 tl.load(prefix_sum_ptr + offsets, mask=(select_mask & mask), other=0.0) - 1
33 )
35 tl.store(out_ptr + out_offset, inp, mask=(select_mask & mask))
38def masked_select(inp, mask):
39 logger.debug("GEMS_KUNLUNXIN MASKED SELECT")
41 inp_shape = tuple(inp.shape)
42 mask_shape = tuple(mask.shape)
44 assert broadcastable(
45 inp_shape, mask_shape
46 ), "The shapes of the `mask` and the `input` tensor must be broadcastable"
47 inp, mask = torch.broadcast_tensors(inp, mask)
49 inp = inp.contiguous()
50 mask = mask.contiguous()
52 mask_flattened = mask.ravel()
54 prefix_sum = mask_flattened.cumsum(axis=0)
55 out = torch.empty(prefix_sum[-1].item(), dtype=inp.dtype, device=inp.device)
57 n_elements = inp.numel()
59 # Use larger block size for better memory throughput on kunlunxin
60 BLOCK_SIZE = 2048
61 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
63 os.environ["TRITONXPU_OTHER_SIM"] = "1"
64 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1"
66 try:
67 with torch_device_fn.device(inp.device):
68 masked_select_kernel[grid](
69 inp, mask_flattened, prefix_sum, out, n_elements, BLOCK_SIZE=BLOCK_SIZE
70 )
71 finally:
72 if "TRITONXPU_OTHER_SIM" in os.environ:
73 del os.environ["TRITONXPU_OTHER_SIM"]
74 if "TRITONXPU_STORE_MASK_SIM" in os.environ:
75 del os.environ["TRITONXPU_STORE_MASK_SIM"]
77 return out