Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
self,
is_checkpoint_fp8_serialized: bool = False,
activation_scheme: str = "dynamic",
ignored_layers: Optional[List[str]] = None,
) -> None:
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
if is_checkpoint_fp8_serialized:
Expand All @@ -41,6 +42,7 @@ def __init__(
raise ValueError(
f"Unsupported activation scheme {activation_scheme}")
self.activation_scheme = activation_scheme
self.ignored_layers = ignored_layers or []

@classmethod
def get_name(cls) -> str:
Expand All @@ -63,8 +65,10 @@ def from_config(cls, config: Dict[str, Any]) -> "Fp8Config":
quant_method = cls.get_from_keys(config, ["quant_method"])
is_checkpoint_fp8_serialized = ("fp8" in quant_method)
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
ignored_layers = cls.get_from_keys(config, ["ignored_layers"])
return cls(is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
activation_scheme=activation_scheme)
activation_scheme=activation_scheme,
ignored_layers=ignored_layers)

def get_quant_method(
self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
Expand Down
25 changes: 17 additions & 8 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,22 @@ def __init__(
prefix: str = "",
) -> None:
super().__init__()
ignored_gate_up_proj = all(
p in quant_config.ignored_layers
for p in [f"{prefix}.gate_proj", f"{prefix}.up_proj"])
self.gate_up_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
quant_config=quant_config if not ignored_gate_up_proj else None,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = RowParallelLinear(input_size=intermediate_size,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
ignore_down_proj = f"{prefix}.down_proj" in quant_config.ignored_layers
self.down_proj = RowParallelLinear(
input_size=intermediate_size,
output_size=hidden_size,
bias=bias,
quant_config=quant_config if not ignore_down_proj else None,
prefix=f"{prefix}.down_proj")
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
Expand Down Expand Up @@ -128,21 +133,25 @@ def __init__(
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
ignore_qkv_proj = all(
p in quant_config.ignored_layers for p in
[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"])

self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
quant_config=quant_config if not ignore_qkv_proj else None,
prefix=f"{prefix}.qkv_proj",
)
ignore_o_proj = f"{prefix}.o_proj" in quant_config.ignored_layers
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias,
quant_config=quant_config,
quant_config=quant_config if not ignore_o_proj else None,
prefix=f"{prefix}.o_proj",
)

Expand Down