Coverage for src/flag_gems/runtime/backend/_sunrise/ops/flash_kernel.py: 0%
572 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-05 07:36 +0800
1import triton
2import triton.language as tl
4from flag_gems.utils import libentry, tl_extra_shim
7@triton.jit
8def u64_to_lohi(x):
9 return (x >> 32).to(tl.uint32), (x & 0xFFFFFFFF).to(tl.uint32)
12@triton.jit
13def u64_from_lohi(lo, hi):
14 return hi.to(tl.uint64) << 32 + lo.to(tl.uint64)
17@triton.jit
18def philox_(seed, subsequence, offset):
19 kPhilox10A: tl.constexpr = 0x9E3779B9
20 kPhilox10B: tl.constexpr = 0xBB67AE85
21 k0, k1 = u64_to_lohi(seed.to(tl.uint64))
22 c0, c1 = u64_to_lohi(offset.to(tl.uint64))
23 c2, c3 = u64_to_lohi(subsequence.to(tl.uint64))
25 # pragma unroll
26 kPhiloxSA: tl.constexpr = 0xD2511F53
27 kPhiloxSB: tl.constexpr = 0xCD9E8D57
28 for _ in tl.static_range(6):
29 res0 = kPhiloxSA * c0.to(tl.uint64)
30 res1 = kPhiloxSB * c2.to(tl.uint64)
31 res0_x, res0_y = u64_to_lohi(res0)
32 res1_x, res1_y = u64_to_lohi(res1)
33 c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x
34 k0 += kPhilox10A
35 k1 += kPhilox10B
37 res0 = kPhiloxSA * c0.to(tl.uint64)
38 res1 = kPhiloxSB * c2.to(tl.uint64)
39 res0_x, res0_y = u64_to_lohi(res0)
40 res1_x, res1_y = u64_to_lohi(res1)
41 c0, c1, c2, c3 = res1_y ^ c1 ^ k0, res1_x, res0_y ^ c3 ^ k1, res0_x
43 return c0, c1, c2, c3
46@triton.jit
47def apply_dropout_mask(
48 P,
49 mask,
50 encode_dropout_in_sign_bit: tl.constexpr,
51):
52 if encode_dropout_in_sign_bit:
53 P = tl.where(mask, -P, P)
54 else:
55 P = tl.where(mask, (P * 0).to(P.dtype), P)
56 return P
59@triton.jit
60def apply_dropout(
61 P,
62 row_start,
63 col_start,
64 n_cols,
65 bid,
66 hid,
67 philox_seed,
68 philox_offset,
69 p_dropout_uint8: tl.constexpr,
70 is_dropout: tl.constexpr,
71 encode_dropout_in_sign_bit: tl.constexpr,
72 NUM_HEADS: tl.constexpr,
73 BLOCK_M: tl.constexpr,
74 BLOCK_N: tl.constexpr,
75):
76 if is_dropout:
77 row_start = tl.multiple_of(row_start, BLOCK_M)
78 col_start = tl.multiple_of(col_start, BLOCK_N)
79 row = row_start + tl.arange(0, BLOCK_M)[:, None]
80 # Down scale col_idx by 4
81 col = col_start // 4 + tl.arange(0, BLOCK_N // 4)[None, :]
83 subsequence = row.to(tl.uint64) * n_cols + col.to(tl.uint64)
85 offset = philox_offset + bid * NUM_HEADS + hid
86 offset += subsequence * 0
87 r0, r1, r2, r3 = philox_(philox_seed, subsequence, offset)
89 r = tl.join(tl.join(r0, r1), tl.join(r2, r3)).reshape(BLOCK_M, BLOCK_N)
91 mask = (r & 0xFF) >= p_dropout_uint8
93 P = apply_dropout_mask(
94 P, mask, encode_dropout_in_sign_bit=encode_dropout_in_sign_bit
95 )
96 return P
99@triton.jit
100def apply_alibi(
101 S,
102 col_idx,
103 row_idx,
104 max_seqlen_q,
105 max_seqlen_k,
106 is_causal: tl.constexpr,
107 is_alibi: tl.constexpr,
108 alibi_slope: tl.constexpr = None,
109):
110 if is_alibi:
111 if is_causal:
112 # The row independent alibi bias renders the same attention output
113 # as with the standard alibi because softmax is shift invariant, i.e.,
114 # softmax(A + bias + const) = softamx(A + bias). The following two
115 # biases are no different if causal is true.
116 # bias_1 = [
117 # -4, -3, -2, X, X,
118 # -4, -3, -2, -1, X,
119 # -4, -3, -2, -1, 0,
120 # ]
121 # bias_2 = [
122 # -2, -1, 0, X, X,
123 # -3, -2, -1, 0, X,
124 # -4, -3, -2, -1, 0,
125 # ]
126 bias = alibi_slope * (-max_seqlen_k + 1 + col_idx[None, :]).to(tl.float32)
127 S += bias
128 else:
129 bias = -alibi_slope * tl.abs(
130 col_idx[None, :] - max_seqlen_k + max_seqlen_q - row_idx[:, None]
131 ).to(tl.float32)
132 S += bias
134 return S
137@triton.jit
138def apply_mask(
139 S,
140 col_idx,
141 row_idx,
142 max_seqlen_q,
143 max_seqlen_k,
144 window_size_left,
145 window_size_right,
146 is_even_mn: tl.constexpr,
147 is_causal: tl.constexpr,
148 is_local: tl.constexpr,
149):
150 need_mask = is_causal | is_local | (not is_even_mn)
151 # need_mask: tl.constexpr = is_causal | is_local
152 if need_mask:
153 # Extra care should be taken to void one-off errors: both col_lb and col_rb are inclusive!
154 col_lb = max(0, row_idx + max_seqlen_k - max_seqlen_q - window_size_left)
155 col_rb = min(
156 max_seqlen_k - 1, row_idx + max_seqlen_k - max_seqlen_q + window_size_right
157 )
159 if is_causal:
160 S = tl.where(col_idx[None, :] > col_rb[:, None], float("-inf"), S)
162 if is_local:
163 S = tl.where(
164 (col_idx[None, :] > col_rb[:, None])
165 | (col_idx[None, :] < col_lb[:, None]),
166 float("-inf"),
167 S,
168 )
170 if (not is_local) & (not is_causal) & (not is_even_mn):
171 S = tl.where(col_idx[None, :] >= max_seqlen_k, float("-inf"), S)
173 return S
176@triton.jit
177def softmax_rescale(
178 O_acc,
179 S,
180 row_max,
181 row_sum,
182 softmax_scale_log2e: tl.constexpr,
183 is_border: tl.constexpr,
184 # is_init: tl.constexpr
185):
186 prev_max = row_max
187 row_max = tl.maximum(row_max, tl.max(S, 1))
189 if is_border:
190 cur_max = tl.where(row_max == float("-inf"), 0, row_max)
191 else:
192 cur_max = row_max
194 p_scale = tl.math.exp2((prev_max - cur_max) * softmax_scale_log2e)
195 row_sum *= p_scale
196 O_acc *= p_scale[:, None]
198 max_scaled = tl.where(row_max == float("-inf"), 0, row_max * softmax_scale_log2e)
199 P = tl.math.exp2(S * softmax_scale_log2e - max_scaled[:, None])
200 row_sum = row_sum + tl.sum(P, 1)
201 return O_acc, P, row_max, row_sum
204@triton.jit
205def apply_softcap(S, softcap, is_softcap: tl.constexpr):
206 if is_softcap:
207 S = tl_extra_shim.tanh(S * softcap)
209 return S
212def block_m_splitkv_heuristic(headdim):
213 return 128 if headdim <= 128 else 64
216def block_n_splitkv_heuristic(headdim):
217 return 64 if headdim <= 64 else 32
220def is_even_mn(M, N, BM, BN, WL, WR):
221 if M % BM == 0 and N % BN == 0:
222 if M % N == 0 or N % M == 0:
223 if (WL == -1 or WL % BN == 0) and (WR == -1 or WR % BN == 0):
224 return True
225 return False
228def block_m_splitkv_heuristic_spec_args(args):
229 return 128 if args["d"] <= 128 else 64
232def block_n_splitkv_heuristic_spec_args(args):
233 return 64 if args["d"] <= 64 else 32
236def is_even_mn_spec_args(args):
237 if (
238 args["seqlen_q"] % args["BLOCK_M"] == 0
239 and args["seqlen_k"] % args["BLOCK_N"] == 0
240 ):
241 if (
242 args["seqlen_q"] % args["seqlen_k"] == 0
243 or args["seqlen_k"] % args["seqlen_q"] == 0
244 ):
245 if (
246 args["window_size_left"] == -1
247 or args["window_size_left"] % args["BLOCK_N"] == 0
248 ) and (
249 args["window_size_right"] == -1
250 or args["window_size_right"] % args["BLOCK_N"] == 0
251 ):
252 return True
253 return False
256def keep(cfg, must_keep=None):
257 BM = cfg.kwargs["BLOCK_M"]
258 BN = cfg.kwargs["BLOCK_N"]
259 w = cfg.num_warps
261 # we always keep configurations in `must_keep`
262 return (BM, BN, w) in ((128, 32, 4), (128, 128, 8)) or (
263 must_keep and cfg in must_keep
264 )
267def prune_fwd_configs(configs, nargs, **kwargs):
268 is_dropout = nargs["is_dropout"]
269 if is_dropout:
270 return list(
271 filter(lambda cfg: cfg.num_warps == 4 and cfg.num_stages < 4, configs)
272 )
273 else:
274 return configs
277def flash_fwd_kernel_heur_block_k(args):
278 return triton.next_power_of_2(args["d"])
281@libentry()
282@triton.heuristics(
283 values={
284 "BLOCK_K": flash_fwd_kernel_heur_block_k,
285 "PRE_LOAD_V": lambda args: False,
286 "IS_EVEN_MN": lambda args: is_even_mn(
287 args["seqlen_q"],
288 args["seqlen_k"],
289 args["BLOCK_M"],
290 args["BLOCK_N"],
291 args["window_size_left"],
292 args["window_size_right"],
293 ),
294 }
295)
296@triton.jit(
297 do_not_specialize=["seqlen_q", "seqlen_k", "seqlen_q_rounded", "seqlen_k_rounded"]
298)
299def flash_fwd_kernel(
300 q_ptr,
301 k_ptr,
302 v_ptr,
303 o_ptr,
304 p_ptr,
305 softmax_lse_ptr,
306 q_row_stride,
307 k_row_stride,
308 v_row_stride,
309 q_head_stride,
310 k_head_stride,
311 v_head_stride,
312 o_row_stride,
313 o_head_stride,
314 q_batch_stride,
315 k_batch_stride,
316 v_batch_stride,
317 o_batch_stride,
318 is_cu_seqlens_q,
319 cu_seqlens_q_ptr,
320 is_cu_seqlens_k,
321 cu_seqlens_k_ptr,
322 is_seqused_k,
323 seqused_k_ptr,
324 # sizes
325 b: tl.constexpr,
326 bk: tl.constexpr,
327 h: tl.constexpr,
328 hk: tl.constexpr,
329 h_hk_ratio: tl.constexpr,
330 seqlen_q,
331 seqlen_k,
332 seqlen_q_rounded,
333 seqlen_k_rounded,
334 d: tl.constexpr,
335 d_rounded: tl.constexpr,
336 # scaling factors
337 is_softcap: tl.constexpr,
338 softcap: tl.constexpr,
339 scale_softmax: tl.constexpr,
340 scale_softmax_log2: tl.constexpr,
341 # dropout
342 is_dropout: tl.constexpr,
343 p_dropout: tl.constexpr,
344 rp_dropout: tl.constexpr,
345 p_dropout_in_uint8_t: tl.constexpr,
346 philox_args,
347 return_softmax: tl.constexpr,
348 # causal and swa
349 is_causal: tl.constexpr,
350 is_local: tl.constexpr,
351 window_size_left: tl.constexpr,
352 window_size_right: tl.constexpr,
353 seqlenq_ngroups_swapped: tl.constexpr,
354 is_paged: tl.constexpr,
355 # alibi
356 is_alibi: tl.constexpr,
357 alibi_slopes_ptr,
358 alibi_slopes_batch_stride: tl.constexpr,
359 # block table
360 total_q: tl.constexpr,
361 page_table_ptr,
362 page_table_batch_stride: tl.constexpr,
363 block_size: tl.constexpr,
364 # kernel params
365 IS_EVEN_MN: tl.constexpr,
366 PRE_LOAD_V: tl.constexpr,
367 BLOCK_M: tl.constexpr,
368 BLOCK_N: tl.constexpr,
369 BLOCK_K: tl.constexpr,
370 num_warps: tl.constexpr,
371 num_stages: tl.constexpr,
372):
373 m_block = tl.program_id(0)
374 bh = tl.program_id(1)
375 hid = bh % h
376 bid = bh // h
377 num_m_blocks = tl.cdiv(seqlen_q, BLOCK_M)
379 # We draw a minimum covering frame on the attention map that this CTA is assigned to process.
380 # The frame edges are rounded to multiples of BLOCK_M and BLOCK_N for rows and columns respectively.
382 col_min = 0
383 if is_local:
384 col_min = max(0, m_block * BLOCK_M + seqlen_k - seqlen_q - window_size_left)
385 if not IS_EVEN_MN:
386 # round left
387 col_min = (col_min // BLOCK_N) * BLOCK_N
389 col_max = seqlen_k
390 if is_causal or is_local:
391 col_max += (m_block - num_m_blocks + 1) * BLOCK_M
392 if is_local:
393 col_max += window_size_right
394 col_max = min(seqlen_k, col_max)
396 if not IS_EVEN_MN:
397 # round right
398 col_max = tl.cdiv(col_max, BLOCK_N) * BLOCK_N
400 if (not is_causal) and (not is_local):
401 if IS_EVEN_MN:
402 masking_cols: tl.constexpr = 0
403 else:
404 masking_cols: tl.constexpr = BLOCK_N
405 elif (
406 is_causal | is_local
407 ) and IS_EVEN_MN: # causal implies window_size_right is zero
408 masking_cols: tl.constexpr = tl.cdiv(BLOCK_M, BLOCK_N) * BLOCK_N
409 else:
410 # local
411 masking_cols: tl.constexpr = (tl.cdiv(BLOCK_M, BLOCK_N) + 1) * BLOCK_N
413 if is_dropout:
414 philox_seed = tl.load(philox_args).to(tl.uint64)
415 philox_offset = tl.load(philox_args + 1).to(tl.uint64)
417 if is_alibi:
418 alibi_offset = bid * alibi_slopes_batch_stride + hid
419 alibi_slope = tl.load(alibi_slopes_ptr + alibi_offset)
420 alibi_slope /= scale_softmax
421 else:
422 alibi_slope = 0.0
424 q_batch_stride = tl.multiple_of(q_batch_stride, d * h)
425 q_ptr += bid * q_batch_stride + hid * q_head_stride
426 row_start = m_block * BLOCK_M
427 row_idx = row_start + tl.arange(0, BLOCK_M)
428 q_off = row_idx[:, None] * q_row_stride + tl.arange(0, BLOCK_K)[None, :]
429 dmask = tl.arange(0, BLOCK_K) < d
430 qmask = dmask[None, :] & (row_idx[:, None] < seqlen_q)
431 if IS_EVEN_MN & d == BLOCK_K:
432 Q = tl.load(q_ptr + q_off, cache_modifier=".cg")
433 else:
434 Q = tl.load(q_ptr + q_off, mask=qmask, cache_modifier=".cg")
436 if return_softmax:
437 p_ptr += (
438 (bid * h + hid) * seqlen_q_rounded + m_block * BLOCK_M
439 ) * seqlen_k_rounded
440 p_offset = tl.arange(0, BLOCK_M)[:, None] * seqlen_k_rounded + tl.arange(
441 0, BLOCK_N
442 )
443 p_bp0 = p_ptr + p_offset
445 acc_ = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)
446 rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
447 rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
449 k_batch_stride = tl.multiple_of(k_batch_stride, d * hk)
450 h_hk_ratio = h // hk
451 k_ptr += bid * k_batch_stride
452 k_ptr += (hid // h_hk_ratio) * k_head_stride
453 v_ptr += bid * k_batch_stride
454 v_ptr += (hid // h_hk_ratio) * k_head_stride
456 k_offset = (
457 tl.arange(0, BLOCK_N)[None, :] * k_row_stride + tl.arange(0, BLOCK_K)[:, None]
458 )
459 v_offset = (
460 tl.arange(0, BLOCK_N)[:, None] * k_row_stride + tl.arange(0, BLOCK_K)[None, :]
461 )
463 p_bk0 = k_ptr + k_offset
464 p_bv0 = v_ptr + v_offset
466 if is_causal | is_local | (not IS_EVEN_MN):
467 # Cut short masking cols if there's not enough cols out there
468 masking_cols = min(col_max - col_min, masking_cols)
469 for col_shift in tl.range(0, masking_cols, step=BLOCK_N):
470 col_start = col_max - col_shift - BLOCK_N
471 col_start = tl.multiple_of(col_start, BLOCK_N)
472 off = col_start * k_row_stride
473 if IS_EVEN_MN & d == BLOCK_K:
474 K = tl.load(p_bk0 + off, cache_modifier=".cg")
475 if PRE_LOAD_V:
476 V = tl.load(p_bv0 + off, cache_modifier=".cg")
477 elif d == BLOCK_K:
478 col_idx = col_start + tl.arange(0, BLOCK_N)
479 kvmask = col_idx < seqlen_k
480 K = tl.load(p_bk0 + off, mask=kvmask[None, :], cache_modifier=".cg")
481 if PRE_LOAD_V:
482 V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg")
483 else:
484 col_idx = col_start + tl.arange(0, BLOCK_N)
485 kvmask = col_idx < seqlen_k
486 K = tl.load(
487 p_bk0 + off,
488 mask=kvmask[None, :] & dmask[:, None],
489 cache_modifier=".cg",
490 )
491 if PRE_LOAD_V:
492 V = tl.load(
493 p_bv0 + off,
494 mask=kvmask[:, None] & dmask[None, :],
495 cache_modifier=".cg",
496 )
497 S = tl.dot(Q, K, allow_tf32=False)
498 S = apply_softcap(S, softcap, is_softcap)
499 col_idx = col_start + tl.arange(0, BLOCK_N)
500 row_idx = row_start + tl.arange(0, BLOCK_M)
501 S = apply_alibi(
502 S,
503 col_idx,
504 row_idx,
505 seqlen_q,
506 seqlen_k,
507 is_causal=is_causal,
508 is_alibi=is_alibi,
509 alibi_slope=alibi_slope,
510 )
511 # tl.store(p_bp0 + col_start, S)
512 S = apply_mask(
513 S,
514 col_idx,
515 row_idx,
516 seqlen_q,
517 seqlen_k,
518 window_size_left,
519 window_size_right,
520 is_even_mn=IS_EVEN_MN,
521 is_causal=is_causal,
522 is_local=is_local,
523 )
525 acc_, P, rowmax_, rowsum_ = softmax_rescale(
526 acc_,
527 S,
528 rowmax_,
529 rowsum_,
530 softmax_scale_log2e=scale_softmax_log2,
531 is_border=(is_causal or is_local),
532 )
533 P = P.to(v_ptr.type.element_ty)
535 if is_dropout:
536 if return_softmax:
537 P_drop = P
539 P_drop = apply_dropout(
540 P_drop,
541 row_start,
542 col_start,
543 seqlen_k,
544 bid,
545 hid,
546 philox_seed,
547 philox_offset,
548 p_dropout_in_uint8_t,
549 is_dropout,
550 encode_dropout_in_sign_bit=True,
551 NUM_HEADS=h,
552 BLOCK_M=BLOCK_M,
553 BLOCK_N=BLOCK_N,
554 )
555 if IS_EVEN_MN:
556 tl.store(p_bp0 + col_start, P_drop)
557 else:
558 kvmask = col_idx < seqlen_k
559 tl.store(
560 p_bp0 + col_start, P_drop, mask=qmask & kvmask[None, :]
561 )
563 P = apply_dropout(
564 P,
565 row_start,
566 col_start,
567 seqlen_k,
568 bid,
569 hid,
570 philox_seed,
571 philox_offset,
572 p_dropout_in_uint8_t,
573 is_dropout,
574 encode_dropout_in_sign_bit=False,
575 NUM_HEADS=h,
576 BLOCK_M=BLOCK_M,
577 BLOCK_N=BLOCK_N,
578 )
580 if not PRE_LOAD_V:
581 off = col_start * k_row_stride
582 if IS_EVEN_MN & d == BLOCK_K:
583 V = tl.load(p_bv0 + off, cache_modifier=".cg")
584 elif d == BLOCK_K:
585 kvmask = col_idx < seqlen_k
586 V = tl.load(p_bv0 + off, mask=kvmask[:, None], cache_modifier=".cg")
587 else:
588 kvmask = col_idx < seqlen_k
589 V = tl.load(
590 p_bv0 + off,
591 mask=kvmask[:, None] & dmask[None, :],
592 cache_modifier=".cg",
593 )
594 acc_ = tl.dot(P, V, acc_, allow_tf32=False)
596 for col_start in tl.range(
597 col_min, col_max - masking_cols, step=BLOCK_N, num_stages=num_stages
598 ):
599 col_start = tl.multiple_of(col_start, BLOCK_N)
600 off = col_start * k_row_stride
601 if d == BLOCK_K:
602 K = tl.load(p_bk0 + off, cache_modifier=".cg")
603 if PRE_LOAD_V:
604 V = tl.load(p_bv0 + off, cache_modifier=".cg")
605 else:
606 K = tl.load(p_bk0 + off, mask=dmask[:, None], cache_modifier=".cg")
607 if PRE_LOAD_V:
608 V = tl.load(p_bv0 + off, mask=dmask[None, :], cache_modifier=".cg")
610 S = tl.dot(Q, K)
611 S = apply_softcap(S, softcap, is_softcap)
612 col_idx = col_start + tl.arange(0, BLOCK_N)
613 row_idx = row_start + tl.arange(0, BLOCK_M)
614 S = apply_alibi(
615 S,
616 col_idx,
617 row_idx,
618 seqlen_q,
619 seqlen_k,
620 is_causal=is_causal,
621 is_alibi=is_alibi,
622 alibi_slope=alibi_slope,
623 )
624 S = apply_mask(
625 S,
626 col_idx,
627 row_idx,
628 seqlen_q,
629 seqlen_k,
630 window_size_left,
631 window_size_right,
632 is_even_mn=True,
633 is_causal=False,
634 is_local=is_local,
635 )
637 acc_, P, rowmax_, rowsum_ = softmax_rescale(
638 acc_,
639 S,
640 rowmax_,
641 rowsum_,
642 softmax_scale_log2e=scale_softmax_log2,
643 is_border=is_local,
644 )
645 P = P.to(v_ptr.type.element_ty)
647 if is_dropout:
648 if return_softmax:
649 P_drop = P
650 P_drop = apply_dropout(
651 P_drop,
652 row_start,
653 col_start,
654 seqlen_k,
655 bid,
656 hid,
657 philox_seed,
658 philox_offset,
659 p_dropout_in_uint8_t,
660 is_dropout,
661 encode_dropout_in_sign_bit=True,
662 NUM_HEADS=h,
663 BLOCK_M=BLOCK_M,
664 BLOCK_N=BLOCK_N,
665 )
666 if IS_EVEN_MN:
667 tl.store(p_bp0 + col_start, P_drop)
668 else:
669 kvmask = col_idx < seqlen_k
670 tl.store(p_bp0 + col_start, P_drop, mask=qmask & kvmask[None, :])
672 P = apply_dropout(
673 P,
674 row_start,
675 col_start,
676 seqlen_k,
677 bid,
678 hid,
679 philox_seed,
680 philox_offset,
681 p_dropout_in_uint8_t,
682 is_dropout,
683 encode_dropout_in_sign_bit=False,
684 NUM_HEADS=h,
685 BLOCK_M=BLOCK_M,
686 BLOCK_N=BLOCK_N,
687 )
689 if not PRE_LOAD_V:
690 off = col_start * k_row_stride
691 if d == BLOCK_K:
692 V = tl.load(p_bv0 + off, cache_modifier=".cg")
693 else:
694 V = tl.load(p_bv0 + off, mask=dmask[None, :], cache_modifier=".cg")
695 acc_ = tl.dot(P, V, acc_)
697 # LSE
698 # Note, rowsum = exp(-rowmax) * exp(lse), therefore rowmax + log(rowsum) cancels
699 # the effect of rowmax and outputs lse only.
700 lse = tl.where(
701 rowsum_ == 0 | (rowsum_ != rowsum_),
702 float("inf"),
703 rowmax_ * scale_softmax + tl.log(rowsum_),
704 )
705 inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)
707 if is_dropout:
708 acc_ *= inv_sum[:, None] * rp_dropout
709 else:
710 acc_ *= inv_sum[:, None]
712 out = acc_.to(o_ptr.type.element_ty) # noqa
714 # Write back output
715 o_batch_stride = tl.multiple_of(o_batch_stride, d * h)
716 o_ptr += bid * o_batch_stride
717 o_ptr += hid * o_head_stride
718 o_offset = row_idx[:, None] * o_row_stride + tl.arange(0, BLOCK_K)
720 if IS_EVEN_MN & d == BLOCK_K:
721 tl.store(o_ptr + o_offset, out)
722 else:
723 tl.store(o_ptr + o_offset, out, mask=qmask)
725 # Write back lse
726 p_lse = softmax_lse_ptr + (bid * h + hid) * seqlen_q
727 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
729 if IS_EVEN_MN:
730 tl.store(p_lse + row_idx, lse)
731 else:
732 tl.store(p_lse + row_idx, lse, mask=row_idx < seqlen_q)
735@triton.jit(do_not_specialize=["seqlen_q", "seqlen_k"])
736def flash_fwd_bh_parallel_kernel():
737 # (TODO)
738 pass
741def flash_fwd_splitkv_kernel_heur_block_k(args):
742 return triton.next_power_of_2(args["d"])
745@libentry()
746@triton.heuristics(
747 values={
748 "BLOCK_M": block_m_splitkv_heuristic_spec_args,
749 "BLOCK_N": block_n_splitkv_heuristic_spec_args,
750 "BLOCK_K": flash_fwd_splitkv_kernel_heur_block_k,
751 "num_warps": lambda args: 4,
752 "num_stages": lambda args: 3,
753 "PRE_LOAD_V": lambda args: True,
754 "IS_EVEN_MN": is_even_mn_spec_args,
755 }
756)
757@triton.jit(
758 do_not_specialize=["seqlen_q", "seqlen_k", "seqlen_q_rounded", "seqlen_k_rounded"]
759)
760def flash_fwd_splitkv_kernel(
761 q_ptr,
762 k_ptr,
763 v_ptr,
764 o_ptr,
765 p_ptr,
766 softmax_lse_ptr,
767 q_row_stride,
768 k_row_stride,
769 v_row_stride,
770 q_head_stride,
771 k_head_stride,
772 v_head_stride,
773 o_row_stride,
774 o_head_stride,
775 q_batch_stride,
776 k_batch_stride,
777 v_batch_stride,
778 o_batch_stride,
779 is_cu_seqlens_q,
780 cu_seqlens_q_ptr,
781 is_cu_seqlens_k: tl.constexpr,
782 cu_seqlens_k_ptr,
783 is_seqused_k: tl.constexpr,
784 seqused_k_ptr,
785 # sizes
786 b: tl.constexpr,
787 bk: tl.constexpr,
788 h: tl.constexpr,
789 hk: tl.constexpr,
790 h_hk_ratio: tl.constexpr,
791 seqlen_q,
792 seqlen_k,
793 seqlen_q_rounded,
794 seqlen_k_rounded,
795 d: tl.constexpr,
796 d_rounded: tl.constexpr,
797 # scaling factors
798 is_softcap: tl.constexpr,
799 softcap: tl.constexpr,
800 scale_softmax: tl.constexpr,
801 scale_softmax_log2: tl.constexpr,
802 # dropout
803 is_dropout: tl.constexpr,
804 p_dropout: tl.constexpr,
805 rp_dropout: tl.constexpr,
806 p_dropout_in_uint8_t: tl.constexpr,
807 philox_args,
808 return_softmax: tl.constexpr,
809 # causal and swa
810 is_causal: tl.constexpr,
811 is_local: tl.constexpr,
812 window_size_left: tl.constexpr,
813 window_size_right: tl.constexpr,
814 seqlenq_ngroups_swapped: tl.constexpr,
815 is_paged: tl.constexpr,
816 # alibi
817 is_alibi: tl.constexpr,
818 alibi_slopes_ptr,
819 alibi_slopes_batch_stride: tl.constexpr,
820 # block table
821 total_q,
822 page_table_ptr,
823 page_table_batch_stride: tl.constexpr,
824 block_size: tl.constexpr,
825 # kernel params
826 IS_EVEN_MN: tl.constexpr,
827 PRE_LOAD_V: tl.constexpr,
828 blocks_per_split: tl.constexpr,
829 BLOCK_M: tl.constexpr,
830 BLOCK_N: tl.constexpr,
831 BLOCK_K: tl.constexpr,
832 num_warps: tl.constexpr,
833 num_stages: tl.constexpr,
834):
835 m_block = tl.program_id(0)
836 split_id = tl.program_id(1)
837 bid = tl.program_id(2) // h
838 hid = tl.program_id(2) % h
840 split_block_min = split_id * blocks_per_split
841 split_block_max = split_block_min + blocks_per_split
843 n_block_max = tl.cdiv(seqlen_k, BLOCK_N)
844 if is_causal:
845 n_block_max = min(
846 n_block_max,
847 tl.cdiv(
848 (m_block + 1) * BLOCK_M + seqlen_k - seqlen_q + window_size_right,
849 BLOCK_N,
850 ),
851 )
853 if is_alibi:
854 alibi_offset = bid * alibi_slopes_batch_stride + hid
855 alibi_slope = tl.load(alibi_slopes_ptr + alibi_offset)
856 alibi_slope /= scale_softmax
857 else:
858 alibi_slope = 0
860 if not is_causal:
861 if IS_EVEN_MN:
862 masking_block_min = n_block_max
863 else:
864 masking_block_min = n_block_max - 1
865 elif is_causal and IS_EVEN_MN: # causal implies window_size_right is zero
866 masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N)
867 else:
868 masking_block_min = n_block_max - tl.cdiv(BLOCK_M, BLOCK_N) - 1
870 q_ptr += bid * q_batch_stride
871 q_ptr += hid * q_head_stride
872 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
873 q_off = row_idx[:, None] * q_row_stride + tl.arange(0, BLOCK_K)[None, :]
874 p_qm = q_ptr + q_off
875 dmask = tl.arange(0, BLOCK_K) < d
876 qmask = dmask[None, :] & (row_idx[:, None] < seqlen_q)
877 if IS_EVEN_MN & BLOCK_K == d:
878 Q = tl.load(p_qm, cache_modifier=".cg")
879 else:
880 Q = tl.load(p_qm, mask=qmask, cache_modifier=".cg")
882 h_hk_ratio = h // hk
883 k_ptr += bid * k_batch_stride
884 k_ptr += (hid // h_hk_ratio) * k_head_stride
885 v_ptr += bid * k_batch_stride
886 v_ptr += (hid // h_hk_ratio) * k_head_stride
888 k_offset = (
889 tl.arange(0, BLOCK_N)[None, :] * k_row_stride + tl.arange(0, BLOCK_K)[:, None]
890 )
891 p_k0 = k_ptr + k_offset
893 v_offset = (
894 tl.arange(0, BLOCK_N)[:, None] * k_row_stride + tl.arange(0, BLOCK_K)[None, :]
895 )
896 p_v0 = v_ptr + v_offset
898 acc_ = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)
899 rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
900 rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
902 if split_block_max <= masking_block_min:
903 # no masking needed
904 for n_block in tl.range(
905 split_block_min, split_block_max, num_stages=num_stages
906 ):
907 kv_off = n_block * BLOCK_N * k_row_stride
908 if d == BLOCK_K:
909 K = tl.load(p_k0 + kv_off, cache_modifier=".cg")
910 else:
911 K = tl.load(
912 p_k0 + kv_off, mask=dmask[:, None], cache_modifier=".cg", other=0.0
913 )
914 if PRE_LOAD_V:
915 if d == BLOCK_K:
916 V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
917 else:
918 V = tl.load(
919 p_v0 + kv_off,
920 mask=dmask[None, :],
921 cache_modifier=".cg",
922 other=0.0,
923 )
924 S = tl.dot(Q, K)
925 S = apply_softcap(S, softcap, is_softcap)
926 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
927 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
928 S = apply_alibi(
929 S,
930 col_idx,
931 row_idx,
932 seqlen_q,
933 seqlen_k,
934 is_causal=is_causal,
935 is_alibi=is_alibi,
936 alibi_slope=alibi_slope,
937 )
938 acc_, P, rowmax_, rowsum_ = softmax_rescale(
939 acc_,
940 S,
941 rowmax_,
942 rowsum_,
943 softmax_scale_log2e=scale_softmax_log2,
944 is_border=False,
945 )
947 if not PRE_LOAD_V:
948 if d == BLOCK_K:
949 V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
950 else:
951 V = tl.load(
952 p_v0 + kv_off,
953 mask=dmask[None, :],
954 cache_modifier=".cg",
955 other=0.0,
956 )
957 P = P.to(v_ptr.type.element_ty)
958 acc_ = tl.dot(P, V, acc_)
959 else:
960 for n_block in tl.range(split_block_min, min(split_block_max, n_block_max)):
961 kv_off = n_block * BLOCK_N * k_row_stride
962 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
963 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
964 if IS_EVEN_MN & d == BLOCK_K:
965 K = tl.load(p_k0 + kv_off, cache_modifier=".cg")
966 if PRE_LOAD_V:
967 V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
968 elif d == BLOCK_K:
969 kvmask = col_idx < seqlen_k
970 K = tl.load(p_k0 + kv_off, mask=kvmask[None, :], cache_modifier=".cg")
971 if PRE_LOAD_V:
972 V = tl.load(
973 p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg"
974 )
975 else:
976 kvmask = col_idx < seqlen_k
977 K = tl.load(
978 p_k0 + kv_off,
979 mask=dmask[:, None] & kvmask[None, :],
980 cache_modifier=".cg",
981 other=0.0,
982 )
983 if PRE_LOAD_V:
984 V = tl.load(
985 p_v0 + kv_off,
986 mask=dmask[None, :] & kvmask[:, None],
987 cache_modifier=".cg",
988 other=0.0,
989 )
991 S = tl.dot(Q, K)
992 S = apply_softcap(S, softcap, is_softcap)
993 S = apply_alibi(
994 S,
995 col_idx,
996 row_idx,
997 seqlen_q,
998 seqlen_k,
999 is_causal=is_causal,
1000 is_alibi=is_alibi,
1001 alibi_slope=alibi_slope,
1002 )
1003 S = apply_mask(
1004 S,
1005 col_idx,
1006 row_idx,
1007 seqlen_q,
1008 seqlen_k,
1009 window_size_left,
1010 window_size_right,
1011 is_even_mn=IS_EVEN_MN,
1012 is_causal=is_causal,
1013 is_local=False,
1014 )
1016 acc_, P, rowmax_, rowsum_ = softmax_rescale(
1017 acc_,
1018 S,
1019 rowmax_,
1020 rowsum_,
1021 softmax_scale_log2e=scale_softmax_log2,
1022 is_border=(is_causal or is_local),
1023 )
1025 if not PRE_LOAD_V:
1026 if IS_EVEN_MN & d == BLOCK_K:
1027 V = tl.load(p_v0 + kv_off, cache_modifier=".cg")
1028 elif d == BLOCK_K:
1029 V = tl.load(
1030 p_v0 + kv_off, mask=kvmask[:, None], cache_modifier=".cg"
1031 )
1032 else:
1033 V = tl.load(
1034 p_v0 + kv_off,
1035 mask=dmask[None, :] & kvmask[:, None],
1036 cache_modifier=".cg",
1037 other=0.0,
1038 )
1039 P = P.to(v_ptr.type.element_ty)
1040 acc_ = tl.dot(P, V, acc_)
1042 # LSE
1043 lse = tl.where(
1044 rowsum_ == 0 | (rowsum_ != rowsum_),
1045 float("-inf"),
1046 rowmax_ * scale_softmax + tl.log(rowsum_),
1047 )
1048 inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)
1050 # Rescale output
1051 acc_ *= inv_sum[:, None]
1053 # Write back output
1054 # o_splits layout = (n_splits, batch_size, num_heads, seqlen_q, head_size)
1055 # grid = (seq_block, split, batch * head)
1056 o_split_ptr = o_ptr
1057 # + split, batch, head offsets, seq_block offsets are already added in row_idx
1058 o_split_ptr += (split_id * tl.num_programs(2) + tl.program_id(2)) * seqlen_q * d
1059 o_split_offset = row_idx[:, None] * d + tl.arange(0, BLOCK_K)
1060 o_split_ptr = tl.multiple_of(o_split_ptr, d)
1061 p_om = o_split_ptr + o_split_offset
1063 if IS_EVEN_MN & BLOCK_K == d:
1064 tl.store(p_om, acc_, cache_modifier=".cg")
1065 else:
1066 tl.store(p_om, acc_, mask=qmask, cache_modifier=".cg")
1068 # Write back lse
1069 # lse_splits layout = (n_splits, batch_size, num_heads, seqlen_q)
1070 lse_split_ptr = softmax_lse_ptr
1071 # + split, batch, head, seq_block offsets
1072 lse_split_ptr += (
1073 split_id * tl.num_programs(2) + tl.program_id(2)
1074 ) * seqlen_q + m_block * BLOCK_M
1076 if IS_EVEN_MN:
1077 tl.store(lse_split_ptr + tl.arange(0, BLOCK_M), lse, cache_modifier=".cg")
1078 else:
1079 tl.store(
1080 lse_split_ptr + tl.arange(0, BLOCK_M),
1081 lse,
1082 mask=row_idx < seqlen_q,
1083 cache_modifier=".cg",
1084 )
1087@libentry()
1088@triton.jit
1089def flash_fwd_splitkv_combine_kernel(
1090 out_ptr,
1091 lse_ptr,
1092 out_splits_ptr,
1093 lse_splits_ptr,
1094 head_size: tl.constexpr,
1095 out_split_stride,
1096 lse_split_stride,
1097 out_b_stride,
1098 out_s_stride,
1099 out_h_stride,
1100 n_splits,
1101 BLOCK_M: tl.constexpr,
1102 BLOCK_K: tl.constexpr,
1103 q_total,
1104 MAX_N_SPLITS: tl.constexpr,
1105):
1106 pid = tl.program_id(0)
1107 lse_splits_ptr += pid * BLOCK_M
1108 lse_ptr += pid * BLOCK_M
1109 out_splits_ptr += pid * BLOCK_M * head_size
1110 out_ptr += pid * BLOCK_M * head_size
1112 # Subtracting maximum from each of the split lse's for better numerical stability
1113 lse_split_offset = (
1114 tl.arange(0, BLOCK_M)[:, None]
1115 + tl.arange(0, MAX_N_SPLITS)[None, :] * lse_split_stride
1116 )
1117 lse_split_mask = (pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] < q_total) & (
1118 tl.arange(0, MAX_N_SPLITS)[None, :] < n_splits
1119 )
1120 lse_splits = tl.load(
1121 lse_splits_ptr + lse_split_offset, mask=lse_split_mask, other=float("-inf")
1122 )
1123 max_lse = tl.max(lse_splits, 1)
1125 # Sum exp(lse(i) - max_lse) over all split i to obtain Z=sumexp(QK) up to a scaled factor exp(-max_lse)
1126 Zi_scaled = tl.exp(lse_splits - max_lse[:, None])
1127 Z_scaled = tl.sum(Zi_scaled, 1)
1128 Zi_Z = Zi_scaled / Z_scaled[:, None]
1130 # Write back LSE
1131 lse = tl.log(Z_scaled) + max_lse
1132 out_mask = pid * BLOCK_M + tl.arange(0, BLOCK_M) < q_total
1133 tl.store(lse_ptr + tl.arange(0, BLOCK_M), lse, mask=out_mask)
1135 out_split_offset = (
1136 tl.arange(0, BLOCK_M)[:, None, None] * head_size
1137 + tl.arange(0, MAX_N_SPLITS)[None, :, None] * out_split_stride
1138 + tl.arange(0, BLOCK_K)[None, None, :]
1139 )
1140 out_split_mask = (
1141 (pid * BLOCK_M + tl.arange(0, BLOCK_M)[:, None, None] < q_total)
1142 & (tl.arange(0, MAX_N_SPLITS)[None, :, None] < n_splits)
1143 & (tl.arange(0, BLOCK_K)[None, None, :] < head_size)
1144 )
1145 out_splits = tl.load(
1146 out_splits_ptr + out_split_offset, mask=out_split_mask, other=0.0
1147 )
1148 out = tl.sum(Zi_Z[:, :, None] * out_splits, 1)
1149 out = out.to(out_ptr.type.element_ty)
1151 # Write back output
1152 out_offset = tl.arange(0, BLOCK_M)[:, None] * out_s_stride + tl.arange(0, BLOCK_K)
1153 dmask = tl.arange(0, BLOCK_K) < head_size
1154 tl.store(out_ptr + out_offset, out, mask=out_mask[:, None] & dmask[None, :])
1157@triton.jit
1158def virtual_to_cache(
1159 virtual_index,
1160 max_virtual_index,
1161 page_table_ptr,
1162 block_size,
1163 boundary_check: tl.constexpr = False,
1164):
1165 # virtual_index is the kv sequence index in the current batch element
1166 # page_table_ptr is already pointed at current batch element's block table entry
1167 # block_size is the size of each block in the page table
1168 virtual_page_index = virtual_index // block_size
1169 page_offset = virtual_index % block_size
1170 if boundary_check:
1171 page_block_index = tl.load(
1172 page_table_ptr + virtual_page_index,
1173 mask=virtual_index < max_virtual_index,
1174 other=0,
1175 ).to(tl.int32)
1176 else:
1177 page_block_index = tl.load(page_table_ptr + virtual_page_index).to(tl.int32)
1178 return page_block_index * block_size + page_offset
1181@triton.jit
1182def load_from_kvcache(
1183 virtual_index,
1184 max_virtual_index,
1185 page_table_ptr,
1186 k_ptr_base,
1187 v_ptr_base,
1188 block_size,
1189 d: tl.constexpr,
1190 k_row_stride,
1191 BLOCK_K: tl.constexpr,
1192 boundary_check: tl.constexpr = False,
1193):
1194 kvcache_idx = virtual_to_cache(
1195 virtual_index, max_virtual_index, page_table_ptr, block_size, boundary_check
1196 )
1197 k_offset = tl.arange(0, BLOCK_K)[:, None] + kvcache_idx[None, :] * k_row_stride
1198 v_offset = tl.arange(0, BLOCK_K)[None, :] + kvcache_idx[:, None] * k_row_stride
1199 if d == BLOCK_K:
1200 bK_mask = virtual_index[None, :] < max_virtual_index[None, :]
1201 bV_mask = virtual_index[:, None] < max_virtual_index[:, None]
1202 bK = tl.load(k_ptr_base + k_offset, mask=bK_mask, other=0.0)
1203 bV = tl.load(v_ptr_base + v_offset, mask=bV_mask, other=0.0)
1204 else:
1205 bK_mask = (tl.arange(0, BLOCK_K)[:, None] < d) & (
1206 virtual_index[None, :] < max_virtual_index[None, :]
1207 )
1208 bV_mask = (tl.arange(0, BLOCK_K)[None, :] < d) & (
1209 virtual_index[:, None] < max_virtual_index[:, None]
1210 )
1211 bK = tl.load(k_ptr_base + k_offset, mask=bK_mask, other=0.0)
1212 bV = tl.load(v_ptr_base + v_offset, mask=bV_mask, other=0.0)
1213 return bK, bV
1216@libentry()
1217@triton.jit(
1218 do_not_specialize=[
1219 "q_batch_stride",
1220 "k_batch_stride",
1221 "v_batch_stride",
1222 "o_batch_stride",
1223 "b",
1224 "bk",
1225 "seqlen_q",
1226 "seqlen_k",
1227 "seqlen_q_rounded",
1228 "seqlen_k_rounded",
1229 "total_q",
1230 ]
1231)
1232def flash_varlen_fwd_kernel(
1233 q_ptr,
1234 k_ptr,
1235 v_ptr,
1236 o_ptr,
1237 p_ptr,
1238 softmax_lse_ptr,
1239 q_row_stride,
1240 k_row_stride,
1241 v_row_stride,
1242 q_head_stride,
1243 k_head_stride,
1244 v_head_stride,
1245 o_row_stride,
1246 o_head_stride,
1247 q_batch_stride,
1248 k_batch_stride,
1249 v_batch_stride,
1250 o_batch_stride,
1251 is_cu_seqlens_q: tl.constexpr,
1252 cu_seqlens_q_ptr,
1253 is_cu_seqlens_k: tl.constexpr,
1254 cu_seqlens_k_ptr,
1255 is_seqused_k: tl.constexpr,
1256 seqused_k_ptr,
1257 # sizes
1258 b,
1259 bk,
1260 h: tl.constexpr,
1261 hk: tl.constexpr,
1262 h_hk_ratio: tl.constexpr,
1263 seqlen_q,
1264 seqlen_k,
1265 seqlen_q_rounded,
1266 seqlen_k_rounded,
1267 d: tl.constexpr,
1268 d_rounded: tl.constexpr,
1269 # scaling factors
1270 is_softcap: tl.constexpr,
1271 softcap: tl.constexpr,
1272 scale_softmax: tl.constexpr,
1273 scale_softmax_log2: tl.constexpr,
1274 # dropout
1275 is_dropout: tl.constexpr,
1276 p_dropout: tl.constexpr,
1277 rp_dropout: tl.constexpr,
1278 p_dropout_in_uint8_t: tl.constexpr,
1279 philox_args,
1280 return_softmax: tl.constexpr,
1281 # causal and swa
1282 is_causal: tl.constexpr,
1283 is_local: tl.constexpr,
1284 window_size_left: tl.constexpr,
1285 window_size_right: tl.constexpr,
1286 seqlenq_ngroups_swapped: tl.constexpr,
1287 is_paged: tl.constexpr,
1288 # alibi
1289 is_alibi: tl.constexpr,
1290 alibi_slopes_ptr,
1291 alibi_slopes_batch_stride: tl.constexpr,
1292 # block table
1293 total_q,
1294 page_table_ptr,
1295 page_table_batch_stride: tl.constexpr,
1296 block_size: tl.constexpr,
1297 # kernel params
1298 BLOCK_M: tl.constexpr,
1299 BLOCK_N: tl.constexpr,
1300 BLOCK_K: tl.constexpr,
1301 num_warps: tl.constexpr,
1302 num_stages: tl.constexpr,
1303):
1304 m_block = tl.program_id(0)
1305 bid = tl.program_id(1)
1306 hid = tl.program_id(2)
1307 # num_m_blocks = tl.cdiv(seqlen_q, BLOCK_M)
1309 if is_cu_seqlens_q:
1310 q_eos = tl.load(cu_seqlens_q_ptr + bid + 1).to(tl.int32)
1311 q_bos = tl.load(cu_seqlens_q_ptr + bid).to(tl.int32)
1312 q_len = q_eos - q_bos
1313 # Current request's start offset in the batched Q
1314 q_offset = q_bos * q_row_stride
1315 o_offset = q_bos * o_row_stride
1316 lse_offset = q_bos * 1
1317 else:
1318 q_len = seqlen_q
1319 q_offset = bid * q_batch_stride
1320 o_offset = bid * o_batch_stride
1321 lse_offset = bid * seqlen_q
1323 if is_cu_seqlens_k:
1324 k_eos = tl.load(cu_seqlens_k_ptr + bid + 1).to(tl.int32)
1325 k_bos = tl.load(cu_seqlens_k_ptr + bid).to(tl.int32)
1326 k_len_cache = k_eos - k_bos
1327 # k_offset = k_bos * k_row_stride
1328 else:
1329 k_len_cache = seqlen_k
1330 # k_offset = bid * k_batch_stride
1332 if is_seqused_k:
1333 k_len = tl.load(seqused_k_ptr + bid).to(tl.int32)
1334 else:
1335 k_len = k_len_cache
1337 # Noop CTA
1338 if m_block * BLOCK_M > q_len:
1339 return
1341 # is_even_mn = (q_len % BLOCK_M == 0) and (k_len % BLOCK_N == 0)
1342 is_even_mn: tl.constexpr = False
1344 if is_local:
1345 n_block_min = max(
1346 0, (m_block * BLOCK_M + k_len - q_len - window_size_left) // BLOCK_N
1347 )
1348 else:
1349 n_block_min = 0
1351 n_block_max = tl.cdiv(k_len, BLOCK_N)
1352 if is_causal or is_local:
1353 n_block_max = min(
1354 n_block_max,
1355 tl.cdiv(
1356 (m_block + 1) * BLOCK_M + k_len - q_len + window_size_right, BLOCK_N
1357 ),
1358 )
1360 if is_dropout:
1361 philox_seed = tl.load(philox_args).to(tl.uint64)
1362 philox_offset = tl.load(philox_args + 1).to(tl.uint64)
1364 # Locate the page table entry for the current batch element
1365 if is_paged:
1366 page_table_ptr += bid * page_table_batch_stride
1367 # Calculate the starting offset of q for the current head
1368 q_row_offset = hid * q_head_stride
1369 # Calculate the starting offset of k and v for the current head
1370 k_row_offset = (hid // h_hk_ratio) * k_head_stride
1371 # Shift the k, v pointers to align with the current head
1372 k_ptr_base = k_ptr + k_row_offset
1373 v_ptr_base = v_ptr + k_row_offset
1375 gQ = tl.make_block_ptr(
1376 base=q_ptr + q_offset + q_row_offset,
1377 shape=(q_len, d),
1378 strides=(q_row_stride, 1),
1379 offsets=(0, 0),
1380 block_shape=(BLOCK_M, BLOCK_K),
1381 order=(1, 0),
1382 )
1383 bQ = tl.load(gQ.advance([m_block * BLOCK_M, 0]), boundary_check=(0, 1))
1385 acc_ = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)
1386 rowmax_ = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)
1387 rowsum_ = tl.zeros([BLOCK_M], dtype=tl.float32)
1389 if is_alibi:
1390 alibi_offset = bid * alibi_slopes_batch_stride + hid
1391 alibi_slope = tl.load(alibi_slopes_ptr + alibi_offset)
1392 alibi_slope /= scale_softmax
1393 else:
1394 alibi_slope = 0.0
1396 if not is_causal and not is_local:
1397 n_masking_steps = 1
1398 elif is_even_mn:
1399 n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N)
1400 else:
1401 n_masking_steps = tl.cdiv(BLOCK_M, BLOCK_N) + 1
1403 n_masking_steps = min(n_block_max - n_block_min, n_masking_steps)
1405 row_idx = m_block * BLOCK_M + tl.arange(0, BLOCK_M)
1406 n_block = n_block_max - 1
1407 for step in tl.range(0, n_masking_steps):
1408 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
1409 if is_paged:
1410 bK, bV = load_from_kvcache(
1411 col_idx,
1412 k_len,
1413 page_table_ptr,
1414 k_ptr_base,
1415 v_ptr_base,
1416 block_size,
1417 d,
1418 k_row_stride,
1419 BLOCK_K=BLOCK_K,
1420 boundary_check=True,
1421 )
1422 else:
1423 start_n = n_block * BLOCK_N
1424 k_ptr_seq = k_ptr_base + k_bos * k_row_stride
1425 v_ptr_seq = v_ptr_base + k_bos * k_row_stride
1426 gK = tl.make_block_ptr(
1427 base=k_ptr_seq,
1428 shape=(k_len, d),
1429 strides=(k_row_stride, 1),
1430 offsets=(start_n, 0),
1431 block_shape=(BLOCK_N, BLOCK_K),
1432 order=(0, 1),
1433 )
1434 gV = tl.make_block_ptr(
1435 base=v_ptr_seq,
1436 shape=(k_len, d),
1437 strides=(k_row_stride, 1),
1438 offsets=(start_n, 0),
1439 block_shape=(BLOCK_N, BLOCK_K),
1440 order=(0, 1),
1441 )
1442 bK = tl.load(gK, boundary_check=(0, 1))
1443 bK = tl.trans(bK)
1444 bV = tl.load(gV, boundary_check=(0, 1))
1445 S = tl.dot(bQ, bK, out_dtype=tl.float32)
1446 S = apply_softcap(S, softcap, is_softcap)
1447 S = apply_alibi(
1448 S,
1449 col_idx,
1450 row_idx,
1451 q_len,
1452 k_len,
1453 is_causal=is_causal,
1454 is_alibi=is_alibi,
1455 alibi_slope=alibi_slope,
1456 )
1457 S = apply_mask(
1458 S,
1459 col_idx,
1460 row_idx,
1461 q_len,
1462 k_len,
1463 window_size_left,
1464 window_size_right,
1465 is_even_mn=is_even_mn,
1466 is_causal=is_causal,
1467 is_local=is_local,
1468 )
1470 acc_, P, rowmax_, rowsum_ = softmax_rescale(
1471 acc_,
1472 S,
1473 rowmax_,
1474 rowsum_,
1475 softmax_scale_log2e=scale_softmax_log2,
1476 is_border=True,
1477 )
1478 P = P.to(v_ptr.type.element_ty)
1480 if is_dropout:
1481 P = apply_dropout(
1482 P,
1483 n_block * BLOCK_N,
1484 m_block * BLOCK_M,
1485 k_len,
1486 bid,
1487 hid,
1488 philox_seed,
1489 philox_offset,
1490 p_dropout_in_uint8_t,
1491 is_dropout,
1492 encode_dropout_in_sign_bit=False,
1493 NUM_HEADS=h,
1494 BLOCK_M=BLOCK_M,
1495 BLOCK_N=BLOCK_N,
1496 )
1498 acc_ = tl.dot(P, bV, acc_)
1499 n_block -= 1
1501 for n_block in tl.range(
1502 n_block_max - n_masking_steps - 1, n_block_min - 1, step=-1
1503 ):
1504 col_idx = n_block * BLOCK_N + tl.arange(0, BLOCK_N)
1505 if is_paged:
1506 bK, bV = load_from_kvcache(
1507 col_idx,
1508 k_len,
1509 page_table_ptr,
1510 k_ptr_base,
1511 v_ptr_base,
1512 block_size,
1513 d,
1514 k_row_stride,
1515 BLOCK_K=BLOCK_K,
1516 )
1517 else:
1518 start_n = n_block * BLOCK_N
1519 k_ptr_seq = k_ptr_base + k_bos * k_row_stride
1520 v_ptr_seq = v_ptr_base + k_bos * k_row_stride
1521 gK = tl.make_block_ptr(
1522 base=k_ptr_seq,
1523 shape=(k_len, d),
1524 strides=(k_row_stride, 1),
1525 offsets=(start_n, 0),
1526 block_shape=(BLOCK_N, BLOCK_K),
1527 order=(0, 1),
1528 )
1529 gV = tl.make_block_ptr(
1530 base=v_ptr_seq,
1531 shape=(k_len, d),
1532 strides=(k_row_stride, 1),
1533 offsets=(start_n, 0),
1534 block_shape=(BLOCK_N, BLOCK_K),
1535 order=(0, 1),
1536 )
1537 bK = tl.load(gK)
1538 bK = tl.trans(bK)
1539 bV = tl.load(gV)
1540 S = tl.dot(bQ, bK, out_dtype=tl.float32)
1541 S = apply_softcap(S, softcap, is_softcap)
1542 S = apply_alibi(
1543 S,
1544 col_idx,
1545 row_idx,
1546 q_len,
1547 k_len,
1548 is_causal=is_causal,
1549 is_alibi=is_alibi,
1550 alibi_slope=alibi_slope,
1551 )
1552 S = apply_mask(
1553 S,
1554 col_idx,
1555 row_idx,
1556 q_len,
1557 k_len,
1558 window_size_left,
1559 window_size_right,
1560 is_even_mn=True,
1561 is_causal=False,
1562 is_local=is_local,
1563 )
1565 acc_, P, rowmax_, rowsum_ = softmax_rescale(
1566 acc_,
1567 S,
1568 rowmax_,
1569 rowsum_,
1570 softmax_scale_log2e=scale_softmax_log2,
1571 is_border=is_local,
1572 )
1573 P = P.to(v_ptr.type.element_ty)
1575 if is_dropout:
1576 P = apply_dropout(
1577 P,
1578 m_block * BLOCK_M,
1579 n_block * BLOCK_N,
1580 k_len,
1581 bid,
1582 hid,
1583 philox_seed,
1584 philox_offset,
1585 p_dropout_in_uint8_t,
1586 is_dropout,
1587 encode_dropout_in_sign_bit=False,
1588 NUM_HEADS=h,
1589 BLOCK_M=BLOCK_M,
1590 BLOCK_N=BLOCK_N,
1591 )
1592 acc_ = tl.dot(P, bV, acc_)
1594 # LSE
1595 lse = tl.where(
1596 rowsum_ == 0 | (rowsum_ != rowsum_),
1597 float("inf"),
1598 rowmax_ * scale_softmax + tl.log(rowsum_),
1599 )
1600 inv_sum = tl.where(rowsum_ == 0 | (rowsum_ != rowsum_), 1.0, 1.0 / rowsum_)
1602 acc_ *= inv_sum[:, None]
1604 out = acc_.to(o_ptr.type.element_ty) # noqa
1606 # Write back output
1607 o_row_offset = hid * o_head_stride
1609 gO = tl.make_block_ptr(
1610 base=o_ptr + o_offset + o_row_offset,
1611 shape=(q_len, d),
1612 strides=(o_row_stride, 1),
1613 offsets=(0, 0),
1614 block_shape=(BLOCK_M, BLOCK_K),
1615 order=(1, 0),
1616 )
1617 tl.store(gO.advance([m_block * BLOCK_M, 0]), out, boundary_check=(0, 1))
1619 # Write back lse
1620 # lse shape: [h, total_q]
1621 softmax_lse_ptr += hid * total_q
1622 lse_row_offset = lse_offset + m_block * BLOCK_M + tl.arange(0, BLOCK_M)
1623 tl.store(
1624 softmax_lse_ptr + lse_row_offset,
1625 lse,
1626 mask=lse_row_offset < (lse_offset + q_len),
1627 )