Coverage for src/flag_gems/runtime/backend/_tsingmicro/ops/flash_kernel.py: 0%
521 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import triton
2import triton.language as tl
4from flag_gems import runtime
5from flag_gems.utils import libentry, libtuner, tl_extra_shim
8@triton.jit
9def u64_to_lohi(x):
10 return (x >> 32).to(tl.uint32), (x & 0xFFFFFFFF).to(tl.uint32)
13@triton.jit
14def u64_from_lohi(lo, hi):
15 return hi.to(tl.uint64) << 32 + lo.to(tl.uint64)
18@triton.jit
19def philox_(seed, subsequence, offset):
20 kPhilox10A: tl.constexpr = 0x9E3779B9
21 kPhilox10B: tl.constexpr = 0xBB67AE85
22 k0, k1 = u64_to_lohi(seed.to(tl.uint64))
23 c0, c1 = u64_to_lohi(offset.to(tl.uint64))
24 c2, c3 = u64_to_lohi(subsequence.to(tl.uint64))
26 # pragma unroll
27 kPhiloxSA: tl.constexpr = 0xD2511F53
28 kPhiloxSB: tl.constexpr = 0xCD9E8D57
29 for _ in tl.static_range(6):
30 res0 = kPhiloxSA * c0.to(tl.uint64)
31 res1 = kPhiloxSB * c2.to(tl.uint64)
32 res0_x, res0_y = u64_to_lohi(res0)
33 res1_x, res1_y = u64_to_lohi(res1)
34 c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x
35 k0 += kPhilox10A
36 k1 += kPhilox10B
38 res0 = kPhiloxSA * c0.to(tl.uint64)
39 res1 = kPhiloxSB * c2.to(tl.uint64)
40 res0_x, res0_y = u64_to_lohi(res0)
41 res1_x, res1_y = u64_to_lohi(res1)
42 c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x
44 return c0, c1, c2, c3
47@triton.jit
48def apply_dropout_mask(
49 P,
50 mask,
51 encode_dropout_in_sign_bit: tl.constexpr,
52):
53 if encode_dropout_in_sign_bit:
54 P = tl.where(mask, -P, P)
55 else:
56 P = tl.where(mask, (P * 0).to(P.dtype), P)
57 return P
60@triton.jit
61def apply_dropout(
62 P,
63 row_start,
64 col_start,
65 n_cols,
66 bid,
67 hid,
68 philox_seed,
69 philox_offset,
70 p_dropout_uint8: tl.constexpr,
71 is_dropout: tl.constexpr,
72 encode_dropout_in_sign_bit: tl.constexpr,
73 NUM_HEADS: tl.constexpr,
74 BLOCK_M: tl.constexpr,
75 BLOCK_N: tl.constexpr,
76):
77 if is_dropout:
78 row_start = tl.multiple_of(row_start, BLOCK_M)
79 col_start = tl.multiple_of(col_start, BLOCK_N)
80 row = row_start + tl.arange(0, BLOCK_M)[:, None]
81 # Down scale col_idx by 4
82 col = col_start // 4 + tl.arange(0, BLOCK_N // 4)[None, :]
84 subsequence = row.to(tl.uint64) * n_cols + col.to(tl.uint64)
86 offset = philox_offset + bid * NUM_HEADS + hid
87 offset += subsequence * 0
88 r0, r1, r2, r3 = philox_(philox_seed, subsequence, offset)
90 r = tl.join(tl.join(r0, r1), tl.join(r2, r3)).reshape(BLOCK_M, BLOCK_N)
92 mask = (r & 0xFF) >= p_dropout_uint8
94 P = apply_dropout_mask(
95 P, mask, encode_dropout_in_sign_bit=encode_dropout_in_sign_bit
96 )
97 return P
100@triton.jit
101def apply_alibi(
102 S,
103 col_idx,
104 row_idx,
105 max_seqlen_q,
106 max_seqlen_k,
107 is_causal: tl.constexpr,
108 is_alibi: tl.constexpr,
109 alibi_slope: tl.constexpr = None,
110):
111 if is_alibi:
112 if is_causal:
113 # The row independent alibi bias renders the same attention output
114 # as with the standard alibi because softmax is shift invariant, i.e.,
115 # softmax(A + bias + const) = softamx(A + bias). The following two
116 # biases are no different if causal is true.
117 # bias_1 = [
118 # -4, -3, -2, X, X,
119 # -4, -3, -2, -1, X,
120 # -4, -3, -2, -1, 0,
121 # ]
122 # bias_2 = [
123 # -2, -1, 0, X, X,
124 # -3, -2, -1, 0, X,
125 # -4, -3, -2, -1, 0,
126 # ]
127 bias = alibi_slope * (-max_seqlen_k + 1 + col_idx[None, :]).to(tl.float32)
128 S += bias
129 else:
130 bias = -alibi_slope * tl.abs(
131 col_idx[None, :] - max_seqlen_k + max_seqlen_q - row_idx[:, None]
132 ).to(tl.float32)
133 S += bias
135 return S
138@triton.jit
139def apply_mask(
140 S,
141 col_idx,
142 row_idx,
143 max_seqlen_q,
144 max_seqlen_k,
145 window_size_left,
146 window_size_right,
147 is_even_mn: tl.constexpr,
148 is_causal: tl.constexpr,
149 is_local: tl.constexpr,
150):
151 need_mask = is_causal | is_local | (not is_even_mn)
152 # need_mask: tl.constexpr = is_causal | is_local
153 if need_mask:
154 # Extra care should be taken to void one-off errors: both col_lb and col_rb are inclusive!
155 col_lb = max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left)
156 col_rb = min(
157 max_seqlen_k - 1, row_idx + max_seqlen_k - max_seqlen_q + window_size_right
158 )
160 if is_causal:
161 S = tl.where(col_idx[None, :] > col_rb[:, None], float("-inf"), S)
163 if is_local:
164 S = tl.where(
165 (col_idx[None, :] > col_rb[:, None])
166 | (col_idx[None, :] < col_lb[:, None]),
167 float("-inf"),
168 S,
169 )
171 if (not is_local) & (not is_causal) & (not is_even_mn):
172 S = tl.where(col_idx[None, :] >= max_seqlen_k, float("-inf"), S)
174 return S
177@triton.jit
178def softmax_rescale(
179 O_acc,
180 S,
181 row_max,
182 row_sum,
183 softmax_scale_log2e: tl.constexpr,
184 is_border: tl.constexpr,
185 # is_init: tl.constexpr
186):
187 prev_max = row_max
188 row_max = tl.maximum(row_max, tl.max(S, 1))
190 if is_border:
191 cur_max = tl.where(row_max == float("-inf"), 0, row_max)
192 else:
193 cur_max = row_max
195 p_scale = tl.math.exp2((prev_max - cur_max) * softmax_scale_log2e)
196 row_sum *= p_scale
197 O_acc *= p_scale[:, None]
199 max_scaled = tl.where(row_max == float("-inf"), 0, row_max * softmax_scale_log2e)
200 P = tl.math.exp2(S * softmax_scale_log2e - max_scaled[:, None])
201 row_sum = row_sum + tl.sum(P, 1)
202 return O_acc, P, row_max, row_sum
205@triton.jit
206def apply_softcap(S, softcap, is_softcap: tl.constexpr):
207 if is_softcap:
208 S = tl_extra_shim.tanh(S * softcap)
210 return S
213def block_m_splitkv_heuristic(headdim):
214 return 128 if headdim <= 128 else 64
217def block_n_splitkv_heuristic(headdim):
218 return 64 if headdim <= 64 else 32
221def is_even_mn(M, N, BM, BN, WL, WR):
222 if M % BM == 0 and N % BN == 0:
223 if M % N == 0 or N % M == 0:
224 if (WL == -1 or WL % BN == 0) and (WR == -1 or WR % BN == 0):
225 return True
226 return False
229def block_m_splitkv_heuristic_spec_args(args):
230 return 128 if args["d"] <= 128 else 64
233def block_n_splitkv_heuristic_spec_args(args):
234 return 64 if args["d"] <= 64 else 32
237def is_even_mn_spec_args(args):
238 if (
239 args["seqlen_q"] % args["BLOCK_M"] == 0
240 and args["seqlen_k"] % args["BLOCK_N"] == 0
241 ):
242 if (
243 args["seqlen_q"] % args["seqlen_k"] == 0
244 or args["seqlen_k"] % args["seqlen_q"] == 0
245 ):
246 if (
247 args["window_size_left"] == -1
248 or args["window_size_left"] % args["BLOCK_N"] == 0
249 ) and (
250 args["window_size_right"] == -1
251 or args["window_size_right"] % args["BLOCK_N"] == 0
252 ):
253 return True
254 return False
257def keep(cfg):
258 BM = cfg.kwargs["BLOCK_M"]
259 BN = cfg.kwargs["BLOCK_N"]
260 w = cfg.num_warps
262 return (BM, BN, w) in ((128, 32, 4), (128, 128, 8))
265def prune_fwd_configs(configs, nargs, **kwargs):
266 is_dropout = nargs["is_dropout"]
267 if is_dropout:
268 return list(
269 filter(lambda cfg: cfg.num_warps == 4 and cfg.num_stages < 4, configs)
270 )
271 else:
272 seqlen_q = nargs["seqlen_q"]
273 if seqlen_q >= 1024:
274 return list(filter(lambda cfg: cfg.kwargs["BLOCK_M"] == 512, configs))
275 elif seqlen_q >= 512:
276 return list(filter(lambda cfg: cfg.kwargs["BLOCK_M"] == 256, configs))
277 elif seqlen_q >= 256:
278 return list(filter(lambda cfg: cfg.kwargs["BLOCK_M"] == 128, configs))
279 elif seqlen_q >= 128:
280 return list(filter(lambda cfg: cfg.kwargs["BLOCK_M"] == 64, configs))
281 elif seqlen_q >= 64:
282 return list(filter(lambda cfg: cfg.kwargs["BLOCK_M"] == 32, configs))
283 return configs
286@libentry()
287@libtuner(
288 # configs=list(filter(keep, runtime.get_tuned_config("attention"))),
289 configs=runtime.get_tuned_config("attention"),
290 prune_configs_by={"early_config_prune": prune_fwd_configs},
291 key=["seqlen_q", "d", "is_dropout"],
292 strategy=[
293 "align32",
294 "align32",
295 lambda a: a,
296 ],
297 warmup=1,
298 rep=1,
299)
300@triton.heuristics(
301 values={
302 "PRE_LOAD_V": lambda args: False,
303 "IS_EVEN_MN": lambda args: is_even_mn(
304 args["seqlen_q"],
305 args["seqlen_k"],
306 args["BLOCK_M"],
307 args["BLOCK_N"],
308 args["window_size_left"],
309 args["window_size_right"],
310 ),
311 }
312)
313@triton.jit(
314 do_not_specialize=["seqlen_q", "seqlen_k", "seqlen_q_rounded", "seqlen_k_rounded"]
315)
316def flash_fwd_kernel(
317 q_ptr,
318 k_ptr,
319 v_ptr,
320 o_ptr,
321 p_ptr,
322 softmax_lse_ptr,
323 q_row_stride,
324 k_row_stride,
325 v_row_stride,
326 q_head_stride,
327 k_head_stride,
328 v_head_stride,
329 o_row_stride,
330 o_head_stride,
331 q_batch_stride,
332 k_batch_stride,
333 v_batch_stride,
334 o_batch_stride,
335 is_cu_seqlens_q,
336 cu_seqlens_q_ptr,
337 is_cu_seqlens_k,
338 cu_seqlens_k_ptr,
339 is_seqused_k,
340 seqused_k_ptr,
341 # sizes
342 b: tl.constexpr,
343 bk: tl.constexpr,
344 h: tl.constexpr,
345 hk: tl.constexpr,
346 h_hk_ratio: tl.constexpr,
347 seqlen_q,
348 seqlen_k,
349 seqlen_q_rounded,
350 seqlen_k_rounded,
351 d: tl.constexpr,
352 d_rounded: tl.constexpr,
353 # scaling factors
354 is_softcap: tl.constexpr,
355 softcap: tl.constexpr,
356 scale_softmax: tl.constexpr,
357 scale_softmax_log2: tl.constexpr,
358 # dropout
359 is_dropout: tl.constexpr,
360 p_dropout: tl.constexpr,
361 rp_dropout: tl.constexpr,
362 p_dropout_in_uint8_t: tl.constexpr,
363 philox_args,
364 return_softmax: tl.constexpr,
365 # causal and swa
366 is_causal: tl.constexpr,
367 is_local: tl.constexpr,
368 window_size_left: tl.constexpr,
369 window_size_right: tl.constexpr,
370 seqlenq_ngroups_swapped: tl.constexpr,
371 # alibi
372 is_alibi: tl.constexpr,
373 alibi_slopes_ptr,
374 alibi_slopes_batch_stride: tl.constexpr,
375 # block table
376 total_q: tl.constexpr,
377 page_table_ptr,
378 page_table_batch_stride: tl.constexpr,
379 block_size: tl.constexpr,
380 # kernel params
381 IS_EVEN_MN: tl.constexpr,
382 PRE_LOAD_V: tl.constexpr,
383 BLOCK_M: tl.constexpr,
384 BLOCK_N: tl.constexpr,
385 num_warps: tl.constexpr,
386 num_stages: tl.constexpr,
387):
388 m_block = tl.program_id(0)
389 bh = tl.program_id(1)
390 hid = bh % h
391 bid = bh // h
392 num_m_blocks = tl.cdiv(seqlen_q, BLOCK_M)
394 # We draw a minimum covering frame on the attention map that this CTA is assigned to process.
395 # The frame edges are rounded to multiples of BLOCK_M and BLOCK_N for rows and columns respectively.
397 col_min = 0
398 if is_local:
399 col_min = max(0, m_block * BLOCK_M + seqlen_k - seqlen_q - window_size_left)
400 if not IS_EVEN_MN:
401 # round left
402 col_min = (col_min // BLOCK_N) * BLOCK_N
404 col_max = seqlen_k
405 if is_causal or is_local:
406 col_max += (m_block - num_m_blocks + 1) * BLOCK_M
407 if is_local:
408 col_max += window_size_right
409 col_max = min(seqlen_k, col_max)
411 if not IS_EVEN_MN:
412 # round right
413 col_max = tl.cdiv(col_max, BLOCK_N) * BLOCK_N
415 if (not is_causal) and (not is_local):
416 if IS_EVEN_MN:
417 masking_cols: tl.constexpr = 0
418 else:
419 masking_cols: tl.constexpr = BLOCK_N
420 elif (
421 is_causal | is_local
422 ) and IS_EVEN_MN: # causal implies window_size_right is zero
423 masking_cols: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) * BLOCK_N
424 else:
425 # local
426 masking_cols: tl.constexpr = (tl.cdiv(BLOCK_M, BLOCK_N) + 1) * BLOCK_N
428 if is_dropout:
429 philox_seed = tl.load(philox_args).to(tl.uint64)
430 philox_offset = tl.load(philox_args + 1).to(tl.uint64)
432 if is_alibi:
433 alibi_offset = bid * alibi_slopes_batch_stride + hid
434 alibi_slope = tl.load(alibi_slopes_ptr + alibi_offset)
435 alibi_slope /= scale_softmax
436 else:
437 alibi_slope = 0.0
439 q_batch_stride = tl.multiple_of(q_batch_stride, d * h)
440 q_ptr += bid * q_batch_stride + hid * q_head_stride
441 row_start = m_block * BLOCK_M
442 row_idx = row_start + tl.arange(0, BLOCK_M)
443 q_off = row_idx[:, None] * q_row_stride + tl.arange(0, d)[None, :]
444 qmask = row_idx[:, None] < seqlen_q
445 if IS_EVEN_MN:
446 Q = tl.load(q_ptr + q_off, cache_modifier=".cg")
447 else:
448 Q = tl.load(q_ptr + q_off, mask=qmask, cache_modifier=".cg")
450 if return_softmax:
451 p_ptr += (
452 (bid * h + hid) * seqlen_q_rounded + m_block * BLOCK_M
453 ) * seqlen_k_rounded
454 p_offset = tl.arange(0, BLOCK_M)[:, None] * seqlen_k_rounded + tl.arange(
455 0, BLOCK_N
456 )
457 p_bp0 = p_ptr + p_offset
459 acc_ = tl.zeros((BLOCK_M, d), dtype=tl.float32)
460 rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
461 rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
463 k_batch_stride = tl.multiple_of(k_batch_stride, d * hk)
464 h_hk_ratio = h // hk
465 k_ptr += bid * k_batch_stride
466 k_ptr += (hid // h_hk_ratio) * k_head_stride
467 v_ptr += bid * k_batch_stride
468 v_ptr += (hid // h_hk_ratio) * k_head_stride
470 k_offset = tl.arange(0, BLOCK_N)[None, :] * k_row_stride + tl.arange(0, d)[:, None]
471 v_offset = tl.arange(0, BLOCK_N)[:, None] * k_row_stride + tl.arange(0, d)[None, :]
473 p_bk0 = k_ptr + k_offset
474 p_bv0 = v_ptr + v_offset
476 if is_causal | is_local | (not IS_EVEN_MN):
477 # Cut short masking cols if there's not enough cols out there
478 masking_cols = min(col_max - col_min, masking_cols)
479 for col_shift in tl.range(0, masking_cols, step=BLOCK_N):
480 col_start = col_max - col_shift - BLOCK_N
481 col_start = tl.multiple_of(col_start, BLOCK_N)
482 off = col_start * k_row_stride
483 if IS_EVEN_MN:
484 K = tl.load(p_bk0 + off, cache_modifier=".cg")
485 if PRE_LOAD_V:
486 V = tl.load(p_bv0 + off, cache_modifier=".cg")
487 else:
488 col_idx = col_start + tl.arange(0, BLOCK_N)
489 kvmask = col_idx < seqlen_k
490 K = tl.load(p_bk0 + off, mask=kvmask[None, :], cache_modifier=".cg")
491 if PRE_LOAD_V:
492 V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg")
493 S = tl.dot(Q, K, allow_tf32=False)
494 S = apply_softcap(S, softcap, is_softcap)
495 col_idx = col_start + tl.arange(0, BLOCK_N)
496 row_idx = row_start + tl.arange(0, BLOCK_M)
497 S = apply_alibi(
498 S,
499 col_idx,
500 row_idx,
501 seqlen_q,
502 seqlen_k,
503 is_causal=is_causal,
504 is_alibi=is_alibi,
505 alibi_slope=alibi_slope,
506 )
507 # tl.store(p_bp0 + col_start, S)
508 S = apply_mask(
509 S,
510 col_idx,
511 row_idx,
512 seqlen_q,
513 seqlen_k,
514 window_size_left,
515 window_size_right,
516 is_even_mn=IS_EVEN_MN,
517 is_causal=is_causal,
518 is_local=is_local,
519 )
521 acc_, P, rowmax_, rowsum_ = softmax_rescale(
522 acc_,
523 S,
524 rowmax_,
525 rowsum_,
526 softmax_scale_log2e=scale_softmax_log2,
527 is_border=(is_causal or is_local),
528 )
529 P = P.to(v_ptr.type.element_ty)
531 if is_dropout:
532 if return_softmax:
533 P_drop = P
535 P_drop = apply_dropout(
536 P_drop,
537 row_start,
538 col_start,
539 seqlen_k,
540 bid,
541 hid,
542 philox_seed,
543 philox_offset,
544 p_dropout_in_uint8_t,
545 is_dropout,
546 encode_dropout_in_sign_bit=True,
547 NUM_HEADS=h,
548 BLOCK_M=BLOCK_M,
549 BLOCK_N=BLOCK_N,
550 )
551 if IS_EVEN_MN:
552 tl.store(p_bp0 + col_start, P_drop)
553 else:
554 kvmask = col_idx < seqlen_k
555 tl.store(
556 p_bp0 + col_start, P_drop, mask=qmask & kvmask[None, :]
557 )
559 P = apply_dropout(
560 P,
561 row_start,
562 col_start,
563 seqlen_k,
564 bid,
565 hid,
566 philox_seed,
567 philox_offset,
568 p_dropout_in_uint8_t,
569 is_dropout,
570 encode_dropout_in_sign_bit=False,
571 NUM_HEADS=h,
572 BLOCK_M=BLOCK_M,
573 BLOCK_N=BLOCK_N,
574 )
576 if not PRE_LOAD_V:
577 off = col_start * k_row_stride
578 if IS_EVEN_MN:
579 V = tl.load(p_bv0 + off, cache_modifier=".cg")
580 else:
581 kvmask = col_idx < seqlen_k
582 V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg")
583 acc_ = tl.dot(P, V, acc_, allow_tf32=False)
585 for col_start in tl.range(
586 col_min, col_max - masking_cols, step=BLOCK_N, num_stages=num_stages
587 ):
588 col_start = tl.multiple_of(col_start, BLOCK_N)
589 off = col_start * k_row_stride
590 K = tl.load(p_bk0 + off, cache_modifier=".cg")
591 if PRE_LOAD_V:
592 V = tl.load(p_bv0 + off, cache_modifier=".cg")
593 S = tl.dot(Q, K)
594 S = apply_softcap(S, softcap, is_softcap)
595 col_idx = col_start + tl.arange(0, BLOCK_N)
596 row_idx = row_start + tl.arange(0, BLOCK_M)
597 S = apply_alibi(
598 S,
599 col_idx,
600 row_idx,
601 seqlen_q,
602 seqlen_k,
603 is_causal=is_causal,
604 is_alibi=is_alibi,
605 alibi_slope=alibi_slope,
606 )
607 S = apply_mask(
608 S,
609 col_idx,
610 row_idx,
611 seqlen_q,
612 seqlen_k,
613 window_size_left,
614 window_size_right,
615 is_even_mn=True,
616 is_causal=False,
617 is_local=is_local,
618 )
620 acc_, P, rowmax_, rowsum_ = softmax_rescale(
621 acc_,
622 S,
623 rowmax_,
624 rowsum_,
625 softmax_scale_log2e=scale_softmax_log2,
626 is_border=is_local,
627 )
628 P = P.to(v_ptr.type.element_ty)
630 if is_dropout:
631 if return_softmax:
632 P_drop = P
633 P_drop = apply_dropout(
634 P_drop,
635 row_start,
636 col_start,
637 seqlen_k,
638 bid,
639 hid,
640 philox_seed,
641 philox_offset,
642 p_dropout_in_uint8_t,
643 is_dropout,
644 encode_dropout_in_sign_bit=True,
645 NUM_HEADS=h,
646 BLOCK_M=BLOCK_M,
647 BLOCK_N=BLOCK_N,
648 )
649 if IS_EVEN_MN:
650 tl.store(p_bp0 + col_start, P_drop)
651 else:
652 kvmask = col_idx < seqlen_k
653 tl.store(p_bp0 + col_start, P_drop, mask=qmask & kvmask[None, :])
655 P = apply_dropout(
656 P,
657 row_start,
658 col_start,
659 seqlen_k,
660 bid,
661 hid,
662 philox_seed,
663 philox_offset,
664 p_dropout_in_uint8_t,
665 is_dropout,
666 encode_dropout_in_sign_bit=False,
667 NUM_HEADS=h,
668 BLOCK_M=BLOCK_M,
669 BLOCK_N=BLOCK_N,
670 )
672 if not PRE_LOAD_V:
673 off = col_start * k_row_stride
674 V = tl.load(p_bv0 + off, cache_modifier=".cg")
675 acc_ = tl.dot(P, V, acc_)
677 # LSE
678 # Note, rowsum = exp(-rowmax) * exp(lse), therefore rowmax + log(rowsum) cancels
679 # the effect of rowmax and outputs lse only.
680 lse = tl.where(
681 rowsum_ == 0 | (rowsum_ != rowsum_),
682 float("inf"),
683 rowmax_ * scale_softmax + tl.log(rowsum_),
684 )
685 inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)
687 if is_dropout:
688 acc_ *= inv_sum[:, None] * rp_dropout
689 else:
690 acc_ *= inv_sum[:, None]
692 out = acc_.to(o_ptr.type.element_ty) # noqa
694 # Write back output
695 o_batch_stride = tl.multiple_of(o_batch_stride, d * h)
696 o_ptr += bid * o_batch_stride
697 o_ptr += hid * o_head_stride
698 o_offset = row_idx[:, None] * o_row_stride + tl.arange(0, d)
700 if IS_EVEN_MN:
701 tl.store(o_ptr + o_offset, out)
702 else:
703 tl.store(o_ptr + o_offset, out, mask=qmask)
705 # Write back lse
706 p_lse = softmax_lse_ptr + (bid * h + hid) * seqlen_q
707 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
709 if IS_EVEN_MN:
710 tl.store(p_lse + row_idx, lse)
711 else:
712 tl.store(p_lse + row_idx, lse, mask=row_idx < seqlen_q)
715@triton.jit(do_not_specialize=["seqlen_q", "seqlen_k"])
716def flash_fwd_bh_parallel_kernel():
717 # (TODO)
718 pass
721@libentry()
722@triton.heuristics(
723 values={
724 "BLOCK_M": block_m_splitkv_heuristic_spec_args,
725 "BLOCK_N": block_n_splitkv_heuristic_spec_args,
726 "num_warps": lambda args: 4,
727 "num_stages": lambda args: 3,
728 "PRE_LOAD_V": lambda args: True,
729 "IS_EVEN_MN": is_even_mn_spec_args,
730 }
731)
732@triton.jit(
733 do_not_specialize=["seqlen_q", "seqlen_k", "seqlen_q_rounded", "seqlen_k_rounded"]
734)
735def flash_fwd_splitkv_kernel(
736 q_ptr,
737 k_ptr,
738 v_ptr,
739 o_ptr,
740 p_ptr,
741 softmax_lse_ptr,
742 q_row_stride,
743 k_row_stride,
744 v_row_stride,
745 q_head_stride,
746 k_head_stride,
747 v_head_stride,
748 o_row_stride,
749 o_head_stride,
750 q_batch_stride,
751 k_batch_stride,
752 v_batch_stride,
753 o_batch_stride,
754 is_cu_seqlens_q,
755 cu_seqlens_q_ptr,
756 is_cu_seqlens_k: tl.constexpr,
757 cu_seqlens_k_ptr,
758 is_seqused_k: tl.constexpr,
759 seqused_k_ptr,
760 # sizes
761 b: tl.constexpr,
762 bk: tl.constexpr,
763 h: tl.constexpr,
764 hk: tl.constexpr,
765 h_hk_ratio: tl.constexpr,
766 seqlen_q,
767 seqlen_k,
768 seqlen_q_rounded,
769 seqlen_k_rounded,
770 d: tl.constexpr,
771 d_rounded: tl.constexpr,
772 # scaling factors
773 is_softcap: tl.constexpr,
774 softcap: tl.constexpr,
775 scale_softmax: tl.constexpr,
776 scale_softmax_log2: tl.constexpr,
777 # dropout
778 is_dropout: tl.constexpr,
779 p_dropout: tl.constexpr,
780 rp_dropout: tl.constexpr,
781 p_dropout_in_uint8_t: tl.constexpr,
782 philox_args,
783 return_softmax: tl.constexpr,
784 # causal and swa
785 is_causal: tl.constexpr,
786 is_local: tl.constexpr,
787 window_size_left: tl.constexpr,
788 window_size_right: tl.constexpr,
789 seqlenq_ngroups_swapped: tl.constexpr,
790 # alibi
791 is_alibi: tl.constexpr,
792 alibi_slopes_ptr,
793 alibi_slopes_batch_stride: tl.constexpr,
794 # block table
795 total_q,
796 page_table_ptr,
797 page_table_batch_stride: tl.constexpr,
798 block_size: tl.constexpr,
799 # kernel params
800 IS_EVEN_MN: tl.constexpr,
801 PRE_LOAD_V: tl.constexpr,
802 blocks_per_split: tl.constexpr,
803 BLOCK_M: tl.constexpr,
804 BLOCK_N: tl.constexpr,
805 num_warps: tl.constexpr,
806 num_stages: tl.constexpr,
807):
808 m_block = tl.program_id(0)
809 split_id = tl.program_id(1)
810 bid = tl.program_id(2) // h
811 hid = tl.program_id(2) % h
813 split_block_min = split_id * blocks_per_split
814 split_block_max = split_block_min + blocks_per_split
816 n_block_max = tl.cdiv(seqlen_k, BLOCK_N)
817 if is_causal:
818 n_block_max = min(
819 n_block_max,
820 tl.cdiv(
821 (m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + window_size_right,
822 BLOCK_N,
823 ),
824 )
826 if is_alibi:
827 alibi_offset = bid * alibi_slopes_batch_stride + hid
828 alibi_slope = tl.load(alibi_slopes_ptr + alibi_offset)
829 alibi_slope /= scale_softmax
830 else:
831 alibi_slope = 0
833 if not is_causal:
834 if IS_EVEN_MN:
835 masking_block_min = n_block_max
836 else:
837 masking_block_min = n_block_max - 1
838 elif is_causal and IS_EVEN_MN: # causal implies window_size_right is zero
839 masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N)
840 else:
841 masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N) - 1
843 q_ptr += bid * q_batch_stride
844 q_ptr += hid * q_head_stride
845 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
846 q_off = row_idx[:, None] * q_row_stride + tl.arange(0, d)[None, :]
847 p_qm = q_ptr + q_off
848 qmask = row_idx[:, None] < seqlen_q
849 if IS_EVEN_MN:
850 Q = tl.load(p_qm, cache_modifier=".cg")
851 else:
852 Q = tl.load(p_qm, mask=qmask, cache_modifier=".cg")
854 h_hk_ratio = h // hk
855 k_ptr += bid * k_batch_stride
856 k_ptr += (hid // h_hk_ratio) * k_head_stride
857 v_ptr += bid * k_batch_stride
858 v_ptr += (hid // h_hk_ratio) * k_head_stride
860 k_offset = tl.arange(0, BLOCK_N)[None, :] * k_row_stride + tl.arange(0, d)[:, None]
861 p_k0 = k_ptr + k_offset
863 v_offset = tl.arange(0, BLOCK_N)[:, None] * k_row_stride + tl.arange(0, d)[None, :]
864 p_v0 = v_ptr + v_offset
866 acc_ = tl.zeros((BLOCK_M, d), dtype=tl.float32)
867 rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
868 rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
870 if split_block_max <= masking_block_min:
871 # no masking needed
872 for n_block in tl.range(
873 split_block_min, split_block_max, num_stages=num_stages
874 ):
875 kv_off = n_block * BLOCK_N * k_row_stride
876 K = tl.load(p_k0 + kv_off, cache_modifier=".cg")
877 if PRE_LOAD_V:
878 V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
879 S = tl.dot(Q, K)
880 S = apply_softcap(S, softcap, is_softcap)
881 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
882 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
883 S = apply_alibi(
884 S,
885 col_idx,
886 row_idx,
887 seqlen_q,
888 seqlen_k,
889 is_causal=is_causal,
890 is_alibi=is_alibi,
891 alibi_slope=alibi_slope,
892 )
893 acc_, P, rowmax_, rowsum_ = softmax_rescale(
894 acc_,
895 S,
896 rowmax_,
897 rowsum_,
898 softmax_scale_log2e=scale_softmax_log2,
899 is_border=False,
900 )
902 if not PRE_LOAD_V:
903 V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
904 P = P.to(v_ptr.type.element_ty)
905 acc_ = tl.dot(P, V, acc_)
906 else:
907 for n_block in tl.range(split_block_min, min(split_block_max, n_block_max)):
908 kv_off = n_block * BLOCK_N * k_row_stride
909 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
910 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
911 if IS_EVEN_MN:
912 K = tl.load(p_k0 + kv_off, cache_modifier=".cg")
913 if PRE_LOAD_V:
914 V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
915 else:
916 kvmask = col_idx < seqlen_k
917 K = tl.load(p_k0 + kv_off, mask=kvmask[None, :], cache_modifier=".cg")
918 if PRE_LOAD_V:
919 V = tl.load(
920 p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg"
921 )
923 S = tl.dot(Q, K)
924 S = apply_softcap(S, softcap, is_softcap)
925 S = apply_alibi(
926 S,
927 col_idx,
928 row_idx,
929 seqlen_q,
930 seqlen_k,
931 is_causal=is_causal,
932 is_alibi=is_alibi,
933 alibi_slope=alibi_slope,
934 )
935 S = apply_mask(
936 S,
937 col_idx,
938 row_idx,
939 seqlen_q,
940 seqlen_k,
941 window_size_left,
942 window_size_right,
943 is_even_mn=IS_EVEN_MN,
944 is_causal=is_causal,
945 is_local=False,
946 )
948 acc_, P, rowmax_, rowsum_ = softmax_rescale(
949 acc_,
950 S,
951 rowmax_,
952 rowsum_,
953 softmax_scale_log2e=scale_softmax_log2,
954 is_border=(is_causal or is_local),
955 )
957 if not PRE_LOAD_V:
958 if IS_EVEN_MN:
959 V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
960 else:
961 V = tl.load(
962 p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg"
963 )
964 P = P.to(v_ptr.type.element_ty)
965 acc_ = tl.dot(P, V, acc_)
967 # LSE
968 lse = tl.where(
969 rowsum_ == 0 | (rowsum_ != rowsum_),
970 float("-inf"),
971 rowmax_ * scale_softmax + tl.log(rowsum_),
972 )
973 inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)
975 # Rescale output
976 acc_ *= inv_sum[:, None]
978 # Write back output
979 # o_splits layout = (n_splits, batch_size, num_heads, seqlen_q, head_size)
980 # grid = (seq_block, split, batch * head)
981 o_split_ptr = o_ptr
982 # + split, batch, head offsets, seq_block offsets are already added in row_idx
983 o_split_ptr += (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q * d
984 o_split_offset = row_idx[:, None] * d + tl.arange(0, d)
985 o_split_ptr = tl.multiple_of(o_split_ptr, d)
986 p_om = o_split_ptr + o_split_offset
988 if IS_EVEN_MN:
989 tl.store(p_om, acc_, cache_modifier=".cg")
990 else:
991 tl.store(p_om, acc_, mask=qmask, cache_modifier=".cg")
993 # Write back lse
994 # lse_splits layout = (n_splits, batch_size, num_heads, seqlen_q)
995 lse_split_ptr = softmax_lse_ptr
996 # + split, batch, head, seq_block offsets
997 lse_split_ptr += (
998 split_id * tl.num_programs(2) + tl.program_id(2)
999 ) * seqlen_q + m_block * BLOCK_M
1001 if IS_EVEN_MN:
1002 tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse, cache_modifier=".cg")
1003 else:
1004 tl.store(
1005 lse_split_ptr + tl.arange(0, BLOCK_M),
1006 lse,
1007 mask=row_idx < seqlen_q,
1008 cache_modifier=".cg",
1009 )
1012@libentry()
1013@triton.jit
1014def flash_fwd_splitkv_combine_kernel(
1015 out_ptr,
1016 lse_ptr,
1017 out_splits_ptr,
1018 lse_splits_ptr,
1019 head_size: tl.constexpr,
1020 out_b_stride,
1021 out_s_stride,
1022 out_h_stride,
1023 n_splits,
1024 BLOCK_M: tl.constexpr,
1025 q_total,
1026 MAX_N_SPLITS: tl.constexpr,
1027):
1028 pid = tl.program_id(0)
1029 lse_splits_ptr += pid * BLOCK_M
1030 lse_ptr += pid * BLOCK_M
1031 out_splits_ptr += pid * BLOCK_M * head_size
1032 out_ptr += pid * BLOCK_M * head_size
1033 lse_split_stride = tl.num_programs(0) * BLOCK_M
1034 out_split_stride = tl.num_programs(0) * BLOCK_M * head_size
1036 # Subtracting maximum from each of the split lse's for better numerical stability
1037 lse_split_offset = (
1038 tl.arange(0, BLOCK_M)[:, None]
1039 + tl.arange(0, MAX_N_SPLITS)[None, :] * lse_split_stride
1040 )
1041 lse_split_mask = (pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] < q_total) & (
1042 tl.arange(0, MAX_N_SPLITS)[None, :] < n_splits
1043 )
1044 lse_splits = tl.load(
1045 lse_splits_ptr + lse_split_offset, mask=lse_split_mask, other=float("-inf")
1046 )
1047 max_lse = tl.max(lse_splits, 1)
1049 # Sum exp(lse(i) - max_lse) over all split i to obtain Z=sumexp(QK) up to a scaled factor exp(-max_lse)
1050 Zi_scaled = tl.exp(lse_splits - max_lse[:, None])
1051 Z_scaled = tl.sum(Zi_scaled, 1)
1052 Zi_Z = Zi_scaled / Z_scaled[:, None]
1054 # Write back LSE
1055 lse = tl.log(Z_scaled) + max_lse
1056 out_mask = pid * BLOCK_M + tl.arange(0, BLOCK_M) < q_total
1057 tl.store(lse_ptr + tl.arange(0, BLOCK_M), lse, mask=out_mask)
1059 out_split_offset = (
1060 tl.arange(0, BLOCK_M)[:, None, None] * head_size
1061 + tl.arange(0, MAX_N_SPLITS)[None, :, None] * out_split_stride
1062 + tl.arange(0, head_size)[None, None, :]
1063 )
1064 out_split_mask = (
1065 pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None, None] < q_total
1066 ) & (tl.arange(0, MAX_N_SPLITS)[None, :, None] < n_splits)
1067 out_splits = tl.load(
1068 out_splits_ptr + out_split_offset, mask=out_split_mask, other=0
1069 )
1070 out = tl.sum(Zi_Z[:, :, None] * out_splits, 1)
1071 out = out.to(out_ptr.type.element_ty)
1073 # Write back output
1074 out_offset = tl.arange(0, BLOCK_M)[:, None] * out_s_stride + tl.arange(0, head_size)
1075 tl.store(out_ptr + out_offset, out, mask=out_mask[:, None])
1078@triton.jit
1079def block_to_cache_index(
1080 n_block, page_table_ptr, block_size, page_stride, row_stride, BLOCK_N
1081):
1082 row_index = n_block * BLOCK_N
1083 page_offset = row_index % block_size
1084 virtual_page_index = row_index // block_size
1085 # page_table_ptr is already pointed at the start of the current batch element
1086 cache_page_index = tl.load(page_table_ptr + virtual_page_index).to(tl.int32)
1087 return cache_page_index * block_size + page_offset
1090@libentry()
1091@triton.jit(
1092 do_not_specialize=[
1093 "seqlen_q",
1094 "seqlen_k",
1095 "seqlen_q_rounded",
1096 "seqlen_k_rounded",
1097 "total_q",
1098 ]
1099)
1100def flash_varlen_fwd_kernel(
1101 q_ptr,
1102 k_ptr,
1103 v_ptr,
1104 o_ptr,
1105 p_ptr,
1106 softmax_lse_ptr,
1107 q_row_stride,
1108 k_row_stride,
1109 v_row_stride,
1110 q_head_stride,
1111 k_head_stride,
1112 v_head_stride,
1113 o_row_stride,
1114 o_head_stride,
1115 q_batch_stride,
1116 k_batch_stride,
1117 v_batch_stride,
1118 o_batch_stride,
1119 is_cu_seqlens_q: tl.constexpr,
1120 cu_seqlens_q_ptr,
1121 is_cu_seqlens_k: tl.constexpr,
1122 cu_seqlens_k_ptr,
1123 is_seqused_k: tl.constexpr,
1124 seqused_k_ptr,
1125 # sizes
1126 b: tl.constexpr,
1127 bk: tl.constexpr,
1128 h: tl.constexpr,
1129 hk: tl.constexpr,
1130 h_hk_ratio: tl.constexpr,
1131 seqlen_q,
1132 seqlen_k,
1133 seqlen_q_rounded,
1134 seqlen_k_rounded,
1135 d: tl.constexpr,
1136 d_rounded: tl.constexpr,
1137 # scaling factors
1138 is_softcap: tl.constexpr,
1139 softcap: tl.constexpr,
1140 scale_softmax: tl.constexpr,
1141 scale_softmax_log2: tl.constexpr,
1142 # dropout
1143 is_dropout: tl.constexpr,
1144 p_dropout: tl.constexpr,
1145 rp_dropout: tl.constexpr,
1146 p_dropout_in_uint8_t: tl.constexpr,
1147 philox_args,
1148 return_softmax: tl.constexpr,
1149 # causal and swa
1150 is_causal: tl.constexpr,
1151 is_local: tl.constexpr,
1152 window_size_left: tl.constexpr,
1153 window_size_right: tl.constexpr,
1154 seqlenq_ngroups_swapped: tl.constexpr,
1155 # alibi
1156 is_alibi: tl.constexpr,
1157 alibi_slopes_ptr,
1158 alibi_slopes_batch_stride: tl.constexpr,
1159 # block table
1160 total_q,
1161 page_table_ptr,
1162 page_table_batch_stride: tl.constexpr,
1163 block_size: tl.constexpr,
1164 # kernel params
1165 BLOCK_M: tl.constexpr,
1166 BLOCK_N: tl.constexpr,
1167 num_warps: tl.constexpr,
1168 num_stages: tl.constexpr,
1169):
1170 m_block = tl.program_id(0)
1171 bid = tl.program_id(1)
1172 hid = tl.program_id(2)
1173 # num_m_blocks = tl.cdiv(seqlen_q, BLOCK_M)
1175 if is_cu_seqlens_q:
1176 q_eos = tl.load(cu_seqlens_q_ptr + bid + 1).to(tl.int32)
1177 q_bos = tl.load(cu_seqlens_q_ptr + bid).to(tl.int32)
1178 q_len = q_eos - q_bos
1179 # Current request's start offset in the batched Q
1180 q_offset = q_bos * q_row_stride
1181 o_offset = q_bos * o_row_stride
1182 lse_offset = q_bos * 1
1183 else:
1184 q_len = seqlen_q
1185 q_offset = bid * q_batch_stride
1186 o_offset = bid * o_batch_stride
1187 lse_offset = bid * seqlen_q
1189 if is_cu_seqlens_k:
1190 k_eos = tl.load(cu_seqlens_k_ptr + bid + 1).to(tl.int32)
1191 k_bos = tl.load(cu_seqlens_k_ptr + bid).to(tl.int32)
1192 k_len_cache = k_eos - k_bos
1193 # k_offset = k_bos * k_row_stride
1194 else:
1195 k_len_cache = seqlen_k
1196 # k_offset = bid * k_batch_stride
1198 # v_head_offset = (hid / h_hk_ratio) * k_head_stride
1200 if is_seqused_k:
1201 k_len = tl.load(seqused_k_ptr + bid).to(tl.int32)
1202 else:
1203 k_len = k_len_cache
1205 # Noop CTA
1206 if m_block * BLOCK_M > q_len:
1207 return
1209 # is_even_mn = (q_len % BLOCK_M == 0) and (k_len % BLOCK_N == 0)
1210 is_even_mn: tl.constexpr = False
1212 if is_local:
1213 n_block_min = max(
1214 0, (m_block * BLOCK_M + k_len - q_len - window_size_left) // BLOCK_N
1215 )
1216 else:
1217 n_block_min = 0
1219 n_block_max = tl.cdiv(k_len, BLOCK_N)
1220 if is_causal or is_local:
1221 n_block_max = min(
1222 n_block_max,
1223 tl.cdiv(
1224 (m_block + 1) * BLOCK_M + k_len - q_len + window_size_right, BLOCK_N
1225 ),
1226 )
1228 if is_dropout:
1229 philox_seed = tl.load(philox_args).to(tl.uint64)
1230 philox_offset = tl.load(philox_args + 1).to(tl.uint64)
1232 # start processing kv blocks
1233 page_table_ptr += bid * page_table_batch_stride
1234 q_row_offset = hid * q_head_stride
1235 k_row_offset = (hid // h_hk_ratio) * k_head_stride
1237 gQ = tl.make_block_ptr(
1238 base=q_ptr + q_offset + q_row_offset,
1239 shape=(q_len, d),
1240 strides=(q_row_stride, 1),
1241 offsets=(0, 0),
1242 block_shape=(BLOCK_M, d),
1243 order=(1, 0),
1244 )
1246 gK = tl.make_block_ptr(
1247 base=k_ptr + k_row_offset,
1248 shape=(d, bk * block_size),
1249 strides=(1, k_row_stride),
1250 offsets=(0, 0),
1251 block_shape=(d, BLOCK_N),
1252 order=(0, 1),
1253 )
1255 gV = tl.make_block_ptr(
1256 base=v_ptr + k_row_offset,
1257 shape=(bk * block_size, d),
1258 strides=(k_row_stride, 1),
1259 offsets=(0, 0),
1260 block_shape=(BLOCK_N, d),
1261 order=(1, 0),
1262 )
1264 bQ = tl.load(gQ.advance([m_block * BLOCK_M, 0]), boundary_check=(0,))
1266 acc_ = tl.zeros((BLOCK_M, d), dtype=tl.float32)
1267 rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
1268 rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
1270 if is_alibi:
1271 alibi_offset = bid * alibi_slopes_batch_stride + hid
1272 alibi_slope = tl.load(alibi_slopes_ptr + alibi_offset)
1273 alibi_slope /= scale_softmax
1274 else:
1275 alibi_slope = 0.0
1277 if not is_causal and not is_local:
1278 n_masking_steps = 1
1279 elif is_even_mn:
1280 n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N)
1281 else:
1282 n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N) + 1
1284 n_masking_steps = min(n_block_max - n_block_min, n_masking_steps)
1285 for step in tl.range(0, n_masking_steps):
1286 # for step in tl.range(1):
1287 n_block = n_block_max - 1 - step
1288 cache_row_index = block_to_cache_index(
1289 n_block,
1290 page_table_ptr,
1291 block_size,
1292 page_table_batch_stride,
1293 k_row_stride,
1294 BLOCK_N,
1295 )
1296 bK = tl.load(gK.advance([0, cache_row_index]), boundary_check=(1,))
1297 # preload V
1298 bV = tl.load(gV.advance([cache_row_index, 0]), boundary_check=(0,))
1299 S = tl.dot(bQ, bK, out_dtype=tl.float32)
1300 S = apply_softcap(S, softcap, is_softcap)
1301 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
1302 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
1303 S = apply_alibi(
1304 S,
1305 col_idx,
1306 row_idx,
1307 q_len,
1308 k_len,
1309 is_causal=is_causal,
1310 is_alibi=is_alibi,
1311 alibi_slope=alibi_slope,
1312 )
1313 S = apply_mask(
1314 S,
1315 col_idx,
1316 row_idx,
1317 q_len,
1318 k_len,
1319 window_size_left,
1320 window_size_right,
1321 is_even_mn=is_even_mn,
1322 is_causal=is_causal,
1323 is_local=is_local,
1324 )
1326 acc_, P, rowmax_, rowsum_ = softmax_rescale(
1327 acc_,
1328 S,
1329 rowmax_,
1330 rowsum_,
1331 softmax_scale_log2e=scale_softmax_log2,
1332 is_border=True,
1333 )
1334 P = P.to(v_ptr.type.element_ty)
1336 if is_dropout:
1337 P = apply_dropout(
1338 P,
1339 n_block * BLOCK_N,
1340 m_block * BLOCK_M,
1341 k_len,
1342 bid,
1343 hid,
1344 philox_seed,
1345 philox_offset,
1346 p_dropout_in_uint8_t,
1347 is_dropout,
1348 encode_dropout_in_sign_bit=False,
1349 NUM_HEADS=h,
1350 BLOCK_M=BLOCK_M,
1351 BLOCK_N=BLOCK_N,
1352 )
1354 acc_ = tl.dot(P, bV, acc_)
1356 for n_block in tl.range(
1357 n_block_max - n_masking_steps - 1, n_block_min - 1, step=-1
1358 ):
1359 cache_row_index = block_to_cache_index(
1360 n_block,
1361 page_table_ptr,
1362 block_size,
1363 page_table_batch_stride,
1364 k_row_stride,
1365 BLOCK_N,
1366 )
1367 bK = tl.load(gK.advance([0, cache_row_index]))
1368 # preload V
1369 bV = tl.load(gV.advance([cache_row_index, 0]))
1370 S = tl.dot(bQ, bK, out_dtype=tl.float32)
1371 S = apply_softcap(S, softcap, is_softcap)
1372 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
1373 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
1374 S = apply_alibi(
1375 S,
1376 col_idx,
1377 row_idx,
1378 q_len,
1379 k_len,
1380 is_causal=is_causal,
1381 is_alibi=is_alibi,
1382 alibi_slope=alibi_slope,
1383 )
1384 S = apply_mask(
1385 S,
1386 col_idx,
1387 row_idx,
1388 q_len,
1389 k_len,
1390 window_size_left,
1391 window_size_right,
1392 is_even_mn=True,
1393 is_causal=False,
1394 is_local=is_local,
1395 )
1397 acc_, P, rowmax_, rowsum_ = softmax_rescale(
1398 acc_,
1399 S,
1400 rowmax_,
1401 rowsum_,
1402 softmax_scale_log2e=scale_softmax_log2,
1403 is_border=is_local,
1404 )
1405 P = P.to(v_ptr.type.element_ty)
1407 if is_dropout:
1408 P = apply_dropout(
1409 P,
1410 m_block * BLOCK_M,
1411 n_block * BLOCK_N,
1412 k_len,
1413 bid,
1414 hid,
1415 philox_seed,
1416 philox_offset,
1417 p_dropout_in_uint8_t,
1418 is_dropout,
1419 encode_dropout_in_sign_bit=False,
1420 NUM_HEADS=h,
1421 BLOCK_M=BLOCK_M,
1422 BLOCK_N=BLOCK_N,
1423 )
1424 acc_ = tl.dot(P, bV, acc_)
1426 # LSE
1427 lse = tl.where(
1428 rowsum_ == 0 | (rowsum_ != rowsum_),
1429 float("inf"),
1430 rowmax_ * scale_softmax + tl.log(rowsum_),
1431 )
1432 inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)
1434 acc_ *= inv_sum[:, None]
1436 out = acc_.to(o_ptr.type.element_ty) # noqa
1438 # Write back output
1439 o_row_offset = hid * o_head_stride
1441 gO = tl.make_block_ptr(
1442 base=o_ptr + o_offset + o_row_offset,
1443 shape=(q_len, d),
1444 strides=(o_row_stride, 1),
1445 offsets=(0, 0),
1446 block_shape=(BLOCK_M, d),
1447 order=(1, 0),
1448 )
1449 tl.store(gO.advance([m_block * BLOCK_M, 0]), out, boundary_check=(0,))
1451 # Write back lse
1452 # lse shape: [h, total_q]
1453 softmax_lse_ptr += hid * total_q
1454 lse_row_offset = lse_offset + m_block * BLOCK_M + tl.arange(0, BLOCK_M)
1455 tl.store(softmax_lse_ptr + lse_row_offset, lse, mask=lse_row_offset < total_q)