Skip to content

Commit 8c8b6a8

Browse files
committed
review comments
Signed-off-by: Bill Nell <[email protected]>
1 parent de74947 commit 8c8b6a8

File tree

5 files changed

+35
-38
lines changed

5 files changed

+35
-38
lines changed

vllm/distributed/device_communicators/base_device_communicator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -266,14 +266,14 @@ def prepare_communication_buffer_for_model(self, model: torch.nn.Module) -> None
266266
module
267267
for module in model.modules()
268268
# TODO(bnell): Should use isinstance but can't. Maybe search for
269-
# presence of quant_method.init_prepare_finalize?
269+
# presence of quant_method.maybe_init_modular_kernel?
270270
if (
271271
module.__class__.__name__ == "FusedMoE"
272272
or module.__class__.__name__ == "SharedFusedMoE"
273273
)
274274
]
275275
for module in moe_modules:
276-
module.init_prepare_finalize()
276+
module.maybe_init_modular_kernel()
277277

278278
def dispatch(
279279
self,

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,6 @@ def __init__(self, moe: FusedMoEConfig):
119119
super().__init__()
120120
self.moe = moe
121121
self.moe_quant_config: FusedMoEQuantConfig | None = None
122-
self.topk_indices_dtype = None
123122

124123
@abstractmethod
125124
def create_weights(
@@ -244,7 +243,7 @@ def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None:
244243
else:
245244
return None
246245

247-
def init_prepare_finalize(
246+
def maybe_init_modular_kernel(
248247
self, layer: torch.nn.Module
249248
) -> FusedMoEModularKernel | None:
250249
assert self.moe is not None
@@ -260,8 +259,6 @@ def init_prepare_finalize(
260259
logger.debug(
261260
"%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self)
262261
)
263-
assert self.topk_indices_dtype is None
264-
self.topk_indices_dtype = prepare_finalize.topk_indices_dtype()
265262
experts = self.select_gemm_impl(prepare_finalize, layer)
266263
return FusedMoEModularKernel(
267264
prepare_finalize,
@@ -289,6 +286,10 @@ def get_fused_moe_quant_config(
289286
) -> FusedMoEQuantConfig | None:
290287
raise NotImplementedError
291288

289+
@property
290+
def topk_indices_dtype(self) -> torch.dtype | None:
291+
return None
292+
292293
@property
293294
def supports_eplb(self) -> bool:
294295
return False
@@ -328,31 +329,33 @@ def apply(
328329
class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp):
329330
def __init__(
330331
self,
331-
old_moe_method: FusedMoEMethodBase,
332+
old_quant_method: FusedMoEMethodBase,
332333
fused_experts: FusedMoEModularKernel,
333334
):
334-
super().__init__(old_moe_method.moe)
335-
# Find better way to copy attributes?
336-
# self.__dict__.update(old_moe_method.__dict__)
337-
338-
self.moe_quant_config = old_moe_method.moe_quant_config
335+
super().__init__(old_quant_method.moe)
336+
# Find better way to copy attributes? Should we even copy attributes?
337+
# self.__dict__.update(old_quant_method.__dict__)
338+
self.moe_quant_config = old_quant_method.moe_quant_config
339339
self.fused_experts = fused_experts
340-
self.topk_indices_dtype = old_moe_method.topk_indices_dtype
341-
self.disable_expert_map = not fused_experts.supports_expert_map()
342-
self.old_method_name = old_moe_method.__class__.__name__
343-
self._supports_eplb = old_moe_method.supports_eplb
344-
self._allow_inplace = old_moe_method.allow_inplace
345-
if isinstance(old_moe_method, torch.nn.Module):
346-
self.load_state_dict(old_moe_method.state_dict())
347-
logger.debug("Swapping out %s", self.old_method_name)
340+
self.disable_expert_map = getattr(
341+
old_quant_method,
342+
"disable_expert_map",
343+
not fused_experts.supports_expert_map(),
344+
)
345+
self.old_quant_method = old_quant_method
346+
logger.debug("Swapping out %s", self.old_quant_method.__class__.__name__)
347+
348+
@property
349+
def topk_indices_dtype(self) -> torch.dtype | None:
350+
return self.fused_experts.prepare_finalize.topk_indices_dtype()
348351

349352
@property
350353
def supports_eplb(self) -> bool:
351-
return self._supports_eplb
354+
return self.old_quant_method.supports_eplb
352355

353356
@property
354357
def allow_inplace(self) -> bool:
355-
return self._allow_inplace
358+
return self.old_quant_method.allow_inplace
356359

357360
def create_weights(
358361
self,
@@ -405,10 +408,11 @@ def apply(
405408
assert isinstance(layer, FusedMoE)
406409
else:
407410
raise NotImplementedError(
408-
f"EPLB is not supported for {self.old_method_name}"
411+
"EPLB is not supported for "
412+
f"{self.old_quant_method.__class__.__name__}."
409413
)
410414

411-
select_result = FusedMoE.select_experts(
415+
topk_weights, topk_ids, zero_expert_result = FusedMoE.select_experts(
412416
hidden_states=x,
413417
router_logits=router_logits,
414418
use_grouped_topk=use_grouped_topk,
@@ -431,8 +435,6 @@ def apply(
431435
zero_expert_type=zero_expert_type,
432436
)
433437

434-
topk_weights, topk_ids, zero_expert_result = select_result
435-
436438
result = self.fused_experts(
437439
hidden_states=x,
438440
w1=layer.w13_weight,
@@ -1421,7 +1423,7 @@ def _get_quant_method() -> FusedMoEMethodBase:
14211423
)
14221424

14231425
if not isinstance(
1424-
quant_method, (UnquantizedFusedMoEMethod, ModelOptFp8MoEMethod)
1426+
self.quant_method, (UnquantizedFusedMoEMethod, ModelOptFp8MoEMethod)
14251427
):
14261428
raise NotImplementedError(
14271429
"is_act_and_mul=False is supported only for unquantized "
@@ -1441,6 +1443,7 @@ def _get_quant_method() -> FusedMoEMethodBase:
14411443
# If you plan to add support for more quantization methods,
14421444
# please refer to the implementation in `Fp8MoEMethod`.
14431445
raise NotImplementedError(
1446+
f"EPLB is not supported {self.quant_method.__class__.__name__}. "
14441447
"EPLB is only supported for FP8 quantization for now."
14451448
)
14461449

@@ -1466,12 +1469,12 @@ def _get_quant_method() -> FusedMoEMethodBase:
14661469
self.batched_hidden_states: torch.Tensor | None = None
14671470
self.batched_router_logits: torch.Tensor | None = None
14681471

1469-
# Note: init_prepare_finalize should only be called by
1472+
# Note: maybe_init_modular_kernel should only be called by
14701473
# prepare_communication_buffer_for_model.
14711474
# This is called after all weight loading and post-processing, so it
14721475
# should be safe to swap out the quant_method.
1473-
def init_prepare_finalize(self) -> None:
1474-
mk = self.quant_method.init_prepare_finalize(self)
1476+
def maybe_init_modular_kernel(self) -> None:
1477+
mk = self.quant_method.maybe_init_modular_kernel(self)
14751478
if mk is not None:
14761479
self.quant_method = FusedMoEModularMethod(self.quant_method, mk)
14771480

vllm/model_executor/layers/fused_moe/modular_kernel.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -709,12 +709,9 @@ def __init__(
709709

710710
def supports_expert_map(self) -> bool:
711711
"""
712-
A flag indicating whether or not this class supports expert maps
712+
A flag indicating whether or not this class supports expert maps.
713713
"""
714-
return (
715-
self.prepare_finalize.num_dispatchers() <= 1
716-
and self.fused_experts.supports_expert_map()
717-
)
714+
return self.fused_experts.supports_expert_map()
718715

719716
def output_is_reduced(self) -> bool:
720717
"""

vllm/model_executor/layers/quantization/moe_wna16.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,6 @@ def create_weights(
226226
params_dtype: torch.dtype,
227227
**extra_weight_attrs,
228228
):
229-
self.moe = layer
230229
layer.quant_config = self.quant_config
231230
bit8_pack_factor = self.quant_config.bit8_pack_factor
232231
group_size = self.quant_config.group_size

vllm/model_executor/layers/quantization/mxfp4.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,6 @@ def get_quant_method(
181181
class Mxfp4MoEMethod(FusedMoEMethodBase):
182182
def __init__(self, moe: FusedMoEConfig):
183183
super().__init__(moe)
184-
self.topk_indices_dtype = None
185-
self.moe = moe
186184
self.mxfp4_backend = get_mxfp4_backend()
187185
self.max_capture_size = (
188186
get_current_vllm_config().compilation_config.max_cudagraph_capture_size

0 commit comments

Comments
 (0)