|
8 | 8 |
|
9 | 9 | from vllm.platforms import current_platform |
10 | 10 |
|
11 | | -from ...utils import check_outputs_equal |
| 11 | +from ...utils import check_logprobs_close, check_outputs_equal |
12 | 12 |
|
13 | 13 | MODELS = [ |
14 | 14 | "meta-llama/Llama-2-7b-hf", |
@@ -43,18 +43,40 @@ def test_models( |
43 | 43 | dtype: str, |
44 | 44 | max_tokens: int, |
45 | 45 | ) -> None: |
46 | | - with hf_runner(model, dtype=dtype) as hf_model: |
47 | | - hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) |
48 | 46 |
|
49 | | - with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model: |
50 | | - vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) |
51 | | - |
52 | | - check_outputs_equal( |
53 | | - outputs_0_lst=hf_outputs, |
54 | | - outputs_1_lst=vllm_outputs, |
55 | | - name_0="hf", |
56 | | - name_1="vllm", |
57 | | - ) |
| 47 | + if model == "openbmb/MiniCPM3-4B": |
| 48 | + # the output becomes slightly different when upgrading to |
| 49 | + # pytorch 2.5 . Changing to logprobs checks instead of exact |
| 50 | + # output checks. |
| 51 | + NUM_LOG_PROBS = 8 |
| 52 | + with hf_runner(model, dtype=dtype) as hf_model: |
| 53 | + hf_outputs = hf_model.generate_greedy_logprobs_limit( |
| 54 | + example_prompts, max_tokens, NUM_LOG_PROBS) |
| 55 | + |
| 56 | + with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model: |
| 57 | + vllm_outputs = vllm_model.generate_greedy_logprobs( |
| 58 | + example_prompts, max_tokens, NUM_LOG_PROBS) |
| 59 | + |
| 60 | + check_logprobs_close( |
| 61 | + outputs_0_lst=hf_outputs, |
| 62 | + outputs_1_lst=vllm_outputs, |
| 63 | + name_0="hf", |
| 64 | + name_1="vllm", |
| 65 | + ) |
| 66 | + else: |
| 67 | + with hf_runner(model, dtype=dtype) as hf_model: |
| 68 | + hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) |
| 69 | + |
| 70 | + with vllm_runner(model, dtype=dtype, enforce_eager=True) as vllm_model: |
| 71 | + vllm_outputs = vllm_model.generate_greedy(example_prompts, |
| 72 | + max_tokens) |
| 73 | + |
| 74 | + check_outputs_equal( |
| 75 | + outputs_0_lst=hf_outputs, |
| 76 | + outputs_1_lst=vllm_outputs, |
| 77 | + name_0="hf", |
| 78 | + name_1="vllm", |
| 79 | + ) |
58 | 80 |
|
59 | 81 |
|
60 | 82 | @pytest.mark.parametrize("model", MODELS) |
|
0 commit comments