Skip to content

Commit d6569b7

Browse files
WIP
1 parent e9d517f commit d6569b7

File tree

3 files changed

+26
-8
lines changed

3 files changed

+26
-8
lines changed

vllm/model_executor/models/llama.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
483483
mistral_mapping = {
484484
"layers": "model.layers",
485485
"attention": "self_attn",
486+
"qscale_act": "input_scale",
487+
"qscale_weight": "weight_scale",
488+
"kv_fake_quantizer.qscale_act": "kv_scale",
486489
"wq": "q_proj",
487490
"wk": "k_proj",
488491
"wv": "v_proj",
@@ -603,15 +606,23 @@ def permute(w: torch.Tensor, n_heads: int):
603606
modules = name.split(".")
604607

605608
# rotary embeds should be sliced
606-
if "wk" in modules:
609+
if "wk" in modules and modules[-1] == "weight":
607610
loaded_weight = permute(loaded_weight,
608611
self.config.num_key_value_heads)
609-
elif "wq" in modules:
612+
elif "wq" in modules and modules[-1] == "weight":
610613
loaded_weight = permute(loaded_weight,
611614
self.config.num_attention_heads)
612615

613-
for item in modules:
614-
if item in mapping and mapping[item] not in name:
616+
num_modules = len(modules)
617+
for i in range(num_modules):
618+
item = modules[i]
619+
next_item = modules[i + 1] if i < num_modules - 1 else None
620+
621+
combined_item = f"{item}.{next_item}" if next_item is not None else None
622+
623+
if combined_item in mapping:
624+
name = name.replace(combined_item, mapping[combined_item])
625+
elif item in mapping and mapping[item] not in name:
615626
name = name.replace(item, mapping[item])
616627

617628
return name, loaded_weight

vllm/transformers_utils/config.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,12 +218,12 @@ def load_params_config(model, revision) -> PretrainedConfig:
218218
"hidden_dim": "intermediate_size",
219219
}
220220

221-
def recurse_elems(elem: Any):
222-
if isinstance(elem, dict):
221+
def recurse_elems(elem: Any, wrap_to_hf_config: bool = True):
222+
if isinstance(elem, dict) and wrap_to_hf_config:
223223
config_dict = {}
224224
for key, value in elem.items():
225225
key = config_mapping.get(key, key)
226-
config_dict[key] = recurse_elems(value)
226+
config_dict[key] = recurse_elems(value, wrap_to_hf_config=False)
227227
return PretrainedConfig(**config_dict)
228228
else:
229229
return elem
@@ -236,6 +236,12 @@ def recurse_elems(elem: Any):
236236
config_dict["max_position_embeddings"] = config_dict.get(
237237
"max_position_embeddings", 128_000)
238238

239+
if config_dict.get("quantization") is not None:
240+
config_dict["quantization_config"] = {
241+
"quant_method": "fp8",
242+
"activation_scheme": "static"
243+
}
244+
239245
if config_dict.get("moe") is not None:
240246
config_dict["architectures"] = ["MixtralForCausalLM"]
241247
else:
@@ -252,6 +258,7 @@ def recurse_elems(elem: Any):
252258
config_dict["model_type"] = "pixtral"
253259

254260
config = recurse_elems(config_dict)
261+
255262
return config
256263

257264

vllm/transformers_utils/tokenizers/mistral.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def convert_ids_to_tokens(
220220

221221
tokens = [self.tokenizer.id_to_piece(id) for id in ids]
222222

223-
if any(t.strip() == "�" for t in tokens):
223+
if any(t.strip() == "�" for t in tokens) and isinstance(self.tokenizer, Tekkenizer):
224224
# if any stripped decoded token is undefined
225225
# because it's invalid unicode then pass bytes
226226
# See: https://github.com/vllm-project/vllm/pull/8640

0 commit comments

Comments
 (0)