Coverage for src/flag_gems/runtime/backend/_spacemit/heuristics_config_utils.py: 0%

134 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import triton 

2 

3 

4def argmax_heur_block_m(args): 

5 return 4 if args["M"] < 4096 else 8 

6 

7 

8def argmax_heur_block_n(args): 

9 return min(4096, triton.next_power_of_2(args["N"])) 

10 

11 

12def argmin_heur_block_m(args): 

13 return 4 if args["M"] < 4096 else 8 

14 

15 

16def argmin_heur_block_n(args): 

17 return min(4096, triton.next_power_of_2(args["N"])) 

18 

19 

20def bmm_heur_divisible_m(args): 

21 return args["M"] % args["TILE_M"] == 0 

22 

23 

24def bmm_heur_divisible_n(args): 

25 return args["N"] % args["TILE_N"] == 0 

26 

27 

28def bmm_heur_divisible_k(args): 

29 return args["K"] % args["TILE_K"] == 0 

30 

31 

32def dropout_heur_block(args): 

33 if args["N"] <= 512: 

34 return 512 

35 else: 

36 return 1024 

37 

38 

39def dropout_heur_num_warps(args): 

40 if args["N"] <= 512: 

41 return 4 

42 elif args["N"] <= 1024: 

43 return 8 

44 else: 

45 return 16 

46 

47 

48def exponential_heur_block(args): 

49 if args["N"] <= 512: 

50 return 512 

51 else: 

52 return 1024 

53 

54 

55def exponential_heur_num_warps(args): 

56 if args["N"] <= 512: 

57 return 4 

58 elif args["N"] <= 1024: 

59 return 8 

60 else: 

61 return 16 

62 

63 

64def gather_heur_block_m(args): 

65 return min(4, triton.next_power_of_2(triton.cdiv(args["N"], 2048))) 

66 

67 

68def gather_heur_block_n(args): 

69 return min(2048, triton.next_power_of_2(args["N"])) 

70 

71 

72def index_select_heur_block_m(args): 

73 return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"]))) 

74 

75 

76def index_select_heur_block_n(args): 

77 m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512) 

78 return max(m, 16) 

79 

80 

81def mm_heur_even_k(args): 

82 return args["K"] % args["BLOCK_SIZE_K"] == 0 

83 

84 

85def rand_heur_block(args): 

86 if args["N"] <= 512: 

87 return 512 

88 else: 

89 return 1024 

90 

91 

92def rand_heur_num_warps(args): 

93 if args["N"] <= 512: 

94 return 4 

95 elif args["N"] <= 1024: 

96 return 8 

97 else: 

98 return 16 

99 

100 

101def randn_heur_block(args): 

102 if args["N"] <= 512: 

103 return 512 

104 else: 

105 return 1024 

106 

107 

108def randn_heur_num_warps(args): 

109 if args["N"] <= 512: 

110 return 4 

111 elif args["N"] <= 1024: 

112 return 8 

113 else: 

114 return 16 

115 

116 

117def softmax_heur_tile_k(args): 

118 MAX_TILE_K = 8192 

119 NUM_SMS = 8 

120 tile_k = 1 

121 upper_bound = min(args["K"], MAX_TILE_K) 

122 while tile_k <= upper_bound: 

123 num_blocks = args["M"] * triton.cdiv(args["K"], tile_k) 

124 num_waves = num_blocks / NUM_SMS 

125 if (num_waves > 1) and (tile_k * 2 <= upper_bound): 

126 tile_k *= 2 

127 else: 

128 break 

129 return tile_k 

130 

131 

132def softmax_heur_tile_n_non_inner(args): 

133 return triton.cdiv(8192, args["TILE_K"]) 

134 

135 

136def softmax_heur_one_tile_per_cta(args): 

137 return args["TILE_N"] >= args["N"] 

138 

139 

140def softmax_heur_num_warps_non_inner(args): 

141 tile_size = args["TILE_N"] * args["TILE_K"] 

142 if tile_size < 2048: 

143 return 4 

144 elif tile_size < 4096: 

145 return 8 

146 else: 

147 return 16 

148 

149 

150def softmax_heur_tile_n_inner(args): 

151 if args["N"] <= (32 * 1024): 

152 return triton.next_power_of_2(args["N"]) 

153 else: 

154 return 4096 

155 

156 

157def softmax_heur_num_warps_inner(args): 

158 tile_size = args["TILE_N"] 

159 if tile_size < 2048: 

160 return 4 

161 elif tile_size < 4096: 

162 return 8 

163 else: 

164 return 16 

165 

166 

167def softmax_heur_tile_n_bwd_non_inner(args): 

168 return max(1, 1024 // args["TILE_K"]) 

169 

170 

171def softmax_heru_tile_m(args): 

172 return max(1, 1024 // args["TILE_N"]) 

173 

174 

175def uniform_heur_block(args): 

176 if args["N"] <= 512: 

177 return 512 

178 else: 

179 return 1024 

180 

181 

182def uniform_heur_num_warps(args): 

183 if args["N"] <= 512: 

184 return 4 

185 elif args["N"] <= 1024: 

186 return 8 

187 else: 

188 return 16 

189 

190 

191def var_mean_heur_block_n(args): 

192 return triton.next_power_of_2(args["BLOCK_NUM"]) 

193 

194 

195def upsample_nearest2d_SAME_H(args): 

196 return args["OH"] == args["IH"] 

197 

198 

199def upsample_nearest2d_SAME_W(args): 

200 return args["OW"] == args["IW"] 

201 

202 

203def batch_norm_heur_block_m(args): 

204 return min(2048, triton.next_power_of_2(args["batch_dim"])) 

205 

206 

207def batch_norm_heur_block_n(args): 

208 # A maximum of 16384 elements are loaded at once. 

209 BLOCK_M = batch_norm_heur_block_m(args) 

210 BLOCK_N = triton.next_power_of_2(args["spatial_dim"]) 

211 return min(BLOCK_N, max(1, 2**14 // BLOCK_M)) 

212 

213 

214def vdot_heur_block_size(args): 

215 n = args["n_elements"] 

216 if n < 1024: 

217 return 32 

218 elif n < 8192: 

219 return 256 

220 else: 

221 return 1024 

222 

223 

224HEURISTICS_CONFIGS = { 

225 "argmax": { 

226 "BLOCK_M": argmax_heur_block_m, 

227 "BLOCK_N": argmax_heur_block_n, 

228 }, 

229 "argmin": { 

230 "BLOCK_M": argmin_heur_block_m, 

231 "BLOCK_N": argmin_heur_block_n, 

232 }, 

233 "bmm": { 

234 "DIVISIBLE_K": bmm_heur_divisible_k, 

235 }, 

236 "dropout": { 

237 "BLOCK": dropout_heur_block, 

238 }, 

239 "exponential_": { 

240 "BLOCK": exponential_heur_block, 

241 }, 

242 "gather": { 

243 "BLOCK_M": gather_heur_block_m, 

244 "BLOCK_N": gather_heur_block_n, 

245 }, 

246 "index_select": { 

247 "BLOCK_M": index_select_heur_block_m, 

248 "BLOCK_N": index_select_heur_block_n, 

249 }, 

250 "mm": { 

251 "EVEN_K": mm_heur_even_k, 

252 }, 

253 "rand": { 

254 "BLOCK": rand_heur_block, 

255 }, 

256 "randn": { 

257 "BLOCK": randn_heur_block, 

258 }, 

259 "softmax_non_inner": { 

260 "TILE_K": softmax_heur_tile_k, 

261 "TILE_N": softmax_heur_tile_n_non_inner, 

262 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

263 }, 

264 "softmax_inner": { 

265 "TILE_N": softmax_heur_tile_n_inner, 

266 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

267 }, 

268 "softmax_backward_non_inner": { 

269 "TILE_N": softmax_heur_tile_n_bwd_non_inner, 

270 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

271 }, 

272 "softmax_backward_inner": { 

273 "TILE_M": softmax_heru_tile_m, 

274 "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, 

275 }, 

276 "uniform": { 

277 "BLOCK": uniform_heur_block, 

278 }, 

279 "upsample_nearest2d": { 

280 "SAME_H": upsample_nearest2d_SAME_H, 

281 "SAME_W": upsample_nearest2d_SAME_W, 

282 }, 

283 "var_mean": { 

284 "BLOCK_N": var_mean_heur_block_n, 

285 }, 

286 "batch_norm": { 

287 "BLOCK_M": batch_norm_heur_block_m, 

288 "BLOCK_N": batch_norm_heur_block_n, 

289 }, 

290 "vdot": { 

291 "BLOCK_SIZE": vdot_heur_block_size, 

292 }, 

293}