Coverage for src/flag_gems/fused/DSA/indexer_k_tiled.py: 14%

58 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-26 06:59 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5indexer_fwd_configs = [ 

6 triton.Config({"num_stages": 2, "num_warps": 4}), 

7 triton.Config({"num_stages": 4, "num_warps": 8}), 

8] 

9 

10 

11@triton.autotune( # Decorate the kernel 

12 configs=indexer_fwd_configs, 

13 key=["Q", "K", "H", "D"], 

14) 

15@triton.jit 

16def triton_lighting_indexer_k_tiled( 

17 q_index, 

18 k_index, 

19 cu_bg_seqlens, 

20 cu_ed_seqlens, 

21 weights, 

22 logits, 

23 stride_qh, 

24 stride_qd, 

25 stride_kn, 

26 stride_kd, 

27 stride_wh, 

28 stride_lm, 

29 stride_ln, 

30 Q: tl.constexpr, 

31 H: tl.constexpr, 

32 K: tl.constexpr, 

33 TK: tl.constexpr, 

34 D: tl.constexpr, 

35 CU: tl.constexpr, 

36 BQ: tl.constexpr, 

37 BK: tl.constexpr, 

38): 

39 i_sh, i_k = tl.program_id(0), tl.program_id(1) 

40 

41 offs_cu = tl.arange(0, BQ) + i_sh * BQ 

42 mask_cu = offs_cu < CU 

43 bos_vec, eos_vec = tl.load( 

44 cu_bg_seqlens + offs_cu, mask_cu, 1000000000 

45 ) + i_k * TK, tl.load( 

46 cu_ed_seqlens + offs_cu, mask_cu, -1000000000 

47 ) # [BQ] 

48 eos_vec = tl.minimum(eos_vec, bos_vec + (i_k + 1) * TK) 

49 bos, eos = max(bos_vec.min(0), 0), min(eos_vec.max(0), K) 

50 CK = eos - bos 

51 if CK > 0: 

52 q_base = q_index 

53 k_base = k_index + bos * stride_kn 

54 w_base = weights 

55 o_base = logits + bos * stride_ln 

56 offs_bq = tl.arange(0, BQ * H) + i_sh * (BQ * H) 

57 offs_boq = tl.arange(0, BQ) + i_sh * BQ 

58 offs_d = tl.arange(0, D) 

59 offs_w = offs_bq 

60 mask_bq = offs_bq < Q * H 

61 mask_d = offs_d < D 

62 mask_boq = offs_boq < Q 

63 

64 q_ptr = q_base + offs_bq[:, None] * stride_qh + offs_d[None, :] * stride_qd 

65 q_msk = mask_bq[:, None] & mask_d[None, :] 

66 q_blk = tl.load(q_ptr, q_msk, 0.0).to(tl.float16) # [BQ*H, D] 

67 

68 w_ptr = w_base + offs_w * stride_wh 

69 w_msk = mask_bq 

70 w_blk = tl.load(w_ptr, w_msk, 0.0).to(tl.float16) # [BQ*H] 

71 

72 CK = tl.cdiv(CK, BK) 

73 for ck in range(CK, warp_specialize=True): 

74 offs_bk = ck * BK + tl.arange(0, BK) 

75 mask_bk = bos + offs_bk < eos 

76 k_ptr = k_base + offs_d[:, None] * stride_kd + offs_bk[None, :] * stride_kn 

77 k_msk = mask_d[:, None] & mask_bk[None, :] 

78 k_blk = tl.load(k_ptr, k_msk, 0.0).to(tl.float16) 

79 acc = tl.dot(q_blk, k_blk, out_dtype=tl.float16) # [BQ*H, BK] 

80 acc = tl.maximum(acc, 0.0) * w_blk[:, None] 

81 out_blk = acc.trans().reshape(BK, BQ, H).sum(-1).trans() # [BQ, BK] 

82 out_ptr = ( 

83 o_base + offs_boq[:, None] * stride_lm + offs_bk[None, :] * stride_ln 

84 ) 

85 out_msk = ( 

86 mask_boq[:, None] 

87 & mask_bk[None, :] 

88 & (bos_vec[:, None] <= offs_bk[None, :] + bos) 

89 & (eos_vec[:, None] > offs_bk[None, :] + bos) 

90 ) 

91 tl.store(out_ptr, out_blk.to(tl.float16), out_msk) 

92 

93 

94def triton_lighting_indexer_k_tiled_interface( 

95 q, kv, weights, cu_seqlen_ks, cu_seqlen_ke 

96): 

97 Q, H, D = q.shape[0], q.shape[1], q.shape[2] 

98 K = kv.shape[0] 

99 CU = cu_seqlen_ks.shape[0] 

100 logits = torch.full([Q, K], float("-inf"), device="cuda", dtype=torch.float32) 

101 BQ = 1 

102 BK = 64 

103 TK = 2048 

104 NQ = triton.cdiv(Q, BQ) 

105 NK = triton.cdiv(K, TK) 

106 grid = (NQ, NK) 

107 triton_lighting_indexer_k_tiled[grid]( 

108 q, 

109 kv, 

110 cu_seqlen_ks, 

111 cu_seqlen_ke, 

112 weights, 

113 logits, 

114 q.stride(1), 

115 q.stride(2), 

116 kv.stride(0), 

117 kv.stride(1), 

118 weights.stride(1), 

119 logits.stride(0), 

120 logits.stride(1), 

121 Q, 

122 H, 

123 K, 

124 TK, 

125 D, 

126 CU, 

127 BQ, 

128 BK, 

129 ) 

130 return logits