Coverage for src/flag_gems/runtime/backend/_ascend/fused/sparse_attention.py: 0%

66 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-05 07:36 +0800

1import os 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7# Enable all blocks parallel to avoid coreDim > 65535 issue on NPU 

8os.environ["TRITON_ALL_BLOCKS_PARALLEL"] = "1" 

9 

10 

11# --------------------------------------------------------------------------- 

12# Triton kernel: sparse attention with attention-sink 

13# Adapted for Ascend NPU: 1D grid, tiling for UB overflow 

14# --------------------------------------------------------------------------- 

15@triton.jit 

16def sparse_attn_triton_kernel( 

17 Q, # (b, m, h, d) bf16 

18 KV, # (b, n, d) bf16 

19 O, # (b, m, h, d) bf16 

20 attn_sink, # (h,) fp32 

21 topk_idxs, # (b, m, topk) int32 

22 stride_qb, 

23 stride_qm, 

24 stride_qh, 

25 stride_qd, 

26 stride_kvb, 

27 stride_kvn, 

28 stride_kvd, 

29 stride_ob, 

30 stride_om, 

31 stride_oh, 

32 stride_od, 

33 stride_idxb, 

34 stride_idxm, 

35 stride_idxk, 

36 scale, 

37 topk, 

38 kv_len, 

39 H_ACTUAL, 

40 BLOCK: tl.constexpr, 

41 BLOCK_SUB: tl.constexpr, 

42 D: tl.constexpr, 

43 H: tl.constexpr, 

44 BATCH_STRIDE: tl.constexpr, 

45): 

46 # 1D grid: each task handles one (batch, seq_pos) 

47 pid = tl.program_id(0) 

48 pid_b = pid // BATCH_STRIDE 

49 pid_m = pid % BATCH_STRIDE 

50 

51 # ---- load Q matrix: (H, D) — all heads at once ---- 

52 q_base = Q + pid_b * stride_qb + pid_m * stride_qm 

53 offs_h = tl.arange(0, H) 

54 offs_d = tl.arange(0, D) 

55 h_mask = offs_h < H_ACTUAL 

56 q_ptrs = q_base + offs_h[:, None] * stride_qh + offs_d[None, :] * stride_qd 

57 q_block = tl.load(q_ptrs, mask=h_mask[:, None], other=0.0) # (H, D) bf16 

58 

59 # ---- base pointers ---- 

60 kv_base = KV + pid_b * stride_kvb 

61 idx_base = topk_idxs + pid_b * stride_idxb + pid_m * stride_idxm 

62 

63 # ---- online softmax state ---- 

64 acc_o = tl.zeros([H, D], dtype=tl.float32) 

65 scores_max = tl.full([H], float("-inf"), dtype=tl.float32) 

66 sum_exp = tl.zeros([H], dtype=tl.float32) 

67 

68 # Two-level tiling: BLOCK (outer) -> BLOCK_SUB (inner) 

69 num_block_iter = (topk + BLOCK - 1) // BLOCK 

70 num_sub_iter = (BLOCK + BLOCK_SUB - 1) // BLOCK_SUB 

71 offs_blk = tl.arange(0, BLOCK_SUB) 

72 

73 for t in range(num_block_iter): 

74 block_start = t * BLOCK 

75 # Process BLOCK elements in sub-tiles 

76 for s in range(num_sub_iter): 

77 sub_start = block_start + s * BLOCK_SUB 

78 raw_offs = sub_start + offs_blk # (BLOCK_SUB,) 

79 idx_mask = raw_offs < topk 

80 idxs = tl.load( 

81 idx_base + raw_offs * stride_idxk, mask=idx_mask, other=0 

82 ) # (BLOCK_SUB,) 

83 

84 # Clamp negative indices to 0 (matching PyTorch behavior on NPU) 

85 idxs = tl.where(idxs < 0, 0, idxs) 

86 

87 # Check index validity: idxs must be >= 0 and < kv_len 

88 # Create valid mask based on both position and index value 

89 index_valid = (idxs >= 0) & (idxs < kv_len) 

90 valid_mask = idx_mask & index_valid # (BLOCK_SUB,) 

91 

92 # -- gather KV block: (BLOCK_SUB, D) -- 

93 kv_ptrs = ( 

94 kv_base + idxs[:, None] * stride_kvn + offs_d[None, :] * stride_kvd 

95 ) 

96 kv_block = tl.load( 

97 kv_ptrs, mask=valid_mask[:, None], other=0.0 

98 ) # (BLOCK_SUB, D) bf16 

99 

100 # -- scores: Q @ KV^T -> (H, BLOCK_SUB) via GEMM -- 

101 acc_s = tl.dot( 

102 q_block, tl.trans(kv_block) 

103 ) # (H, D) @ (D, BLOCK_SUB) = (H, BLOCK_SUB) 

104 acc_s = acc_s * scale 

105 # mask invalid positions to -inf 

106 mask_bias = tl.where(valid_mask, 0.0, float("-inf")) # (BLOCK_SUB,) 

107 acc_s = acc_s + mask_bias[None, :] # broadcast: (H, BLOCK_SUB) 

108 

109 # -- online softmax update -- 

110 scores_max_prev = scores_max 

111 block_max = tl.max(acc_s, axis=1) # (H,) 

112 scores_max = tl.maximum(scores_max, block_max) 

113 

114 correction = tl.exp(scores_max_prev - scores_max) # (H,) 

115 p = tl.exp(acc_s - scores_max[:, None]) # (H, BLOCK_SUB) 

116 

117 # -- accumulate output: acc_o = acc_o * correction + P @ KV -- 

118 acc_o = acc_o * correction[:, None] 

119 acc_o += tl.dot( 

120 p.to(tl.bfloat16), kv_block 

121 ) # (H, BLOCK_SUB) @ (BLOCK_SUB, D) = (H, D) 

122 

123 scores_sum = tl.sum(p, axis=1) # (H,) 

124 sum_exp = sum_exp * correction + scores_sum 

125 

126 # ---- incorporate attn_sink ---- 

127 sink_vals = tl.load(attn_sink + offs_h, mask=h_mask, other=0.0) # (H,) 

128 sum_exp = sum_exp + tl.exp(sink_vals - scores_max) 

129 

130 # ---- normalize ---- 

131 acc_o = acc_o / sum_exp[:, None] 

132 

133 # ---- store output: (H, D) ---- 

134 o_base = O + pid_b * stride_ob + pid_m * stride_om 

135 o_ptrs = o_base + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od 

136 tl.store(o_ptrs, acc_o.to(tl.bfloat16), mask=h_mask[:, None]) 

137 

138 

139# --------------------------------------------------------------------------- 

140# Python wrapper 

141# --------------------------------------------------------------------------- 

142def sparse_attn_triton( 

143 q: torch.Tensor, 

144 kv: torch.Tensor, 

145 attn_sink: torch.Tensor, 

146 topk_idxs: torch.Tensor, 

147 softmax_scale: float, 

148) -> torch.Tensor: 

149 b, m, h, d = q.shape 

150 topk = topk_idxs.shape[-1] 

151 kv_len = kv.shape[1] 

152 o = torch.empty_like(q) 

153 

154 # NPU optimization: use tiling to avoid UB overflow 

155 # BLOCK: number of KV elements per outer loop iteration 

156 # BLOCK_SUB: tile size for UB management 

157 # UB (192KB) constraint: need to fit q_block + kv_block + acc_o + intermediate buffers 

158 # Use fixed BLOCK to avoid edge cases with non-power-of-2 topk 

159 BLOCK = 64 

160 BLOCK_SUB = 16 # smaller chunks to fit UB (192KB), with multi-buffer overhead 

161 

162 # H must be >= 16 for tl.dot; pad to next power of 2 

163 H_padded = max(16, triton.next_power_of_2(h)) 

164 

165 # NPU: use 1D grid, TRITON_ALL_BLOCKS_PARALLEL handles large grid 

166 grid = (b * m,) 

167 

168 sparse_attn_triton_kernel[grid]( 

169 q, 

170 kv, 

171 o, 

172 attn_sink, 

173 topk_idxs, 

174 q.stride(0), 

175 q.stride(1), 

176 q.stride(2), 

177 q.stride(3), 

178 kv.stride(0), 

179 kv.stride(1), 

180 kv.stride(2), 

181 o.stride(0), 

182 o.stride(1), 

183 o.stride(2), 

184 o.stride(3), 

185 topk_idxs.stride(0), 

186 topk_idxs.stride(1), 

187 topk_idxs.stride(2), 

188 softmax_scale, 

189 topk, 

190 kv_len, 

191 h, 

192 BLOCK=BLOCK, 

193 BLOCK_SUB=BLOCK_SUB, 

194 D=d, 

195 H=H_padded, 

196 BATCH_STRIDE=m, # for 1D grid: pid = pid_b * m + pid_m 

197 num_warps=4, # reduced for NPU 

198 ) 

199 return o