Coverage for src/flag_gems/runtime/backend/_cambricon/ops/softmax.py: 0%
552 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import copy
2import logging
3import math
5import torch
6import triton
7import triton.language as tl
9from flag_gems import runtime
10from flag_gems.runtime import torch_device_fn
11from flag_gems.utils import libentry, libtuner
13from ..utils import MAX_NRAM_SIZE, TOTAL_CORE_NUM
14from .zeros import zero_
16logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
17MAX_N = 16384
20def align(max_block):
21 a = triton.next_power_of_2(max_block)
22 return max_block if max_block == a else a // 2
25def config_prune1(configs, named_args, **kwargs):
26 M = named_args["M"]
27 N = named_args["N"]
28 K = named_args["K"]
29 input = named_args["input_ptr"]
30 configs_map = {}
31 for config in configs:
32 kw = config.kwargs
33 TILE_K, TILE_N, num_warps, num_stages = (
34 kw["TILE_K"],
35 kw["TILE_N"],
36 config.num_warps,
37 config.num_stages,
38 )
39 if N < MAX_N:
40 config = copy.deepcopy(config)
41 TILE_N = config.kwargs["TILE_N"] = N
42 k_per_core = math.ceil(K / max(TOTAL_CORE_NUM // M, 1))
43 TILE_K = config.kwargs["TILE_K"] = k_per_core
44 num_stages = config.num_stages = 1
45 key = (TILE_K, TILE_N, num_warps, num_stages)
46 configs_map.setdefault(key, config)
48 config = copy.deepcopy(config)
49 max_tile_k_without_pipe = MAX_NRAM_SIZE // 4 // (2 * TILE_N + 1)
50 TILE_K = config.kwargs["TILE_K"] = align(max_tile_k_without_pipe)
51 num_stages = config.num_stages = 1
52 key = (TILE_K, TILE_N, num_warps, num_stages)
53 configs_map.setdefault(key, config)
55 config = copy.deepcopy(config)
56 max_tile_k_without_pipe = MAX_NRAM_SIZE // 4 // (3 * TILE_N + 1)
57 if input.dtype == torch.float32:
58 max_tile_k_without_pipe = MAX_NRAM_SIZE // 4 // (4 * TILE_N + 1)
59 TILE_K = config.kwargs["TILE_K"] = align(max_tile_k_without_pipe)
60 num_stages = config.num_stages = 3
61 key = (TILE_K, TILE_N, num_warps, num_stages)
62 configs_map.setdefault(key, config)
63 else:
64 key = (TILE_K, TILE_N, num_warps, num_stages)
65 configs_map.setdefault(key, config)
66 pruned_configs = []
67 for k, v in configs_map.items():
68 pruned_configs.append(v)
69 extra_config = copy.deepcopy(pruned_configs[0])
70 extra_config.kwargs["TILE_K"] = 1
71 extra_config.kwargs["TILE_N"] = N
72 extra_config.num_warps = 1
73 extra_config.num_stages = 3
74 pruned_configs.append(extra_config)
75 extra_config2 = copy.deepcopy(extra_config)
76 extra_config2.num_stages = 1
77 pruned_configs.append(extra_config2)
78 return pruned_configs
81def softmax_tile_mode_for_non_inner(M, N, K, TILE_N, TILE_K):
82 one_tile_k = TILE_K * max(TOTAL_CORE_NUM // M, 1) >= K
83 one_tile_n = TILE_N >= N
84 if one_tile_n and one_tile_k:
85 return 0
86 elif one_tile_n and not one_tile_k:
87 return 1
88 else:
89 return 2
92@libentry()
93@libtuner(
94 configs=runtime.get_tuned_config("softmax_non_inner"),
95 key=[
96 "N",
97 "K",
98 ],
99 prune_configs_by={"early_config_prune": config_prune1},
100)
101@triton.heuristics(runtime.get_heuristic_config("softmax_non_inner"))
102@triton.jit
103def softmax_kernel_non_inner(
104 output_ptr,
105 input_ptr,
106 M,
107 N,
108 K,
109 TILE_N: tl.constexpr,
110 TILE_K: tl.constexpr,
111 TILE_MODE: tl.constexpr,
112):
113 pid_m = tl.program_id(0)
114 pid_k = tl.program_id(1)
116 p_k_num = tl.num_programs(axis=1)
117 split_k = tl.cdiv(K, p_k_num)
118 k_start = pid_k * split_k
120 if TILE_MODE == 0:
121 n_offset = tl.arange(0, TILE_N)
122 k_offset = pid_k * TILE_K + tl.arange(0, TILE_K)
123 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :]
124 mask = k_offset[None, :] < K
125 input_ptrs = input_ptr + offset
126 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32)
127 row_minus_max = inp - tl.max(inp, axis=0)[None, :]
128 numerator = tl.exp(row_minus_max)
129 denominator = tl.sum(numerator, axis=0)[None, :]
130 recip = 1.0 / denominator
131 softmax_output = numerator * recip
132 output_ptrs = output_ptr + offset
133 tl.store(output_ptrs, softmax_output, mask=mask)
134 elif TILE_MODE == 1:
135 for k_idx in range(0, split_k, TILE_K):
136 k_offset = k_start + k_idx + tl.arange(0, TILE_K)
137 n_offset = tl.arange(0, TILE_N)
138 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :]
139 mask = k_offset[None, :] < K
140 input_ptrs = input_ptr + offset
141 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32)
142 row_minus_max = inp - tl.max(inp, axis=0)[None, :]
143 numerator = tl.exp(row_minus_max)
144 denominator = tl.sum(numerator, axis=0)[None, :]
145 recip = 1.0 / denominator
146 softmax_output = numerator * recip
147 output_ptrs = output_ptr + offset
148 tl.store(output_ptrs, softmax_output, mask=mask)
149 else:
150 for k_idx in range(0, split_k, TILE_K):
151 k_offset = k_start + k_idx + tl.arange(0, TILE_K)
152 m = tl.full([TILE_N, TILE_K], value=float("-inf"), dtype=tl.float32)
153 z = tl.full([TILE_N, TILE_K], value=0.0, dtype=tl.float32)
154 # specialization does not improve performance inn this example, as tested
155 for start_n in range(0, N, TILE_N):
156 n_offset = start_n + tl.arange(0, TILE_N)
157 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :]
158 mask = (n_offset[:, None] < N) & (k_offset[None, :] < K)
159 inp = tl.load(input_ptr + offset, mask=mask, other=-float("inf")).to(
160 tl.float32
161 )
162 m_new = tl.maximum(m, inp)
163 all_neg_inf = m_new == float("-inf")
164 z = tl.where(
165 all_neg_inf, z, z * tl.exp(m - m_new) + tl.exp(inp - m_new)
166 )
167 m = m_new
168 m_reduced = tl.max(m, 0) # (TILE_K,)
169 z = tl.sum(z * tl.exp(m - m_reduced[None, :]), 0) # (TILE_K, )
170 recip_z = 1.0 / z
171 m = m_reduced
172 # specialization does not improve performance inn this example, as tested
173 for start_n in range(0, N, TILE_N):
174 n_offset = start_n + tl.arange(0, TILE_N)
175 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :]
176 mask = (n_offset[:, None] < N) & (k_offset[None, :] < K)
177 inp = tl.load(input_ptr + offset, mask=mask, other=-float("inf")).to(
178 tl.float32
179 )
180 o = tl.exp(inp - m[None, :]) * recip_z[None, :]
181 tl.store(output_ptr + offset, o, mask=mask)
184def config_prune2(configs, named_args, **kwargs):
185 M = named_args["M"]
186 N = named_args["N"]
187 input = named_args["input_ptr"]
188 configs_map = {}
189 # When N is less than MAX_C_MLU_SOFTMAX_FORWARD, no reduction loops
190 for config in configs:
191 kw = config.kwargs
192 BLOCK_M, BLOCK_N, num_warps, num_stages = (
193 kw["BLOCK_M"],
194 kw["BLOCK_N"],
195 config.num_warps,
196 config.num_stages,
197 )
198 if N < MAX_N:
199 config = copy.deepcopy(config)
200 BLOCK_N = config.kwargs["BLOCK_N"] = N
201 m_per_core = math.ceil(M / TOTAL_CORE_NUM)
202 BLOCK_M = config.kwargs["BLOCK_M"] = m_per_core
203 num_stages = config.num_stages = 1
204 key = (BLOCK_M, BLOCK_N, num_warps, num_stages)
205 configs_map.setdefault(key, config)
207 config = copy.deepcopy(config)
208 max_block_m_without_pipe = MAX_NRAM_SIZE // 4 // (2 * BLOCK_N + 1)
209 BLOCK_M = config.kwargs["BLOCK_M"] = align(max_block_m_without_pipe)
210 num_stages = config.num_stages = 1
211 key = (BLOCK_M, BLOCK_N, num_warps, num_stages)
212 configs_map.setdefault(key, config)
214 config = copy.deepcopy(config)
215 max_block_m_without_pipe = MAX_NRAM_SIZE // 4 // (4 * BLOCK_N + 1)
216 if input.dtype == torch.float32:
217 max_block_m_without_pipe = MAX_NRAM_SIZE // 4 // (6 * BLOCK_N + 1)
218 BLOCK_M = config.kwargs["BLOCK_M"] = align(max_block_m_without_pipe)
219 num_stages = config.num_stages = 3
220 key = (BLOCK_M, BLOCK_N, num_warps, num_stages)
221 configs_map.setdefault(key, config)
222 key = (BLOCK_M, BLOCK_N, num_warps, num_stages)
223 # Only keep one config for the same key
224 configs_map.setdefault(key, config)
225 pruned_configs = []
226 for k, v in configs_map.items():
227 pruned_configs.append(v)
228 # Add a heuristic config.
229 extra_config = copy.deepcopy(pruned_configs[0])
230 extra_config.kwargs["BLOCK_M"] = 1
231 extra_config.kwargs["BLOCK_N"] = N
232 extra_config.num_warps = 1
233 extra_config.num_stages = 3
234 pruned_configs.append(extra_config)
235 extra_config2 = copy.deepcopy(extra_config)
236 extra_config2.num_stages = 1
237 pruned_configs.append(extra_config2)
238 return pruned_configs
241def softmax_tile_mode_for_inner(args):
242 one_tile_m = args["BLOCK_M"] * TOTAL_CORE_NUM >= args["M"]
243 one_tile_n = args["BLOCK_N"] >= args["N"]
244 if one_tile_n and one_tile_m:
245 return 0
246 elif one_tile_n and not one_tile_m:
247 return 1
248 else:
249 return 2
252@libentry()
253@libtuner(
254 configs=runtime.get_tuned_config("softmax_inner"),
255 key=[
256 "M",
257 "N",
258 ],
259 prune_configs_by={"early_config_prune": config_prune2},
260)
261@triton.heuristics(runtime.get_heuristic_config("softmax_inner"))
262@triton.jit
263def softmax_kernel_inner(
264 output_ptr,
265 input_ptr,
266 M,
267 N,
268 BLOCK_M: tl.constexpr,
269 BLOCK_N: tl.constexpr,
270 TILE_MODE: tl.constexpr,
271):
272 pid_m = tl.program_id(0)
273 pnum = tl.num_programs(axis=0)
274 split_m = tl.cdiv(M, pnum)
275 m_start = pid_m * split_m
277 if TILE_MODE == 0:
278 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
279 n_offset = tl.arange(0, BLOCK_N)
280 offset = m_offset[:, None] * N + n_offset[None, :]
281 mask = m_offset[:, None] < M
282 input_ptrs = input_ptr + offset
283 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32)
284 row_minus_max = inp - tl.max(inp, axis=1)[:, None]
285 numerator = tl.exp(row_minus_max)
286 denominator = tl.sum(numerator, axis=1)[:, None]
287 recip = 1.0 / denominator
288 softmax_output = numerator * recip
289 output_ptrs = output_ptr + offset
290 tl.store(output_ptrs, softmax_output, mask=mask)
291 elif TILE_MODE == 1:
292 for m_idx in range(0, split_m, BLOCK_M):
293 m_offset = m_start + m_idx + tl.arange(0, BLOCK_M)
294 n_offset = tl.arange(0, BLOCK_N)
295 offset = m_offset[:, None] * N + n_offset[None, :]
296 mask = m_offset[:, None] < M
297 input_ptrs = input_ptr + offset
298 inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(tl.float32)
299 trans_inp = tl.trans(inp)
300 row_minus_max = trans_inp - tl.max(trans_inp, axis=0)[None, :]
301 numerator = tl.exp(row_minus_max)
302 denominator = tl.sum(numerator, axis=0)[None, :]
303 recip = 1.0 / denominator
304 softmax_output = tl.trans(numerator * recip)
305 output_ptrs = output_ptr + offset
306 tl.store(output_ptrs, softmax_output, mask=mask)
307 else:
308 for m_idx in range(0, split_m, BLOCK_M):
309 m_offset = m_start + m_idx + tl.arange(0, BLOCK_M)
310 block_max = tl.full(
311 [BLOCK_M, BLOCK_N], value=float("-inf"), dtype=tl.float32
312 )
313 block_sum = tl.full([BLOCK_M, BLOCK_N], value=0.0, dtype=tl.float32)
314 # specialization does not improve performance inn this example, as tested
315 for start_n in range(0, N, BLOCK_N):
316 n_offset = start_n + tl.arange(0, BLOCK_N)
317 offset = m_offset[:, None] * N + n_offset[None, :]
318 mask = m_offset[:, None] < M and n_offset[None, :] < N
319 inp = tl.load(input_ptr + offset, mask=mask, other=-float("inf")).to(
320 tl.float32
321 )
322 cur_max = tl.maximum(block_max, inp)
323 all_neg_inf = cur_max == float("-inf")
324 block_sum = tl.where(
325 all_neg_inf,
326 block_sum,
327 block_sum * tl.exp(block_max - cur_max) + tl.exp(inp - cur_max),
328 )
329 block_max = cur_max
331 trans_block_max = tl.trans(block_max)
332 trans_block_sum = tl.trans(block_sum)
333 max_reduced = tl.max(trans_block_max, 0)
334 total_sum = tl.sum(
335 trans_block_sum * tl.exp(trans_block_max - max_reduced[None, :]), 0
336 )
337 recip_total_sum = 1.0 / total_sum
338 total_max = max_reduced
340 for start_n in range(0, N, BLOCK_N):
341 n_offset = start_n + tl.arange(0, BLOCK_N)
342 offset = m_offset[:, None] * N + n_offset[None, :]
343 mask = m_offset[:, None] < M and n_offset[None, :] < N
344 inp = tl.load(input_ptr + offset, mask=mask, other=-float("inf")).to(
345 tl.float32
346 )
347 o = tl.exp(inp - total_max[:, None]) * recip_total_sum[:, None]
348 tl.store(output_ptr + offset, o, mask=mask)
351@triton.jit
352def softmax_kernel_inner_k_partial_stats(
353 x_ptr,
354 max_buf_ptr,
355 sum_buf_ptr,
356 M,
357 N,
358 T,
359 BLOCK_M: tl.constexpr,
360 BLOCK_N: tl.constexpr,
361):
362 pnum = tl.num_programs(axis=0)
363 pid = tl.program_id(0)
364 total_blocks = (M // BLOCK_M) * T
365 work_per_core = (total_blocks + pnum - 1) // pnum
366 start = pid * work_per_core
367 end = tl.minimum(start + work_per_core, total_blocks)
369 for task in range(start, end):
370 row_id = task // T
371 tile_id = task % T
373 offs_m = row_id * BLOCK_M + tl.arange(0, BLOCK_M)
374 offs_n = tile_id * BLOCK_N + tl.arange(0, BLOCK_N)
375 mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
377 tile = tl.load(
378 x_ptr + offs_m[:, None] * N + offs_n[None, :],
379 mask=mask,
380 other=-float("inf"),
381 ).to(tl.float32)
383 tile_max = tl.max(tile, axis=1)
384 all_neg_inf = tile_max == -float("inf")
386 tile_sum = tl.where(
387 all_neg_inf,
388 0.0,
389 tl.sum(tl.exp(tile - tile_max[:, None]), axis=1),
390 )
392 tl.store(max_buf_ptr + offs_m * T + tile_id, tile_max, mask=(offs_m < M))
393 tl.store(sum_buf_ptr + offs_m * T + tile_id, tile_sum, mask=(offs_m < M))
396@triton.jit
397def softmax_kernel_inner_k_merge_stats(
398 max_buf_ptr,
399 sum_buf_ptr,
400 gmax_ptr,
401 gsum_ptr,
402 M: tl.constexpr,
403 T: tl.constexpr,
404 BLOCK_M: tl.constexpr,
405):
406 pid_m = tl.program_id(axis=0)
407 offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # [BM]
408 mask_m = offs_m < M
409 tile_max = tl.load(
410 max_buf_ptr + offs_m[:, None] * T + tl.arange(0, T)[None, :],
411 mask=(offs_m[:, None] < M),
412 other=-float("inf"),
413 )
414 tile_sum = tl.load(
415 sum_buf_ptr + offs_m[:, None] * T + tl.arange(0, T)[None, :],
416 mask=(offs_m[:, None] < M),
417 other=0.0,
418 ).to(tl.float32)
420 gmax = tl.max(tile_max, axis=1)
421 scale = tl.exp(tile_max - gmax[:, None])
422 scale = tl.where(gmax[:, None] == -float("inf"), 0.0, scale)
423 gsum = tl.sum(tile_sum * scale, axis=1)
425 tl.store(gmax_ptr + offs_m, gmax, mask=mask_m)
426 tl.store(gsum_ptr + offs_m, gsum, mask=mask_m)
429@triton.jit
430def softmax_kernel_inner_k_write_softmax(
431 x_ptr,
432 y_ptr,
433 gmax_ptr,
434 gsum_ptr,
435 M,
436 N,
437 T,
438 BLOCK_M: tl.constexpr,
439 BLOCK_N: tl.constexpr,
440):
441 pnum = tl.num_programs(axis=0)
442 pid = tl.program_id(0)
443 total_blocks = (M // BLOCK_M) * T
444 work_per_core = (total_blocks + pnum - 1) // pnum
445 start = pid * work_per_core
446 end = tl.minimum(start + work_per_core, total_blocks)
448 for task in range(start, end):
449 row_id = task // T
450 tile_id = task % T
452 offs_m = row_id * BLOCK_M + tl.arange(0, BLOCK_M)
453 offs_n = tile_id * BLOCK_N + tl.arange(0, BLOCK_N)
454 mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
456 # load global stats
457 gmax = tl.load(gmax_ptr + offs_m, mask=(offs_m < M), other=-float("inf")).to(
458 tl.float32
459 )
460 gsum = tl.load(gsum_ptr + offs_m, mask=(offs_m < M), other=0.0).to(tl.float32)
462 # load tile
463 tile = tl.load(
464 x_ptr + offs_m[:, None] * N + offs_n[None, :],
465 mask=mask,
466 other=-float("inf"),
467 ).to(tl.float32)
469 valid = gsum[:, None] > 0
471 out = tl.where(
472 valid,
473 tl.exp(tile - gmax[:, None]) / gsum[:, None],
474 0.0,
475 )
477 tl.store(y_ptr + offs_m[:, None] * N + offs_n[None, :], out, mask=mask)
480# ------------------------ backward -------------------------------
483def nram_usage_for_backward_non_inner(bn, bk, tile_mode, num_stages, dtype):
484 coef = 1
485 if tile_mode == 0:
486 coef = 3
487 elif tile_mode == 1:
488 if num_stages == 1:
489 coef = 3
490 else:
491 if dtype == torch.float32:
492 coef = 7
493 else:
494 coef = 6
495 else:
496 if num_stages == 1:
497 coef = 5
498 else:
499 if dtype == torch.float32:
500 coef = 13
501 else:
502 coef = 10
503 return (coef * bn + 1) * bk * 4
506def config_prune3(configs, named_args, **kwargs):
507 M = named_args["M"]
508 N = named_args["N"]
509 K = named_args["K"]
510 output = named_args["output_ptr"]
511 dtype = output.dtype
512 k_per_core = math.ceil(K / max(TOTAL_CORE_NUM // M, 1))
513 # No need for any loop.
514 if nram_usage_for_backward_non_inner(N, k_per_core, 0, 1, dtype) < MAX_NRAM_SIZE:
515 config = copy.deepcopy(configs[0])
516 config.kwargs["TILE_K"] = k_per_core
517 config.kwargs["TILE_N"] = N
518 config.num_stages = 1
519 return [config]
520 align_num = 256 // 4 if dtype == torch.float32 else 256 // 2
521 pruned_configs = []
522 for config in configs:
523 kw = config.kwargs
524 TILE_K, TILE_N, num_stages = (
525 kw["TILE_K"],
526 kw["TILE_N"],
527 config.num_stages,
528 )
529 # Align the lowest dimension to 256B while loading/storing data.
530 if TILE_K % align_num != 0:
531 continue
532 # nram usage shoule be smaller than MAX_NRAM_SIZE
533 mode = softmax_tile_mode_for_non_inner(M, N, K, TILE_N, TILE_K)
534 nram = nram_usage_for_backward_non_inner(
535 TILE_N, TILE_K, mode, num_stages, dtype
536 )
537 if nram > MAX_NRAM_SIZE or nram < MAX_NRAM_SIZE // 2:
538 continue
539 pruned_configs.append(config)
540 return pruned_configs
543@libentry()
544@libtuner(
545 configs=runtime.get_tuned_config("softmax_non_inner_bw"),
546 key=[
547 "N",
548 "K",
549 ],
550 prune_configs_by={"early_config_prune": config_prune3},
551)
552@triton.heuristics(runtime.get_heuristic_config("softmax_backward_non_inner"))
553@triton.jit
554def softmax_backward_kernel_non_inner(
555 output_ptr,
556 out_grad_ptr,
557 in_grad_ptr,
558 M,
559 N,
560 K,
561 TILE_N: tl.constexpr,
562 TILE_K: tl.constexpr,
563 TILE_MODE: tl.constexpr,
564):
565 pid_m = tl.program_id(0)
566 pid_k = tl.program_id(1)
568 p_k_num = tl.num_programs(axis=1)
569 split_k = tl.cdiv(K, p_k_num)
570 k_start = pid_k * split_k
572 if TILE_MODE == 0:
573 n_offset = tl.arange(0, TILE_N)
574 k_offset = pid_k * TILE_K + tl.arange(0, TILE_K)
575 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :]
576 mask = k_offset[None, :] < K
577 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32)
578 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32)
579 scale = tl.sum(out_tile * out_grad_tile, axis=0)
580 in_grad_tile = out_tile * (out_grad_tile - scale[None, :])
581 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask)
582 elif TILE_MODE == 1:
583 for k_idx in range(0, split_k, TILE_K):
584 k_offset = k_start + k_idx + tl.arange(0, TILE_K)
585 n_offset = tl.arange(0, TILE_N)
586 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :]
587 mask = k_offset[None, :] < K and n_offset[:, None] < N
588 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32)
589 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32)
590 scale = tl.sum(out_tile * out_grad_tile, axis=0)
591 in_grad_tile = out_tile * (out_grad_tile - scale[None, :])
592 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask)
593 else:
594 for k_idx in range(0, split_k, TILE_K):
595 k_offset = k_start + k_idx + tl.arange(0, TILE_K)
596 scale = tl.zeros([TILE_N, TILE_K], dtype=tl.float32)
597 # specialization does not improve performance inn this example, as tested
598 for start_n in range(0, N, TILE_N):
599 n_offset = start_n + tl.arange(0, TILE_N)
600 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :]
601 mask = (n_offset[:, None] < N) & (k_offset[None, :] < K)
602 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32)
603 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32)
604 scale += out_tile * out_grad_tile
605 scale = tl.sum(scale, axis=0)
606 for start_n in range(0, N, TILE_N):
607 n_offset = start_n + tl.arange(0, TILE_N)
608 offset = pid_m * N * K + n_offset[:, None] * K + k_offset[None, :]
609 mask = (n_offset[:, None] < N) & (k_offset[None, :] < K)
610 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32)
611 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32)
612 in_grad_tile = out_tile * (out_grad_tile - scale[None, :])
613 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask)
616def config_prune4(configs, named_args, **kwargs):
617 M = named_args["M"]
618 N = named_args["N"]
619 output = named_args["output_ptr"]
620 configs_map = {}
621 # When N is less than MAX_C_MLU_SOFTMAX_FORWARD, no reduction loops
622 for config in configs:
623 kw = config.kwargs
624 BLOCK_M, BLOCK_N, num_warps, num_stages = (
625 kw["BLOCK_M"],
626 kw["BLOCK_N"],
627 config.num_warps,
628 config.num_stages,
629 )
630 if N < MAX_N:
631 config = copy.deepcopy(config)
632 BLOCK_N = config.kwargs["BLOCK_N"] = N
633 m_per_core = math.ceil(M / TOTAL_CORE_NUM)
634 BLOCK_M = config.kwargs["BLOCK_M"] = m_per_core
635 num_stages = config.num_stages = 1
636 key = (BLOCK_M, BLOCK_N, num_warps, num_stages)
637 configs_map.setdefault(key, config)
639 config = copy.deepcopy(config)
640 max_block_m_without_pipe = MAX_NRAM_SIZE // 4 // (3 * BLOCK_N + 1)
641 BLOCK_M = config.kwargs["BLOCK_M"] = align(max_block_m_without_pipe)
642 num_stages = config.num_stages = 1
643 key = (BLOCK_M, BLOCK_N, num_warps, num_stages)
644 configs_map.setdefault(key, config)
646 config = copy.deepcopy(config)
647 max_block_m_without_pipe = MAX_NRAM_SIZE // 4 // (6 * BLOCK_N + 1)
648 if output.dtype == torch.float32:
649 max_block_m_without_pipe = MAX_NRAM_SIZE // 4 // (7 * BLOCK_N + 1)
650 BLOCK_M = config.kwargs["BLOCK_M"] = align(max_block_m_without_pipe)
651 num_stages = config.num_stages = 3
652 key = (BLOCK_M, BLOCK_N, num_warps, num_stages)
653 configs_map.setdefault(key, config)
654 key = (BLOCK_M, BLOCK_N, num_warps, num_stages)
655 # Only keep one config for the same key
656 configs_map.setdefault(key, config)
657 pruned_configs = []
658 for k, v in configs_map.items():
659 pruned_configs.append(v)
660 # Add a heuristic config.
661 extra_config = copy.deepcopy(pruned_configs[0])
662 extra_config.kwargs["BLOCK_M"] = 1
663 extra_config.kwargs["BLOCK_N"] = N
664 extra_config.num_warps = 1
665 extra_config.num_stages = 3
666 pruned_configs.append(extra_config)
667 extra_config2 = copy.deepcopy(extra_config)
668 extra_config2.num_stages = 1
669 pruned_configs.append(extra_config2)
670 return pruned_configs
673@libentry()
674@libtuner(
675 configs=runtime.get_tuned_config("softmax_inner_bw"),
676 key=[
677 "M",
678 "N",
679 ],
680 prune_configs_by={"early_config_prune": config_prune4},
681)
682@triton.heuristics(
683 values=runtime.get_heuristic_config("softmax_backward_inner"),
684)
685@triton.jit
686def softmax_backward_kernel_inner(
687 output_ptr,
688 out_grad_ptr,
689 in_grad_ptr,
690 M,
691 N,
692 BLOCK_M: tl.constexpr,
693 BLOCK_N: tl.constexpr,
694 TILE_MODE: tl.constexpr,
695):
696 pid_m = tl.program_id(0)
697 pnum = tl.num_programs(axis=0)
698 split_m = tl.cdiv(M, pnum)
699 m_start = pid_m * split_m
701 if TILE_MODE == 0:
702 m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
703 n_offset = tl.arange(0, BLOCK_N)
704 offset = m_offset[:, None] * N + n_offset[None, :]
705 mask = m_offset[:, None] < M
706 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32)
707 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32)
708 scale = tl.sum(out_tile * out_grad_tile, 1)
709 in_grad_tile = out_tile * (out_grad_tile - scale[:, None])
710 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask)
711 elif TILE_MODE == 1:
712 for m_idx in range(0, split_m, BLOCK_M):
713 m_offset = m_start + m_idx + tl.arange(0, BLOCK_M)
714 n_offset = tl.arange(0, BLOCK_N)
715 offset = m_offset[:, None] * N + n_offset[None, :]
716 mask = m_offset[:, None] < M
717 out_tile = tl.load(output_ptr + offset, mask=mask).to(tl.float32)
718 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32)
719 scale = tl.sum(out_tile * out_grad_tile, 1)
720 in_grad_tile = out_tile * (out_grad_tile - scale[:, None])
721 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask)
722 else:
723 for m_idx in range(0, split_m, BLOCK_M):
724 m_offset = m_start + m_idx + tl.arange(0, BLOCK_M)
725 scale = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
726 for start_n in range(0, N, BLOCK_N):
727 n_offset = start_n + tl.arange(0, BLOCK_N)
728 offset = m_offset[:, None] * N + n_offset[None, :]
729 mask = m_offset[:, None] < M and n_offset[None, :] < N
730 out_tile = tl.load(
731 output_ptr + offset, mask=mask, eviction_policy="evict_last"
732 ).to(tl.float32)
733 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32)
734 scale += out_tile * out_grad_tile
735 scale = tl.sum(scale, 1)
736 for start_n in range(0, N, BLOCK_N):
737 n_offset = start_n + tl.arange(0, BLOCK_N)
738 offset = m_offset[:, None] * N + n_offset[None, :]
739 mask = m_offset[:, None] < M and n_offset[None, :] < N
740 out_tile = tl.load(
741 output_ptr + offset, mask=mask, eviction_policy="evict_first"
742 ).to(tl.float32)
743 out_grad_tile = tl.load(out_grad_ptr + offset, mask=mask).to(tl.float32)
744 in_grad_tile = out_tile * (out_grad_tile - scale[:, None])
745 tl.store(in_grad_ptr + offset, in_grad_tile, mask=mask)
748def softmax(self, dim, half_to_float=False):
749 logger.debug("GEMS_CAMBRICON SOFTMAX")
751 assert dim >= -self.ndim and dim < self.ndim, "Invalid dim"
753 # special handling for dim = 0 and empty tensor
754 if self.numel() == 0:
755 # empty tensor, return the same shape with 1's
756 out_shape = list(self.shape)
757 out = torch.empty(out_shape, dtype=self.dtype, device=self.device)
758 zero_(out)
759 return out
761 dim = dim % self.ndim
762 M = 1
763 N = self.shape[dim]
764 for i in range(dim):
765 M *= self.shape[i] # pre_dim
766 self = self.contiguous()
767 if half_to_float:
768 dtype = torch.float32
769 else:
770 dtype = self.dtype
771 out = torch.empty_like(self, dtype=dtype)
772 K = self.numel() // M // N # post_dim
774 with torch_device_fn.device(self.device):
775 if K > 1:
776 logger.debug("GEMS_CAMBRICON SOFTMAX USE NON INNER")
777 grid = lambda meta: (M, max(TOTAL_CORE_NUM // M, 1), 1)
778 softmax_kernel_non_inner[grid](
779 out,
780 self,
781 M,
782 N,
783 K,
784 )
785 else:
786 logger.debug("GEMS_CAMBRICON SOFTMAX USE INNER")
787 if M > TOTAL_CORE_NUM or N < 1024 * 8 * 8:
788 softmax_kernel_inner[TOTAL_CORE_NUM, 1, 1](
789 out,
790 self,
791 M,
792 N,
793 )
794 else:
795 block_m = 1
796 block_n = 8192 * 4
797 if dtype is torch.float32:
798 block_n = 8192 * 2
799 # workspace
800 T = (N + block_n - 1) // block_n
801 max_buf = torch.empty((M, T), device=self.device, dtype=torch.float32)
802 sum_buf = torch.empty((M, T), device=self.device, dtype=torch.float32)
803 gmax = torch.empty((M,), device=self.device, dtype=torch.float32)
804 gsum = torch.empty((M,), device=self.device, dtype=torch.float32)
805 # kernel 1: per-tile stats
806 softmax_kernel_inner_k_partial_stats[(TOTAL_CORE_NUM,)](
807 self,
808 max_buf,
809 sum_buf,
810 M,
811 N,
812 T,
813 BLOCK_M=block_m,
814 BLOCK_N=block_n,
815 bottleneck="simd",
816 num_stages=3,
817 )
818 # kernel 2: merge stats along N-tiles
819 grid_merge = (triton.cdiv(M, block_m),)
820 softmax_kernel_inner_k_merge_stats[grid_merge](
821 max_buf, sum_buf, gmax, gsum, M, T, BLOCK_M=block_m
822 )
823 block_n = block_n // 2
824 T = (N + block_n - 1) // block_n
825 # kernel 3: write normalized outputs
826 softmax_kernel_inner_k_write_softmax[(TOTAL_CORE_NUM,)](
827 self,
828 out,
829 gmax,
830 gsum,
831 M,
832 N,
833 T,
834 BLOCK_M=block_m,
835 BLOCK_N=block_n,
836 bottleneck="simd",
837 num_stages=3,
838 )
839 return out
842def softmax_backward(grad_output, output, dim, input_dtype):
843 logger.debug("GEMS_CAMBRICON SOFTMAX VJP")
845 assert dim >= -output.ndim and dim < output.ndim, "Invalid dim"
846 dim = dim % output.ndim
847 M = 1
848 N = output.shape[dim]
849 for i in range(dim):
850 M *= output.shape[i]
852 grad_output = grad_output.contiguous()
853 in_grad = torch.empty_like(output)
854 K = output.numel() // M // N
856 with torch_device_fn.device(in_grad.device):
857 if K > 1:
858 logger.debug("GEMS_CAMBRICON SOFTMAX VJP USE NON INNER")
859 grid = lambda meta: (M, max(TOTAL_CORE_NUM // M, 1), 1)
860 softmax_backward_kernel_non_inner[grid](
861 output,
862 grad_output,
863 in_grad,
864 M,
865 N,
866 K,
867 )
868 else:
869 logger.debug("GEMS_CAMBRICON SOFTMAX VJP USE INNER")
870 softmax_backward_kernel_inner[TOTAL_CORE_NUM, 1, 1](
871 output,
872 grad_output,
873 in_grad,
874 M,
875 N,
876 )
877 return in_grad