Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/count_nonzero.py: 0%
114 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-06 06:51 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import dim_compress, libentry, libtuner
10from flag_gems.utils import triton_lang_extension as tle
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15@libentry()
16@triton.jit
17def count_nonzero_kernel_1(x_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr):
18 pid = tle.program_id(0)
19 block_start = pid * BLOCK_SIZE
20 offsets = block_start + tl.arange(0, BLOCK_SIZE)
21 mask = offsets < numel
22 x = tl.load(x_ptr + offsets, mask=mask, other=0)
23 is_nonzero = (x != 0).to(tl.int32)
24 nonzero_count = tl.sum(is_nonzero, axis=0)
25 tl.atomic_add(out_ptr, nonzero_count)
28@libentry()
29@libtuner(
30 configs=runtime.get_tuned_config("count_nonzero"),
31 key=["numel"],
32 strategy=["align32"],
33 warmup=1,
34 rep=2,
35)
36@triton.jit
37def count_nonzero_kernel(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr):
38 pid_0 = tle.program_id(0)
39 num_p = tle.num_programs(0)
40 rows = (numel + N - 1) // N
41 rows_per_p = rows // num_p
43 for pid_n in range(0, rows_per_p):
44 pid_x = pid_0 * rows_per_p + pid_n
46 nonzero_count = tl.full((), value=0, dtype=out_ptr.dtype.element_ty)
47 for start_n in range(0, N, BLOCK_SIZE):
48 cols_offsets = start_n + tl.arange(0, BLOCK_SIZE)
49 offset = pid_x * N + cols_offsets
50 mask = offset < numel and cols_offsets < N
51 x = tl.load(x_ptr + offset, mask=mask, other=0)
52 is_nonzero = (x != 0).to(tl.int64)
53 nonzero_count += tl.sum(is_nonzero)
55 tl.store(out_ptr + pid_x, nonzero_count)
57 remain = rows % num_p
58 if pid_0 < remain:
59 pid_x = rows // num_p * num_p + pid_0
60 nonzero_count = tl.full((), value=0, dtype=out_ptr.dtype.element_ty)
61 for start_n in range(0, N, BLOCK_SIZE):
62 cols_offsets = start_n + tl.arange(0, BLOCK_SIZE)
63 offset = pid_x * N + cols_offsets
64 mask = offset < numel and cols_offsets < N
65 x = tl.load(x_ptr + offset, mask=mask, other=0)
66 is_nonzero = (x != 0).to(tl.int64)
67 nonzero_count += tl.sum(is_nonzero)
69 tl.store(out_ptr + pid_x, nonzero_count)
72@libentry()
73@libtuner(
74 configs=runtime.get_tuned_config("count_nonzero"),
75 key=["numel"],
76 strategy=["align32"],
77 warmup=1,
78 rep=2,
79)
80@triton.jit
81def count_nonzero_combin_kernel_1(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr):
82 pid_x = tle.program_id(0)
83 nonzero_count = tl.full((), value=0, dtype=out_ptr.dtype.element_ty)
84 for start_n in range(0, N, BLOCK_SIZE):
85 cols_offsets = start_n + tl.arange(0, BLOCK_SIZE)
86 offset = pid_x * N + cols_offsets
87 mask = offset < numel and cols_offsets < N
88 x = tl.load(x_ptr + offset, mask=mask, other=0)
89 nonzero_count += tl.sum(x)
90 tl.store(out_ptr + pid_x, nonzero_count)
93@libentry()
94@triton.jit
95def count_nonzero_combin_kernel(
96 x_ptr, combin_ptr, N, combin_N, numel, BLOCK_SIZE: tl.constexpr
97):
98 pid_x = tle.program_id(0)
99 pid_y = tle.program_id(1)
100 cols_offsets = pid_y * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
101 offset = pid_x * N + cols_offsets
102 mask = offset < numel and cols_offsets < N
103 x = tl.load(x_ptr + offset, mask=mask, other=0)
104 is_nonzero = (x != 0).to(tl.int64)
105 nonzero_count = tl.sum(is_nonzero)
106 tl.store(combin_ptr + pid_x * combin_N + pid_y, nonzero_count)
109def count_nonzero(x, dim=None):
110 logger.debug("GEMS_TSINGMICRO COUNT NONZERO")
111 print("GEMS_TSINGMICRO COUNT NONZERO")
112 if dim is not None:
113 assert dim >= -x.ndim and dim < x.ndim, "Invalid dim"
114 shape = x.shape
115 BLOCK_SIZE = 2048
116 numel = x.numel()
117 x = dim_compress(x, dim)
118 x = x.contiguous().flatten()
119 combin_shape = list(shape)
120 combin_shape[dim] = triton.cdiv(combin_shape[dim], BLOCK_SIZE)
121 if combin_shape[dim] != 1:
122 combin = torch.zeros(combin_shape, dtype=torch.int64, device=x.device)
123 grid = (triton.cdiv(numel, shape[dim]), combin_shape[dim], 1)
124 count_nonzero_combin_kernel[grid](
125 x, combin, shape[dim], combin_shape[dim], numel, BLOCK_SIZE
126 )
127 x = combin
128 shape = x.shape
129 numel = x.numel()
130 out_shape = list(shape)
131 del out_shape[dim]
132 out = torch.zeros(out_shape, dtype=torch.int64, device=x.device)
133 grid = lambda meta: (triton.cdiv(numel, shape[dim]),)
134 count_nonzero_combin_kernel_1[grid](x, out, shape[dim], numel)
135 return out
136 out_shape = list(shape)
137 del out_shape[dim]
138 out = torch.zeros(out_shape, dtype=torch.int64, device=x.device)
139 grid = lambda meta: (
140 min(
141 torch_device_fn.get_device_properties().multi_processor_count,
142 triton.cdiv(numel, shape[dim]),
143 ),
144 )
145 count_nonzero_kernel[grid](x, out, shape[dim], numel)
146 return out
147 else:
148 x = x.contiguous().flatten()
149 numel = x.numel()
151 out = torch.zeros(1, dtype=torch.int32, device=x.device)
153 BLOCK_SIZE = 1024 * 8
154 grid = lambda meta: (triton.cdiv(numel, meta["BLOCK_SIZE"]),)
156 count_nonzero_kernel_1[grid](x, out, numel, BLOCK_SIZE=BLOCK_SIZE)
158 return out[0].to(torch.int64)