Coverage for src/flag_gems/runtime/backend/_arm/ops/masked_fill.py: 0%
119 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-10 07:09 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.utils import broadcastable_to
10@triton.jit(do_not_specialize=["value", "n_elements"])
11def _masked_fill_kernel(
12 inp_ptr,
13 mask_ptr,
14 value,
15 out_ptr,
16 n_elements,
17 BLOCK_SIZE: tl.constexpr,
18):
19 pid = tl.program_id(0)
20 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
21 mask = offsets < n_elements
22 x = tl.load(inp_ptr + offsets, mask=mask, other=0.0)
23 m = tl.load(mask_ptr + offsets, mask=mask, other=0).to(tl.int1)
24 y = tl.where(m, value, x)
25 tl.store(out_ptr + offsets, y, mask=mask)
28@triton.jit(do_not_specialize=["value", "n_elements"])
29def _masked_fill_single_program_kernel(
30 inp_ptr,
31 mask_ptr,
32 value,
33 out_ptr,
34 n_elements,
35 BLOCK_SIZE: tl.constexpr,
36):
37 offs = tl.arange(0, BLOCK_SIZE)
38 for base in range(0, n_elements, BLOCK_SIZE):
39 idx = base + offs
40 mask = idx < n_elements
41 x = tl.load(inp_ptr + idx, mask=mask, other=0.0)
42 m = tl.load(mask_ptr + idx, mask=mask, other=0).to(tl.int1)
43 y = tl.where(m, value, x)
44 tl.store(out_ptr + idx, y, mask=mask)
47@triton.jit(do_not_specialize=["value", "n_elements"])
48def _masked_fill_inplace_kernel(
49 inp_ptr,
50 mask_ptr,
51 value,
52 n_elements,
53 BLOCK_SIZE: tl.constexpr,
54):
55 pid = tl.program_id(0)
56 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
57 mask = offsets < n_elements
58 x = tl.load(inp_ptr + offsets, mask=mask, other=0.0)
59 m = tl.load(mask_ptr + offsets, mask=mask, other=0).to(tl.int1)
60 y = tl.where(m, value, x)
61 tl.store(inp_ptr + offsets, y, mask=mask)
64@triton.jit(do_not_specialize=["value", "n_elements"])
65def _masked_fill_inplace_single_program_kernel(
66 inp_ptr,
67 mask_ptr,
68 value,
69 n_elements,
70 BLOCK_SIZE: tl.constexpr,
71):
72 offs = tl.arange(0, BLOCK_SIZE)
73 for base in range(0, n_elements, BLOCK_SIZE):
74 idx = base + offs
75 mask = idx < n_elements
76 x = tl.load(inp_ptr + idx, mask=mask, other=0.0)
77 m = tl.load(mask_ptr + idx, mask=mask, other=0).to(tl.int1)
78 y = tl.where(m, value, x)
79 tl.store(inp_ptr + idx, y, mask=mask)
82def _select_block_size(n_elements):
83 if n_elements <= 32:
84 return 32
85 if n_elements <= 1024:
86 return 32
87 if n_elements <= 8192:
88 return 64
89 return 128
92def _normalize_scalar_value(value):
93 assert (
94 (torch.is_tensor(value) and value.ndim == 0)
95 or isinstance(value, int)
96 or isinstance(value, float)
97 ), "masked_fill only supports scalar/0-d tensor value"
98 if torch.is_tensor(value):
99 return value.item()
100 return value
103def _prepare_mask(mask, inp_shape):
104 if mask.dtype == torch.bool and tuple(mask.shape) == tuple(inp_shape):
105 return mask if mask.is_contiguous() else mask.contiguous()
106 if mask.dtype != torch.bool:
107 mask = mask.to(torch.bool)
108 if tuple(mask.shape) == tuple(inp_shape):
109 return mask if mask.is_contiguous() else mask.contiguous()
110 return mask.expand(inp_shape).contiguous()
113def _launch_masked_fill(inp, expand_mask, value, out):
114 n_elements = inp.numel()
115 if n_elements == 0:
116 return
117 if 1 < n_elements <= 8192:
118 single_block = 32 if n_elements <= 4096 else 64
119 _masked_fill_single_program_kernel[(1,)](
120 inp,
121 expand_mask,
122 value,
123 out,
124 n_elements,
125 BLOCK_SIZE=single_block,
126 num_warps=1,
127 num_stages=1,
128 )
129 return
131 block_size = _select_block_size(n_elements)
132 grid = (triton.cdiv(n_elements, block_size),)
133 _masked_fill_kernel[grid](
134 inp,
135 expand_mask,
136 value,
137 out,
138 n_elements,
139 BLOCK_SIZE=block_size,
140 num_warps=1,
141 num_stages=1,
142 )
145def _launch_masked_fill_inplace(inp, expand_mask, value):
146 n_elements = inp.numel()
147 if n_elements == 0:
148 return
149 if 1 < n_elements <= 8192:
150 single_block = 32 if n_elements <= 4096 else 64
151 _masked_fill_inplace_single_program_kernel[(1,)](
152 inp,
153 expand_mask,
154 value,
155 n_elements,
156 BLOCK_SIZE=single_block,
157 num_warps=1,
158 num_stages=1,
159 )
160 return
162 block_size = _select_block_size(n_elements)
163 grid = (triton.cdiv(n_elements, block_size),)
164 _masked_fill_inplace_kernel[grid](
165 inp,
166 expand_mask,
167 value,
168 n_elements,
169 BLOCK_SIZE=block_size,
170 num_warps=1,
171 num_stages=1,
172 )
175def masked_fill(inp, mask, value):
176 logging.debug("GEMS MASKED_FILL")
177 value = _normalize_scalar_value(value)
178 assert broadcastable_to(
179 mask.shape, inp.shape
180 ), "mask shape must be broadcastable to input shape"
182 if inp.ndim == 0:
183 return (
184 torch.tensor(value, dtype=inp.dtype, device=inp.device)
185 if mask.item()
186 else inp.clone()
187 )
189 if mask.ndim == 0:
190 if bool(mask.item()):
191 return torch.full_like(inp, value)
192 return inp.clone()
194 inp_contig = inp.contiguous() if not inp.is_contiguous() else inp
195 expand_mask = _prepare_mask(mask, inp_contig.shape)
196 out = torch.empty_like(inp_contig, dtype=inp_contig.dtype, device=inp_contig.device)
197 _launch_masked_fill(inp_contig, expand_mask, value, out)
198 return out
201def masked_fill_(inp, mask, value):
202 logging.debug("GEMS MASKED_FILL_")
203 value = _normalize_scalar_value(value)
204 assert broadcastable_to(
205 mask.shape, inp.shape
206 ), "mask shape must be broadcastable to input shape"
208 if inp.ndim == 0:
209 if mask.item():
210 inp[()] = value
211 return inp
213 if mask.ndim == 0:
214 if bool(mask.item()):
215 inp.fill_(value)
216 return inp
218 inp_contig = inp.contiguous() if not inp.is_contiguous() else inp
219 expand_mask = _prepare_mask(mask, inp_contig.shape)
220 _launch_masked_fill_inplace(inp_contig, expand_mask, value)
221 if inp_contig is not inp:
222 inp.copy_(inp_contig)
223 return inp