|
7 | 7 | from pathlib import Path |
8 | 8 |
|
9 | 9 | import pandas as pd |
| 10 | +import pytest |
10 | 11 |
|
11 | | -from cuml.benchmark.config import load_and_resolve_config |
| 12 | +from cuml.benchmark.config import BenchmarkConfigError, load_and_resolve_config |
12 | 13 | from cuml.benchmark.run_benchmarks import _run_config_benchmarks, main |
13 | 14 |
|
14 | 15 |
|
@@ -152,6 +153,57 @@ def test_load_and_resolve_config_expands_shape_pairs_and_param_grid(tmp_path): |
152 | 153 | assert entry["param_override_list"] == [{"C": 0.25}, {"C": 1.0}] |
153 | 154 |
|
154 | 155 |
|
| 156 | +@pytest.mark.parametrize( |
| 157 | + ("defaults_block", "expected_field"), |
| 158 | + [ |
| 159 | + ("", "dataset"), |
| 160 | + (" dataset: classification\n", "input_type"), |
| 161 | + (" dataset: classification\n input_type: numpy\n", "dtype"), |
| 162 | + ( |
| 163 | + " dataset: classification\n input_type: numpy\n dtype: fp32\n", |
| 164 | + "n_reps", |
| 165 | + ), |
| 166 | + ( |
| 167 | + " dataset: classification\n" |
| 168 | + " input_type: numpy\n" |
| 169 | + " dtype: fp32\n" |
| 170 | + " n_reps: 2\n", |
| 171 | + "test_split", |
| 172 | + ), |
| 173 | + ], |
| 174 | +) |
| 175 | +def test_load_and_resolve_config_requires_runtime_fields_after_defaults( |
| 176 | + tmp_path, defaults_block, expected_field |
| 177 | +): |
| 178 | + config_path = tmp_path / "missing-required.yaml" |
| 179 | + config_path.write_text( |
| 180 | + ( |
| 181 | + "version: 1\n\n" |
| 182 | + "suite:\n" |
| 183 | + " name: missing-required\n" |
| 184 | + " tier: test\n" |
| 185 | + " description: missing field coverage\n\n" |
| 186 | + "defaults:\n" |
| 187 | + f"{defaults_block}" |
| 188 | + " run_cpu: true\n" |
| 189 | + " run_gpu: false\n\n" |
| 190 | + "benchmarks:\n" |
| 191 | + " - id: shaped_logreg\n" |
| 192 | + " algorithm: LogisticRegression\n" |
| 193 | + " operation: fit\n" |
| 194 | + " rows: [100]\n" |
| 195 | + " features: [8]\n" |
| 196 | + ), |
| 197 | + encoding="utf-8", |
| 198 | + ) |
| 199 | + |
| 200 | + with pytest.raises( |
| 201 | + BenchmarkConfigError, |
| 202 | + match=rf"must define .*'{expected_field}'.*after applying defaults", |
| 203 | + ): |
| 204 | + load_and_resolve_config(str(config_path)) |
| 205 | + |
| 206 | + |
155 | 207 | def test_run_config_benchmarks_uses_shape_pairs_without_cartesian_product( |
156 | 208 | monkeypatch, |
157 | 209 | ): |
|
0 commit comments