Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/masked_select.py: 0%

43 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2import os 

3 

4import torch 

5import triton 

6import triton.language as tl 

7 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import broadcastable, libentry 

10from flag_gems.utils import triton_lang_extension as ext 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15@libentry() 

16@triton.jit 

17def masked_select_kernel( 

18 inp_ptr, 

19 select_mask_ptr, 

20 prefix_sum_ptr, 

21 out_ptr, 

22 n_elements, 

23 BLOCK_SIZE: tl.constexpr, 

24): 

25 pid = ext.program_id(axis=0) 

26 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

27 mask = offsets < n_elements 

28 

29 inp = tl.load(inp_ptr + offsets, mask=mask, other=0.0) 

30 select_mask = tl.load(select_mask_ptr + offsets, mask=mask, other=0.0).to(tl.int1) 

31 out_offset = ( 

32 tl.load(prefix_sum_ptr + offsets, mask=(select_mask & mask), other=0.0) - 1 

33 ) 

34 

35 tl.store(out_ptr + out_offset, inp, mask=(select_mask & mask)) 

36 

37 

38def masked_select(inp, mask): 

39 logger.debug("GEMS_KUNLUNXIN MASKED SELECT") 

40 

41 inp_shape = tuple(inp.shape) 

42 mask_shape = tuple(mask.shape) 

43 

44 assert broadcastable( 

45 inp_shape, mask_shape 

46 ), "The shapes of the `mask` and the `input` tensor must be broadcastable" 

47 inp, mask = torch.broadcast_tensors(inp, mask) 

48 

49 inp = inp.contiguous() 

50 mask = mask.contiguous() 

51 

52 mask_flattened = mask.ravel() 

53 

54 prefix_sum = mask_flattened.cumsum(axis=0) 

55 out = torch.empty(prefix_sum[-1].item(), dtype=inp.dtype, device=inp.device) 

56 

57 n_elements = inp.numel() 

58 

59 # Use larger block size for better memory throughput on kunlunxin 

60 BLOCK_SIZE = 2048 

61 grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 

62 

63 os.environ["TRITONXPU_OTHER_SIM"] = "1" 

64 os.environ["TRITONXPU_STORE_MASK_SIM"] = "1" 

65 

66 try: 

67 with torch_device_fn.device(inp.device): 

68 masked_select_kernel[grid]( 

69 inp, mask_flattened, prefix_sum, out, n_elements, BLOCK_SIZE=BLOCK_SIZE 

70 ) 

71 finally: 

72 if "TRITONXPU_OTHER_SIM" in os.environ: 

73 del os.environ["TRITONXPU_OTHER_SIM"] 

74 if "TRITONXPU_STORE_MASK_SIM" in os.environ: 

75 del os.environ["TRITONXPU_STORE_MASK_SIM"] 

76 

77 return out