Coverage for src/flag_gems/runtime/backend/_sunrise/ops/masked_select.py: 0%
105 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import broadcastable, libentry
9from flag_gems.utils.shape_utils import bracket_next_power_of_2
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
14@libentry()
15@triton.jit
16def masked_select_single_pass_kernel(
17 inp_ptr, mask_ptr, out_ptr, N, BLOCK_SIZE: tl.constexpr
18):
19 pid = tl.program_id(0)
20 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
21 inp = tl.load(inp_ptr + offsets, mask=offsets < N)
22 mask = tl.load(mask_ptr + offsets, mask=offsets < N).to(tl.int1)
23 mask_ints = mask.to(tl.int32)
24 out_offsets = tl.cumsum(mask_ints, axis=0) - 1
26 tl.store(out_ptr + out_offsets, inp, mask=(offsets < N) & mask)
29def masked_select_single_pass(inp, mask, out, N):
30 BLOCK_SIZE = triton.next_power_of_2(N)
31 if BLOCK_SIZE <= 512:
32 num_warps = 4
33 elif BLOCK_SIZE <= 2048:
34 num_warps = 8
35 else:
36 num_warps = 16
37 masked_select_single_pass_kernel[(1,)](
38 inp, mask, out, N, BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
39 )
40 return out
43@libentry()
44@triton.jit(do_not_specialize=["N", "nr", "row_stride"])
45def mask_part_sum_kernel(
46 inp_ptr,
47 mask_ptr,
48 part_sums_ptr,
49 counter_ptr,
50 N,
51 num_blocks,
52 num_blocks_per_row,
53 NP_BLOCK: tl.constexpr,
54 BLOCK_SIZE: tl.constexpr,
55):
56 row_id = tl.program_id(0)
57 start_block = row_id * num_blocks_per_row
58 offset = start_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
59 acc = tl.zeros((BLOCK_SIZE,), dtype=part_sums_ptr.dtype.element_ty)
61 last_block_id = min(num_blocks - 1, start_block + num_blocks_per_row - 1)
63 for block_id in range(start_block, last_block_id):
64 select = tl.load(mask_ptr + offset)
65 select_ints = select.to(part_sums_ptr.dtype.element_ty)
66 acc += select_ints
67 offset += BLOCK_SIZE
68 # Peeled last block
69 select = tl.load(mask_ptr + offset, mask=offset < N, other=0)
70 select_ints = select.to(part_sums_ptr.dtype.element_ty)
71 acc += select_ints
73 part_sum = tl.sum(acc, axis=0)
74 tl.store(part_sums_ptr + row_id, part_sum)
75 # cumsum the part_sums
76 count = tl.atomic_add(counter_ptr, 1, sem="acq_rel")
77 np = tl.num_programs(0)
78 if count == np - 1:
79 mask = tl.arange(0, NP_BLOCK) < np
80 part_sums = tl.load(part_sums_ptr + tl.arange(0, NP_BLOCK), mask=mask)
81 final_sum = tl.sum(part_sums, axis=0)
82 pre_sums = tl.cumsum(part_sums, axis=0)
83 tl.store(
84 part_sums_ptr + tl.arange(0, NP_BLOCK), pre_sums - part_sums, mask=mask
85 )
86 tl.store(part_sums_ptr + np, final_sum)
89@libentry()
90@triton.jit(do_not_specialize=["N", "nr", "row_stride"])
91def write_back_kernel(
92 inp_ptr,
93 mask_ptr,
94 part_sums_ptr,
95 out_ptr,
96 N,
97 num_blocks,
98 num_blocks_per_row,
99 NP_BLOCK: tl.constexpr,
100 BLOCK_SIZE: tl.constexpr,
101):
102 row_id = tl.program_id(0)
104 start_block = row_id * num_blocks_per_row
105 offset = start_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
106 advance = tl.load(part_sums_ptr + row_id)
108 last_block_id = min(num_blocks - 1, start_block + num_blocks_per_row - 1)
110 for block_id in range(start_block, last_block_id):
111 inp = tl.load(inp_ptr + offset)
112 select_mask = tl.load(mask_ptr + offset).to(tl.int1)
113 select_ints = select_mask.to(tl.constexpr(part_sums_ptr.dtype.element_ty))
114 out_ptr += advance
115 advance = tl.sum(select_ints, axis=0)
116 pre_sums = tl.cumsum(select_ints, axis=0) - 1
117 tl.store(out_ptr + pre_sums, inp, mask=select_mask)
118 offset += BLOCK_SIZE
119 # Peeled last block
120 inp = tl.load(inp_ptr + offset, mask=offset < N)
121 select_mask = tl.load(mask_ptr + offset, mask=offset < N, other=0).to(tl.int1)
122 select_ints = select_mask.to(tl.constexpr(part_sums_ptr.dtype.element_ty))
123 out_ptr += advance
124 pre_sums = tl.cumsum(select_ints, axis=0) - 1
125 tl.store(out_ptr + pre_sums, inp, mask=(offset < N) & select_mask)
128def masked_select(inp, mask):
129 logger.debug("GEMS MASKED SELECT")
131 inp_shape = tuple(inp.shape)
132 mask_shape = tuple(mask.shape)
134 assert broadcastable(
135 inp_shape, mask_shape
136 ), "The shapes of the `mask` and the `input` tensor must be broadcastable"
137 inp, mask = torch.broadcast_tensors(inp, mask)
139 inp = inp.contiguous()
140 mask = mask.contiguous()
142 N = inp.numel()
143 if N <= 4096:
144 out = torch.empty(mask.cpu().sum(), dtype=inp.dtype).to(
145 device=inp.device
146 ) # [sunrise fix]
147 return masked_select_single_pass(inp, mask, out, N)
149 # return mask_select(inp, mask)
151 BLOCK_SIZE = bracket_next_power_of_2(N, 128, 4096)
152 num_warps = min(16, BLOCK_SIZE // 32)
154 # max degree of parallelism
155 # np = torch_device_fn.get_device_properties(mask.device).multi_processor_count
156 np = 32 # [DIPU] torch.cuda.get_device_properties(mask.device).multi_processor_count can't get infomation
158 # arranged as np rows of blocks
159 n_blocks = triton.cdiv(N, BLOCK_SIZE)
160 np = min(n_blocks, np)
161 n_blocks_per_row = triton.cdiv(n_blocks, np)
162 np = triton.cdiv(n_blocks, n_blocks_per_row)
163 NP_BLOCK = triton.next_power_of_2(np)
165 with torch_device_fn.device(inp.device):
166 # Compute per cta sums and cumulative sums across ctas
167 dtype = torch.int32 if N < 2**31 else torch.int64
168 part_sums = torch.empty(np + 1, dtype=dtype, device=mask.device)
169 barrier = torch.zeros([], dtype=torch.int, device=mask.device)
170 mask_part_sum_kernel[(np,)](
171 inp,
172 mask,
173 part_sums,
174 barrier,
175 N,
176 n_blocks,
177 n_blocks_per_row,
178 NP_BLOCK=NP_BLOCK,
179 BLOCK_SIZE=BLOCK_SIZE,
180 num_warps=num_warps,
181 )
183 # Write back selected data
184 out = torch.empty(part_sums[-1], dtype=inp.dtype, device=mask.device)
185 # write_offsets = pre_sums - part_sums
186 write_back_kernel[(np,)](
187 inp,
188 mask,
189 part_sums,
190 out,
191 N,
192 n_blocks,
193 n_blocks_per_row,
194 NP_BLOCK=triton.next_power_of_2(np),
195 BLOCK_SIZE=BLOCK_SIZE,
196 num_warps=num_warps,
197 )
199 return out