Coverage for src/flag_gems/runtime/backend/_arm/ops/masked_fill.py: 0%

119 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.utils import broadcastable_to 

8 

9 

10@triton.jit(do_not_specialize=["value", "n_elements"]) 

11def _masked_fill_kernel( 

12 inp_ptr, 

13 mask_ptr, 

14 value, 

15 out_ptr, 

16 n_elements, 

17 BLOCK_SIZE: tl.constexpr, 

18): 

19 pid = tl.program_id(0) 

20 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

21 mask = offsets < n_elements 

22 x = tl.load(inp_ptr + offsets, mask=mask, other=0.0) 

23 m = tl.load(mask_ptr + offsets, mask=mask, other=0).to(tl.int1) 

24 y = tl.where(m, value, x) 

25 tl.store(out_ptr + offsets, y, mask=mask) 

26 

27 

28@triton.jit(do_not_specialize=["value", "n_elements"]) 

29def _masked_fill_single_program_kernel( 

30 inp_ptr, 

31 mask_ptr, 

32 value, 

33 out_ptr, 

34 n_elements, 

35 BLOCK_SIZE: tl.constexpr, 

36): 

37 offs = tl.arange(0, BLOCK_SIZE) 

38 for base in range(0, n_elements, BLOCK_SIZE): 

39 idx = base + offs 

40 mask = idx < n_elements 

41 x = tl.load(inp_ptr + idx, mask=mask, other=0.0) 

42 m = tl.load(mask_ptr + idx, mask=mask, other=0).to(tl.int1) 

43 y = tl.where(m, value, x) 

44 tl.store(out_ptr + idx, y, mask=mask) 

45 

46 

47@triton.jit(do_not_specialize=["value", "n_elements"]) 

48def _masked_fill_inplace_kernel( 

49 inp_ptr, 

50 mask_ptr, 

51 value, 

52 n_elements, 

53 BLOCK_SIZE: tl.constexpr, 

54): 

55 pid = tl.program_id(0) 

56 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

57 mask = offsets < n_elements 

58 x = tl.load(inp_ptr + offsets, mask=mask, other=0.0) 

59 m = tl.load(mask_ptr + offsets, mask=mask, other=0).to(tl.int1) 

60 y = tl.where(m, value, x) 

61 tl.store(inp_ptr + offsets, y, mask=mask) 

62 

63 

64@triton.jit(do_not_specialize=["value", "n_elements"]) 

65def _masked_fill_inplace_single_program_kernel( 

66 inp_ptr, 

67 mask_ptr, 

68 value, 

69 n_elements, 

70 BLOCK_SIZE: tl.constexpr, 

71): 

72 offs = tl.arange(0, BLOCK_SIZE) 

73 for base in range(0, n_elements, BLOCK_SIZE): 

74 idx = base + offs 

75 mask = idx < n_elements 

76 x = tl.load(inp_ptr + idx, mask=mask, other=0.0) 

77 m = tl.load(mask_ptr + idx, mask=mask, other=0).to(tl.int1) 

78 y = tl.where(m, value, x) 

79 tl.store(inp_ptr + idx, y, mask=mask) 

80 

81 

82def _select_block_size(n_elements): 

83 if n_elements <= 32: 

84 return 32 

85 if n_elements <= 1024: 

86 return 32 

87 if n_elements <= 8192: 

88 return 64 

89 return 128 

90 

91 

92def _normalize_scalar_value(value): 

93 assert ( 

94 (torch.is_tensor(value) and value.ndim == 0) 

95 or isinstance(value, int) 

96 or isinstance(value, float) 

97 ), "masked_fill only supports scalar/0-d tensor value" 

98 if torch.is_tensor(value): 

99 return value.item() 

100 return value 

101 

102 

103def _prepare_mask(mask, inp_shape): 

104 if mask.dtype == torch.bool and tuple(mask.shape) == tuple(inp_shape): 

105 return mask if mask.is_contiguous() else mask.contiguous() 

106 if mask.dtype != torch.bool: 

107 mask = mask.to(torch.bool) 

108 if tuple(mask.shape) == tuple(inp_shape): 

109 return mask if mask.is_contiguous() else mask.contiguous() 

110 return mask.expand(inp_shape).contiguous() 

111 

112 

113def _launch_masked_fill(inp, expand_mask, value, out): 

114 n_elements = inp.numel() 

115 if n_elements == 0: 

116 return 

117 if 1 < n_elements <= 8192: 

118 single_block = 32 if n_elements <= 4096 else 64 

119 _masked_fill_single_program_kernel[(1,)]( 

120 inp, 

121 expand_mask, 

122 value, 

123 out, 

124 n_elements, 

125 BLOCK_SIZE=single_block, 

126 num_warps=1, 

127 num_stages=1, 

128 ) 

129 return 

130 

131 block_size = _select_block_size(n_elements) 

132 grid = (triton.cdiv(n_elements, block_size),) 

133 _masked_fill_kernel[grid]( 

134 inp, 

135 expand_mask, 

136 value, 

137 out, 

138 n_elements, 

139 BLOCK_SIZE=block_size, 

140 num_warps=1, 

141 num_stages=1, 

142 ) 

143 

144 

145def _launch_masked_fill_inplace(inp, expand_mask, value): 

146 n_elements = inp.numel() 

147 if n_elements == 0: 

148 return 

149 if 1 < n_elements <= 8192: 

150 single_block = 32 if n_elements <= 4096 else 64 

151 _masked_fill_inplace_single_program_kernel[(1,)]( 

152 inp, 

153 expand_mask, 

154 value, 

155 n_elements, 

156 BLOCK_SIZE=single_block, 

157 num_warps=1, 

158 num_stages=1, 

159 ) 

160 return 

161 

162 block_size = _select_block_size(n_elements) 

163 grid = (triton.cdiv(n_elements, block_size),) 

164 _masked_fill_inplace_kernel[grid]( 

165 inp, 

166 expand_mask, 

167 value, 

168 n_elements, 

169 BLOCK_SIZE=block_size, 

170 num_warps=1, 

171 num_stages=1, 

172 ) 

173 

174 

175def masked_fill(inp, mask, value): 

176 logging.debug("GEMS MASKED_FILL") 

177 value = _normalize_scalar_value(value) 

178 assert broadcastable_to( 

179 mask.shape, inp.shape 

180 ), "mask shape must be broadcastable to input shape" 

181 

182 if inp.ndim == 0: 

183 return ( 

184 torch.tensor(value, dtype=inp.dtype, device=inp.device) 

185 if mask.item() 

186 else inp.clone() 

187 ) 

188 

189 if mask.ndim == 0: 

190 if bool(mask.item()): 

191 return torch.full_like(inp, value) 

192 return inp.clone() 

193 

194 inp_contig = inp.contiguous() if not inp.is_contiguous() else inp 

195 expand_mask = _prepare_mask(mask, inp_contig.shape) 

196 out = torch.empty_like(inp_contig, dtype=inp_contig.dtype, device=inp_contig.device) 

197 _launch_masked_fill(inp_contig, expand_mask, value, out) 

198 return out 

199 

200 

201def masked_fill_(inp, mask, value): 

202 logging.debug("GEMS MASKED_FILL_") 

203 value = _normalize_scalar_value(value) 

204 assert broadcastable_to( 

205 mask.shape, inp.shape 

206 ), "mask shape must be broadcastable to input shape" 

207 

208 if inp.ndim == 0: 

209 if mask.item(): 

210 inp[()] = value 

211 return inp 

212 

213 if mask.ndim == 0: 

214 if bool(mask.item()): 

215 inp.fill_(value) 

216 return inp 

217 

218 inp_contig = inp.contiguous() if not inp.is_contiguous() else inp 

219 expand_mask = _prepare_mask(mask, inp_contig.shape) 

220 _launch_masked_fill_inplace(inp_contig, expand_mask, value) 

221 if inp_contig is not inp: 

222 inp.copy_(inp_contig) 

223 return inp