|
21 | 21 |
|
22 | 22 | import pytest |
23 | 23 |
|
24 | | -from .conftest import run_greedy_equality_correctness_test |
| 24 | +from .conftest import run_greedy_equality_correctness_test, run_equality_correctness_test |
25 | 25 |
|
26 | 26 | # main model |
27 | 27 | MAIN_MODEL = "JackFram/llama-160m" |
@@ -77,6 +77,49 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator, |
77 | 77 | force_output_len=True) |
78 | 78 |
|
79 | 79 |
|
| 80 | +@pytest.mark.parametrize( |
| 81 | + "common_llm_kwargs", |
| 82 | + [{ |
| 83 | + # Skip cuda graph recording for fast test. |
| 84 | + "enforce_eager": True, |
| 85 | +
|
| 86 | + # Required for spec decode. |
| 87 | + "use_v2_block_manager": True, |
| 88 | +
|
| 89 | + # Print spec metrics. |
| 90 | + "disable_log_stats": False, |
| 91 | +
|
| 92 | + # Precision |
| 93 | + "dtype": PRECISION, |
| 94 | +
|
| 95 | + # Main model |
| 96 | + "model": MAIN_MODEL, |
| 97 | + }]) |
| 98 | +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) |
| 99 | +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) |
| 100 | +@pytest.mark.parametrize("test_llm_kwargs", [ |
| 101 | + { |
| 102 | + "speculative_model": SPEC_MODEL, |
| 103 | + }, |
| 104 | +]) |
| 105 | +@pytest.mark.parametrize("output_len", [ |
| 106 | + 128, |
| 107 | +]) |
| 108 | +@pytest.mark.parametrize("batch_size", [1, 32]) |
| 109 | +@pytest.mark.parametrize("temperature", [0.1, 1.0]) |
| 110 | +@pytest.mark.parametrize("seed", [1]) |
| 111 | +def test_mlp_e2e_seeded_correctness(baseline_llm_generator, test_llm_generator, |
| 112 | + batch_size: int, output_len: int, temperature: float): |
| 113 | + """Verify seeded runs produce the same output.""" |
| 114 | + run_equality_correctness_test(baseline_llm_generator, |
| 115 | + baseline_llm_generator, |
| 116 | + batch_size, |
| 117 | + max_output_len=output_len, |
| 118 | + temperature=temperature, |
| 119 | + seeded=True, |
| 120 | + force_output_len=True) |
| 121 | + |
| 122 | + |
80 | 123 | @pytest.mark.parametrize( |
81 | 124 | "common_llm_kwargs", |
82 | 125 | [{ |
|
0 commit comments