Coverage for src/flag_gems/runtime/backend/_ascend/heuristics_config_utils.py: 0%

136 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import triton 

2 

3 

4def argmax_heur_block_m(args): 

5 return 16 

6 

7 

8def argmax_heur_block_n(args): 

9 return 100 

10 

11 

12def argmin_heur_block_m(args): 

13 return 16 

14 

15 

16def argmin_heur_block_n(args): 

17 return 100 

18 

19 

20def bmm_heur_divisible_m(args): 

21 return args["M"] % args["TILE_M"] == 0 

22 

23 

24def bmm_heur_divisible_n(args): 

25 return args["N"] % args["TILE_N"] == 0 

26 

27 

28def bmm_heur_divisible_k(args): 

29 return args["K"] % args["TILE_K"] == 0 

30 

31 

32def dropout_heur_block(args): 

33 if args["N"] <= 512: 

34 return 512 

35 else: 

36 return 4096 

37 

38 

39def dropout_heur_num_warps(args): 

40 if args["N"] <= 512: 

41 return 4 

42 elif args["N"] <= 1024: 

43 return 8 

44 else: 

45 return 16 

46 

47 

48def exponential_heur_block(args): 

49 if args["N"] <= 512: 

50 return 512 

51 else: 

52 return 1024 

53 

54 

55def exponential_heur_num_warps(args): 

56 if args["N"] <= 512: 

57 return 4 

58 elif args["N"] <= 1024: 

59 return 8 

60 else: 

61 return 16 

62 

63 

64def gather_heur_block_m(args): 

65 return min(4, triton.next_power_of_2(triton.cdiv(args["N"], 2048))) 

66 

67 

68def gather_heur_block_n(args): 

69 return min(2048, triton.next_power_of_2(args["N"])) 

70 

71 

72def index_select_heur_block_m(args): 

73 return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"]))) 

74 

75 

76def index_select_heur_block_n(args): 

77 m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512) 

78 return max(m, 16) 

79 

80 

81def rand_heur_block(args): 

82 if args["N"] <= 512: 

83 return 2048 

84 else: 

85 return 4097 

86 

87 

88def rand_heur_num_warps(args): 

89 if args["N"] <= 512: 

90 return 4 

91 elif args["N"] <= 1024: 

92 return 8 

93 else: 

94 return 16 

95 

96 

97def randn_heur_block(args): 

98 if args["N"] <= 512: 

99 return 2048 

100 else: 

101 return 4097 

102 

103 

104def randn_heur_num_warps(args): 

105 if args["N"] <= 512: 

106 return 4 

107 elif args["N"] <= 1024: 

108 return 8 

109 else: 

110 return 16 

111 

112 

113def softmax_heur_tile_k(args): 

114 MAX_TILE_K = 4096 

115 # FIXME: 

116 # NUM_SMS should be obtained by API. 

117 # It is actually the number of AIV cores which depends on the Ascend version. 

118 NUM_SMS = 40 

119 tile_k = 1 

120 upper_bound = min(args["K"], MAX_TILE_K) 

121 while tile_k <= upper_bound: 

122 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k) 

123 num_waves = num_blocks / NUM_SMS 

124 if (num_waves > 1) and (tile_k * 2 <= upper_bound): 

125 tile_k *= 2 

126 else: 

127 break 

128 return tile_k 

129 

130 

131def softmax_heur_tile_n_non_inner(args): 

132 return triton.cdiv(1024, args["TILE_K"]) 

133 

134 

135def softmax_heur_one_tile_per_cta(args): 

136 return args["TILE_N"] >= args["N"] 

137 

138 

139def softmax_heur_num_warps_non_inner(args): 

140 tile_size = args["TILE_N"] * args["TILE_K"] 

141 if tile_size < 2048: 

142 return 4 

143 elif tile_size < 4096: 

144 return 8 

145 else: 

146 return 16 

147 

148 

149def softmax_heur_tile_n_inner(args): 

150 if args["N"] <= (32 * 1024): 

151 return triton.next_power_of_2(args["N"]) 

152 else: 

153 return 4096 

154 

155 

156def softmax_heur_num_warps_inner(args): 

157 tile_size = args["TILE_N"] 

158 if tile_size < 2048: 

159 return 4 

160 elif tile_size < 4096: 

161 return 8 

162 else: 

163 return 16 

164 

165 

166def softmax_heur_tile_n_bwd_non_inner(args): 

167 return max(1, 1024 // args["TILE_K"]) 

168 

169 

170def softmax_heur_tile_m(args): 

171 return max(1, 1024 // args["TILE_N"]) 

172 

173 

174def uniform_heur_block(args): 

175 if args["N"] <= 512: 

176 return 512 

177 elif args["N"] >= 1073741824: 

178 return 4097 

179 else: 

180 return 1024 

181 

182 

183def uniform_heur_num_warps(args): 

184 if args["N"] <= 512: 

185 return 4 

186 elif args["N"] <= 1024: 

187 return 8 

188 else: 

189 return 16 

190 

191 

192def var_mean_heur_block_n(args): 

193 return triton.next_power_of_2(args["BLOCK_NUM"]) 

194 

195 

196def upsample_nearest2d_SAME_H(args): 

197 return args["OH"] == args["IH"] 

198 

199 

200def upsample_nearest2d_SAME_W(args): 

201 return args["OW"] == args["IW"] 

202 

203 

204def batch_norm_heur_block_m(args): 

205 return min(64, triton.next_power_of_2(args["batch_dim"])) 

206 

207 

208def batch_norm_heur_block_n(args): 

209 BLOCK_M = batch_norm_heur_block_m(args) 

210 BLOCK_N = triton.next_power_of_2(args["spatial_dim"]) 

211 return min(BLOCK_N, max(1, 2**10 // BLOCK_M)) 

212 

213 

214def vdot_heur_block_size(args): 

215 n = args["n_elements"] 

216 if n < 1024: 

217 return 32 

218 elif n < 8192: 

219 return 256 

220 else: 

221 return 1024 

222 

223 

224def mm_heur_even_k(args): 

225 return args["K"] % (args["BLOCK_K"] * args["SPLIT_K"]) == 0 

226 

227 

228HEURISTICS_CONFIGS = { 

229 "argmax": { 

230 "BLOCK_M": argmax_heur_block_m, 

231 "BLOCK_N": argmax_heur_block_n, 

232 }, 

233 "argmin": { 

234 "BLOCK_M": argmin_heur_block_m, 

235 "BLOCK_N": argmin_heur_block_n, 

236 }, 

237 "baddbmm": { 

238 "DIVISIBLE_M": bmm_heur_divisible_m, 

239 "DIVISIBLE_N": bmm_heur_divisible_n, 

240 "DIVISIBLE_K": bmm_heur_divisible_k, 

241 }, 

242 "bmm": { 

243 "DIVISIBLE_M": bmm_heur_divisible_m, 

244 "DIVISIBLE_N": bmm_heur_divisible_n, 

245 "DIVISIBLE_K": bmm_heur_divisible_k, 

246 }, 

247 "dropout": { 

248 "BLOCK": dropout_heur_block, 

249 "num_warps": dropout_heur_num_warps, 

250 }, 

251 "exponential_": { 

252 "BLOCK": exponential_heur_block, 

253 "num_warps": exponential_heur_num_warps, 

254 }, 

255 "gather": { 

256 "BLOCK_M": gather_heur_block_m, 

257 "BLOCK_N": gather_heur_block_n, 

258 }, 

259 "index_select": { 

260 "BLOCK_M": index_select_heur_block_m, 

261 "BLOCK_N": index_select_heur_block_n, 

262 }, 

263 "mm": { 

264 "EVEN_K": mm_heur_even_k, 

265 }, 

266 "rand": { 

267 "BLOCK": rand_heur_block, 

268 "num_warps": rand_heur_num_warps, 

269 }, 

270 "randn": { 

271 "BLOCK": randn_heur_block, 

272 "num_warps": randn_heur_num_warps, 

273 }, 

274 "softmax_non_inner": { 

275 "TILE_K": softmax_heur_tile_k, 

276 "TILE_N": softmax_heur_tile_n_non_inner, 

277 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

278 "num_warps": softmax_heur_num_warps_non_inner, 

279 }, 

280 "softmax_inner": { 

281 "TILE_N": softmax_heur_tile_n_inner, 

282 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

283 "num_warps": softmax_heur_num_warps_inner, 

284 }, 

285 "softmax_backward_non_inner": { 

286 "TILE_N": softmax_heur_tile_n_bwd_non_inner, 

287 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

288 }, 

289 "softmax_backward_inner": { 

290 "TILE_M": softmax_heur_tile_m, 

291 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

292 }, 

293 "uniform": { 

294 "BLOCK": uniform_heur_block, 

295 "num_warps": uniform_heur_num_warps, 

296 }, 

297 "upsample_nearest2d": { 

298 "SAME_H": upsample_nearest2d_SAME_H, 

299 "SAME_W": upsample_nearest2d_SAME_W, 

300 }, 

301 "var_mean": { 

302 "BLOCK_N": var_mean_heur_block_n, 

303 }, 

304 "batch_norm": { 

305 "BLOCK_M": batch_norm_heur_block_m, 

306 "BLOCK_N": batch_norm_heur_block_n, 

307 }, 

308 "vdot": { 

309 "BLOCK_SIZE": vdot_heur_block_size, 

310 }, 

311 "mha_block_16": { 

312 "BLOCK_M": lambda args: 16, # 16 

313 "BLOCK_N": lambda args: 16, # 64 

314 "num_warps": lambda args: 4, # 4 

315 "num_stages": lambda args: 3, # 3 

316 }, 

317}