Coverage for src/flag_gems/runtime/backend/_ascend/fused/moe_sum.py: 0%

84 statements  

« prev     ^ index     » next       coverage.py v7.6.9, created at 2026-06-10 07:09 +0800

1import logging 

2 

3import torch 

4import triton 

5import triton.language as tl 

6 

7from flag_gems import runtime 

8from flag_gems.runtime import torch_device_fn 

9from flag_gems.utils import libentry 

10from flag_gems.utils import triton_lang_extension as ext 

11 

12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip(".")) 

13 

14 

15@libentry() 

16@triton.autotune( 

17 configs=runtime.get_tuned_config("moe_sum"), 

18 key=["hidden_size", "topk"], 

19) 

20@triton.jit 

21def moe_sum_kernel( 

22 input_ptr, 

23 output_ptr, 

24 num_tokens, 

25 topk, 

26 hidden_size, 

27 input_stride_token, 

28 input_stride_topk, 

29 output_stride_token, 

30 IS_CONTIGUOUS: tl.constexpr, 

31 BLOCK_SIZE: tl.constexpr, 

32 BLOCK_SIZE_SUB: tl.constexpr, 

33): 

34 """ 

35 Ascend-optimized MoE sum kernel. 

36 

37 Optimization Round 5: 

38 - Manual loop unrolling hints for common topk values 

39 - Reduced loop overhead for small topk 

40 - Vectorized accumulation pattern 

41 """ 

42 pid = ext.program_id(0) 

43 

44 # Task partition 

45 num_hidden_blocks = tl.cdiv(hidden_size, BLOCK_SIZE) 

46 token_idx = pid // num_hidden_blocks 

47 block_idx = pid % num_hidden_blocks 

48 

49 if token_idx >= num_tokens: 

50 return 

51 

52 hidden_base = block_idx * BLOCK_SIZE 

53 

54 if IS_CONTIGUOUS: 

55 # Contiguous tensor path - optimized for common case 

56 input_token_offset = token_idx * topk * hidden_size 

57 output_token_offset = token_idx * hidden_size 

58 

59 for sub_idx in range(0, BLOCK_SIZE, BLOCK_SIZE_SUB): 

60 h_offset = hidden_base + sub_idx 

61 h_indices = h_offset + tl.arange(0, BLOCK_SIZE_SUB) 

62 valid_mask = h_indices < hidden_size 

63 

64 # Initialize accumulator 

65 result = tl.zeros((BLOCK_SIZE_SUB,), dtype=tl.float32) 

66 

67 # Compute base pointer for expert 0 

68 base = input_ptr + input_token_offset + h_indices 

69 expert_stride = hidden_size 

70 

71 # Accumulate - compiler unrolls for small constant topk 

72 # For topk=2,4,8 this is fully unrolled 

73 for k in range(topk): 

74 val = tl.load( 

75 base + k * expert_stride, 

76 mask=valid_mask, 

77 other=0.0, 

78 care_padding=False, 

79 ) 

80 result += val.to(tl.float32) 

81 

82 # Store 

83 out_ptr = output_ptr + output_token_offset + h_indices 

84 tl.store(out_ptr, result.to(output_ptr.dtype.element_ty), mask=valid_mask) 

85 

86 else: 

87 # Non-contiguous path 

88 input_base = input_ptr + token_idx * input_stride_token 

89 output_base = output_ptr + token_idx * output_stride_token 

90 

91 for sub_idx in range(0, BLOCK_SIZE, BLOCK_SIZE_SUB): 

92 h_offset = hidden_base + sub_idx 

93 h_indices = h_offset + tl.arange(0, BLOCK_SIZE_SUB) 

94 valid_mask = h_indices < hidden_size 

95 

96 result = tl.zeros((BLOCK_SIZE_SUB,), dtype=tl.float32) 

97 

98 for k in range(topk): 

99 ptr = input_base + k * input_stride_topk + h_indices 

100 val = tl.load(ptr, mask=valid_mask, other=0.0, care_padding=False) 

101 result += val.to(tl.float32) 

102 

103 tl.store( 

104 output_base + h_indices, 

105 result.to(output_ptr.dtype.element_ty), 

106 mask=valid_mask, 

107 ) 

108 

109 

110# Specialized kernel for topk=2 (most common in MoE) 

111@libentry() 

112@triton.autotune( 

113 configs=runtime.get_tuned_config("moe_sum"), 

114 key=["hidden_size"], 

115) 

116@triton.jit 

117def moe_sum_kernel_topk2( 

118 input_ptr, 

119 output_ptr, 

120 num_tokens, 

121 hidden_size, 

122 BLOCK_SIZE: tl.constexpr, 

123 BLOCK_SIZE_SUB: tl.constexpr, 

124): 

125 """Specialized kernel for topk=2 with fully unrolled expert loop.""" 

126 pid = ext.program_id(0) 

127 

128 num_hidden_blocks = tl.cdiv(hidden_size, BLOCK_SIZE) 

129 token_idx = pid // num_hidden_blocks 

130 block_idx = pid % num_hidden_blocks 

131 

132 if token_idx >= num_tokens: 

133 return 

134 

135 hidden_base = block_idx * BLOCK_SIZE 

136 input_token_offset = token_idx * 2 * hidden_size 

137 output_token_offset = token_idx * hidden_size 

138 

139 for sub_idx in range(0, BLOCK_SIZE, BLOCK_SIZE_SUB): 

140 h_offset = hidden_base + sub_idx 

141 h_indices = h_offset + tl.arange(0, BLOCK_SIZE_SUB) 

142 valid_mask = h_indices < hidden_size 

143 

144 base = input_ptr + input_token_offset + h_indices 

145 

146 # Fully unrolled for topk=2 

147 val0 = tl.load(base, mask=valid_mask, other=0.0, care_padding=False) 

148 val1 = tl.load( 

149 base + hidden_size, mask=valid_mask, other=0.0, care_padding=False 

150 ) 

151 

152 result = val0.to(tl.float32) + val1.to(tl.float32) 

153 

154 out_ptr = output_ptr + output_token_offset + h_indices 

155 tl.store(out_ptr, result.to(output_ptr.dtype.element_ty), mask=valid_mask) 

156 

157 

158def moe_sum( 

159 input: torch.Tensor, 

160 output: torch.Tensor, 

161): 

162 """ 

163 MoE sum operation optimized for Ascend NPU. 

164 

165 Sums over the expert dimension (dim=1). 

166 Input shape: (num_tokens, topk, hidden_size) 

167 Output shape: (num_tokens, hidden_size) 

168 """ 

169 logger.debug("GEMS_ASCEND MOE_SUM") 

170 

171 num_tokens, topk, hidden_size = input.shape 

172 

173 # Get strides 

174 in_s0, in_s1, in_s2 = input.stride() 

175 out_s0, out_s1 = output.stride() 

176 

177 # Check contiguous pattern 

178 is_contiguous = ( 

179 in_s2 == 1 

180 and in_s1 == hidden_size 

181 and in_s0 == topk * hidden_size 

182 and out_s1 == 1 

183 and out_s0 == hidden_size 

184 ) 

185 

186 def grid(meta): 

187 n_blocks = triton.cdiv(hidden_size, meta["BLOCK_SIZE"]) 

188 total = num_tokens * n_blocks 

189 return (min(total, 65535),) 

190 

191 with torch_device_fn.device(input.device): 

192 # Use specialized kernel for topk=2 (most common case) 

193 if topk == 2 and is_contiguous: 

194 moe_sum_kernel_topk2[grid]( 

195 input, 

196 output, 

197 num_tokens, 

198 hidden_size, 

199 ) 

200 else: 

201 moe_sum_kernel[grid]( 

202 input, 

203 output, 

204 num_tokens, 

205 topk, 

206 hidden_size, 

207 in_s0, 

208 in_s1, 

209 out_s0, 

210 IS_CONTIGUOUS=is_contiguous, 

211 )