Coverage for src/flag_gems/runtime/backend/_ascend/fused/moe_sum.py: 0%
84 statements
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
« prev ^ index » next coverage.py v7.6.9, created at 2026-06-04 09:03 +0800
1import logging
3import torch
4import triton
5import triton.language as tl
7from flag_gems import runtime
8from flag_gems.runtime import torch_device_fn
9from flag_gems.utils import libentry
10from flag_gems.utils import triton_lang_extension as ext
12logger = logging.getLogger("flag_gems").getChild(__name__.lstrip("."))
15@libentry()
16@triton.autotune(
17 configs=runtime.get_tuned_config("moe_sum"),
18 key=["hidden_size", "topk"],
19)
20@triton.jit
21def moe_sum_kernel(
22 input_ptr,
23 output_ptr,
24 num_tokens,
25 topk,
26 hidden_size,
27 input_stride_token,
28 input_stride_topk,
29 output_stride_token,
30 IS_CONTIGUOUS: tl.constexpr,
31 BLOCK_SIZE: tl.constexpr,
32 BLOCK_SIZE_SUB: tl.constexpr,
33):
34 """
35 Ascend-optimized MoE sum kernel.
37 Optimization Round 5:
38 - Manual loop unrolling hints for common topk values
39 - Reduced loop overhead for small topk
40 - Vectorized accumulation pattern
41 """
42 pid = ext.program_id(0)
44 # Task partition
45 num_hidden_blocks = tl.cdiv(hidden_size, BLOCK_SIZE)
46 token_idx = pid // num_hidden_blocks
47 block_idx = pid % num_hidden_blocks
49 if token_idx >= num_tokens:
50 return
52 hidden_base = block_idx * BLOCK_SIZE
54 if IS_CONTIGUOUS:
55 # Contiguous tensor path - optimized for common case
56 input_token_offset = token_idx * topk * hidden_size
57 output_token_offset = token_idx * hidden_size
59 for sub_idx in range(0, BLOCK_SIZE, BLOCK_SIZE_SUB):
60 h_offset = hidden_base + sub_idx
61 h_indices = h_offset + tl.arange(0, BLOCK_SIZE_SUB)
62 valid_mask = h_indices < hidden_size
64 # Initialize accumulator
65 result = tl.zeros((BLOCK_SIZE_SUB,), dtype=tl.float32)
67 # Compute base pointer for expert 0
68 base = input_ptr + input_token_offset + h_indices
69 expert_stride = hidden_size
71 # Accumulate - compiler unrolls for small constant topk
72 # For topk=2,4,8 this is fully unrolled
73 for k in range(topk):
74 val = tl.load(
75 base + k * expert_stride,
76 mask=valid_mask,
77 other=0.0,
78 care_padding=False,
79 )
80 result += val.to(tl.float32)
82 # Store
83 out_ptr = output_ptr + output_token_offset + h_indices
84 tl.store(out_ptr, result.to(output_ptr.dtype.element_ty), mask=valid_mask)
86 else:
87 # Non-contiguous path
88 input_base = input_ptr + token_idx * input_stride_token
89 output_base = output_ptr + token_idx * output_stride_token
91 for sub_idx in range(0, BLOCK_SIZE, BLOCK_SIZE_SUB):
92 h_offset = hidden_base + sub_idx
93 h_indices = h_offset + tl.arange(0, BLOCK_SIZE_SUB)
94 valid_mask = h_indices < hidden_size
96 result = tl.zeros((BLOCK_SIZE_SUB,), dtype=tl.float32)
98 for k in range(topk):
99 ptr = input_base + k * input_stride_topk + h_indices
100 val = tl.load(ptr, mask=valid_mask, other=0.0, care_padding=False)
101 result += val.to(tl.float32)
103 tl.store(
104 output_base + h_indices,
105 result.to(output_ptr.dtype.element_ty),
106 mask=valid_mask,
107 )
110# Specialized kernel for topk=2 (most common in MoE)
111@libentry()
112@triton.autotune(
113 configs=runtime.get_tuned_config("moe_sum"),
114 key=["hidden_size"],
115)
116@triton.jit
117def moe_sum_kernel_topk2(
118 input_ptr,
119 output_ptr,
120 num_tokens,
121 hidden_size,
122 BLOCK_SIZE: tl.constexpr,
123 BLOCK_SIZE_SUB: tl.constexpr,
124):
125 """Specialized kernel for topk=2 with fully unrolled expert loop."""
126 pid = ext.program_id(0)
128 num_hidden_blocks = tl.cdiv(hidden_size, BLOCK_SIZE)
129 token_idx = pid // num_hidden_blocks
130 block_idx = pid % num_hidden_blocks
132 if token_idx >= num_tokens:
133 return
135 hidden_base = block_idx * BLOCK_SIZE
136 input_token_offset = token_idx * 2 * hidden_size
137 output_token_offset = token_idx * hidden_size
139 for sub_idx in range(0, BLOCK_SIZE, BLOCK_SIZE_SUB):
140 h_offset = hidden_base + sub_idx
141 h_indices = h_offset + tl.arange(0, BLOCK_SIZE_SUB)
142 valid_mask = h_indices < hidden_size
144 base = input_ptr + input_token_offset + h_indices
146 # Fully unrolled for topk=2
147 val0 = tl.load(base, mask=valid_mask, other=0.0, care_padding=False)
148 val1 = tl.load(
149 base + hidden_size, mask=valid_mask, other=0.0, care_padding=False
150 )
152 result = val0.to(tl.float32) + val1.to(tl.float32)
154 out_ptr = output_ptr + output_token_offset + h_indices
155 tl.store(out_ptr, result.to(output_ptr.dtype.element_ty), mask=valid_mask)
158def moe_sum(
159 input: torch.Tensor,
160 output: torch.Tensor,
161):
162 """
163 MoE sum operation optimized for Ascend NPU.
165 Sums over the expert dimension (dim=1).
166 Input shape: (num_tokens, topk, hidden_size)
167 Output shape: (num_tokens, hidden_size)
168 """
169 logger.debug("GEMS_ASCEND MOE_SUM")
171 num_tokens, topk, hidden_size = input.shape
173 # Get strides
174 in_s0, in_s1, in_s2 = input.stride()
175 out_s0, out_s1 = output.stride()
177 # Check contiguous pattern
178 is_contiguous = (
179 in_s2 == 1
180 and in_s1 == hidden_size
181 and in_s0 == topk * hidden_size
182 and out_s1 == 1
183 and out_s0 == hidden_size
184 )
186 def grid(meta):
187 n_blocks = triton.cdiv(hidden_size, meta["BLOCK_SIZE"])
188 total = num_tokens * n_blocks
189 return (min(total, 65535),)
191 with torch_device_fn.device(input.device):
192 # Use specialized kernel for topk=2 (most common case)
193 if topk == 2 and is_contiguous:
194 moe_sum_kernel_topk2[grid](
195 input,
196 output,
197 num_tokens,
198 hidden_size,
199 )
200 else:
201 moe_sum_kernel[grid](
202 input,
203 output,
204 num_tokens,
205 topk,
206 hidden_size,
207 in_s0,
208 in_s1,
209 out_s0,
210 IS_CONTIGUOUS=is_contiguous,
211 )