Coverage for src/flag_gems/runtime/backend/_kunlunxin/fused/sparse_attention.py: 0%

42 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-05-06 06:51 +0800

1import torch 

2import triton 

3import triton.language as tl 

4 

5 

6# --------------------------------------------------------------------------- 

7# Triton kernel: sparse attention with attention-sink 

8# grid = (m, b, h) — one program per (seq_pos, batch, head) 

9# 昆仑芯适配版本 

10# --------------------------------------------------------------------------- 

11@triton.jit 

12def sparse_attn_triton_kernel( 

13 Q, # (b, m, h, d) bf16 

14 KV, # (b, n, d) bf16 

15 O, # (b, m, h, d) bf16 

16 attn_sink, # (h,) fp32 

17 topk_idxs, # (b, m, topk) int32 

18 stride_qb, 

19 stride_qm, 

20 stride_qh, 

21 stride_qd, 

22 stride_kvb, 

23 stride_kvn, 

24 stride_kvd, 

25 stride_ob, 

26 stride_om, 

27 stride_oh, 

28 stride_od, 

29 stride_idxb, 

30 stride_idxm, 

31 stride_idxk, 

32 scale, 

33 topk, 

34 D: tl.constexpr, 

35): 

36 pid_m = tl.program_id(0) 

37 pid_b = tl.program_id(1) 

38 pid_h = tl.program_id(2) 

39 

40 # ---- load Q vector: (D,) for this head ---- 

41 q_base = Q + pid_b * stride_qb + pid_m * stride_qm + pid_h * stride_qh 

42 offs_d = tl.arange(0, D) 

43 q_vec = tl.load(q_base + offs_d * stride_qd) # (D,) bf16 

44 

45 # ---- base pointers ---- 

46 kv_base = KV + pid_b * stride_kvb 

47 idx_base = topk_idxs + pid_b * stride_idxb + pid_m * stride_idxm 

48 

49 # ---- online softmax state ---- 

50 acc_o = tl.zeros([D], dtype=tl.float32) 

51 score_max = float("-inf") 

52 sum_exp = 0.0 

53 

54 # Process each topk element one by one 

55 for k in range(topk): 

56 # -- gather KV vector -- 

57 idx = tl.load(idx_base + k * stride_idxk) # scalar 

58 # Handle negative indices (padding values like -1): clamp to 0 

59 idx = tl.where(idx < 0, 0, idx) 

60 

61 # Load KV for this index: (D,) 

62 kv_ptrs = kv_base + idx * stride_kvn + offs_d * stride_kvd 

63 kv_vec = tl.load(kv_ptrs) # (D,) 

64 

65 # -- compute score using element-wise multiply then sum -- 

66 # This is equivalent to dot product for 1D vectors 

67 score = tl.sum(q_vec * kv_vec) 

68 

69 score = score * scale 

70 

71 # -- online softmax update -- 

72 score_max_prev = score_max 

73 score_max = tl.maximum(score_max, score) 

74 

75 correction = tl.exp(score_max_prev - score_max) 

76 p = tl.exp(score - score_max) 

77 

78 # -- accumulate output: acc_o = acc_o * correction + p * kv_vec -- 

79 acc_o = acc_o * correction + p * kv_vec.to(tl.float32) 

80 

81 sum_exp = sum_exp * correction + p 

82 

83 # ---- incorporate attn_sink ---- 

84 sink_val = tl.load(attn_sink + pid_h) # scalar 

85 sum_exp = sum_exp + tl.exp(sink_val - score_max) 

86 

87 # ---- normalize ---- 

88 acc_o = acc_o / sum_exp 

89 

90 # ---- store output: (D,) ---- 

91 o_base = O + pid_b * stride_ob + pid_m * stride_om + pid_h * stride_oh 

92 o_ptrs = o_base + offs_d * stride_od 

93 tl.store(o_ptrs, acc_o.to(tl.bfloat16)) 

94 

95 

96# --------------------------------------------------------------------------- 

97# Python wrapper 

98# --------------------------------------------------------------------------- 

99def sparse_attn_triton( 

100 q: torch.Tensor, 

101 kv: torch.Tensor, 

102 attn_sink: torch.Tensor, 

103 topk_idxs: torch.Tensor, 

104 softmax_scale: float, 

105) -> torch.Tensor: 

106 b, m, h, d = q.shape 

107 topk = topk_idxs.shape[-1] 

108 o = torch.empty_like(q) 

109 

110 grid = (m, b, h) # each program handles one (seq_pos, batch, head) 

111 sparse_attn_triton_kernel[grid]( 

112 q, 

113 kv, 

114 o, 

115 attn_sink, 

116 topk_idxs, 

117 q.stride(0), 

118 q.stride(1), 

119 q.stride(2), 

120 q.stride(3), 

121 kv.stride(0), 

122 kv.stride(1), 

123 kv.stride(2), 

124 o.stride(0), 

125 o.stride(1), 

126 o.stride(2), 

127 o.stride(3), 

128 topk_idxs.stride(0), 

129 topk_idxs.stride(1), 

130 topk_idxs.stride(2), 

131 softmax_scale, 

132 topk, 

133 D=d, 

134 num_warps=2, 

135 ) 

136 return o