|
25 | 25 | QWEN2VL_MODEL_ID = "Qwen/Qwen2-VL-2B-Instruct" |
26 | 26 | MLLAMA_MODEL_ID = "meta-llama/Llama-3.2-11B-Vision-Instruct" |
27 | 27 | LLAMA_GUARD_MODEL_ID = "meta-llama/Llama-Guard-3-1B" |
28 | | - |
| 28 | +CohereForAI_MODEL_ID = "CohereForAI/c4ai-command-r7b-12-2024" |
29 | 29 |
|
30 | 30 | @pytest.fixture(scope="function") |
31 | 31 | def phi3v_model_config(): |
@@ -711,6 +711,42 @@ def get_conversation(is_hf: bool): |
711 | 711 | assert hf_result == vllm_result |
712 | 712 |
|
713 | 713 |
|
| 714 | +@pytest.mark.parametrize("model", [ |
| 715 | + QWEN2VL_MODEL_ID, # chat_template is of type str. |
| 716 | + CohereForAI_MODEL_ID, # chat_template is of type dict. |
| 717 | + ]) |
| 718 | +def test_chat_template_hf(model): |
| 719 | + """checks that chat_template is a dict type for HF models.""" |
| 720 | + |
| 721 | + def get_conversation(): |
| 722 | + return [ |
| 723 | + {"role": "system", "content": "You are a helpful assistant."}, |
| 724 | + {'role': 'user','content': 'Hello, how are you?'}] |
| 725 | + # Build a config for the model |
| 726 | + model_config = ModelConfig(model, |
| 727 | + task="generate", |
| 728 | + tokenizer=model, |
| 729 | + tokenizer_mode="auto", |
| 730 | + trust_remote_code=True, |
| 731 | + dtype="auto", |
| 732 | + seed=0) |
| 733 | + # Build the tokenizer group and grab the underlying tokenizer |
| 734 | + tokenizer_group = TokenizerGroup( |
| 735 | + model, |
| 736 | + enable_lora=False, |
| 737 | + max_num_seqs=5, |
| 738 | + max_input_length=None, |
| 739 | + ) |
| 740 | + tokenizer = tokenizer_group.tokenizer |
| 741 | + apply_hf_chat_template( |
| 742 | + tokenizer, |
| 743 | + conversation=get_conversation(), |
| 744 | + # test that chat_template is None. use default chat_template. |
| 745 | + chat_template=None, |
| 746 | + add_generation_prompt=True |
| 747 | + ) |
| 748 | + |
| 749 | + |
714 | 750 | # yapf: disable |
715 | 751 | @pytest.mark.parametrize( |
716 | 752 | ("model", "expected_format"), |
|
0 commit comments