Coverage for src/flag_gems/runtime/backend/_ascend/ops/masked_scatter.py: 0%
98 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import broadcastable, libentry
9from flag_gems.utils.shape_utils import bracket_next_power_of_2
11logger = logging.getLogger(__name__)
14@libentry()
15@triton.jit
16def masked_scatter_single_pass_kernel(
17 inp_ptr, mask_ptr, src_ptr, N, BLOCK_SIZE: tl.constexpr
18):
19 pid = tl.program_id(0)
20 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
22 block_mask = offsets < N
24 mask_val = tl.load(mask_ptr + offsets, mask=block_mask, other=0).to(tl.int1)
26 mask_ints = mask_val.to(tl.int32)
27 src_indices = tl.cumsum(mask_ints, axis=0) - 1
29 active = block_mask & mask_val
30 src_val = tl.load(src_ptr + src_indices, mask=active)
31 tl.store(inp_ptr + offsets, src_val, mask=active)
34@libentry()
35@triton.jit
36def count_mask_per_block_kernel(mask_ptr, counts_ptr, N, BLOCK_SIZE: tl.constexpr):
37 pid = tl.program_id(0)
38 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
39 block_mask = offset < N
40 mask_val = tl.load(mask_ptr + offset, mask=block_mask, other=0).to(tl.int32)
41 count = tl.sum(mask_val)
42 tl.store(counts_ptr + pid, count)
45@libentry()
46@triton.jit(do_not_specialize=["N", "num_blocks", "num_blocks_per_row"])
47def masked_scatter_kernel(
48 inp_ptr,
49 mask_ptr,
50 src_ptr,
51 part_sums_ptr,
52 N,
53 num_blocks,
54 num_blocks_per_row,
55 BLOCK_SIZE: tl.constexpr,
56):
57 row_id = tl.program_id(0)
59 start_block = row_id * num_blocks_per_row
60 offset = start_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
62 advance = tl.load(part_sums_ptr + row_id)
64 last_block_id = min(num_blocks - 1, start_block + num_blocks_per_row - 1)
66 for block_id in range(start_block, last_block_id):
67 select_mask = tl.load(mask_ptr + offset).to(tl.int1)
68 select_ints = select_mask.to(tl.int32)
70 block_cumsum = tl.cumsum(select_ints, axis=0) - 1
71 global_src_idx = advance + block_cumsum
73 advance += tl.sum(select_ints, axis=0)
75 src_val = tl.load(src_ptr + global_src_idx, mask=select_mask)
76 tl.store(inp_ptr + offset, src_val, mask=select_mask)
78 offset += BLOCK_SIZE
80 block_mask = offset < N
81 select_mask = tl.load(mask_ptr + offset, mask=block_mask, other=0).to(tl.int1)
83 select_ints = select_mask.to(tl.int32)
84 block_cumsum = tl.cumsum(select_ints, axis=0) - 1
85 global_src_idx = advance + block_cumsum
87 active = block_mask & select_mask
88 src_val = tl.load(src_ptr + global_src_idx, mask=active)
89 tl.store(inp_ptr + offset, src_val, mask=active)
92def masked_scatter_impl(inp, mask, source, N):
93 true_count = mask.sum().item()
94 if true_count == 0:
95 return inp
97 if N <= 4096:
98 BLOCK_SIZE = triton.next_power_of_2(N)
99 masked_scatter_single_pass_kernel[(1,)](
100 inp, mask, source, N, BLOCK_SIZE=BLOCK_SIZE
101 )
102 return inp
104 BLOCK_SIZE = bracket_next_power_of_2(N, 128, 4096)
105 n_blocks = triton.cdiv(N, BLOCK_SIZE)
107 with torch_device_fn.device(inp.device):
108 block_counts = torch.empty(n_blocks, dtype=torch.int64, device=mask.device)
109 count_mask_per_block_kernel[(n_blocks,)](
110 mask, block_counts, N, BLOCK_SIZE=BLOCK_SIZE
111 )
113 counts_cpu = block_counts.cpu().to(torch.int64)
114 prefix_sum = torch.zeros(n_blocks, dtype=torch.int64)
115 torch.cumsum(counts_cpu[:-1], dim=0, out=prefix_sum[1:])
116 part_sums = prefix_sum.to(mask.device)
118 masked_scatter_kernel[(n_blocks,)](
119 inp,
120 mask,
121 source,
122 part_sums,
123 N,
124 n_blocks,
125 1,
126 BLOCK_SIZE=BLOCK_SIZE,
127 )
129 return inp
132def masked_scatter(inp, mask, source):
133 logger.debug("GEMS_ASCEND MASKED SCATTER")
135 assert broadcastable(
136 inp.shape, mask.shape
137 ), "The shapes of the `mask` and the `input` tensor must be broadcastable"
139 _, mask = torch.broadcast_tensors(inp, mask)
141 out = inp.clone()
142 if not out.is_contiguous():
143 out = out.contiguous()
144 if not mask.is_contiguous():
145 mask = mask.contiguous()
146 if not source.is_contiguous():
147 source = source.contiguous()
149 N = out.numel()
151 masked_scatter_impl(out, mask, source, N)
153 return out
156def masked_scatter_(inp, mask, source):
157 logger.debug("GEMS_ASCEND MASKED SCATTER_")
159 assert broadcastable(inp.shape, mask.shape)
160 _, mask = torch.broadcast_tensors(inp, mask)
162 if not inp.is_contiguous():
163 raise RuntimeError(
164 "in-place operation currently requires contiguous input tensor. "
165 )
167 mask = mask if mask.is_contiguous() else mask.contiguous()
168 source = source if source.is_contiguous() else source.contiguous()
170 N = inp.numel()
171 masked_scatter_impl(inp, mask, source, N)
173 return inp