Skip to content

Commit 56b355a

Browse files
committed
rebase + fix layer_utils
1 parent 0decaa3 commit 56b355a

File tree

3 files changed

+10
-13
lines changed

3 files changed

+10
-13
lines changed

tests/compile/test_full_graph.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@
2828
]
2929

3030
# TODO: enable in pytorch 2.5
31-
if False and is_quant_method_supported("aqlm"):
31+
if False and is_quant_method_supported("aqlm"): # noqa: SIM223
3232
TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", {
3333
"quantization": "aqlm"
3434
}))
3535

3636
# TODO: enable in pytorch 2.5
37-
if False and is_quant_method_supported("gguf"):
37+
if False and is_quant_method_supported("gguf"): # noqa: SIM223
3838
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {
3939
"quantization": "gguf"
4040
}))
@@ -85,7 +85,7 @@ def test_full_graph(model_info, tp_size, backend):
8585

8686
# Inductor doesn't support fp8/gptq_marlin_24 yet.
8787
quantization = model_kwargs.get("quantization")
88-
if (quantization == "fp8"
88+
if (quantization == "fp8" or quantization == "gptq_marlin"
8989
or quantization == "gptq_marlin_24") and backend != "eager":
9090
return
9191

vllm/_custom_ops.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -416,8 +416,8 @@ def _fp8_marlin_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor,
416416
@torch.library.register_fake("_C::machete_gemm")
417417
def machete_gemm_fake(
418418
a: torch.Tensor,
419-
b_q: torch.
420-
Tensor, # Should be the tensor returned by machete_prepack_B
419+
# Should be the tensor returned by machete_prepack_B
420+
b_q: torch.Tensor,
421421
b_type: ScalarType,
422422
b_scales: Optional[torch.Tensor] = None,
423423
b_zeros: Optional[torch.Tensor] = None,
@@ -613,16 +613,12 @@ def machete_prepack_B(b_q_weight: torch.Tensor,
613613
return torch.ops._C.machete_prepack_B(b_q_weight, b_type)
614614

615615

616-
# TODO: has to be a better way to do this
617-
try:
618-
torch.ops._C.permute_cols # noqa B018
616+
if hasattr(torch.ops._C, 'permute_cols'):
619617

620618
@torch.library.register_fake("_C::permute_cols")
621619
def _permute_cols_fake(a: torch.Tensor,
622620
perm: torch.Tensor) -> torch.Tensor:
623621
return torch.empty_like(a)
624-
except Exception:
625-
pass
626622

627623

628624
def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:

vllm/model_executor/layers/quantization/utils/layer_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@ def replace_parameter(mod: torch.nn.Module, name: str,
2121
new: Union[torch.Tensor, torch.nn.Parameter]) -> None:
2222

2323
old = getattr(mod, name)
24-
if old.dtype == new.dtype and \
24+
if type(old) is type(new) and old.dtype == new.dtype and \
2525
old.untyped_storage().nbytes() == new.untyped_storage().nbytes():
2626
# If we can just update in-place to avoid re-registering
2727
# can be faster if the underlying storage is the same
2828
update_tensor_inplace(old, new)
2929
else:
3030
# Fallback re-register parameter
3131
if not isinstance(new, torch.nn.Parameter):
32-
new = torch.nn.Parameter(new)
33-
mod.register_parameter(name, torch.nn.Parameter(new))
32+
new = torch.nn.Parameter(new, requires_grad=False)
33+
mod.register_parameter(name,
34+
torch.nn.Parameter(new, requires_grad=False))

0 commit comments

Comments
 (0)