Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions autosklearn/automl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Callable, Iterable, List, Mapping, Optional, Tuple, Union
from typing import Any, Callable, Iterable, Mapping, Optional, Sequence, Tuple

import copy
import io
Expand Down Expand Up @@ -210,7 +210,7 @@ def __init__(
get_smac_object_callback: Optional[Callable] = None,
smac_scenario_args: Optional[Mapping] = None,
logging_config: Optional[Mapping] = None,
metric: Optional[Union[Scorer, List[Scorer], Tuple[Scorer]]] = None,
metric: Optional[Scorer | Sequence[Scorer]] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not neccessary, just something to know Optional[X] == Union[X, None] == X | None
i..e you could write Scorer | Sequence[Scorer] | None = None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's look nice, will do.

scoring_functions: Optional[list[Scorer]] = None,
get_trials_callback: Optional[IncorporateRunResultCallback] = None,
dataset_compression: bool | Mapping[str, Any] = True,
Expand Down Expand Up @@ -692,7 +692,7 @@ def fit(
# defined in the estimator fit call
if self._metric is None:
raise ValueError("No metric given.")
if isinstance(self._metric, (List, Tuple)):
if isinstance(self._metric, Sequence):
for entry in self._metric:
if not isinstance(entry, Scorer):
raise ValueError(
Expand Down Expand Up @@ -796,7 +796,7 @@ def fit(
task=self._task,
metric=(
self._metric[0]
if isinstance(self._metric, (List, Tuple))
if isinstance(self._metric, Sequence)
else self._metric
),
ensemble_size=self._ensemble_size,
Expand Down Expand Up @@ -1501,9 +1501,7 @@ def fit_ensemble(
dataset_name=dataset_name if dataset_name else self._dataset_name,
task=task if task else self._task,
metric=(
self._metric[0]
if isinstance(self._metric, (List, Tuple))
else self._metric
self._metric[0] if isinstance(self._metric, Sequence) else self._metric
),
ensemble_size=ensemble_size if ensemble_size else self._ensemble_size,
ensemble_nbest=ensemble_nbest if ensemble_nbest else self._ensemble_nbest,
Expand Down
6 changes: 4 additions & 2 deletions autosklearn/estimators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# -*- encoding: utf-8 -*-
from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union
from __future__ import annotations

from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union

import dask.distributed
import joblib
Expand Down Expand Up @@ -46,7 +48,7 @@ def __init__(
smac_scenario_args=None,
logging_config=None,
metadata_directory=None,
metric: Optional[Union[Scorer, List[Scorer], Tuple[Scorer]]] = None,
metric: Optional[Scorer | Sequence[Scorer]] = None,
scoring_functions: Optional[List[Scorer]] = None,
load_models: bool = True,
get_trials_callback=None,
Expand Down
21 changes: 17 additions & 4 deletions autosklearn/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
# -*- encoding: utf-8 -*-
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast
from __future__ import annotations

from typing import (
Any,
Callable,
Dict,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)

import functools
import json
Expand Down Expand Up @@ -86,10 +99,10 @@ def fit_predict_try_except_decorator(


def get_cost_of_crash(
metric: Union[Scorer, List[Scorer], Tuple[Scorer]]
metric: Union[Scorer | Sequence[Scorer]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like wise here, Union[X | Y] == X | Y, the | essentially is just the infix operator for Union in the same way you have + instead of add(x, y).

i.e. metric: Scorer | Sequence[Scorer]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching.

) -> Union[float, List[float]]:

if isinstance(metric, (List, Tuple)):
if isinstance(metric, Sequence):
return [cast(float, get_cost_of_crash(metric_)) for metric_ in metric]
elif not isinstance(metric, Scorer):
raise ValueError("The metric must be stricly be an instance of Scorer")
Expand Down Expand Up @@ -129,7 +142,7 @@ def __init__(
resampling_strategy: Union[
str, BaseCrossValidator, _RepeatedSplits, BaseShuffleSplit
],
metric: Union[Scorer, List[Scorer], Tuple[Scorer]],
metric: Union[Scorer | Sequence[Scorer]],
cost_for_crash: float,
abort_on_first_run_crash: bool,
port: int,
Expand Down
7 changes: 5 additions & 2 deletions autosklearn/evaluation/abstract_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any, Dict, List, Optional, TextIO, Tuple, Type, Union, cast
from __future__ import annotations

from typing import Any, Dict, List, Optional, Sequence, TextIO, Tuple, Type, Union, cast

import logging
import multiprocessing
Expand Down Expand Up @@ -184,7 +186,7 @@ def __init__(
self,
backend: Backend,
queue: multiprocessing.Queue,
metric: Union[Scorer, List[Scorer], Tuple[Scorer]],
metric: Union[Scorer | Sequence[Scorer]],
additional_components: Dict[str, ThirdPartyComponents],
port: Optional[int],
configuration: Optional[Union[int, Configuration]] = None,
Expand Down Expand Up @@ -338,6 +340,7 @@ def _loss(
y_true
"""
if not isinstance(self.configuration, Configuration):
# Dummy prediction
if self.scoring_functions:
if isinstance(self.metric, Scorer):
return {self.metric.name: self.metric._worst_possible_result}
Expand Down
8 changes: 5 additions & 3 deletions autosklearn/evaluation/test_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# -*- encoding: utf-8 -*-
from typing import Any, Dict, List, Optional, Tuple, Union
from __future__ import annotations

from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import multiprocessing

Expand All @@ -23,7 +25,7 @@ def __init__(
self,
backend: Backend,
queue: multiprocessing.Queue,
metric: Scorer,
metric: Union[Scorer | Sequence[Scorer]],
additional_components: Dict[str, ThirdPartyComponents],
port: Optional[int],
configuration: Optional[Union[int, Configuration]] = None,
Expand Down Expand Up @@ -111,7 +113,7 @@ def eval_t(
queue: multiprocessing.Queue,
config: Union[int, Configuration],
backend: Backend,
metric: Scorer,
metric: Union[Scorer | Sequence[Scorer]],
seed: int,
num_run: int,
instance: Dict[str, Any],
Expand Down
18 changes: 10 additions & 8 deletions autosklearn/evaluation/train_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any, Dict, List, Optional, Tuple, Union, cast
from __future__ import annotations

from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast

import copy
import json
Expand Down Expand Up @@ -182,7 +184,7 @@ def __init__(
self,
backend: Backend,
queue: multiprocessing.Queue,
metric: Union[Scorer, List[Scorer], Tuple[Scorer]],
metric: Union[Scorer | Sequence[Scorer]],
additional_components: Dict[str, ThirdPartyComponents],
port: Optional[int],
configuration: Optional[Union[int, Configuration]] = None,
Expand Down Expand Up @@ -1328,7 +1330,7 @@ def eval_holdout(
str, BaseCrossValidator, _RepeatedSplits, BaseShuffleSplit
],
resampling_strategy_args: Dict[str, Optional[Union[float, int, str]]],
metric: Union[Scorer, List[Scorer], Tuple[Scorer]],
metric: Union[Scorer | Sequence[Scorer]],
seed: int,
num_run: int,
instance: str,
Expand Down Expand Up @@ -1375,7 +1377,7 @@ def eval_iterative_holdout(
str, BaseCrossValidator, _RepeatedSplits, BaseShuffleSplit
],
resampling_strategy_args: Dict[str, Optional[Union[float, int, str]]],
metric: Union[Scorer, List[Scorer], Tuple[Scorer]],
metric: Union[Scorer | Sequence[Scorer]],
seed: int,
num_run: int,
instance: str,
Expand Down Expand Up @@ -1422,7 +1424,7 @@ def eval_partial_cv(
str, BaseCrossValidator, _RepeatedSplits, BaseShuffleSplit
],
resampling_strategy_args: Dict[str, Optional[Union[float, int, str]]],
metric: Union[Scorer, List[Scorer], Tuple[Scorer]],
metric: Union[Scorer | Sequence[Scorer]],
seed: int,
num_run: int,
instance: str,
Expand Down Expand Up @@ -1475,7 +1477,7 @@ def eval_partial_cv_iterative(
str, BaseCrossValidator, _RepeatedSplits, BaseShuffleSplit
],
resampling_strategy_args: Dict[str, Optional[Union[float, int, str]]],
metric: Union[Scorer, List[Scorer], Tuple[Scorer]],
metric: Union[Scorer | Sequence[Scorer]],
seed: int,
num_run: int,
instance: str,
Expand Down Expand Up @@ -1523,7 +1525,7 @@ def eval_cv(
str, BaseCrossValidator, _RepeatedSplits, BaseShuffleSplit
],
resampling_strategy_args: Dict[str, Optional[Union[float, int, str]]],
metric: Union[Scorer, List[Scorer], Tuple[Scorer]],
metric: Union[Scorer | Sequence[Scorer]],
seed: int,
num_run: int,
instance: str,
Expand Down Expand Up @@ -1571,7 +1573,7 @@ def eval_iterative_cv(
str, BaseCrossValidator, _RepeatedSplits, BaseShuffleSplit
],
resampling_strategy_args: Dict[str, Optional[Union[float, int, str]]],
metric: Union[Scorer, List[Scorer], Tuple[Scorer]],
metric: Union[Scorer | Sequence[Scorer]],
seed: int,
num_run: int,
instance: str,
Expand Down
12 changes: 7 additions & 5 deletions autosklearn/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from abc import ABCMeta, abstractmethod
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
from typing import Any, Callable, Dict, List, Optional, Sequence, Union, cast

from functools import partial
from itertools import product
Expand Down Expand Up @@ -388,7 +390,7 @@ def calculate_score(
solution: np.ndarray,
prediction: np.ndarray,
task_type: int,
metric: Union[Scorer, List[Scorer], Tuple[Scorer]],
metric: Scorer | Sequence[Scorer],
scoring_functions: Optional[List[Scorer]] = None,
) -> Union[float, Dict[str, float]]:
"""
Expand Down Expand Up @@ -420,7 +422,7 @@ def calculate_score(
to_score = []
if scoring_functions:
to_score.extend(scoring_functions)
if isinstance(metric, (list, tuple)):
if isinstance(metric, Sequence):
to_score.extend(metric)
else:
to_score.append(metric)
Expand Down Expand Up @@ -480,7 +482,7 @@ def calculate_loss(
solution: np.ndarray,
prediction: np.ndarray,
task_type: int,
metric: Union[Scorer, List[Scorer], Tuple[Scorer]],
metric: Scorer | Sequence[Scorer],
scoring_functions: Optional[List[Scorer]] = None,
) -> Union[float, Dict[str, float]]:
"""
Expand Down Expand Up @@ -516,7 +518,7 @@ def calculate_loss(
scoring_functions=scoring_functions,
)

if scoring_functions or isinstance(metric, (list, tuple)):
if scoring_functions or isinstance(metric, Sequence):
score = cast(Dict, score)
scoring_functions = cast(List, scoring_functions)
metric_list = list(cast(List, metric)) # Please mypy
Expand Down
10 changes: 5 additions & 5 deletions autosklearn/smbo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import typing
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, List, Optional, Sequence

import copy
import json
Expand Down Expand Up @@ -260,7 +262,7 @@ def __init__(
total_walltime_limit,
func_eval_time_limit,
memory_limit,
metric: Union[Scorer, List[Scorer], Tuple[Scorer]],
metric: Scorer | Sequence[Scorer],
stopwatch: StopWatch,
n_jobs,
dask_client: dask.distributed.Client,
Expand Down Expand Up @@ -362,9 +364,7 @@ def collect_metalearning_suggestions(self, meta_base):
meta_base=meta_base,
basename=self.dataset_name,
metric=(
self.metric[0]
if isinstance(self.metric, (List, Tuple))
else self.metric
self.metric[0] if isinstance(self.metric, Sequence) else self.metric
),
configuration_space=self.config_space,
task=self.task,
Expand Down