Skip to content

Commit 9fb400f

Browse files
committed
fix wrong cache name and remove some commented code
Signed-off-by: jiahanc <[email protected]>
1 parent 8e72f08 commit 9fb400f

File tree

6 files changed

+55
-190
lines changed

6 files changed

+55
-190
lines changed

csrc/trtllm_fused_moe_kernel_launcher.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -735,8 +735,8 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
735735
TVM_FFI_ICHECK(topk_group.has_value()) << "if n_group is given, topk_group must be given";
736736
TVM_FFI_ICHECK_EQ(num_experts % n_group.value(), 0)
737737
<< "num_experts must be divisible by n_group";
738-
// TVM_FFI_ICHECK(top_k <= 8 && top_k > 0)
739-
// << "Current routing kernel (with groups) only supports top_k<=8 && top_k>0.";
738+
TVM_FFI_ICHECK(top_k <= 10 && top_k > 0)
739+
<< "Current routing kernel (with groups) only supports top_k<=10 && top_k>0.";
740740
TVM_FFI_ICHECK(topk_group.value() <= 4 && topk_group.value() > 0)
741741
<< "Current routing kernel only (with groups) supports topk_group<=4 && topk_group > 0.";
742742
TVM_FFI_ICHECK_LE(topk_group.value(), n_group.value())
@@ -749,9 +749,9 @@ Array<Tensor> trtllm_fp4_block_scale_moe_launcher(
749749
static_cast<RoutingMethodType>(routing_method_type) ==
750750
RoutingMethodType::RenormalizeNaive ||
751751
static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::TopK) {
752-
// TVM_FFI_ICHECK(top_k <= 8 && top_k > 0)
753-
// << "Current routing kernel (no groups, renormalize/topk) only supports top_k<=8 && "
754-
// "top_k>0.";
752+
TVM_FFI_ICHECK(top_k <= 10 && top_k > 0)
753+
<< "Current routing kernel (no groups, renormalize/topk) only supports top_k<=10 && "
754+
"top_k>0.";
755755
} else if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::Llama4) {
756756
TVM_FFI_ICHECK_EQ(top_k, 1)
757757
<< "Current routing kernel (no groups, Llama4) only supports top_k=1.";

csrc/trtllm_fused_moe_routing_renormalize.cu

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -463,20 +463,6 @@ void run(Data const& data, void* stream) {
463463
}
464464
}
465465

466-
// void run(Data const& data, void* stream) {
467-
// TVM_FFI_ICHECK(data.mPtrExpertIdx != nullptr || data.mPtrScores != nullptr)
468-
// << "Routing kernel requires at least one input parameter";
469-
// TVM_FFI_ICHECK(data.mPtrPermutedIdxSize != nullptr && data.mPtrCtaIdxXyToBatchIdx != nullptr &&
470-
// data.mPtrCtaIdxXyToMnLimit != nullptr && data.mPtrNumNonExitingCtas != nullptr)
471-
// << "Llama4 routing kernel expects permuted idx and grouped Gemm launch config buffers";
472-
// TVM_FFI_ICHECK_LE(data.mTopK, MaxNumTopExperts)
473-
// << "Routing kernel expects topK experts <= " << MaxNumTopExperts << ", got " << data.mTopK;
474-
// TVM_FFI_ICHECK_LT(data.mPaddingLog2, 8)
475-
// << "Routing kernel expects padding log2 < 8, got " << data.mPaddingLog2;
476-
477-
// runImpl(data, stream);
478-
// }
479-
480466
////////////////////////////////////////////////////////////////////////////////////////////////////
481467

482468
} // namespace routingRenormalize

flashinfer/fused_moe/core.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,9 @@ def _maybe_get_cached_w3_w1_permute_indices(
171171
epilogue_tile_m: int,
172172
num_elts_per_sf: Union[None, int] = None,
173173
) -> torch.Tensor:
174-
if dst_w3_w1_weight.shape not in _cache_permute_indices:
174+
# Create a unique cache key that includes all parameters affecting the permutation
175+
cache_key = ("w3_w1", dst_w3_w1_weight.shape)
176+
if cache_key not in _cache_permute_indices:
175177
# Get permute indices and chain them together
176178
permute0 = get_reorder_rows_for_gated_act_gemm_row_indices(dst_w3_w1_weight)
177179
if num_elts_per_sf is None:
@@ -185,10 +187,10 @@ def _maybe_get_cached_w3_w1_permute_indices(
185187
num_elts_per_sf=num_elts_per_sf,
186188
)
187189
# Memoize permute indices as recompute is **very** costly
188-
_cache_permute_indices[dst_w3_w1_weight.shape] = permute0[permute1].to(
190+
_cache_permute_indices[cache_key] = permute0[permute1].to(
189191
dst_w3_w1_weight.device
190192
)
191-
permute_indices = _cache_permute_indices[dst_w3_w1_weight.shape]
193+
permute_indices = _cache_permute_indices[cache_key]
192194
return permute_indices
193195

194196

@@ -198,7 +200,9 @@ def get_w2_permute_indices_with_cache(
198200
epilogue_tile_m: int,
199201
num_elts_per_sf: Union[None, int] = None,
200202
) -> torch.Tensor:
201-
if dst_w2_weight.shape not in _cache_permute_indices:
203+
# Create a unique cache key that includes all parameters affecting the permutation
204+
cache_key = ("w2", dst_w2_weight.shape)
205+
if cache_key not in _cache_permute_indices:
202206
if num_elts_per_sf is None:
203207
permute_indices = get_shuffle_matrix_a_row_indices(
204208
dst_w2_weight, epilogue_tile_m
@@ -210,8 +214,8 @@ def get_w2_permute_indices_with_cache(
210214
num_elts_per_sf=num_elts_per_sf,
211215
).to(dst_w2_weight.device)
212216
# Memoize permute indices as recompute is **very** costly
213-
_cache_permute_indices[dst_w2_weight.shape] = permute_indices
214-
permute_indices = _cache_permute_indices[dst_w2_weight.shape]
217+
_cache_permute_indices[cache_key] = permute_indices
218+
permute_indices = _cache_permute_indices[cache_key]
215219
return permute_indices
216220

217221

include/flashinfer/trtllm/fused_moe/RoutingKernelTopK.cuh

Lines changed: 0 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -138,160 +138,6 @@ struct Sort<4, RedType> {
138138
}
139139
};
140140

141-
// For N > 4, use a generic bubble sort approach for simplicity
142-
// This is not the most efficient but adequate for small N
143-
template <typename RedType>
144-
struct Sort<5, RedType> {
145-
static __device__ void run(RedType* topK) {
146-
#pragma unroll
147-
for (int i = 0; i < 4; ++i) {
148-
#pragma unroll
149-
for (int j = 0; j < 4 - i; ++j) {
150-
TOPK_SWAP(j, j + 1);
151-
}
152-
}
153-
}
154-
};
155-
156-
template <typename RedType>
157-
struct Sort<6, RedType> {
158-
static __device__ void run(RedType* topK) {
159-
#pragma unroll
160-
for (int i = 0; i < 5; ++i) {
161-
#pragma unroll
162-
for (int j = 0; j < 5 - i; ++j) {
163-
TOPK_SWAP(j, j + 1);
164-
}
165-
}
166-
}
167-
};
168-
169-
template <typename RedType>
170-
struct Sort<7, RedType> {
171-
static __device__ void run(RedType* topK) {
172-
#pragma unroll
173-
for (int i = 0; i < 6; ++i) {
174-
#pragma unroll
175-
for (int j = 0; j < 6 - i; ++j) {
176-
TOPK_SWAP(j, j + 1);
177-
}
178-
}
179-
}
180-
};
181-
182-
template <typename RedType>
183-
struct Sort<8, RedType> {
184-
static __device__ void run(RedType* topK) {
185-
#pragma unroll
186-
for (int i = 0; i < 7; ++i) {
187-
#pragma unroll
188-
for (int j = 0; j < 7 - i; ++j) {
189-
TOPK_SWAP(j, j + 1);
190-
}
191-
}
192-
}
193-
};
194-
195-
template <typename RedType>
196-
struct Sort<9, RedType> {
197-
static __device__ void run(RedType* topK) {
198-
#pragma unroll
199-
for (int i = 0; i < 8; ++i) {
200-
#pragma unroll
201-
for (int j = 0; j < 8 - i; ++j) {
202-
TOPK_SWAP(j, j + 1);
203-
}
204-
}
205-
}
206-
};
207-
208-
template <typename RedType>
209-
struct Sort<10, RedType> {
210-
static __device__ void run(RedType* topK) {
211-
#pragma unroll
212-
for (int i = 0; i < 9; ++i) {
213-
#pragma unroll
214-
for (int j = 0; j < 9 - i; ++j) {
215-
TOPK_SWAP(j, j + 1);
216-
}
217-
}
218-
}
219-
};
220-
221-
template <typename RedType>
222-
struct Sort<11, RedType> {
223-
static __device__ void run(RedType* topK) {
224-
#pragma unroll
225-
for (int i = 0; i < 10; ++i) {
226-
#pragma unroll
227-
for (int j = 0; j < 10 - i; ++j) {
228-
TOPK_SWAP(j, j + 1);
229-
}
230-
}
231-
}
232-
};
233-
234-
template <typename RedType>
235-
struct Sort<12, RedType> {
236-
static __device__ void run(RedType* topK) {
237-
#pragma unroll
238-
for (int i = 0; i < 11; ++i) {
239-
#pragma unroll
240-
for (int j = 0; j < 11 - i; ++j) {
241-
TOPK_SWAP(j, j + 1);
242-
}
243-
}
244-
}
245-
};
246-
template <typename RedType>
247-
struct Sort<13, RedType> {
248-
static __device__ void run(RedType* topK) {
249-
#pragma unroll
250-
for (int i = 0; i < 12; ++i) {
251-
#pragma unroll
252-
for (int j = 0; j < 12 - i; ++j) {
253-
TOPK_SWAP(j, j + 1);
254-
}
255-
}
256-
}
257-
};
258-
template <typename RedType>
259-
struct Sort<14, RedType> {
260-
static __device__ void run(RedType* topK) {
261-
#pragma unroll
262-
for (int i = 0; i < 13; ++i) {
263-
#pragma unroll
264-
for (int j = 0; j < 13 - i; ++j) {
265-
TOPK_SWAP(j, j + 1);
266-
}
267-
}
268-
}
269-
};
270-
template <typename RedType>
271-
struct Sort<15, RedType> {
272-
static __device__ void run(RedType* topK) {
273-
#pragma unroll
274-
for (int i = 0; i < 14; ++i) {
275-
#pragma unroll
276-
for (int j = 0; j < 14 - i; ++j) {
277-
TOPK_SWAP(j, j + 1);
278-
}
279-
}
280-
}
281-
};
282-
template <typename RedType>
283-
struct Sort<16, RedType> {
284-
static __device__ void run(RedType* topK) {
285-
#pragma unroll
286-
for (int i = 0; i < 15; ++i) {
287-
#pragma unroll
288-
for (int j = 0; j < 15 - i; ++j) {
289-
TOPK_SWAP(j, j + 1);
290-
}
291-
}
292-
}
293-
};
294-
295141
////////////////////////////////////////////////////////////////////////////////////////////////////
296142

297143
template <int K, typename Type>

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def is_cuda_oom_error_str(e: str) -> bool:
139139

140140
@pytest.hookimpl(tryfirst=True)
141141
def pytest_runtest_call(item):
142-
# Wrap the test call so we don't invoke item.runtest() ourselves; yield lets pytest run it.
142+
# skip OOM error and missing JIT cache errors
143143
try:
144144
item.runtest()
145145
except (torch.cuda.OutOfMemoryError, RuntimeError) as e:

tests/moe/test_trtllm_gen_fused_moe.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def capture(self, hidden_states_sample, **runtime_args):
105105
self.input_tensor = hidden_states_sample.clone()
106106

107107
# Warmup
108-
with torch.cuda.stream(torch_stream), autotune(False):
108+
with torch.cuda.stream(torch_stream), autotune(True):
109109
for _ in range(1):
110110
self._run_moe_computation(runtime_args)
111111

@@ -1832,13 +1832,14 @@ def _compute_moe_actual_unified(moe_impl, args_dequant, args, **kwargs):
18321832

18331833
@pytest.fixture(scope="module")
18341834
def cache_permute_indices():
1835-
_cache_permute_indices: Dict[torch.Size, torch.Tensor] = {}
1835+
# The cache key is now a tuple of (weight_type, shape)
1836+
_cache_permute_indices: Dict[tuple, torch.Tensor] = {}
18361837
return _cache_permute_indices
18371838

18381839

18391840
@pytest.mark.parametrize("num_tokens", [1, 8, 1024])
18401841
@pytest.mark.parametrize("hidden_size", [1024, 8192])
1841-
@pytest.mark.parametrize("intermediate_size", [1024, 768, 384, 512])
1842+
@pytest.mark.parametrize("intermediate_size", [2048, 1024, 768, 512, 384])
18421843
@pytest.mark.parametrize(
18431844
"moe_impl",
18441845
[
@@ -1905,8 +1906,8 @@ def cache_permute_indices():
19051906
),
19061907
pytest.param(
19071908
{
1908-
"num_experts": 512,
1909-
"top_k": 10,
1909+
"num_experts": 256,
1910+
"top_k": 8,
19101911
"padding": 8,
19111912
"n_groups": None,
19121913
"top_k_groups": None,
@@ -1916,9 +1917,9 @@ def cache_permute_indices():
19161917
"compatible_moe_impls": [FP8BlockScaleMoe, FP4Moe],
19171918
},
19181919
id="Renorm",
1919-
# marks=pytest.mark.skip(
1920-
# reason="Disabled for testing speed - similar to RenormalizeNaive"
1921-
# ),
1920+
marks=pytest.mark.skip(
1921+
reason="Disabled for testing speed - similar to RenormalizeNaive"
1922+
),
19221923
),
19231924
pytest.param(
19241925
{
@@ -1929,6 +1930,20 @@ def cache_permute_indices():
19291930
"top_k_groups": None,
19301931
"routed_scaling": None,
19311932
"has_routing_bias": False,
1933+
"routing_method_type": RoutingMethodType.Renormalize,
1934+
"compatible_moe_impls": [FP8BlockScaleMoe, FP4Moe],
1935+
},
1936+
id="Qwen3_next",
1937+
),
1938+
pytest.param(
1939+
{
1940+
"num_experts": 256,
1941+
"top_k": 8,
1942+
"padding": 8,
1943+
"n_groups": None,
1944+
"top_k_groups": None,
1945+
"routed_scaling": None,
1946+
"has_routing_bias": False,
19321947
"routing_method_type": RoutingMethodType.RenormalizeNaive,
19331948
"compatible_moe_impls": [FP4Moe, FP8BlockScaleMoe],
19341949
},
@@ -2041,6 +2056,20 @@ def test_moe_quantization_classes(
20412056
f"Skip for testing speed: {gated_act_type} + {hidden_size} + {intermediate_size}"
20422057
)
20432058

2059+
# Skip large intermediate sizes for configurations with many experts
2060+
if routing_config["num_experts"] >= 512 and intermediate_size > 512:
2061+
pytest.skip(
2062+
f"Skipping for testing speed: intermediate_size={intermediate_size} with {routing_config['num_experts']} experts"
2063+
)
2064+
2065+
# Skip large intermediate size and hidden size for configurations with small epxerts
2066+
if routing_config["num_experts"] < 512 and (
2067+
intermediate_size > 512 or hidden_size > 1024
2068+
):
2069+
pytest.skip(
2070+
f"Skipping for testing speed: intermediate_size={intermediate_size} with {routing_config['num_experts']} experts"
2071+
)
2072+
20442073
if type(moe_impl) not in routing_config["compatible_moe_impls"]:
20452074
pytest.skip(
20462075
f"Incompatible: {moe_impl.name} + {routing_config['routing_method_type'].name}"
@@ -2085,10 +2114,10 @@ def test_moe_quantization_classes(
20852114
)
20862115
else 64,
20872116
)
2088-
padding = tile_tokens_dim
2117+
20892118
# Validation checks
20902119
assert top_k <= num_experts
2091-
# assert top_k <= 8
2120+
assert top_k <= 10
20922121
if (top_k_groups is not None) and (n_groups is not None) and (n_groups > 0):
20932122
assert top_k_groups <= 4
20942123
assert num_experts > n_groups

0 commit comments

Comments
 (0)