Coverage for src/flag_gems/fused/DSA/indexer_k_tiled.py: 14%
58 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-05-26 06:59 +0800
1import torch
2import triton
3import triton.language as tl
5indexer_fwd_configs = [
6 triton.Config({"num_stages": 2, "num_warps": 4}),
7 triton.Config({"num_stages": 4, "num_warps": 8}),
8]
11@triton.autotune( # Decorate the kernel
12 configs=indexer_fwd_configs,
13 key=["Q", "K", "H", "D"],
14)
15@triton.jit
16def triton_lighting_indexer_k_tiled(
17 q_index,
18 k_index,
19 cu_bg_seqlens,
20 cu_ed_seqlens,
21 weights,
22 logits,
23 stride_qh,
24 stride_qd,
25 stride_kn,
26 stride_kd,
27 stride_wh,
28 stride_lm,
29 stride_ln,
30 Q: tl.constexpr,
31 H: tl.constexpr,
32 K: tl.constexpr,
33 TK: tl.constexpr,
34 D: tl.constexpr,
35 CU: tl.constexpr,
36 BQ: tl.constexpr,
37 BK: tl.constexpr,
38):
39 i_sh, i_k = tl.program_id(0), tl.program_id(1)
41 offs_cu = tl.arange(0, BQ) + i_sh * BQ
42 mask_cu = offs_cu < CU
43 bos_vec, eos_vec = tl.load(
44 cu_bg_seqlens + offs_cu, mask_cu, 1000000000
45 ) + i_k * TK, tl.load(
46 cu_ed_seqlens + offs_cu, mask_cu, -1000000000
47 ) # [BQ]
48 eos_vec = tl.minimum(eos_vec, bos_vec + (i_k + 1) * TK)
49 bos, eos = max(bos_vec.min(0), 0), min(eos_vec.max(0), K)
50 CK = eos - bos
51 if CK > 0:
52 q_base = q_index
53 k_base = k_index + bos * stride_kn
54 w_base = weights
55 o_base = logits + bos * stride_ln
56 offs_bq = tl.arange(0, BQ * H) + i_sh * (BQ * H)
57 offs_boq = tl.arange(0, BQ) + i_sh * BQ
58 offs_d = tl.arange(0, D)
59 offs_w = offs_bq
60 mask_bq = offs_bq < Q * H
61 mask_d = offs_d < D
62 mask_boq = offs_boq < Q
64 q_ptr = q_base + offs_bq[:, None] * stride_qh + offs_d[None, :] * stride_qd
65 q_msk = mask_bq[:, None] & mask_d[None, :]
66 q_blk = tl.load(q_ptr, q_msk, 0.0).to(tl.float16) # [BQ*H, D]
68 w_ptr = w_base + offs_w * stride_wh
69 w_msk = mask_bq
70 w_blk = tl.load(w_ptr, w_msk, 0.0).to(tl.float16) # [BQ*H]
72 CK = tl.cdiv(CK, BK)
73 for ck in range(CK, warp_specialize=True):
74 offs_bk = ck * BK + tl.arange(0, BK)
75 mask_bk = bos + offs_bk < eos
76 k_ptr = k_base + offs_d[:, None] * stride_kd + offs_bk[None, :] * stride_kn
77 k_msk = mask_d[:, None] & mask_bk[None, :]
78 k_blk = tl.load(k_ptr, k_msk, 0.0).to(tl.float16)
79 acc = tl.dot(q_blk, k_blk, out_dtype=tl.float16) # [BQ*H, BK]
80 acc = tl.maximum(acc, 0.0) * w_blk[:, None]
81 out_blk = acc.trans().reshape(BK, BQ, H).sum(-1).trans() # [BQ, BK]
82 out_ptr = (
83 o_base + offs_boq[:, None] * stride_lm + offs_bk[None, :] * stride_ln
84 )
85 out_msk = (
86 mask_boq[:, None]
87 & mask_bk[None, :]
88 & (bos_vec[:, None] <= offs_bk[None, :] + bos)
89 & (eos_vec[:, None] > offs_bk[None, :] + bos)
90 )
91 tl.store(out_ptr, out_blk.to(tl.float16), out_msk)
94def triton_lighting_indexer_k_tiled_interface(
95 q, kv, weights, cu_seqlen_ks, cu_seqlen_ke
96):
97 Q, H, D = q.shape[0], q.shape[1], q.shape[2]
98 K = kv.shape[0]
99 CU = cu_seqlen_ks.shape[0]
100 logits = torch.full([Q, K], float("-inf"), device="cuda", dtype=torch.float32)
101 BQ = 1
102 BK = 64
103 TK = 2048
104 NQ = triton.cdiv(Q, BQ)
105 NK = triton.cdiv(K, TK)
106 grid = (NQ, NK)
107 triton_lighting_indexer_k_tiled[grid](
108 q,
109 kv,
110 cu_seqlen_ks,
111 cu_seqlen_ke,
112 weights,
113 logits,
114 q.stride(1),
115 q.stride(2),
116 kv.stride(0),
117 kv.stride(1),
118 weights.stride(1),
119 logits.stride(0),
120 logits.stride(1),
121 Q,
122 H,
123 K,
124 TK,
125 D,
126 CU,
127 BQ,
128 BK,
129 )
130 return logits