Skip to content

Commit 1b5625c

Browse files
committed
ENH improve validation of required runtime fields after resolving benchmark defaults
1 parent cda21e7 commit 1b5625c

2 files changed

Lines changed: 75 additions & 2 deletions

File tree

python/cuml/cuml/benchmark/config.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,9 +457,30 @@ def _validate_post_defaults_entry(entry: dict[str, Any]) -> None:
457457
entry, context=f"benchmark '{entry.get('id', entry['algorithm'])}'", require_algorithm=True
458458
)
459459

460+
benchmark_name = entry.get("id", entry["algorithm"])
461+
462+
for field in ("dataset", "input_type", "dtype"):
463+
if not isinstance(entry.get(field), str) or not entry[field]:
464+
raise BenchmarkConfigError(
465+
f"Benchmark '{benchmark_name}' must define a non-empty "
466+
f"'{field}' after applying defaults"
467+
)
468+
469+
if not isinstance(entry.get("n_reps"), int):
470+
raise BenchmarkConfigError(
471+
f"Benchmark '{benchmark_name}' must define integer 'n_reps' "
472+
"after applying defaults"
473+
)
474+
475+
if not isinstance(entry.get("test_split"), (int, float)):
476+
raise BenchmarkConfigError(
477+
f"Benchmark '{benchmark_name}' must define numeric 'test_split' "
478+
"after applying defaults"
479+
)
480+
460481
if not entry.get("run_cpu", True) and not entry.get("run_gpu", True):
461482
raise BenchmarkConfigError(
462-
f"Benchmark '{entry.get('id', entry['algorithm'])}' cannot "
483+
f"Benchmark '{benchmark_name}' cannot "
463484
"disable both CPU and GPU execution"
464485
)
465486

python/cuml/tests/test_benchmark_config.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
from pathlib import Path
88

99
import pandas as pd
10+
import pytest
1011

11-
from cuml.benchmark.config import load_and_resolve_config
12+
from cuml.benchmark.config import BenchmarkConfigError, load_and_resolve_config
1213
from cuml.benchmark.run_benchmarks import _run_config_benchmarks, main
1314

1415

@@ -152,6 +153,57 @@ def test_load_and_resolve_config_expands_shape_pairs_and_param_grid(tmp_path):
152153
assert entry["param_override_list"] == [{"C": 0.25}, {"C": 1.0}]
153154

154155

156+
@pytest.mark.parametrize(
157+
("defaults_block", "expected_field"),
158+
[
159+
("", "dataset"),
160+
(" dataset: classification\n", "input_type"),
161+
(" dataset: classification\n input_type: numpy\n", "dtype"),
162+
(
163+
" dataset: classification\n input_type: numpy\n dtype: fp32\n",
164+
"n_reps",
165+
),
166+
(
167+
" dataset: classification\n"
168+
" input_type: numpy\n"
169+
" dtype: fp32\n"
170+
" n_reps: 2\n",
171+
"test_split",
172+
),
173+
],
174+
)
175+
def test_load_and_resolve_config_requires_runtime_fields_after_defaults(
176+
tmp_path, defaults_block, expected_field
177+
):
178+
config_path = tmp_path / "missing-required.yaml"
179+
config_path.write_text(
180+
(
181+
"version: 1\n\n"
182+
"suite:\n"
183+
" name: missing-required\n"
184+
" tier: test\n"
185+
" description: missing field coverage\n\n"
186+
"defaults:\n"
187+
f"{defaults_block}"
188+
" run_cpu: true\n"
189+
" run_gpu: false\n\n"
190+
"benchmarks:\n"
191+
" - id: shaped_logreg\n"
192+
" algorithm: LogisticRegression\n"
193+
" operation: fit\n"
194+
" rows: [100]\n"
195+
" features: [8]\n"
196+
),
197+
encoding="utf-8",
198+
)
199+
200+
with pytest.raises(
201+
BenchmarkConfigError,
202+
match=rf"must define .*'{expected_field}'.*after applying defaults",
203+
):
204+
load_and_resolve_config(str(config_path))
205+
206+
155207
def test_run_config_benchmarks_uses_shape_pairs_without_cartesian_product(
156208
monkeypatch,
157209
):

0 commit comments

Comments
 (0)