Coverage for src/flag_gems/runtime/backend/_cambricon/ops/unique.py: 0%

94 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils.libentry import libentry 

9 

10from ..utils import TOTAL_CORE_NUM 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15@libentry() 

16@triton.autotune( 

17 configs=[ 

18 triton.Config({"BLOCK_SIZE": 2**k}, num_stages=s, num_warps=1) 

19 for k in range(11, 17, 1) 

20 for s in [1, 3] 

21 ], 

22 key=[ 

23 "tile_size", 

24 ], 

25) 

26@triton.jit 

27def get_ne_kernel( 

28 sorted_data_ptr: tl.tensor, 

29 sorted_data_2: tl.tensor, 

30 ne_out_ptr: tl.tensor, 

31 tile_size: tl.constexpr, 

32 BLOCK_SIZE: tl.constexpr, 

33): 

34 pid = tl.program_id(axis=0) 

35 num_jobs = tl.num_programs(axis=0) 

36 split_n = (tile_size + num_jobs - 1) // num_jobs 

37 start_offset = pid * split_n 

38 i0 = tl.arange(0, BLOCK_SIZE) 

39 

40 for i in range(0, split_n, BLOCK_SIZE): 

41 offset = start_offset + i + i0 

42 mask = offset < tile_size 

43 a = tl.load(sorted_data_ptr + offset, mask=mask) 

44 b = tl.load(sorted_data_2 + offset, mask=mask) 

45 # ne 

46 ne_result = (offset > 0) * (a != b) 

47 tl.store(ne_out_ptr + offset, ne_result, mask=mask) 

48 

49 

50@libentry() 

51@triton.autotune( 

52 configs=[ 

53 triton.Config({"BLOCK_SIZE": k}, num_stages=s, num_warps=1) 

54 for k in [32, 256, 1024, 2048, 4096] 

55 for s in [1, 3] 

56 ], 

57 key=[ 

58 "tile_size", 

59 ], 

60) 

61@triton.jit 

62def get_unique_out_kernel( 

63 sorted_data_ptr: tl.tensor, 

64 sorted_indices_ptr: tl.tensor, # in 

65 ne_result_ptr: tl.tensor, 

66 pre_sum_ptr: tl.tensor, 

67 idx_ptr: tl.tensor, 

68 data_out_ptr: tl.tensor, 

69 inverse_indices_ptr: tl.tensor, 

70 return_inverse: tl.constexpr, 

71 return_counts: tl.constexpr, 

72 tile_size: tl.constexpr, 

73 BLOCK_SIZE: tl.constexpr, 

74): 

75 pid = tl.program_id(axis=0) 

76 num_jobs = tl.num_programs(axis=0) 

77 

78 split_n = (tile_size + num_jobs - 1) // num_jobs 

79 start_offset = pid * split_n 

80 i0 = tl.arange(0, BLOCK_SIZE) 

81 

82 for i in range(0, split_n, BLOCK_SIZE): 

83 offset = start_offset + i + i0 

84 mask = offset < tile_size 

85 sorted_data = tl.load(sorted_data_ptr + offset, mask=mask) 

86 pre_sum_data = tl.load(pre_sum_ptr + offset, mask=mask) 

87 

88 # data_out: scatter_(to=pre_sum_data, sorted_data) 

89 tl.store(data_out_ptr + pre_sum_data, sorted_data, mask=mask) 

90 

91 # inverse_indices: scatter_(to=sorted_indices, pre_sum_data) 

92 if return_inverse: 

93 sorted_indices = tl.load(sorted_indices_ptr + offset, mask=mask) 

94 tl.store(inverse_indices_ptr + sorted_indices, pre_sum_data, mask=mask) 

95 

96 # idx: mark positions of unique values in idx_ptr 

97 if return_counts: 

98 ne_result = tl.load(ne_result_ptr + offset, mask=mask) 

99 idx_mask = ((offset == 0) | ne_result.to(tl.int1)) & mask 

100 tl.store(idx_ptr + pre_sum_data, offset, mask=idx_mask) 

101 

102 

103@triton.autotune( 

104 configs=[ 

105 triton.Config({"BLOCK_SIZE": 2**k}, num_stages=s, num_warps=1) 

106 for k in range(7, 14, 1) 

107 for s in [1, 3] 

108 ], 

109 key=[ 

110 "tile_size", 

111 ], 

112) 

113@triton.jit 

114def get_output_counts_kernel( 

115 idx_ptr: tl.tensor, 

116 idx_next_ptr: tl.tensor, 

117 counts_ptr: tl.tensor, # out 

118 tile_size: tl.constexpr, 

119 BLOCK_SIZE: tl.constexpr, 

120): 

121 pid = tl.program_id(axis=0) 

122 num_jobs = tl.num_programs(axis=0) 

123 split_n = (tile_size + num_jobs - 1) // num_jobs 

124 start_offset = pid * split_n 

125 

126 i0 = tl.arange(0, BLOCK_SIZE) 

127 

128 for i in range(0, split_n, BLOCK_SIZE): 

129 offset = start_offset + i + i0 

130 mask = offset < tile_size 

131 # load idx 

132 idx = tl.load(idx_ptr + offset, mask=mask) 

133 # load idx_next 

134 idx_next = tl.load(idx_next_ptr + offset, mask=mask) 

135 # diff 

136 counts = idx_next - idx 

137 # store counts 

138 tl.store(counts_ptr + offset, counts, mask=mask) 

139 

140 

141def sorted_unique_flat( 

142 sorted_data: torch.Tensor, 

143 sorted_indices: torch.Tensor, 

144 return_inverse: bool, 

145 return_counts: bool, 

146): 

147 num_tasks = sorted_data.numel() 

148 grid = lambda meta: ( 

149 min(triton.cdiv(num_tasks, meta["BLOCK_SIZE"]), TOTAL_CORE_NUM), 

150 ) 

151 

152 # allocate tensor 

153 ne_out = torch.empty_like(sorted_data, dtype=torch.bool) 

154 data_out = torch.empty_like(sorted_data) 

155 if return_inverse: 

156 inverse_indices = torch.empty_like(sorted_data, dtype=torch.int64) 

157 else: 

158 inverse_indices = None 

159 if return_counts: 

160 idx = torch.empty_like(sorted_data, dtype=torch.int64) 

161 else: 

162 idx = None 

163 sorted_data_2 = torch.empty_like(sorted_data) 

164 sorted_data_2[1:] = sorted_data[:-1] 

165 

166 # launch kernel 

167 with torch_device_fn.device(sorted_data.device.index): 

168 get_ne_kernel[grid]( 

169 sorted_data, 

170 sorted_data_2, 

171 ne_out, 

172 tile_size=num_tasks, 

173 ) 

174 pre_sum = ne_out.cumsum(axis=0) 

175 get_unique_out_kernel[grid]( 

176 sorted_data, 

177 sorted_indices, 

178 ne_out, 

179 pre_sum, 

180 idx, 

181 data_out, 

182 inverse_indices, 

183 return_inverse, 

184 return_counts, 

185 tile_size=num_tasks, 

186 ) 

187 

188 out_size = pre_sum[-1].item() + 1 

189 counts = None 

190 if return_counts: 

191 idx = idx[:out_size] 

192 sorted_data_size = len(sorted_data) 

193 idx_next = torch.roll(idx, -1) 

194 idx_next[-1] = sorted_data_size 

195 counts = torch.zeros_like(idx) 

196 with torch_device_fn.device(sorted_data.device.index): 

197 get_output_counts_kernel[grid]( 

198 idx, 

199 idx_next, 

200 counts, # out 

201 tile_size=out_size, 

202 ) 

203 return data_out[:out_size], inverse_indices, counts 

204 

205 

206def _unique2( 

207 in0: torch.Tensor, 

208 sorted: bool = True, 

209 return_inverse: bool = False, 

210 return_counts: bool = False, 

211): 

212 logger.debug("GEMS_CAMBRICON _UNIQUE2") 

213 sorted_data, sorted_indices = torch.sort(in0.ravel(), stable=False) 

214 data_out, inverse_indices, counts = sorted_unique_flat( 

215 sorted_data, sorted_indices, return_inverse, return_counts 

216 ) 

217 return ( 

218 data_out, 

219 inverse_indices if inverse_indices is None else inverse_indices.view_as(in0), 

220 counts, 

221 )