Coverage for src/flag_gems/runtime/backend/_thead/heuristics_config_utils.py: 0%
121 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1"""
2T-Head Zhenwu (真武) PPU Heuristics Configuration Utilities
4Provides dynamic parameter selection based on input tensor shapes
5and PPU hardware characteristics.
7Hardware: Zhenwu 810E PPU
8Key Features:
9- Tensor Core with extended PTX instructions for AIU
10- High memory bandwidth with optimized access patterns
11- Multi-stream parallelism support
12- ICN interconnect for multi-card scenarios
14Heuristics are designed to:
151. Maximize Tensor Core utilization
162. Optimize memory hierarchy usage (shared memory, L2 cache)
173. Balance compute and memory bandwidth
184. Adapt to different problem sizes dynamically
20Reference:
21- PPU SDK v2.0.0 Documentation
22- Triton support: 2.3.x - 3.4.x with AIU extensions
23"""
25import torch
26import triton
29def simple_elementwise_blocksize_heur(args):
30 return 1024
33def argmax_heur_block_m(args):
34 """Select BLOCK_M based on M dimension size for PPU"""
35 # PPU benefits from moderate parallelism
36 if args["M"] < 4096:
37 return 4
38 elif args["M"] < 16384:
39 return 8
40 else:
41 return 16
44def argmax_heur_block_n(args):
45 """Select BLOCK_N based on N dimension size for PPU"""
46 # Larger blocks to utilize PPU's memory bandwidth
47 return min(8192, triton.next_power_of_2(args["N"]))
50def argmin_heur_block_m(args):
51 """Select BLOCK_M for argmin operation on PPU"""
52 return argmax_heur_block_m(args)
55def argmin_heur_block_n(args):
56 """Select BLOCK_N for argmin operation on PPU"""
57 return argmax_heur_block_n(args)
60def bmm_heur_divisible_m(args):
61 """Check if M dimension is divisible by TILE_M"""
62 return args["M"] % args["TILE_M"] == 0
65def bmm_heur_divisible_n(args):
66 """Check if N dimension is divisible by TILE_N"""
67 return args["N"] % args["TILE_N"] == 0
70def bmm_heur_divisible_k(args):
71 """Check if K dimension is divisible by TILE_K"""
72 return args["K"] % args["TILE_K"] == 0
75def dropout_heur_block(args):
76 """Select block size for dropout based on N dimension"""
77 if args["N"] <= 512:
78 return 512
79 elif args["N"] <= 1024:
80 return 1024
81 else:
82 return 2048
85def dropout_heur_num_warps(args):
86 """Select num_warps for dropout based on N dimension"""
87 if args["N"] <= 512:
88 return 4
89 elif args["N"] <= 1024:
90 return 8
91 else:
92 return 16
95def softmax_heur_tile_k(args):
96 """
97 Select TILE_K for softmax on PPU.
98 Considers PPU's Tensor Core capabilities and memory hierarchy.
99 """
100 MAX_TILE_K = 8192
101 tile_k = 1
102 upper_bound = min(args["K"], MAX_TILE_K)
104 # Get PPU SM count (if available, otherwise use default)
105 try:
106 NUM_SMS = torch.cuda.get_device_properties(
107 torch.cuda.current_device()
108 ).multi_processor_count
109 except Exception:
110 NUM_SMS = 128 # Default for Zhenwu 810E
112 while tile_k <= upper_bound:
113 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k)
114 num_waves = num_blocks / NUM_SMS
116 # PPU benefits from higher occupancy
117 if (num_waves > 2) and (tile_k * 2 <= upper_bound):
118 tile_k *= 2
119 else:
120 break
122 return tile_k
125def softmax_heur_tile_n_non_inner(args):
126 """Select TILE_N for non-inner softmax on PPU"""
127 return triton.cdiv(8192, args["TILE_K"])
130def softmax_heur_one_tile_per_cta(args):
131 """Determine if one tile per CTA is sufficient"""
132 return args["TILE_N"] >= args["N"]
135def softmax_heur_num_warps_non_inner(args):
136 """Select num_warps based on tile size for PPU"""
137 tile_size = args["TILE_N"] * args["TILE_K"]
138 if tile_size < 2048:
139 return 4
140 elif tile_size < 4096:
141 return 8
142 else:
143 return 16
146def softmax_heur_tile_n_inner(args):
147 """Select TILE_N for inner softmax on PPU"""
148 if args["N"] <= (32 * 1024):
149 return triton.next_power_of_2(args["N"])
150 else:
151 return 4096
154def softmax_heur_num_warps_inner(args):
155 """Select num_warps for inner softmax on PPU"""
156 tile_size = args["TILE_N"]
157 if tile_size < 2048:
158 return 4
159 elif tile_size < 4096:
160 return 8
161 else:
162 return 16
165def layer_norm_heur_block_row_size(args):
166 """Select block row size for layer normalization on PPU"""
167 return min(32, triton.next_power_of_2(args["row_count"]))
170def batch_norm_heur_block_m(args):
171 """Select BLOCK_M for batch normalization on PPU"""
172 return min(2048, triton.next_power_of_2(args["batch_dim"]))
175def batch_norm_heur_block_n(args):
176 """
177 Select BLOCK_N for batch normalization on PPU.
178 Optimizes for PPU's memory access patterns.
179 """
180 BLOCK_M = batch_norm_heur_block_m(args)
181 BLOCK_N = triton.next_power_of_2(args["spatial_dim"])
182 # PPU can handle larger loads efficiently
183 return min(BLOCK_N, max(1, 2**15 // BLOCK_M))
186def mv_heur_block_m(args):
187 """Select BLOCK_M for matrix-vector multiplication on PPU"""
188 return min(1024, triton.next_power_of_2(args["M"]))
191def mv_heur_block_n(args):
192 """Select BLOCK_N for matrix-vector multiplication on PPU"""
193 return min(128, triton.next_power_of_2(args["N"]))
196def attention_heur_block_m(args):
197 """Select BLOCK_M for attention on PPU"""
198 # Attention benefits from larger blocks on PPU
199 if args["M"] < 1024:
200 return 64
201 else:
202 return 128
205def attention_heur_block_n(args):
206 """Select BLOCK_N for attention on PPU"""
207 if args["N"] < 1024:
208 return 32
209 elif args["N"] < 4096:
210 return 64
211 else:
212 return 128
215def index_select_heur_block_m(args):
216 """Select BLOCK_M for index_select on PPU"""
217 return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"])))
220def index_select_heur_block_n(args):
221 """Select BLOCK_N for index_select on PPU"""
222 m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 1024)
223 return max(m, 16)
226def gather_heur_block_m(args):
227 """Select BLOCK_M for gather operation on PPU"""
228 return min(4, triton.next_power_of_2(triton.cdiv(args["N"], 2048)))
231def gather_heur_block_n(args):
232 """Select BLOCK_N for gather operation on PPU"""
233 return min(2048, triton.next_power_of_2(args["N"]))
236def var_mean_heur_block_n(args):
237 """Select BLOCK_N for var_mean on PPU"""
238 return triton.next_power_of_2(args["BLOCK_NUM"])
241def upsample_nearest2d_SAME_H(args):
242 """Check if output height equals input height"""
243 return args["OH"] == args["IH"]
246def upsample_nearest2d_SAME_W(args):
247 """Check if output width equals input width"""
248 return args["OW"] == args["IW"]
251def mm_heur_even_k(args):
252 """Check if K dimension is even for mm operation"""
253 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0
256def rand_heur_block(args):
257 """Select block size for random number generation on PPU"""
258 if args["N"] <= 512:
259 return 512
260 else:
261 return 1024
264def rand_heur_num_warps(args):
265 """Select num_warps for random number generation on PPU"""
266 if args["N"] <= 512:
267 return 4
268 elif args["N"] <= 1024:
269 return 8
270 else:
271 return 16
274# Register all heuristics configurations for PPU
275HEURISTICS_CONFIGS = {
276 "argmax": {
277 "BLOCK_M": argmax_heur_block_m,
278 "BLOCK_N": argmax_heur_block_n,
279 },
280 "argmin": {
281 "BLOCK_M": argmin_heur_block_m,
282 "BLOCK_N": argmin_heur_block_n,
283 },
284 "bmm": {
285 "DIVISIBLE_M": bmm_heur_divisible_m,
286 "DIVISIBLE_N": bmm_heur_divisible_n,
287 "DIVISIBLE_K": bmm_heur_divisible_k,
288 },
289 "dropout": {
290 "BLOCK": dropout_heur_block,
291 "num_warps": dropout_heur_num_warps,
292 },
293 "softmax_non_inner": {
294 "TILE_K": softmax_heur_tile_k,
295 "TILE_N": softmax_heur_tile_n_non_inner,
296 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
297 "num_warps": softmax_heur_num_warps_non_inner,
298 },
299 "softmax_inner": {
300 "TILE_N": softmax_heur_tile_n_inner,
301 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,
302 "num_warps": softmax_heur_num_warps_inner,
303 },
304 "layer_norm_persistent": {
305 "BLOCK_ROW_SIZE": layer_norm_heur_block_row_size,
306 },
307 "batch_norm": {
308 "BLOCK_M": batch_norm_heur_block_m,
309 "BLOCK_N": batch_norm_heur_block_n,
310 },
311 "mv": {
312 "BLOCK_M": mv_heur_block_m,
313 "BLOCK_N": mv_heur_block_n,
314 },
315 "attention": {
316 "BLOCK_M": attention_heur_block_m,
317 "BLOCK_N": attention_heur_block_n,
318 },
319 "index_select": {
320 "BLOCK_M": index_select_heur_block_m,
321 "BLOCK_N": index_select_heur_block_n,
322 },
323 "gather": {
324 "BLOCK_M": gather_heur_block_m,
325 "BLOCK_N": gather_heur_block_n,
326 },
327 "var_mean": {
328 "BLOCK_N": var_mean_heur_block_n,
329 },
330 "upsample_nearest2d": {
331 "SAME_H": upsample_nearest2d_SAME_H,
332 "SAME_W": upsample_nearest2d_SAME_W,
333 },
334 "mm": {
335 "EVEN_K": mm_heur_even_k,
336 },
337 "rand": {
338 "BLOCK": rand_heur_block,
339 "num_warps": rand_heur_num_warps,
340 },
341 "randn": {
342 "BLOCK": rand_heur_block,
343 "num_warps": rand_heur_num_warps,
344 },
345 "elementwise_generic": {
346 "BLOCK_SIZE": simple_elementwise_blocksize_heur,
347 "num_warps": lambda args: 8,
348 },
349}