Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/count_nonzero.py: 0%

114 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import dim_compress, libentry, libtuner 

10from flag_gems.utils import triton_lang_extension as tle 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15@libentry() 

16@triton.jit 

17def count_nonzero_kernel_1(x_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr): 

18 pid = tle.program_id(0) 

19 block_start = pid * BLOCK_SIZE 

20 offsets = block_start + tl.arange(0, BLOCK_SIZE) 

21 mask = offsets < numel 

22 x = tl.load(x_ptr + offsets, mask=mask, other=0) 

23 is_nonzero = (x != 0).to(tl.int32) 

24 nonzero_count = tl.sum(is_nonzero, axis=0) 

25 tl.atomic_add(out_ptr, nonzero_count) 

26 

27 

28@libentry() 

29@libtuner( 

30 configs=runtime.get_tuned_config("count_nonzero"), 

31 key=["numel"], 

32 strategy=["align32"], 

33 warmup=1, 

34 rep=2, 

35) 

36@triton.jit 

37def count_nonzero_kernel(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr): 

38 pid_0 = tle.program_id(0) 

39 num_p = tle.num_programs(0) 

40 rows = (numel + N - 1) // N 

41 rows_per_p = rows // num_p 

42 

43 for pid_n in range(0, rows_per_p): 

44 pid_x = pid_0 * rows_per_p + pid_n 

45 

46 nonzero_count = tl.full((), value=0, dtype=out_ptr.dtype.element_ty) 

47 for start_n in range(0, N, BLOCK_SIZE): 

48 cols_offsets = start_n + tl.arange(0, BLOCK_SIZE) 

49 offset = pid_x * N + cols_offsets 

50 mask = offset < numel and cols_offsets < N 

51 x = tl.load(x_ptr + offset, mask=mask, other=0) 

52 is_nonzero = (x != 0).to(tl.int64) 

53 nonzero_count += tl.sum(is_nonzero) 

54 

55 tl.store(out_ptr + pid_x, nonzero_count) 

56 

57 remain = rows % num_p 

58 if pid_0 < remain: 

59 pid_x = rows // num_p * num_p + pid_0 

60 nonzero_count = tl.full((), value=0, dtype=out_ptr.dtype.element_ty) 

61 for start_n in range(0, N, BLOCK_SIZE): 

62 cols_offsets = start_n + tl.arange(0, BLOCK_SIZE) 

63 offset = pid_x * N + cols_offsets 

64 mask = offset < numel and cols_offsets < N 

65 x = tl.load(x_ptr + offset, mask=mask, other=0) 

66 is_nonzero = (x != 0).to(tl.int64) 

67 nonzero_count += tl.sum(is_nonzero) 

68 

69 tl.store(out_ptr + pid_x, nonzero_count) 

70 

71 

72@libentry() 

73@libtuner( 

74 configs=runtime.get_tuned_config("count_nonzero"), 

75 key=["numel"], 

76 strategy=["align32"], 

77 warmup=1, 

78 rep=2, 

79) 

80@triton.jit 

81def count_nonzero_combin_kernel_1(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr): 

82 pid_x = tle.program_id(0) 

83 nonzero_count = tl.full((), value=0, dtype=out_ptr.dtype.element_ty) 

84 for start_n in range(0, N, BLOCK_SIZE): 

85 cols_offsets = start_n + tl.arange(0, BLOCK_SIZE) 

86 offset = pid_x * N + cols_offsets 

87 mask = offset < numel and cols_offsets < N 

88 x = tl.load(x_ptr + offset, mask=mask, other=0) 

89 nonzero_count += tl.sum(x) 

90 tl.store(out_ptr + pid_x, nonzero_count) 

91 

92 

93@libentry() 

94@triton.jit 

95def count_nonzero_combin_kernel( 

96 x_ptr, combin_ptr, N, combin_N, numel, BLOCK_SIZE: tl.constexpr 

97): 

98 pid_x = tle.program_id(0) 

99 pid_y = tle.program_id(1) 

100 cols_offsets = pid_y * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) 

101 offset = pid_x * N + cols_offsets 

102 mask = offset < numel and cols_offsets < N 

103 x = tl.load(x_ptr + offset, mask=mask, other=0) 

104 is_nonzero = (x != 0).to(tl.int64) 

105 nonzero_count = tl.sum(is_nonzero) 

106 tl.store(combin_ptr + pid_x * combin_N + pid_y, nonzero_count) 

107 

108 

109def count_nonzero(x, dim=None): 

110 logger.debug("GEMS_TSINGMICRO COUNT NONZERO") 

111 print("GEMS_TSINGMICRO COUNT NONZERO") 

112 if dim is not None: 

113 assert dim >= -x.ndim and dim < x.ndim, "Invalid dim" 

114 shape = x.shape 

115 BLOCK_SIZE = 2048 

116 numel = x.numel() 

117 x = dim_compress(x, dim) 

118 x = x.contiguous().flatten() 

119 combin_shape = list(shape) 

120 combin_shape[dim] = triton.cdiv(combin_shape[dim], BLOCK_SIZE) 

121 if combin_shape[dim] != 1: 

122 combin = torch.zeros(combin_shape, dtype=torch.int64, device=x.device) 

123 grid = (triton.cdiv(numel, shape[dim]), combin_shape[dim], 1) 

124 count_nonzero_combin_kernel[grid]( 

125 x, combin, shape[dim], combin_shape[dim], numel, BLOCK_SIZE 

126 ) 

127 x = combin 

128 shape = x.shape 

129 numel = x.numel() 

130 out_shape = list(shape) 

131 del out_shape[dim] 

132 out = torch.zeros(out_shape, dtype=torch.int64, device=x.device) 

133 grid = lambda meta: (triton.cdiv(numel, shape[dim]),) 

134 count_nonzero_combin_kernel_1[grid](x, out, shape[dim], numel) 

135 return out 

136 out_shape = list(shape) 

137 del out_shape[dim] 

138 out = torch.zeros(out_shape, dtype=torch.int64, device=x.device) 

139 grid = lambda meta: ( 

140 min( 

141 torch_device_fn.get_device_properties().multi_processor_count, 

142 triton.cdiv(numel, shape[dim]), 

143 ), 

144 ) 

145 count_nonzero_kernel[grid](x, out, shape[dim], numel) 

146 return out 

147 else: 

148 x = x.contiguous().flatten() 

149 numel = x.numel() 

150 

151 out = torch.zeros(1, dtype=torch.int32, device=x.device) 

152 

153 BLOCK_SIZE = 1024 * 8 

154 grid = lambda meta: (triton.cdiv(numel, meta["BLOCK_SIZE"]),) 

155 

156 count_nonzero_kernel_1[grid](x, out, numel, BLOCK_SIZE=BLOCK_SIZE) 

157 

158 return out[0].to(torch.int64)