Coverage for src/flag_gems/runtime/backend/_ascend/ops/masked_scatter.py: 0%

98 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import broadcastable, libentry 

9from flag_gems.utils.shape_utils import bracket_next_power_of_2 

10 

11logger = logging.getLogger(__name__) 

12 

13 

14@libentry() 

15@triton.jit 

16def masked_scatter_single_pass_kernel( 

17 inp_ptr, mask_ptr, src_ptr, N, BLOCK_SIZE: tl.constexpr 

18): 

19 pid = tl.program_id(0) 

20 offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

21 

22 block_mask = offsets < N 

23 

24 mask_val = tl.load(mask_ptr + offsets, mask=block_mask, other=0).to(tl.int1) 

25 

26 mask_ints = mask_val.to(tl.int32) 

27 src_indices = tl.cumsum(mask_ints, axis=0) - 1 

28 

29 active = block_mask & mask_val 

30 src_val = tl.load(src_ptr + src_indices, mask=active) 

31 tl.store(inp_ptr + offsets, src_val, mask=active) 

32 

33 

34@libentry() 

35@triton.jit 

36def count_mask_per_block_kernel(mask_ptr, counts_ptr, N, BLOCK_SIZE: tl.constexpr): 

37 pid = tl.program_id(0) 

38 offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

39 block_mask = offset < N 

40 mask_val = tl.load(mask_ptr + offset, mask=block_mask, other=0).to(tl.int32) 

41 count = tl.sum(mask_val) 

42 tl.store(counts_ptr + pid, count) 

43 

44 

45@libentry() 

46@triton.jit(do_not_specialize=["N", "num_blocks", "num_blocks_per_row"]) 

47def masked_scatter_kernel( 

48 inp_ptr, 

49 mask_ptr, 

50 src_ptr, 

51 part_sums_ptr, 

52 N, 

53 num_blocks, 

54 num_blocks_per_row, 

55 BLOCK_SIZE: tl.constexpr, 

56): 

57 row_id = tl.program_id(0) 

58 

59 start_block = row_id * num_blocks_per_row 

60 offset = start_block * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

61 

62 advance = tl.load(part_sums_ptr + row_id) 

63 

64 last_block_id = min(num_blocks - 1, start_block + num_blocks_per_row - 1) 

65 

66 for block_id in range(start_block, last_block_id): 

67 select_mask = tl.load(mask_ptr + offset).to(tl.int1) 

68 select_ints = select_mask.to(tl.int32) 

69 

70 block_cumsum = tl.cumsum(select_ints, axis=0) - 1 

71 global_src_idx = advance + block_cumsum 

72 

73 advance += tl.sum(select_ints, axis=0) 

74 

75 src_val = tl.load(src_ptr + global_src_idx, mask=select_mask) 

76 tl.store(inp_ptr + offset, src_val, mask=select_mask) 

77 

78 offset += BLOCK_SIZE 

79 

80 block_mask = offset < N 

81 select_mask = tl.load(mask_ptr + offset, mask=block_mask, other=0).to(tl.int1) 

82 

83 select_ints = select_mask.to(tl.int32) 

84 block_cumsum = tl.cumsum(select_ints, axis=0) - 1 

85 global_src_idx = advance + block_cumsum 

86 

87 active = block_mask & select_mask 

88 src_val = tl.load(src_ptr + global_src_idx, mask=active) 

89 tl.store(inp_ptr + offset, src_val, mask=active) 

90 

91 

92def masked_scatter_impl(inp, mask, source, N): 

93 true_count = mask.sum().item() 

94 if true_count == 0: 

95 return inp 

96 

97 if N <= 4096: 

98 BLOCK_SIZE = triton.next_power_of_2(N) 

99 masked_scatter_single_pass_kernel[(1,)]( 

100 inp, mask, source, N, BLOCK_SIZE=BLOCK_SIZE 

101 ) 

102 return inp 

103 

104 BLOCK_SIZE = bracket_next_power_of_2(N, 128, 4096) 

105 n_blocks = triton.cdiv(N, BLOCK_SIZE) 

106 

107 with torch_device_fn.device(inp.device): 

108 block_counts = torch.empty(n_blocks, dtype=torch.int64, device=mask.device) 

109 count_mask_per_block_kernel[(n_blocks,)]( 

110 mask, block_counts, N, BLOCK_SIZE=BLOCK_SIZE 

111 ) 

112 

113 counts_cpu = block_counts.cpu().to(torch.int64) 

114 prefix_sum = torch.zeros(n_blocks, dtype=torch.int64) 

115 torch.cumsum(counts_cpu[:-1], dim=0, out=prefix_sum[1:]) 

116 part_sums = prefix_sum.to(mask.device) 

117 

118 masked_scatter_kernel[(n_blocks,)]( 

119 inp, 

120 mask, 

121 source, 

122 part_sums, 

123 N, 

124 n_blocks, 

125 1, 

126 BLOCK_SIZE=BLOCK_SIZE, 

127 ) 

128 

129 return inp 

130 

131 

132def masked_scatter(inp, mask, source): 

133 logger.debug("GEMS_ASCEND MASKED SCATTER") 

134 

135 assert broadcastable( 

136 inp.shape, mask.shape 

137 ), "The shapes of the `mask` and the `input` tensor must be broadcastable" 

138 

139 _, mask = torch.broadcast_tensors(inp, mask) 

140 

141 out = inp.clone() 

142 if not out.is_contiguous(): 

143 out = out.contiguous() 

144 if not mask.is_contiguous(): 

145 mask = mask.contiguous() 

146 if not source.is_contiguous(): 

147 source = source.contiguous() 

148 

149 N = out.numel() 

150 

151 masked_scatter_impl(out, mask, source, N) 

152 

153 return out 

154 

155 

156def masked_scatter_(inp, mask, source): 

157 logger.debug("GEMS_ASCEND MASKED SCATTER_") 

158 

159 assert broadcastable(inp.shape, mask.shape) 

160 _, mask = torch.broadcast_tensors(inp, mask) 

161 

162 if not inp.is_contiguous(): 

163 raise RuntimeError( 

164 "in-place operation currently requires contiguous input tensor. " 

165 ) 

166 

167 mask = mask if mask.is_contiguous() else mask.contiguous() 

168 source = source if source.is_contiguous() else source.contiguous() 

169 

170 N = inp.numel() 

171 masked_scatter_impl(inp, mask, source, N) 

172 

173 return inp