Coverage for src/flag_gems/runtime/backend/_kunlunxin/ops/count_nonzero.py: 0%
88 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems.runtime import torch_device_fn
8from flag_gems.utils import dim_compress, libentry
9from flag_gems.utils import triton_lang_extension as ext
11logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
13cluster_num = 12
14core_num = 64
15buf_len_per_core = 2048
18def heur_m_block_size(args):
19 return triton.next_power_of_2(
20 min(triton.cdiv(args.get("M", 0), cluster_num), core_num)
21 )
24def heur_n_block_size(args):
25 return triton.next_power_of_2(min(args.get("N", 0), 512))
28@libentry()
29@triton.heuristics(
30 values={
31 "BLOCK_M": heur_m_block_size,
32 "BLOCK_N": heur_n_block_size,
33 },
34)
35@triton.jit
36def count_nonzero_kernel_dim(
37 inp,
38 out,
39 M,
40 N,
41 BLOCK_M: tl.constexpr,
42 BLOCK_N: tl.constexpr,
43):
44 pid = ext.program_id(0)
45 rows = pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None]
46 inp = inp + rows * N
47 out = out + rows
48 row_mask = rows < M
50 # Use int32 for faster intermediate counting
51 _count = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.int32)
52 for off in range(0, N, BLOCK_N):
53 cols = off + tl.arange(0, BLOCK_N)[None, :]
54 col_mask = cols < N
55 mask = row_mask and col_mask
57 a = tl.load(inp + cols, mask, other=0)
58 _count += (a != 0).to(tl.int32)
60 count = tl.sum(_count, axis=1).to(tl.int64)
61 tl.store(out, count[:, None], row_mask)
64@libentry()
65@triton.jit
66def count_nonzero_kernel_1d_parallel(
67 inp,
68 partial_out,
69 N,
70 BLOCK_N: tl.constexpr,
71):
72 pid = ext.program_id(0)
73 num_pids = ext.num_programs(0)
75 # Use int32 for faster intermediate counting
76 _count = tl.zeros([BLOCK_N], dtype=tl.int32)
77 for off in range(pid * BLOCK_N, N, num_pids * BLOCK_N):
78 cols = off + tl.arange(0, BLOCK_N)
79 col_mask = cols < N
80 a = tl.load(inp + cols, col_mask, other=0)
81 _count += (a != 0).to(tl.int32)
83 count = tl.sum(_count, axis=0).to(tl.int64)
84 tl.store(partial_out + pid, count)
87@libentry()
88@triton.jit
89def reduce_partial_counts(
90 partial_in,
91 out,
92 num_partials,
93 BLOCK: tl.constexpr,
94):
95 _sum = tl.zeros([BLOCK], dtype=tl.int64)
96 for off in range(0, num_partials, BLOCK):
97 idx = off + tl.arange(0, BLOCK)
98 mask = idx < num_partials
99 vals = tl.load(partial_in + idx, mask, other=0)
100 _sum += vals
102 total = tl.sum(_sum, axis=0)
103 tl.store(out, total)
106def count_nonzero(x, dim=None):
107 logger.debug("GEMS_KUNLUNXIN COUNT NONZERO")
109 if dim is not None:
110 assert dim >= -x.ndim and dim < x.ndim, "Invalid dim"
111 shape = x.shape
112 numel = x.numel()
113 # permute
114 x = dim_compress(x, dim)
115 x = x.contiguous().flatten()
116 # 2D count_nonzero
117 out_shape = list(shape)
118 del out_shape[dim]
119 out = torch.zeros(out_shape, dtype=torch.int64, device=x.device)
120 N = shape[dim]
121 M = triton.cdiv(numel, shape[dim])
123 grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),)
124 with torch_device_fn.device(x.device):
125 count_nonzero_kernel_dim[grid](
126 x, out, M, N, buffer_size_limit=buf_len_per_core
127 )
128 return out
129 else:
130 # 1D count_nonzero with parallel reduction
131 x = x.contiguous().flatten()
132 numel = x.numel()
133 out = torch.zeros(1, dtype=torch.int64, device=x.device)
135 # Use larger block size for better memory throughput
136 BLOCK_N = 2048
137 # Use fewer blocks to reduce kernel launch and reduction overhead
138 num_blocks = min(cluster_num, triton.cdiv(numel, BLOCK_N))
139 num_blocks = max(1, num_blocks)
141 with torch_device_fn.device(x.device):
142 if num_blocks == 1:
143 # Small tensor: single block
144 count_nonzero_kernel_1d_parallel[(1,)](
145 x, out, numel, BLOCK_N=BLOCK_N, buffer_size_limit=buf_len_per_core
146 )
147 else:
148 # Large tensor: parallel reduction
149 partial = torch.zeros(num_blocks, dtype=torch.int64, device=x.device)
150 count_nonzero_kernel_1d_parallel[(num_blocks,)](
151 x,
152 partial,
153 numel,
154 BLOCK_N=BLOCK_N,
155 buffer_size_limit=buf_len_per_core,
156 )
157 REDUCE_BLOCK = triton.next_power_of_2(num_blocks)
158 reduce_partial_counts[(1,)](
159 partial,
160 out,
161 num_blocks,
162 BLOCK=REDUCE_BLOCK,
163 buffer_size_limit=buf_len_per_core,
164 )
166 return out[0]