Coverage for src/flag_gems/ops/flash_kernel.py: 13%
574 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-27 08:02 +0800
1import triton
2import triton.language as tl
4from flag_gems import runtime
5from flag_gems.utils import libentry, tl_extra_shim
8@triton.jit
9def u64_to_lohi(x):
10 return (x >> 32).to(tl.uint32), (x & 0xFFFFFFFF).to(tl.uint32)
13@triton.jit
14def u64_from_lohi(lo, hi):
15 return hi.to(tl.uint64) << 32 + lo.to(tl.uint64)
18@triton.jit
19def philox_(seed, subsequence, offset):
20 kPhilox10A: tl.constexpr = 0x9E3779B9
21 kPhilox10B: tl.constexpr = 0xBB67AE85
22 k0, k1 = u64_to_lohi(seed.to(tl.uint64))
23 c0, c1 = u64_to_lohi(offset.to(tl.uint64))
24 c2, c3 = u64_to_lohi(subsequence.to(tl.uint64))
26 # pragma unroll
27 kPhiloxSA: tl.constexpr = 0xD2511F53
28 kPhiloxSB: tl.constexpr = 0xCD9E8D57
29 for _ in tl.static_range(6):
30 res0 = kPhiloxSA * c0.to(tl.uint64)
31 res1 = kPhiloxSB * c2.to(tl.uint64)
32 res0_x, res0_y = u64_to_lohi(res0)
33 res1_x, res1_y = u64_to_lohi(res1)
34 c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x
35 k0 += kPhilox10A
36 k1 += kPhilox10B
38 res0 = kPhiloxSA * c0.to(tl.uint64)
39 res1 = kPhiloxSB * c2.to(tl.uint64)
40 res0_x, res0_y = u64_to_lohi(res0)
41 res1_x, res1_y = u64_to_lohi(res1)
42 c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x
44 return c0, c1, c2, c3
47@triton.jit
48def apply_dropout_mask(
49 P,
50 mask,
51 encode_dropout_in_sign_bit: tl.constexpr,
52):
53 if encode_dropout_in_sign_bit:
54 P = tl.where(mask, -P, P)
55 else:
56 P = tl.where(mask, (P * 0).to(P.dtype), P)
57 return P
60@triton.jit
61def apply_dropout(
62 P,
63 row_start,
64 col_start,
65 n_cols,
66 bid,
67 hid,
68 philox_seed,
69 philox_offset,
70 p_dropout_uint8: tl.constexpr,
71 is_dropout: tl.constexpr,
72 encode_dropout_in_sign_bit: tl.constexpr,
73 NUM_HEADS: tl.constexpr,
74 BLOCK_M: tl.constexpr,
75 BLOCK_N: tl.constexpr,
76):
77 if is_dropout:
78 row_start = tl.multiple_of(row_start, BLOCK_M)
79 col_start = tl.multiple_of(col_start, BLOCK_N)
80 row = row_start + tl.arange(0, BLOCK_M)[:, None]
81 # Down scale col_idx by 4
82 col = col_start // 4 + tl.arange(0, BLOCK_N // 4)[None, :]
84 subsequence = row.to(tl.uint64) * n_cols + col.to(tl.uint64)
86 offset = philox_offset + bid * NUM_HEADS + hid
87 offset += subsequence * 0
88 r0, r1, r2, r3 = philox_(philox_seed, subsequence, offset)
90 r = tl.join(tl.join(r0, r1), tl.join(r2, r3)).reshape(BLOCK_M, BLOCK_N)
92 mask = (r & 0xFF) >= p_dropout_uint8
94 P = apply_dropout_mask(
95 P, mask, encode_dropout_in_sign_bit=encode_dropout_in_sign_bit
96 )
97 return P
100@triton.jit
101def apply_alibi(
102 S,
103 col_idx,
104 row_idx,
105 max_seqlen_q,
106 max_seqlen_k,
107 is_causal: tl.constexpr,
108 is_alibi: tl.constexpr,
109 alibi_slope: tl.constexpr = None,
110):
111 if is_alibi:
112 if is_causal:
113 # The row independent alibi bias renders the same attention output
114 # as with the standard alibi because softmax is shift invariant, i.e.,
115 # softmax(A + bias + const) = softamx(A + bias). The following two
116 # biases are no different if causal is true.
117 # bias_1 = [
118 # -4, -3, -2, X, X,
119 # -4, -3, -2, -1, X,
120 # -4, -3, -2, -1, 0,
121 # ]
122 # bias_2 = [
123 # -2, -1, 0, X, X,
124 # -3, -2, -1, 0, X,
125 # -4, -3, -2, -1, 0,
126 # ]
127 bias = alibi_slope * (-max_seqlen_k + 1 + col_idx[None, :]).to(tl.float32)
128 S += bias
129 else:
130 bias = -alibi_slope * tl.abs(
131 col_idx[None, :] - max_seqlen_k + max_seqlen_q - row_idx[:, None]
132 ).to(tl.float32)
133 S += bias
135 return S
138@triton.jit
139def apply_mask(
140 S,
141 col_idx,
142 row_idx,
143 max_seqlen_q,
144 max_seqlen_k,
145 window_size_left,
146 window_size_right,
147 is_even_mn: tl.constexpr,
148 is_causal: tl.constexpr,
149 is_local: tl.constexpr,
150):
151 need_mask = is_causal | is_local | (not is_even_mn)
152 # need_mask: tl.constexpr = is_causal | is_local
153 if need_mask:
154 # Extra care should be taken to void one-off errors: both col_lb and col_rb are inclusive!
155 col_lb = max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left)
156 col_rb = min(
157 max_seqlen_k - 1, row_idx + max_seqlen_k - max_seqlen_q + window_size_right
158 )
160 if is_causal:
161 S = tl.where(col_idx[None, :] > col_rb[:, None], float("-inf"), S)
163 if is_local:
164 S = tl.where(
165 (col_idx[None, :] > col_rb[:, None])
166 | (col_idx[None, :] < col_lb[:, None]),
167 float("-inf"),
168 S,
169 )
171 if (not is_local) & (not is_causal) & (not is_even_mn):
172 S = tl.where(col_idx[None, :] >= max_seqlen_k, float("-inf"), S)
174 return S
177@triton.jit
178def softmax_rescale(
179 O_acc,
180 S,
181 row_max,
182 row_sum,
183 softmax_scale_log2e: tl.constexpr,
184 is_border: tl.constexpr,
185 # is_init: tl.constexpr
186):
187 prev_max = row_max
188 row_max = tl.maximum(row_max, tl.max(S, 1))
190 if is_border:
191 cur_max = tl.where(row_max == float("-inf"), 0, row_max)
192 else:
193 cur_max = row_max
195 p_scale = tl.math.exp2((prev_max - cur_max) * softmax_scale_log2e)
196 row_sum *= p_scale
197 O_acc *= p_scale[:, None]
199 max_scaled = tl.where(row_max == float("-inf"), 0, row_max * softmax_scale_log2e)
200 P = tl.math.exp2(S * softmax_scale_log2e - max_scaled[:, None])
201 row_sum = row_sum + tl.sum(P, 1)
202 return O_acc, P, row_max, row_sum
205@triton.jit
206def apply_softcap(S, softcap, is_softcap: tl.constexpr):
207 if is_softcap:
208 S = tl_extra_shim.tanh(S * softcap)
210 return S
213def block_m_splitkv_heuristic(headdim):
214 return 128 if headdim <= 128 else 64
217def block_n_splitkv_heuristic(headdim):
218 return 64 if headdim <= 64 else 32
221def is_even_mn(M, N, BM, BN, WL, WR):
222 if M % BM == 0 and N % BN == 0:
223 if M % N == 0 or N % M == 0:
224 if (WL == -1 or WL % BN == 0) and (WR == -1 or WR % BN == 0):
225 return True
226 return False
229def block_m_splitkv_heuristic_spec_args(args):
230 return 128 if args["d"] <= 128 else 64
233def block_n_splitkv_heuristic_spec_args(args):
234 return 64 if args["d"] <= 64 else 32
237def is_even_mn_spec_args(args):
238 if (
239 args["seqlen_q"] % args["BLOCK_M"] == 0
240 and args["seqlen_k"] % args["BLOCK_N"] == 0
241 ):
242 if (
243 args["seqlen_q"] % args["seqlen_k"] == 0
244 or args["seqlen_k"] % args["seqlen_q"] == 0
245 ):
246 if (
247 args["window_size_left"] == -1
248 or args["window_size_left"] % args["BLOCK_N"] == 0
249 ) and (
250 args["window_size_right"] == -1
251 or args["window_size_right"] % args["BLOCK_N"] == 0
252 ):
253 return True
254 return False
257def keep(cfg, must_keep=None):
258 BM = cfg.kwargs["BLOCK_M"]
259 BN = cfg.kwargs["BLOCK_N"]
260 w = cfg.num_warps
262 # we always keep configurations in `must_keep`
263 return (BM, BN, w) in ((128, 32, 4), (128, 128, 8)) or (
264 must_keep and cfg in must_keep
265 )
268def prune_fwd_configs(configs, nargs, **kwargs):
269 is_dropout = nargs["is_dropout"]
270 if is_dropout:
271 return list(
272 filter(lambda cfg: cfg.num_warps == 4 and cfg.num_stages < 4, configs)
273 )
274 else:
275 return configs
278def flash_fwd_kernel_heur_block_k(args):
279 return triton.next_power_of_2(args["d"])
282@libentry()
283@triton.autotune(
284 configs=list(filter(keep, runtime.get_tuned_config("attention"))),
285 prune_configs_by={"early_config_prune": prune_fwd_configs},
286 key=["d", "is_dropout"],
287)
288@triton.heuristics(
289 values={
290 "BLOCK_K": flash_fwd_kernel_heur_block_k,
291 "PRE_LOAD_V": lambda args: False,
292 "IS_EVEN_MN": lambda args: is_even_mn(
293 args["seqlen_q"],
294 args["seqlen_k"],
295 args["BLOCK_M"],
296 args["BLOCK_N"],
297 args["window_size_left"],
298 args["window_size_right"],
299 ),
300 }
301)
302@triton.jit(
303 do_not_specialize=["seqlen_q", "seqlen_k", "seqlen_q_rounded", "seqlen_k_rounded"]
304)
305def flash_fwd_kernel(
306 q_ptr,
307 k_ptr,
308 v_ptr,
309 o_ptr,
310 p_ptr,
311 softmax_lse_ptr,
312 q_row_stride,
313 k_row_stride,
314 v_row_stride,
315 q_head_stride,
316 k_head_stride,
317 v_head_stride,
318 o_row_stride,
319 o_head_stride,
320 q_batch_stride,
321 k_batch_stride,
322 v_batch_stride,
323 o_batch_stride,
324 is_cu_seqlens_q,
325 cu_seqlens_q_ptr,
326 is_cu_seqlens_k,
327 cu_seqlens_k_ptr,
328 is_seqused_k,
329 seqused_k_ptr,
330 # sizes
331 b: tl.constexpr,
332 bk: tl.constexpr,
333 h: tl.constexpr,
334 hk: tl.constexpr,
335 h_hk_ratio: tl.constexpr,
336 seqlen_q,
337 seqlen_k,
338 seqlen_q_rounded,
339 seqlen_k_rounded,
340 d: tl.constexpr,
341 d_rounded: tl.constexpr,
342 # scaling factors
343 is_softcap: tl.constexpr,
344 softcap: tl.constexpr,
345 scale_softmax: tl.constexpr,
346 scale_softmax_log2: tl.constexpr,
347 # dropout
348 is_dropout: tl.constexpr,
349 p_dropout: tl.constexpr,
350 rp_dropout: tl.constexpr,
351 p_dropout_in_uint8_t: tl.constexpr,
352 philox_args,
353 return_softmax: tl.constexpr,
354 # causal and swa
355 is_causal: tl.constexpr,
356 is_local: tl.constexpr,
357 window_size_left: tl.constexpr,
358 window_size_right: tl.constexpr,
359 seqlenq_ngroups_swapped: tl.constexpr,
360 is_paged: tl.constexpr,
361 # alibi
362 is_alibi: tl.constexpr,
363 alibi_slopes_ptr,
364 alibi_slopes_batch_stride: tl.constexpr,
365 # block table
366 total_q: tl.constexpr,
367 page_table_ptr,
368 page_table_batch_stride: tl.constexpr,
369 block_size: tl.constexpr,
370 k_page_stride: tl.constexpr,
371 # kernel params
372 IS_EVEN_MN: tl.constexpr,
373 PRE_LOAD_V: tl.constexpr,
374 BLOCK_M: tl.constexpr,
375 BLOCK_N: tl.constexpr,
376 BLOCK_K: tl.constexpr,
377 num_warps: tl.constexpr,
378 num_stages: tl.constexpr,
379):
380 m_block = tl.program_id(0)
381 bh = tl.program_id(1)
382 hid = bh % h
383 bid = bh // h
384 num_m_blocks = tl.cdiv(seqlen_q, BLOCK_M)
386 # We draw a minimum covering frame on the attention map that this CTA is assigned to process.
387 # The frame edges are rounded to multiples of BLOCK_M and BLOCK_N for rows and columns respectively.
389 col_min = 0
390 if is_local:
391 col_min = max(0, m_block * BLOCK_M + seqlen_k - seqlen_q - window_size_left)
392 if not IS_EVEN_MN:
393 # round left
394 col_min = (col_min // BLOCK_N) * BLOCK_N
396 col_max = seqlen_k
397 if is_causal or is_local:
398 col_max += (m_block - num_m_blocks + 1) * BLOCK_M
399 if is_local:
400 col_max += window_size_right
401 col_max = min(seqlen_k, col_max)
403 if not IS_EVEN_MN:
404 # round right
405 col_max = tl.cdiv(col_max, BLOCK_N) * BLOCK_N
407 if (not is_causal) and (not is_local):
408 if IS_EVEN_MN:
409 masking_cols: tl.constexpr = 0
410 else:
411 masking_cols: tl.constexpr = BLOCK_N
412 elif (
413 is_causal | is_local
414 ) and IS_EVEN_MN: # causal implies window_size_right is zero
415 masking_cols: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) * BLOCK_N
416 else:
417 # local
418 masking_cols: tl.constexpr = (tl.cdiv(BLOCK_M, BLOCK_N) + 1) * BLOCK_N
420 if is_dropout:
421 philox_seed = tl.load(philox_args).to(tl.uint64)
422 philox_offset = tl.load(philox_args + 1).to(tl.uint64)
424 if is_alibi:
425 alibi_offset = bid * alibi_slopes_batch_stride + hid
426 alibi_slope = tl.load(alibi_slopes_ptr + alibi_offset)
427 alibi_slope /= scale_softmax
428 else:
429 alibi_slope = 0.0
431 q_batch_stride = tl.multiple_of(q_batch_stride, d * h)
432 q_ptr += bid * q_batch_stride + hid * q_head_stride
433 row_start = m_block * BLOCK_M
434 row_idx = row_start + tl.arange(0, BLOCK_M)
435 q_off = row_idx[:, None] * q_row_stride + tl.arange(0, BLOCK_K)[None, :]
436 dmask = tl.arange(0, BLOCK_K) < d
437 qmask = dmask[None, :] & (row_idx[:, None] < seqlen_q)
438 if IS_EVEN_MN & d == BLOCK_K:
439 Q = tl.load(q_ptr + q_off, cache_modifier=".cg")
440 else:
441 Q = tl.load(q_ptr + q_off, mask=qmask, cache_modifier=".cg")
443 if return_softmax:
444 p_ptr += (
445 (bid * h + hid) * seqlen_q_rounded + m_block * BLOCK_M
446 ) * seqlen_k_rounded
447 p_offset = tl.arange(0, BLOCK_M)[:, None] * seqlen_k_rounded + tl.arange(
448 0, BLOCK_N
449 )
450 p_bp0 = p_ptr + p_offset
452 acc_ = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)
453 rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
454 rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
456 k_batch_stride = tl.multiple_of(k_batch_stride, d * hk)
457 h_hk_ratio = h // hk
458 k_ptr += bid * k_batch_stride
459 k_ptr += (hid // h_hk_ratio) * k_head_stride
460 v_ptr += bid * k_batch_stride
461 v_ptr += (hid // h_hk_ratio) * k_head_stride
463 k_offset = (
464 tl.arange(0, BLOCK_N)[None, :] * k_row_stride + tl.arange(0, BLOCK_K)[:, None]
465 )
466 v_offset = (
467 tl.arange(0, BLOCK_N)[:, None] * k_row_stride + tl.arange(0, BLOCK_K)[None, :]
468 )
470 p_bk0 = k_ptr + k_offset
471 p_bv0 = v_ptr + v_offset
473 if is_causal | is_local | (not IS_EVEN_MN):
474 # Cut short masking cols if there's not enough cols out there
475 masking_cols = min(col_max - col_min, masking_cols)
476 for col_shift in tl.range(0, masking_cols, step=BLOCK_N):
477 col_start = col_max - col_shift - BLOCK_N
478 col_start = tl.multiple_of(col_start, BLOCK_N)
479 off = col_start * k_row_stride
480 if IS_EVEN_MN & d == BLOCK_K:
481 K = tl.load(p_bk0 + off, cache_modifier=".cg")
482 if PRE_LOAD_V:
483 V = tl.load(p_bv0 + off, cache_modifier=".cg")
484 elif d == BLOCK_K:
485 col_idx = col_start + tl.arange(0, BLOCK_N)
486 kvmask = col_idx < seqlen_k
487 K = tl.load(p_bk0 + off, mask=kvmask[None, :], cache_modifier=".cg")
488 if PRE_LOAD_V:
489 V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg")
490 else:
491 col_idx = col_start + tl.arange(0, BLOCK_N)
492 kvmask = col_idx < seqlen_k
493 K = tl.load(
494 p_bk0 + off,
495 mask=kvmask[None, :] & dmask[:, None],
496 cache_modifier=".cg",
497 )
498 if PRE_LOAD_V:
499 V = tl.load(
500 p_bv0 + off,
501 mask=kvmask[:, None] & dmask[None, :],
502 cache_modifier=".cg",
503 )
504 S = tl.dot(Q, K, allow_tf32=False)
505 S = apply_softcap(S, softcap, is_softcap)
506 col_idx = col_start + tl.arange(0, BLOCK_N)
507 row_idx = row_start + tl.arange(0, BLOCK_M)
508 S = apply_alibi(
509 S,
510 col_idx,
511 row_idx,
512 seqlen_q,
513 seqlen_k,
514 is_causal=is_causal,
515 is_alibi=is_alibi,
516 alibi_slope=alibi_slope,
517 )
518 # tl.store(p_bp0 + col_start, S)
519 S = apply_mask(
520 S,
521 col_idx,
522 row_idx,
523 seqlen_q,
524 seqlen_k,
525 window_size_left,
526 window_size_right,
527 is_even_mn=IS_EVEN_MN,
528 is_causal=is_causal,
529 is_local=is_local,
530 )
532 acc_, P, rowmax_, rowsum_ = softmax_rescale(
533 acc_,
534 S,
535 rowmax_,
536 rowsum_,
537 softmax_scale_log2e=scale_softmax_log2,
538 is_border=(is_causal or is_local),
539 )
540 P = P.to(v_ptr.type.element_ty)
542 if is_dropout:
543 if return_softmax:
544 P_drop = P
546 P_drop = apply_dropout(
547 P_drop,
548 row_start,
549 col_start,
550 seqlen_k,
551 bid,
552 hid,
553 philox_seed,
554 philox_offset,
555 p_dropout_in_uint8_t,
556 is_dropout,
557 encode_dropout_in_sign_bit=True,
558 NUM_HEADS=h,
559 BLOCK_M=BLOCK_M,
560 BLOCK_N=BLOCK_N,
561 )
562 if IS_EVEN_MN:
563 tl.store(p_bp0 + col_start, P_drop)
564 else:
565 kvmask = col_idx < seqlen_k
566 tl.store(
567 p_bp0 + col_start, P_drop, mask=qmask & kvmask[None, :]
568 )
570 P = apply_dropout(
571 P,
572 row_start,
573 col_start,
574 seqlen_k,
575 bid,
576 hid,
577 philox_seed,
578 philox_offset,
579 p_dropout_in_uint8_t,
580 is_dropout,
581 encode_dropout_in_sign_bit=False,
582 NUM_HEADS=h,
583 BLOCK_M=BLOCK_M,
584 BLOCK_N=BLOCK_N,
585 )
587 if not PRE_LOAD_V:
588 off = col_start * k_row_stride
589 if IS_EVEN_MN & d == BLOCK_K:
590 V = tl.load(p_bv0 + off, cache_modifier=".cg")
591 elif d == BLOCK_K:
592 kvmask = col_idx < seqlen_k
593 V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg")
594 else:
595 kvmask = col_idx < seqlen_k
596 V = tl.load(
597 p_bv0 + off,
598 mask=kvmask[:, None] & dmask[None, :],
599 cache_modifier=".cg",
600 )
601 acc_ = tl.dot(P, V, acc_, allow_tf32=False)
603 for col_start in tl.range(
604 col_min, col_max - masking_cols, step=BLOCK_N, num_stages=num_stages
605 ):
606 col_start = tl.multiple_of(col_start, BLOCK_N)
607 off = col_start * k_row_stride
608 if d == BLOCK_K:
609 K = tl.load(p_bk0 + off, cache_modifier=".cg")
610 if PRE_LOAD_V:
611 V = tl.load(p_bv0 + off, cache_modifier=".cg")
612 else:
613 K = tl.load(p_bk0 + off, mask=dmask[:, None], cache_modifier=".cg")
614 if PRE_LOAD_V:
615 V = tl.load(p_bv0 + off, mask=dmask[None, :], cache_modifier=".cg")
617 S = tl.dot(Q, K)
618 S = apply_softcap(S, softcap, is_softcap)
619 col_idx = col_start + tl.arange(0, BLOCK_N)
620 row_idx = row_start + tl.arange(0, BLOCK_M)
621 S = apply_alibi(
622 S,
623 col_idx,
624 row_idx,
625 seqlen_q,
626 seqlen_k,
627 is_causal=is_causal,
628 is_alibi=is_alibi,
629 alibi_slope=alibi_slope,
630 )
631 S = apply_mask(
632 S,
633 col_idx,
634 row_idx,
635 seqlen_q,
636 seqlen_k,
637 window_size_left,
638 window_size_right,
639 is_even_mn=True,
640 is_causal=False,
641 is_local=is_local,
642 )
644 acc_, P, rowmax_, rowsum_ = softmax_rescale(
645 acc_,
646 S,
647 rowmax_,
648 rowsum_,
649 softmax_scale_log2e=scale_softmax_log2,
650 is_border=is_local,
651 )
652 P = P.to(v_ptr.type.element_ty)
654 if is_dropout:
655 if return_softmax:
656 P_drop = P
657 P_drop = apply_dropout(
658 P_drop,
659 row_start,
660 col_start,
661 seqlen_k,
662 bid,
663 hid,
664 philox_seed,
665 philox_offset,
666 p_dropout_in_uint8_t,
667 is_dropout,
668 encode_dropout_in_sign_bit=True,
669 NUM_HEADS=h,
670 BLOCK_M=BLOCK_M,
671 BLOCK_N=BLOCK_N,
672 )
673 if IS_EVEN_MN:
674 tl.store(p_bp0 + col_start, P_drop)
675 else:
676 kvmask = col_idx < seqlen_k
677 tl.store(p_bp0 + col_start, P_drop, mask=qmask & kvmask[None, :])
679 P = apply_dropout(
680 P,
681 row_start,
682 col_start,
683 seqlen_k,
684 bid,
685 hid,
686 philox_seed,
687 philox_offset,
688 p_dropout_in_uint8_t,
689 is_dropout,
690 encode_dropout_in_sign_bit=False,
691 NUM_HEADS=h,
692 BLOCK_M=BLOCK_M,
693 BLOCK_N=BLOCK_N,
694 )
696 if not PRE_LOAD_V:
697 off = col_start * k_row_stride
698 if d == BLOCK_K:
699 V = tl.load(p_bv0 + off, cache_modifier=".cg")
700 else:
701 V = tl.load(p_bv0 + off, mask=dmask[None, :], cache_modifier=".cg")
702 acc_ = tl.dot(P, V, acc_)
704 # LSE
705 # Note, rowsum = exp(-rowmax) * exp(lse), therefore rowmax + log(rowsum) cancels
706 # the effect of rowmax and outputs lse only.
707 lse = tl.where(
708 rowsum_ == 0 | (rowsum_ != rowsum_),
709 float("inf"),
710 rowmax_ * scale_softmax + tl.log(rowsum_),
711 )
712 inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)
714 if is_dropout:
715 acc_ *= inv_sum[:, None] * rp_dropout
716 else:
717 acc_ *= inv_sum[:, None]
719 out = acc_.to(o_ptr.type.element_ty) # noqa
721 # Write back output
722 o_batch_stride = tl.multiple_of(o_batch_stride, d * h)
723 o_ptr += bid * o_batch_stride
724 o_ptr += hid * o_head_stride
725 o_offset = row_idx[:, None] * o_row_stride + tl.arange(0, BLOCK_K)
727 if IS_EVEN_MN & d == BLOCK_K:
728 tl.store(o_ptr + o_offset, out)
729 else:
730 tl.store(o_ptr + o_offset, out, mask=qmask)
732 # Write back lse
733 p_lse = softmax_lse_ptr + (bid * h + hid) * seqlen_q
734 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
736 if IS_EVEN_MN:
737 tl.store(p_lse + row_idx, lse)
738 else:
739 tl.store(p_lse + row_idx, lse, mask=row_idx < seqlen_q)
742@triton.jit(do_not_specialize=["seqlen_q", "seqlen_k"])
743def flash_fwd_bh_parallel_kernel():
744 # (TODO)
745 pass
748def flash_fwd_splitkv_kernel_heur_block_k(args):
749 return triton.next_power_of_2(args["d"])
752@libentry()
753@triton.heuristics(
754 values={
755 "BLOCK_M": block_m_splitkv_heuristic_spec_args,
756 "BLOCK_N": block_n_splitkv_heuristic_spec_args,
757 "BLOCK_K": flash_fwd_splitkv_kernel_heur_block_k,
758 "num_warps": lambda args: 4,
759 "num_stages": lambda args: 3,
760 "PRE_LOAD_V": lambda args: True,
761 "IS_EVEN_MN": is_even_mn_spec_args,
762 }
763)
764@triton.jit(
765 do_not_specialize=["seqlen_q", "seqlen_k", "seqlen_q_rounded", "seqlen_k_rounded"]
766)
767def flash_fwd_splitkv_kernel(
768 q_ptr,
769 k_ptr,
770 v_ptr,
771 o_ptr,
772 p_ptr,
773 softmax_lse_ptr,
774 q_row_stride,
775 k_row_stride,
776 v_row_stride,
777 q_head_stride,
778 k_head_stride,
779 v_head_stride,
780 o_row_stride,
781 o_head_stride,
782 q_batch_stride,
783 k_batch_stride,
784 v_batch_stride,
785 o_batch_stride,
786 is_cu_seqlens_q,
787 cu_seqlens_q_ptr,
788 is_cu_seqlens_k: tl.constexpr,
789 cu_seqlens_k_ptr,
790 is_seqused_k: tl.constexpr,
791 seqused_k_ptr,
792 # sizes
793 b: tl.constexpr,
794 bk: tl.constexpr,
795 h: tl.constexpr,
796 hk: tl.constexpr,
797 h_hk_ratio: tl.constexpr,
798 seqlen_q,
799 seqlen_k,
800 seqlen_q_rounded,
801 seqlen_k_rounded,
802 d: tl.constexpr,
803 d_rounded: tl.constexpr,
804 # scaling factors
805 is_softcap: tl.constexpr,
806 softcap: tl.constexpr,
807 scale_softmax: tl.constexpr,
808 scale_softmax_log2: tl.constexpr,
809 # dropout
810 is_dropout: tl.constexpr,
811 p_dropout: tl.constexpr,
812 rp_dropout: tl.constexpr,
813 p_dropout_in_uint8_t: tl.constexpr,
814 philox_args,
815 return_softmax: tl.constexpr,
816 # causal and swa
817 is_causal: tl.constexpr,
818 is_local: tl.constexpr,
819 window_size_left: tl.constexpr,
820 window_size_right: tl.constexpr,
821 seqlenq_ngroups_swapped: tl.constexpr,
822 is_paged: tl.constexpr,
823 # alibi
824 is_alibi: tl.constexpr,
825 alibi_slopes_ptr,
826 alibi_slopes_batch_stride: tl.constexpr,
827 # block table
828 total_q,
829 page_table_ptr,
830 page_table_batch_stride: tl.constexpr,
831 block_size: tl.constexpr,
832 k_page_stride: tl.constexpr,
833 # kernel params
834 IS_EVEN_MN: tl.constexpr,
835 PRE_LOAD_V: tl.constexpr,
836 blocks_per_split: tl.constexpr,
837 BLOCK_M: tl.constexpr,
838 BLOCK_N: tl.constexpr,
839 BLOCK_K: tl.constexpr,
840 num_warps: tl.constexpr,
841 num_stages: tl.constexpr,
842):
843 m_block = tl.program_id(0)
844 split_id = tl.program_id(1)
845 bid = tl.program_id(2) // h
846 hid = tl.program_id(2) % h
848 split_block_min = split_id * blocks_per_split
849 split_block_max = split_block_min + blocks_per_split
851 n_block_max = tl.cdiv(seqlen_k, BLOCK_N)
852 if is_causal:
853 n_block_max = min(
854 n_block_max,
855 tl.cdiv(
856 (m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + window_size_right,
857 BLOCK_N,
858 ),
859 )
861 if is_alibi:
862 alibi_offset = bid * alibi_slopes_batch_stride + hid
863 alibi_slope = tl.load(alibi_slopes_ptr + alibi_offset)
864 alibi_slope /= scale_softmax
865 else:
866 alibi_slope = 0
868 if not is_causal:
869 if IS_EVEN_MN:
870 masking_block_min = n_block_max
871 else:
872 masking_block_min = n_block_max - 1
873 elif is_causal and IS_EVEN_MN: # causal implies window_size_right is zero
874 masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N)
875 else:
876 masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N) - 1
878 q_ptr += bid * q_batch_stride
879 q_ptr += hid * q_head_stride
880 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
881 q_off = row_idx[:, None] * q_row_stride + tl.arange(0, BLOCK_K)[None, :]
882 p_qm = q_ptr + q_off
883 dmask = tl.arange(0, BLOCK_K) < d
884 qmask = dmask[None, :] & (row_idx[:, None] < seqlen_q)
885 if IS_EVEN_MN & BLOCK_K == d:
886 Q = tl.load(p_qm, cache_modifier=".cg")
887 else:
888 Q = tl.load(p_qm, mask=qmask, cache_modifier=".cg")
890 h_hk_ratio = h // hk
891 k_ptr += bid * k_batch_stride
892 k_ptr += (hid // h_hk_ratio) * k_head_stride
893 v_ptr += bid * k_batch_stride
894 v_ptr += (hid // h_hk_ratio) * k_head_stride
896 k_offset = (
897 tl.arange(0, BLOCK_N)[None, :] * k_row_stride + tl.arange(0, BLOCK_K)[:, None]
898 )
899 p_k0 = k_ptr + k_offset
901 v_offset = (
902 tl.arange(0, BLOCK_N)[:, None] * k_row_stride + tl.arange(0, BLOCK_K)[None, :]
903 )
904 p_v0 = v_ptr + v_offset
906 acc_ = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)
907 rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
908 rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
910 if split_block_max <= masking_block_min:
911 # no masking needed
912 for n_block in tl.range(
913 split_block_min, split_block_max, num_stages=num_stages
914 ):
915 kv_off = n_block * BLOCK_N * k_row_stride
916 if d == BLOCK_K:
917 K = tl.load(p_k0 + kv_off, cache_modifier=".cg")
918 else:
919 K = tl.load(
920 p_k0 + kv_off, mask=dmask[:, None], cache_modifier=".cg", other=0.0
921 )
922 if PRE_LOAD_V:
923 if d == BLOCK_K:
924 V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
925 else:
926 V = tl.load(
927 p_v0 + kv_off,
928 mask=dmask[None, :],
929 cache_modifier=".cg",
930 other=0.0,
931 )
932 S = tl.dot(Q, K)
933 S = apply_softcap(S, softcap, is_softcap)
934 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
935 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
936 S = apply_alibi(
937 S,
938 col_idx,
939 row_idx,
940 seqlen_q,
941 seqlen_k,
942 is_causal=is_causal,
943 is_alibi=is_alibi,
944 alibi_slope=alibi_slope,
945 )
946 acc_, P, rowmax_, rowsum_ = softmax_rescale(
947 acc_,
948 S,
949 rowmax_,
950 rowsum_,
951 softmax_scale_log2e=scale_softmax_log2,
952 is_border=False,
953 )
955 if not PRE_LOAD_V:
956 if d == BLOCK_K:
957 V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
958 else:
959 V = tl.load(
960 p_v0 + kv_off,
961 mask=dmask[None, :],
962 cache_modifier=".cg",
963 other=0.0,
964 )
965 P = P.to(v_ptr.type.element_ty)
966 acc_ = tl.dot(P, V, acc_)
967 else:
968 for n_block in tl.range(split_block_min, min(split_block_max, n_block_max)):
969 kv_off = n_block * BLOCK_N * k_row_stride
970 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
971 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
972 if IS_EVEN_MN & d == BLOCK_K:
973 K = tl.load(p_k0 + kv_off, cache_modifier=".cg")
974 if PRE_LOAD_V:
975 V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
976 elif d == BLOCK_K:
977 kvmask = col_idx < seqlen_k
978 K = tl.load(p_k0 + kv_off, mask=kvmask[None, :], cache_modifier=".cg")
979 if PRE_LOAD_V:
980 V = tl.load(
981 p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg"
982 )
983 else:
984 kvmask = col_idx < seqlen_k
985 K = tl.load(
986 p_k0 + kv_off,
987 mask=dmask[:, None] & kvmask[None, :],
988 cache_modifier=".cg",
989 other=0.0,
990 )
991 if PRE_LOAD_V:
992 V = tl.load(
993 p_v0 + kv_off,
994 mask=dmask[None, :] & kvmask[:, None],
995 cache_modifier=".cg",
996 other=0.0,
997 )
999 S = tl.dot(Q, K)
1000 S = apply_softcap(S, softcap, is_softcap)
1001 S = apply_alibi(
1002 S,
1003 col_idx,
1004 row_idx,
1005 seqlen_q,
1006 seqlen_k,
1007 is_causal=is_causal,
1008 is_alibi=is_alibi,
1009 alibi_slope=alibi_slope,
1010 )
1011 S = apply_mask(
1012 S,
1013 col_idx,
1014 row_idx,
1015 seqlen_q,
1016 seqlen_k,
1017 window_size_left,
1018 window_size_right,
1019 is_even_mn=IS_EVEN_MN,
1020 is_causal=is_causal,
1021 is_local=False,
1022 )
1024 acc_, P, rowmax_, rowsum_ = softmax_rescale(
1025 acc_,
1026 S,
1027 rowmax_,
1028 rowsum_,
1029 softmax_scale_log2e=scale_softmax_log2,
1030 is_border=(is_causal or is_local),
1031 )
1033 if not PRE_LOAD_V:
1034 if IS_EVEN_MN & d == BLOCK_K:
1035 V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
1036 elif d == BLOCK_K:
1037 V = tl.load(
1038 p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg"
1039 )
1040 else:
1041 V = tl.load(
1042 p_v0 + kv_off,
1043 mask=dmask[None, :] & kvmask[:, None],
1044 cache_modifier=".cg",
1045 other=0.0,
1046 )
1047 P = P.to(v_ptr.type.element_ty)
1048 acc_ = tl.dot(P, V, acc_)
1050 # LSE
1051 lse = tl.where(
1052 rowsum_ == 0 | (rowsum_ != rowsum_),
1053 float("-inf"),
1054 rowmax_ * scale_softmax + tl.log(rowsum_),
1055 )
1056 inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)
1058 # Rescale output
1059 acc_ *= inv_sum[:, None]
1061 # Write back output
1062 # o_splits layout = (n_splits, batch_size, num_heads, seqlen_q, head_size)
1063 # grid = (seq_block, split, batch * head)
1064 o_split_ptr = o_ptr
1065 # + split, batch, head offsets, seq_block offsets are already added in row_idx
1066 o_split_ptr += (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q * d
1067 o_split_offset = row_idx[:, None] * d + tl.arange(0, BLOCK_K)
1068 o_split_ptr = tl.multiple_of(o_split_ptr, d)
1069 p_om = o_split_ptr + o_split_offset
1071 if IS_EVEN_MN & BLOCK_K == d:
1072 tl.store(p_om, acc_, cache_modifier=".cg")
1073 else:
1074 tl.store(p_om, acc_, mask=qmask, cache_modifier=".cg")
1076 # Write back lse
1077 # lse_splits layout = (n_splits, batch_size, num_heads, seqlen_q)
1078 lse_split_ptr = softmax_lse_ptr
1079 # + split, batch, head, seq_block offsets
1080 lse_split_ptr += (
1081 split_id * tl.num_programs(2) + tl.program_id(2)
1082 ) * seqlen_q + m_block * BLOCK_M
1084 if IS_EVEN_MN:
1085 tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse, cache_modifier=".cg")
1086 else:
1087 tl.store(
1088 lse_split_ptr + tl.arange(0, BLOCK_M),
1089 lse,
1090 mask=row_idx < seqlen_q,
1091 cache_modifier=".cg",
1092 )
1095@libentry()
1096@triton.jit
1097def flash_fwd_splitkv_combine_kernel(
1098 out_ptr,
1099 lse_ptr,
1100 out_splits_ptr,
1101 lse_splits_ptr,
1102 head_size: tl.constexpr,
1103 out_split_stride,
1104 lse_split_stride,
1105 out_b_stride,
1106 out_s_stride,
1107 out_h_stride,
1108 n_splits,
1109 BLOCK_M: tl.constexpr,
1110 BLOCK_K: tl.constexpr,
1111 q_total,
1112 MAX_N_SPLITS: tl.constexpr,
1113):
1114 pid = tl.program_id(0)
1115 lse_splits_ptr += pid * BLOCK_M
1116 lse_ptr += pid * BLOCK_M
1117 out_splits_ptr += pid * BLOCK_M * head_size
1118 out_ptr += pid * BLOCK_M * head_size
1120 # Subtracting maximum from each of the split lse's for better numerical stability
1121 lse_split_offset = (
1122 tl.arange(0, BLOCK_M)[:, None]
1123 + tl.arange(0, MAX_N_SPLITS)[None, :] * lse_split_stride
1124 )
1125 lse_split_mask = (pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] < q_total) & (
1126 tl.arange(0, MAX_N_SPLITS)[None, :] < n_splits
1127 )
1128 lse_splits = tl.load(
1129 lse_splits_ptr + lse_split_offset, mask=lse_split_mask, other=float("-inf")
1130 )
1131 max_lse = tl.max(lse_splits, 1)
1133 # Sum exp(lse(i) - max_lse) over all split i to obtain Z=sumexp(QK) up to a scaled factor exp(-max_lse)
1134 Zi_scaled = tl.exp(lse_splits - max_lse[:, None])
1135 Z_scaled = tl.sum(Zi_scaled, 1)
1136 Zi_Z = Zi_scaled / Z_scaled[:, None]
1138 # Write back LSE
1139 lse = tl.log(Z_scaled) + max_lse
1140 out_mask = pid * BLOCK_M + tl.arange(0, BLOCK_M) < q_total
1141 tl.store(lse_ptr + tl.arange(0, BLOCK_M), lse, mask=out_mask)
1143 out_split_offset = (
1144 tl.arange(0, BLOCK_M)[:, None, None] * head_size
1145 + tl.arange(0, MAX_N_SPLITS)[None, :, None] * out_split_stride
1146 + tl.arange(0, BLOCK_K)[None, None, :]
1147 )
1148 out_split_mask = (
1149 (pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None, None] < q_total)
1150 & (tl.arange(0, MAX_N_SPLITS)[None, :, None] < n_splits)
1151 & (tl.arange(0, BLOCK_K)[None, None, :] < head_size)
1152 )
1153 out_splits = tl.load(
1154 out_splits_ptr + out_split_offset, mask=out_split_mask, other=0.0
1155 )
1156 out = tl.sum(Zi_Z[:, :, None] * out_splits, 1)
1157 out = out.to(out_ptr.type.element_ty)
1159 # Write back output
1160 out_offset = tl.arange(0, BLOCK_M)[:, None] * out_s_stride + tl.arange(0, BLOCK_K)
1161 dmask = tl.arange(0, BLOCK_K) < head_size
1162 tl.store(out_ptr + out_offset, out, mask=out_mask[:, None] & dmask[None, :])
1165@triton.jit
1166def virtual_to_cache_offset(
1167 virtual_index,
1168 max_virtual_index,
1169 page_table_ptr,
1170 block_size,
1171 k_row_stride,
1172 k_page_stride,
1173 boundary_check: tl.constexpr = False,
1174):
1175 # virtual_index is the kv sequence index in the current batch element
1176 # page_table_ptr is already pointed at current batch element's block table entry
1177 # block_size is the size of each block in the page table
1178 virtual_page_index = virtual_index // block_size
1179 page_offset = virtual_index % block_size
1180 if boundary_check:
1181 page_block_index = tl.load(
1182 page_table_ptr + virtual_page_index,
1183 mask=virtual_index < max_virtual_index,
1184 other=0,
1185 ).to(tl.int64)
1186 else:
1187 page_block_index = tl.load(page_table_ptr + virtual_page_index).to(tl.int64)
1188 return page_block_index * k_page_stride + page_offset * k_row_stride
1191@triton.jit
1192def load_from_kvcache(
1193 virtual_index,
1194 max_virtual_index,
1195 page_table_ptr,
1196 k_ptr_base,
1197 v_ptr_base,
1198 block_size,
1199 d: tl.constexpr,
1200 k_row_stride,
1201 BLOCK_K: tl.constexpr,
1202 k_page_stride=0,
1203 boundary_check: tl.constexpr = False,
1204):
1205 cache_offset = virtual_to_cache_offset(
1206 virtual_index,
1207 max_virtual_index,
1208 page_table_ptr,
1209 block_size,
1210 k_row_stride,
1211 k_page_stride,
1212 boundary_check,
1213 )
1214 k_offset = tl.arange(0, BLOCK_K)[:, None] + cache_offset[None, :]
1215 v_offset = tl.arange(0, BLOCK_K)[None, :] + cache_offset[:, None]
1216 if d == BLOCK_K:
1217 bK_mask = virtual_index[None, :] < max_virtual_index[None, :]
1218 bV_mask = virtual_index[:, None] < max_virtual_index[:, None]
1219 bK = tl.load(k_ptr_base + k_offset, mask=bK_mask, other=0.0)
1220 bV = tl.load(v_ptr_base + v_offset, mask=bV_mask, other=0.0)
1221 else:
1222 bK_mask = (tl.arange(0, BLOCK_K)[:, None] < d) & (
1223 virtual_index[None, :] < max_virtual_index[None, :]
1224 )
1225 bV_mask = (tl.arange(0, BLOCK_K)[None, :] < d) & (
1226 virtual_index[:, None] < max_virtual_index[:, None]
1227 )
1228 bK = tl.load(k_ptr_base + k_offset, mask=bK_mask, other=0.0)
1229 bV = tl.load(v_ptr_base + v_offset, mask=bV_mask, other=0.0)
1230 return bK, bV
1233@libentry()
1234@triton.jit(
1235 do_not_specialize=[
1236 "q_batch_stride",
1237 "k_batch_stride",
1238 "v_batch_stride",
1239 "o_batch_stride",
1240 "b",
1241 "bk",
1242 "seqlen_q",
1243 "seqlen_k",
1244 "seqlen_q_rounded",
1245 "seqlen_k_rounded",
1246 "total_q",
1247 "k_page_stride",
1248 ]
1249)
1250def flash_varlen_fwd_kernel(
1251 q_ptr,
1252 k_ptr,
1253 v_ptr,
1254 o_ptr,
1255 p_ptr,
1256 softmax_lse_ptr,
1257 q_row_stride,
1258 k_row_stride,
1259 v_row_stride,
1260 q_head_stride,
1261 k_head_stride,
1262 v_head_stride,
1263 o_row_stride,
1264 o_head_stride,
1265 q_batch_stride,
1266 k_batch_stride,
1267 v_batch_stride,
1268 o_batch_stride,
1269 is_cu_seqlens_q: tl.constexpr,
1270 cu_seqlens_q_ptr,
1271 is_cu_seqlens_k: tl.constexpr,
1272 cu_seqlens_k_ptr,
1273 is_seqused_k: tl.constexpr,
1274 seqused_k_ptr,
1275 # sizes
1276 b,
1277 bk,
1278 h: tl.constexpr,
1279 hk: tl.constexpr,
1280 h_hk_ratio: tl.constexpr,
1281 seqlen_q,
1282 seqlen_k,
1283 seqlen_q_rounded,
1284 seqlen_k_rounded,
1285 d: tl.constexpr,
1286 d_rounded: tl.constexpr,
1287 # scaling factors
1288 is_softcap: tl.constexpr,
1289 softcap: tl.constexpr,
1290 scale_softmax: tl.constexpr,
1291 scale_softmax_log2: tl.constexpr,
1292 # dropout
1293 is_dropout: tl.constexpr,
1294 p_dropout: tl.constexpr,
1295 rp_dropout: tl.constexpr,
1296 p_dropout_in_uint8_t: tl.constexpr,
1297 philox_args,
1298 return_softmax: tl.constexpr,
1299 # causal and swa
1300 is_causal: tl.constexpr,
1301 is_local: tl.constexpr,
1302 window_size_left: tl.constexpr,
1303 window_size_right: tl.constexpr,
1304 seqlenq_ngroups_swapped: tl.constexpr,
1305 is_paged: tl.constexpr,
1306 # alibi
1307 is_alibi: tl.constexpr,
1308 alibi_slopes_ptr,
1309 alibi_slopes_batch_stride: tl.constexpr,
1310 # block table
1311 total_q,
1312 page_table_ptr,
1313 page_table_batch_stride: tl.constexpr,
1314 block_size: tl.constexpr,
1315 k_page_stride,
1316 # kernel params
1317 BLOCK_M: tl.constexpr,
1318 BLOCK_N: tl.constexpr,
1319 BLOCK_K: tl.constexpr,
1320 num_warps: tl.constexpr,
1321 num_stages: tl.constexpr,
1322):
1323 m_block = tl.program_id(0)
1324 bid = tl.program_id(1)
1325 hid = tl.program_id(2)
1326 # num_m_blocks = tl.cdiv(seqlen_q, BLOCK_M)
1328 if is_cu_seqlens_q:
1329 q_eos = tl.load(cu_seqlens_q_ptr + bid + 1).to(tl.int32)
1330 q_bos = tl.load(cu_seqlens_q_ptr + bid).to(tl.int32)
1331 q_len = q_eos - q_bos
1332 # Current request's start offset in the batched Q
1333 q_offset = q_bos * q_row_stride
1334 o_offset = q_bos * o_row_stride
1335 lse_offset = q_bos * 1
1336 else:
1337 q_len = seqlen_q
1338 q_offset = bid * q_batch_stride
1339 o_offset = bid * o_batch_stride
1340 lse_offset = bid * seqlen_q
1342 if is_cu_seqlens_k:
1343 k_eos = tl.load(cu_seqlens_k_ptr + bid + 1).to(tl.int32)
1344 k_bos = tl.load(cu_seqlens_k_ptr + bid).to(tl.int32)
1345 k_len_cache = k_eos - k_bos
1346 # k_offset = k_bos * k_row_stride
1347 else:
1348 k_len_cache = seqlen_k
1349 # k_offset = bid * k_batch_stride
1351 if is_seqused_k:
1352 k_len = tl.load(seqused_k_ptr + bid).to(tl.int32)
1353 else:
1354 k_len = k_len_cache
1356 # Noop CTA
1357 if m_block * BLOCK_M > q_len:
1358 return
1360 # is_even_mn = (q_len % BLOCK_M == 0) and (k_len % BLOCK_N == 0)
1361 is_even_mn: tl.constexpr = False
1363 if is_local:
1364 n_block_min = max(
1365 0, (m_block * BLOCK_M + k_len - q_len - window_size_left) // BLOCK_N
1366 )
1367 else:
1368 n_block_min = 0
1370 n_block_max = tl.cdiv(k_len, BLOCK_N)
1371 if is_causal or is_local:
1372 n_block_max = min(
1373 n_block_max,
1374 tl.cdiv(
1375 (m_block + 1) * BLOCK_M + k_len - q_len + window_size_right, BLOCK_N
1376 ),
1377 )
1379 if is_dropout:
1380 philox_seed = tl.load(philox_args).to(tl.uint64)
1381 philox_offset = tl.load(philox_args + 1).to(tl.uint64)
1383 # Locate the page table entry for the current batch element
1384 if is_paged:
1385 page_table_ptr += bid * page_table_batch_stride
1386 # Calculate the starting offset of q for the current head
1387 q_row_offset = hid * q_head_stride
1388 # Calculate the starting offset of k and v for the current head
1389 k_row_offset = (hid // h_hk_ratio) * k_head_stride
1390 # Shift the k, v pointers to align with the current head
1391 k_ptr_base = k_ptr + k_row_offset
1392 v_ptr_base = v_ptr + k_row_offset
1394 gQ = tl.make_block_ptr(
1395 base=q_ptr + q_offset + q_row_offset,
1396 shape=(q_len, d),
1397 strides=(q_row_stride, 1),
1398 offsets=(0, 0),
1399 block_shape=(BLOCK_M, BLOCK_K),
1400 order=(1, 0),
1401 )
1402 bQ = tl.load(gQ.advance([m_block * BLOCK_M, 0]), boundary_check=(0, 1))
1404 acc_ = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)
1405 rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
1406 rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
1408 if is_alibi:
1409 alibi_offset = bid * alibi_slopes_batch_stride + hid
1410 alibi_slope = tl.load(alibi_slopes_ptr + alibi_offset)
1411 alibi_slope /= scale_softmax
1412 else:
1413 alibi_slope = 0.0
1415 if not is_causal and not is_local:
1416 n_masking_steps = 1
1417 elif is_even_mn:
1418 n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N)
1419 else:
1420 n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N) + 1
1422 n_masking_steps = min(n_block_max - n_block_min, n_masking_steps)
1424 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
1425 n_block = n_block_max - 1
1426 for step in tl.range(0, n_masking_steps):
1427 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
1428 if is_paged:
1429 bK, bV = load_from_kvcache(
1430 col_idx,
1431 k_len,
1432 page_table_ptr,
1433 k_ptr_base,
1434 v_ptr_base,
1435 block_size,
1436 d,
1437 k_row_stride,
1438 BLOCK_K=BLOCK_K,
1439 k_page_stride=k_page_stride,
1440 boundary_check=True,
1441 )
1442 else:
1443 start_n = n_block * BLOCK_N
1444 k_ptr_seq = k_ptr_base + k_bos * k_row_stride
1445 v_ptr_seq = v_ptr_base + k_bos * k_row_stride
1446 gK = tl.make_block_ptr(
1447 base=k_ptr_seq,
1448 shape=(k_len, d),
1449 strides=(k_row_stride, 1),
1450 offsets=(start_n, 0),
1451 block_shape=(BLOCK_N, BLOCK_K),
1452 order=(0, 1),
1453 )
1454 gV = tl.make_block_ptr(
1455 base=v_ptr_seq,
1456 shape=(k_len, d),
1457 strides=(k_row_stride, 1),
1458 offsets=(start_n, 0),
1459 block_shape=(BLOCK_N, BLOCK_K),
1460 order=(0, 1),
1461 )
1462 bK = tl.load(gK, boundary_check=(0, 1))
1463 bK = tl.trans(bK)
1464 bV = tl.load(gV, boundary_check=(0, 1))
1465 S = tl.dot(bQ, bK, out_dtype=tl.float32)
1466 S = apply_softcap(S, softcap, is_softcap)
1467 S = apply_alibi(
1468 S,
1469 col_idx,
1470 row_idx,
1471 q_len,
1472 k_len,
1473 is_causal=is_causal,
1474 is_alibi=is_alibi,
1475 alibi_slope=alibi_slope,
1476 )
1477 S = apply_mask(
1478 S,
1479 col_idx,
1480 row_idx,
1481 q_len,
1482 k_len,
1483 window_size_left,
1484 window_size_right,
1485 is_even_mn=is_even_mn,
1486 is_causal=is_causal,
1487 is_local=is_local,
1488 )
1490 acc_, P, rowmax_, rowsum_ = softmax_rescale(
1491 acc_,
1492 S,
1493 rowmax_,
1494 rowsum_,
1495 softmax_scale_log2e=scale_softmax_log2,
1496 is_border=True,
1497 )
1498 P = P.to(v_ptr.type.element_ty)
1500 if is_dropout:
1501 P = apply_dropout(
1502 P,
1503 n_block * BLOCK_N,
1504 m_block * BLOCK_M,
1505 k_len,
1506 bid,
1507 hid,
1508 philox_seed,
1509 philox_offset,
1510 p_dropout_in_uint8_t,
1511 is_dropout,
1512 encode_dropout_in_sign_bit=False,
1513 NUM_HEADS=h,
1514 BLOCK_M=BLOCK_M,
1515 BLOCK_N=BLOCK_N,
1516 )
1518 acc_ = tl.dot(P, bV, acc_)
1519 n_block -= 1
1521 for n_block in tl.range(
1522 n_block_max - n_masking_steps - 1, n_block_min - 1, step=-1
1523 ):
1524 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
1525 if is_paged:
1526 bK, bV = load_from_kvcache(
1527 col_idx,
1528 k_len,
1529 page_table_ptr,
1530 k_ptr_base,
1531 v_ptr_base,
1532 block_size,
1533 d,
1534 k_row_stride,
1535 BLOCK_K=BLOCK_K,
1536 k_page_stride=k_page_stride,
1537 )
1538 else:
1539 start_n = n_block * BLOCK_N
1540 k_ptr_seq = k_ptr_base + k_bos * k_row_stride
1541 v_ptr_seq = v_ptr_base + k_bos * k_row_stride
1542 gK = tl.make_block_ptr(
1543 base=k_ptr_seq,
1544 shape=(k_len, d),
1545 strides=(k_row_stride, 1),
1546 offsets=(start_n, 0),
1547 block_shape=(BLOCK_N, BLOCK_K),
1548 order=(0, 1),
1549 )
1550 gV = tl.make_block_ptr(
1551 base=v_ptr_seq,
1552 shape=(k_len, d),
1553 strides=(k_row_stride, 1),
1554 offsets=(start_n, 0),
1555 block_shape=(BLOCK_N, BLOCK_K),
1556 order=(0, 1),
1557 )
1558 bK = tl.load(gK)
1559 bK = tl.trans(bK)
1560 bV = tl.load(gV)
1561 S = tl.dot(bQ, bK, out_dtype=tl.float32)
1562 S = apply_softcap(S, softcap, is_softcap)
1563 S = apply_alibi(
1564 S,
1565 col_idx,
1566 row_idx,
1567 q_len,
1568 k_len,
1569 is_causal=is_causal,
1570 is_alibi=is_alibi,
1571 alibi_slope=alibi_slope,
1572 )
1573 S = apply_mask(
1574 S,
1575 col_idx,
1576 row_idx,
1577 q_len,
1578 k_len,
1579 window_size_left,
1580 window_size_right,
1581 is_even_mn=True,
1582 is_causal=False,
1583 is_local=is_local,
1584 )
1586 acc_, P, rowmax_, rowsum_ = softmax_rescale(
1587 acc_,
1588 S,
1589 rowmax_,
1590 rowsum_,
1591 softmax_scale_log2e=scale_softmax_log2,
1592 is_border=is_local,
1593 )
1594 P = P.to(v_ptr.type.element_ty)
1596 if is_dropout:
1597 P = apply_dropout(
1598 P,
1599 m_block * BLOCK_M,
1600 n_block * BLOCK_N,
1601 k_len,
1602 bid,
1603 hid,
1604 philox_seed,
1605 philox_offset,
1606 p_dropout_in_uint8_t,
1607 is_dropout,
1608 encode_dropout_in_sign_bit=False,
1609 NUM_HEADS=h,
1610 BLOCK_M=BLOCK_M,
1611 BLOCK_N=BLOCK_N,
1612 )
1613 acc_ = tl.dot(P, bV, acc_)
1615 # LSE
1616 lse = tl.where(
1617 rowsum_ == 0 | (rowsum_ != rowsum_),
1618 float("inf"),
1619 rowmax_ * scale_softmax + tl.log(rowsum_),
1620 )
1621 inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)
1623 acc_ *= inv_sum[:, None]
1625 out = acc_.to(o_ptr.type.element_ty) # noqa
1627 # Write back output
1628 o_row_offset = hid * o_head_stride
1630 gO = tl.make_block_ptr(
1631 base=o_ptr + o_offset + o_row_offset,
1632 shape=(q_len, d),
1633 strides=(o_row_stride, 1),
1634 offsets=(0, 0),
1635 block_shape=(BLOCK_M, BLOCK_K),
1636 order=(1, 0),
1637 )
1638 tl.store(gO.advance([m_block * BLOCK_M, 0]), out, boundary_check=(0, 1))
1640 # Write back lse
1641 # lse shape: [h, total_q]
1642 softmax_lse_ptr += hid * total_q
1643 lse_row_offset = lse_offset + m_block * BLOCK_M + tl.arange(0, BLOCK_M)
1644 tl.store(
1645 softmax_lse_ptr + lse_row_offset,
1646 lse,
1647 mask=lse_row_offset < (lse_offset + q_len),
1648 )