Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/count_nonzero.py: 0%

88 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems.runtime import torch_device_fn 

8from flag_gems.utils import dim_compress, libentry 

9from flag_gems.utils import triton_lang_extension as ext 

10 

11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

12 

13cluster_num = 12 

14core_num = 64 

15buf_len_per_core = 2048 

16 

17 

18def heur_m_block_size(args): 

19 return triton.next_power_of_2( 

20 min(triton.cdiv(args.get("M", 0), cluster_num), core_num) 

21 ) 

22 

23 

24def heur_n_block_size(args): 

25 return triton.next_power_of_2(min(args.get("N", 0), 512)) 

26 

27 

28@libentry() 

29@triton.heuristics( 

30 values={ 

31 "BLOCK_M": heur_m_block_size, 

32 "BLOCK_N": heur_n_block_size, 

33 }, 

34) 

35@triton.jit 

36def count_nonzero_kernel_dim( 

37 inp, 

38 out, 

39 M, 

40 N, 

41 BLOCK_M: tl.constexpr, 

42 BLOCK_N: tl.constexpr, 

43): 

44 pid = ext.program_id(0) 

45 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] 

46 inp = inp + rows * N 

47 out = out + rows 

48 row_mask = rows < M 

49 

50 # Use int32 for faster intermediate counting 

51 _count = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.int32) 

52 for off in range(0, N, BLOCK_N): 

53 cols = off + tl.arange(0, BLOCK_N)[None, :] 

54 col_mask = cols < N 

55 mask = row_mask and col_mask 

56 

57 a = tl.load(inp + cols, mask, other=0) 

58 _count += (a != 0).to(tl.int32) 

59 

60 count = tl.sum(_count, axis=1).to(tl.int64) 

61 tl.store(out, count[:, None], row_mask) 

62 

63 

64@libentry() 

65@triton.jit 

66def count_nonzero_kernel_1d_parallel( 

67 inp, 

68 partial_out, 

69 N, 

70 BLOCK_N: tl.constexpr, 

71): 

72 pid = ext.program_id(0) 

73 num_pids = ext.num_programs(0) 

74 

75 # Use int32 for faster intermediate counting 

76 _count = tl.zeros([BLOCK_N], dtype=tl.int32) 

77 for off in range(pid * BLOCK_N, N, num_pids * BLOCK_N): 

78 cols = off + tl.arange(0, BLOCK_N) 

79 col_mask = cols < N 

80 a = tl.load(inp + cols, col_mask, other=0) 

81 _count += (a != 0).to(tl.int32) 

82 

83 count = tl.sum(_count, axis=0).to(tl.int64) 

84 tl.store(partial_out + pid, count) 

85 

86 

87@libentry() 

88@triton.jit 

89def reduce_partial_counts( 

90 partial_in, 

91 out, 

92 num_partials, 

93 BLOCK: tl.constexpr, 

94): 

95 _sum = tl.zeros([BLOCK], dtype=tl.int64) 

96 for off in range(0, num_partials, BLOCK): 

97 idx = off + tl.arange(0, BLOCK) 

98 mask = idx < num_partials 

99 vals = tl.load(partial_in + idx, mask, other=0) 

100 _sum += vals 

101 

102 total = tl.sum(_sum, axis=0) 

103 tl.store(out, total) 

104 

105 

106def count_nonzero(x, dim=None): 

107 logger.debug("GEMS_KUNLUNXIN COUNT NONZERO") 

108 

109 if dim is not None: 

110 assert dim >= -x.ndim and dim < x.ndim, "Invalid dim" 

111 shape = x.shape 

112 numel = x.numel() 

113 # permute 

114 x = dim_compress(x, dim) 

115 x = x.contiguous().flatten() 

116 # 2D count_nonzero 

117 out_shape = list(shape) 

118 del out_shape[dim] 

119 out = torch.zeros(out_shape, dtype=torch.int64, device=x.device) 

120 N = shape[dim] 

121 M = triton.cdiv(numel, shape[dim]) 

122 

123 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) 

124 with torch_device_fn.device(x.device): 

125 count_nonzero_kernel_dim[grid]( 

126 x, out, M, N, buffer_size_limit=buf_len_per_core 

127 ) 

128 return out 

129 else: 

130 # 1D count_nonzero with parallel reduction 

131 x = x.contiguous().flatten() 

132 numel = x.numel() 

133 out = torch.zeros(1, dtype=torch.int64, device=x.device) 

134 

135 # Use larger block size for better memory throughput 

136 BLOCK_N = 2048 

137 # Use fewer blocks to reduce kernel launch and reduction overhead 

138 num_blocks = min(cluster_num, triton.cdiv(numel, BLOCK_N)) 

139 num_blocks = max(1, num_blocks) 

140 

141 with torch_device_fn.device(x.device): 

142 if num_blocks == 1: 

143 # Small tensor: single block 

144 count_nonzero_kernel_1d_parallel[(1,)]( 

145 x, out, numel, BLOCK_N=BLOCK_N, buffer_size_limit=buf_len_per_core 

146 ) 

147 else: 

148 # Large tensor: parallel reduction 

149 partial = torch.zeros(num_blocks, dtype=torch.int64, device=x.device) 

150 count_nonzero_kernel_1d_parallel[(num_blocks,)]( 

151 x, 

152 partial, 

153 numel, 

154 BLOCK_N=BLOCK_N, 

155 buffer_size_limit=buf_len_per_core, 

156 ) 

157 REDUCE_BLOCK = triton.next_power_of_2(num_blocks) 

158 reduce_partial_counts[(1,)]( 

159 partial, 

160 out, 

161 num_blocks, 

162 BLOCK=REDUCE_BLOCK, 

163 buffer_size_limit=buf_len_per_core, 

164 ) 

165 

166 return out[0]