Coverage for src/flag_gems/runtime/backend/_thead/heuristics_config_utils.py: 0%

121 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-04 09:03 +0800

1""" 

2T-Head Zhenwu (真武) PPU Heuristics Configuration Utilities 

3 

4Provides dynamic parameter selection based on input tensor shapes 

5and PPU hardware characteristics. 

6 

7Hardware: Zhenwu 810E PPU 

8Key Features: 

9- Tensor Core with extended PTX instructions for AIU 

10- High memory bandwidth with optimized access patterns 

11- Multi-stream parallelism support 

12- ICN interconnect for multi-card scenarios 

13 

14Heuristics are designed to: 

151. Maximize Tensor Core utilization 

162. Optimize memory hierarchy usage (shared memory, L2 cache) 

173. Balance compute and memory bandwidth 

184. Adapt to different problem sizes dynamically 

19 

20Reference: 

21- PPU SDK v2.0.0 Documentation 

22- Triton support: 2.3.x - 3.4.x with AIU extensions 

23""" 

24 

25import torch 

26import triton 

27 

28 

29def simple_elementwise_blocksize_heur(args): 

30 return 1024 

31 

32 

33def argmax_heur_block_m(args): 

34 """Select BLOCK_M based on M dimension size for PPU""" 

35 # PPU benefits from moderate parallelism 

36 if args["M"] < 4096: 

37 return 4 

38 elif args["M"] < 16384: 

39 return 8 

40 else: 

41 return 16 

42 

43 

44def argmax_heur_block_n(args): 

45 """Select BLOCK_N based on N dimension size for PPU""" 

46 # Larger blocks to utilize PPU's memory bandwidth 

47 return min(8192, triton.next_power_of_2(args["N"])) 

48 

49 

50def argmin_heur_block_m(args): 

51 """Select BLOCK_M for argmin operation on PPU""" 

52 return argmax_heur_block_m(args) 

53 

54 

55def argmin_heur_block_n(args): 

56 """Select BLOCK_N for argmin operation on PPU""" 

57 return argmax_heur_block_n(args) 

58 

59 

60def bmm_heur_divisible_m(args): 

61 """Check if M dimension is divisible by TILE_M""" 

62 return args["M"] % args["TILE_M"] == 0 

63 

64 

65def bmm_heur_divisible_n(args): 

66 """Check if N dimension is divisible by TILE_N""" 

67 return args["N"] % args["TILE_N"] == 0 

68 

69 

70def bmm_heur_divisible_k(args): 

71 """Check if K dimension is divisible by TILE_K""" 

72 return args["K"] % args["TILE_K"] == 0 

73 

74 

75def dropout_heur_block(args): 

76 """Select block size for dropout based on N dimension""" 

77 if args["N"] <= 512: 

78 return 512 

79 elif args["N"] <= 1024: 

80 return 1024 

81 else: 

82 return 2048 

83 

84 

85def dropout_heur_num_warps(args): 

86 """Select num_warps for dropout based on N dimension""" 

87 if args["N"] <= 512: 

88 return 4 

89 elif args["N"] <= 1024: 

90 return 8 

91 else: 

92 return 16 

93 

94 

95def softmax_heur_tile_k(args): 

96 """ 

97 Select TILE_K for softmax on PPU. 

98 Considers PPU's Tensor Core capabilities and memory hierarchy. 

99 """ 

100 MAX_TILE_K = 8192 

101 tile_k = 1 

102 upper_bound = min(args["K"], MAX_TILE_K) 

103 

104 # Get PPU SM count (if available, otherwise use default) 

105 try: 

106 NUM_SMS = torch.cuda.get_device_properties( 

107 torch.cuda.current_device() 

108 ).multi_processor_count 

109 except Exception: 

110 NUM_SMS = 128 # Default for Zhenwu 810E 

111 

112 while tile_k <= upper_bound: 

113 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k) 

114 num_waves = num_blocks / NUM_SMS 

115 

116 # PPU benefits from higher occupancy 

117 if (num_waves > 2) and (tile_k * 2 <= upper_bound): 

118 tile_k *= 2 

119 else: 

120 break 

121 

122 return tile_k 

123 

124 

125def softmax_heur_tile_n_non_inner(args): 

126 """Select TILE_N for non-inner softmax on PPU""" 

127 return triton.cdiv(8192, args["TILE_K"]) 

128 

129 

130def softmax_heur_one_tile_per_cta(args): 

131 """Determine if one tile per CTA is sufficient""" 

132 return args["TILE_N"] >= args["N"] 

133 

134 

135def softmax_heur_num_warps_non_inner(args): 

136 """Select num_warps based on tile size for PPU""" 

137 tile_size = args["TILE_N"] * args["TILE_K"] 

138 if tile_size < 2048: 

139 return 4 

140 elif tile_size < 4096: 

141 return 8 

142 else: 

143 return 16 

144 

145 

146def softmax_heur_tile_n_inner(args): 

147 """Select TILE_N for inner softmax on PPU""" 

148 if args["N"] <= (32 * 1024): 

149 return triton.next_power_of_2(args["N"]) 

150 else: 

151 return 4096 

152 

153 

154def softmax_heur_num_warps_inner(args): 

155 """Select num_warps for inner softmax on PPU""" 

156 tile_size = args["TILE_N"] 

157 if tile_size < 2048: 

158 return 4 

159 elif tile_size < 4096: 

160 return 8 

161 else: 

162 return 16 

163 

164 

165def layer_norm_heur_block_row_size(args): 

166 """Select block row size for layer normalization on PPU""" 

167 return min(32, triton.next_power_of_2(args["row_count"])) 

168 

169 

170def batch_norm_heur_block_m(args): 

171 """Select BLOCK_M for batch normalization on PPU""" 

172 return min(2048, triton.next_power_of_2(args["batch_dim"])) 

173 

174 

175def batch_norm_heur_block_n(args): 

176 """ 

177 Select BLOCK_N for batch normalization on PPU. 

178 Optimizes for PPU's memory access patterns. 

179 """ 

180 BLOCK_M = batch_norm_heur_block_m(args) 

181 BLOCK_N = triton.next_power_of_2(args["spatial_dim"]) 

182 # PPU can handle larger loads efficiently 

183 return min(BLOCK_N, max(1, 2**15 // BLOCK_M)) 

184 

185 

186def mv_heur_block_m(args): 

187 """Select BLOCK_M for matrix-vector multiplication on PPU""" 

188 return min(1024, triton.next_power_of_2(args["M"])) 

189 

190 

191def mv_heur_block_n(args): 

192 """Select BLOCK_N for matrix-vector multiplication on PPU""" 

193 return min(128, triton.next_power_of_2(args["N"])) 

194 

195 

196def attention_heur_block_m(args): 

197 """Select BLOCK_M for attention on PPU""" 

198 # Attention benefits from larger blocks on PPU 

199 if args["M"] < 1024: 

200 return 64 

201 else: 

202 return 128 

203 

204 

205def attention_heur_block_n(args): 

206 """Select BLOCK_N for attention on PPU""" 

207 if args["N"] < 1024: 

208 return 32 

209 elif args["N"] < 4096: 

210 return 64 

211 else: 

212 return 128 

213 

214 

215def index_select_heur_block_m(args): 

216 """Select BLOCK_M for index_select on PPU""" 

217 return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"]))) 

218 

219 

220def index_select_heur_block_n(args): 

221 """Select BLOCK_N for index_select on PPU""" 

222 m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 1024) 

223 return max(m, 16) 

224 

225 

226def gather_heur_block_m(args): 

227 """Select BLOCK_M for gather operation on PPU""" 

228 return min(4, triton.next_power_of_2(triton.cdiv(args["N"], 2048))) 

229 

230 

231def gather_heur_block_n(args): 

232 """Select BLOCK_N for gather operation on PPU""" 

233 return min(2048, triton.next_power_of_2(args["N"])) 

234 

235 

236def var_mean_heur_block_n(args): 

237 """Select BLOCK_N for var_mean on PPU""" 

238 return triton.next_power_of_2(args["BLOCK_NUM"]) 

239 

240 

241def upsample_nearest2d_SAME_H(args): 

242 """Check if output height equals input height""" 

243 return args["OH"] == args["IH"] 

244 

245 

246def upsample_nearest2d_SAME_W(args): 

247 """Check if output width equals input width""" 

248 return args["OW"] == args["IW"] 

249 

250 

251def mm_heur_even_k(args): 

252 """Check if K dimension is even for mm operation""" 

253 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0 

254 

255 

256def rand_heur_block(args): 

257 """Select block size for random number generation on PPU""" 

258 if args["N"] <= 512: 

259 return 512 

260 else: 

261 return 1024 

262 

263 

264def rand_heur_num_warps(args): 

265 """Select num_warps for random number generation on PPU""" 

266 if args["N"] <= 512: 

267 return 4 

268 elif args["N"] <= 1024: 

269 return 8 

270 else: 

271 return 16 

272 

273 

274# Register all heuristics configurations for PPU 

275HEURISTICS_CONFIGS = { 

276 "argmax": { 

277 "BLOCK_M": argmax_heur_block_m, 

278 "BLOCK_N": argmax_heur_block_n, 

279 }, 

280 "argmin": { 

281 "BLOCK_M": argmin_heur_block_m, 

282 "BLOCK_N": argmin_heur_block_n, 

283 }, 

284 "bmm": { 

285 "DIVISIBLE_M": bmm_heur_divisible_m, 

286 "DIVISIBLE_N": bmm_heur_divisible_n, 

287 "DIVISIBLE_K": bmm_heur_divisible_k, 

288 }, 

289 "dropout": { 

290 "BLOCK": dropout_heur_block, 

291 "num_warps": dropout_heur_num_warps, 

292 }, 

293 "softmax_non_inner": { 

294 "TILE_K": softmax_heur_tile_k, 

295 "TILE_N": softmax_heur_tile_n_non_inner, 

296 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

297 "num_warps": softmax_heur_num_warps_non_inner, 

298 }, 

299 "softmax_inner": { 

300 "TILE_N": softmax_heur_tile_n_inner, 

301 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

302 "num_warps": softmax_heur_num_warps_inner, 

303 }, 

304 "layer_norm_persistent": { 

305 "BLOCK_ROW_SIZE": layer_norm_heur_block_row_size, 

306 }, 

307 "batch_norm": { 

308 "BLOCK_M": batch_norm_heur_block_m, 

309 "BLOCK_N": batch_norm_heur_block_n, 

310 }, 

311 "mv": { 

312 "BLOCK_M": mv_heur_block_m, 

313 "BLOCK_N": mv_heur_block_n, 

314 }, 

315 "attention": { 

316 "BLOCK_M": attention_heur_block_m, 

317 "BLOCK_N": attention_heur_block_n, 

318 }, 

319 "index_select": { 

320 "BLOCK_M": index_select_heur_block_m, 

321 "BLOCK_N": index_select_heur_block_n, 

322 }, 

323 "gather": { 

324 "BLOCK_M": gather_heur_block_m, 

325 "BLOCK_N": gather_heur_block_n, 

326 }, 

327 "var_mean": { 

328 "BLOCK_N": var_mean_heur_block_n, 

329 }, 

330 "upsample_nearest2d": { 

331 "SAME_H": upsample_nearest2d_SAME_H, 

332 "SAME_W": upsample_nearest2d_SAME_W, 

333 }, 

334 "mm": { 

335 "EVEN_K": mm_heur_even_k, 

336 }, 

337 "rand": { 

338 "BLOCK": rand_heur_block, 

339 "num_warps": rand_heur_num_warps, 

340 }, 

341 "randn": { 

342 "BLOCK": rand_heur_block, 

343 "num_warps": rand_heur_num_warps, 

344 }, 

345 "elementwise_generic": { 

346 "BLOCK_SIZE": simple_elementwise_blocksize_heur, 

347 "num_warps": lambda args: 8, 

348 }, 

349}