diff --git a/.flake8 b/.flake8 index f3a26a3b56..d01cdb93b6 100644 --- a/.flake8 +++ b/.flake8 @@ -5,6 +5,7 @@ extend-exclude = venv .venv build + autosklearn/automl_common extend-ignore = # No whitespace before ':' in [x : y] E203 diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index be2cb7481c..735305307d 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -137,6 +137,8 @@ jobs: - name: Check for files left behind by test if: ${{ always() }} run: | + # Deleting `.pytest_chache` as it's used during testing and not deleted + rm -rf ".pytest_cache" before="${{ steps.status-before.outputs.BEFORE }}" after="$(git status --porcelain -b)" if [[ "$before" != "$after" ]]; then diff --git a/autosklearn/automl.py b/autosklearn/automl.py index e0aae596e1..1496aa2224 100644 --- a/autosklearn/automl.py +++ b/autosklearn/automl.py @@ -1,5 +1,6 @@ -# -*- encoding: utf-8 -*- -from typing import Any, Dict, List, Mapping, Optional, Tuple, Union, cast +from __future__ import annotations + +from typing import Any, Callable, Iterable, Mapping, Optional, Tuple import copy import io @@ -11,12 +12,9 @@ import sys import tempfile import time -import unittest.mock import uuid import warnings -import dask -import dask.distributed import distro import joblib import numpy as np @@ -25,8 +23,9 @@ import pkg_resources import scipy.stats import sklearn.utils -from ConfigSpace.configuration_space import Configuration +from ConfigSpace.configuration_space import Configuration, ConfigurationSpace from ConfigSpace.read_and_write import json as cs_json +from dask.distributed import Client, LocalCluster from scipy.sparse import spmatrix from sklearn.base import BaseEstimator from sklearn.dummy import DummyClassifier, DummyRegressor @@ -39,9 +38,11 @@ ) from sklearn.utils import check_random_state from sklearn.utils.validation import check_is_fitted +from smac.callbacks import IncorporateRunResultCallback from smac.runhistory.runhistory import RunInfo, RunValue from smac.stats.stats import Stats from smac.tae import StatusType +from typing_extensions import Literal from autosklearn.automl_common.common.utils.backend import Backend, create from autosklearn.constants import ( @@ -99,6 +100,8 @@ from autosklearn.util.single_thread_client import SingleThreadedClient from autosklearn.util.stopwatch import StopWatch +import unittest.mock + def _model_predict( model: Any, @@ -138,7 +141,7 @@ def _model_predict( The predictions produced by the model """ # Copy the array and ensure is has the attr 'shape' - X_ = np.asarray(X) if isinstance(X, List) else X.copy() + X_ = np.asarray(X) if isinstance(X, list) else X.copy() assert X_.shape[0] >= 1, f"X must have more than 1 sample but has {X_.shape[0]}" @@ -181,173 +184,139 @@ def _model_predict( class AutoML(BaseEstimator): + """Base class for handling the AutoML procedure""" + def __init__( self, - time_left_for_this_task, - per_run_time_limit, + time_left_for_this_task: int, + per_run_time_limit: int, temporary_directory: Optional[str] = None, delete_tmp_folder_after_terminate: bool = True, - initial_configurations_via_metalearning=25, - ensemble_size=1, - ensemble_nbest=1, - max_models_on_disc=1, - seed=1, - memory_limit=3072, - metadata_directory=None, - debug_mode=False, - include=None, - exclude=None, - resampling_strategy="holdout-iterative-fit", - resampling_strategy_arguments=None, - n_jobs=None, - dask_client: Optional[dask.distributed.Client] = None, - precision=32, - disable_evaluator_output=False, - get_smac_object_callback=None, - smac_scenario_args=None, - logging_config=None, - metric=None, - scoring_functions=None, - get_trials_callback=None, - dataset_compression: Union[bool, Mapping[str, Any]] = True, + initial_configurations_via_metalearning: int = 25, + ensemble_size: int = 1, + ensemble_nbest: int = 1, + max_models_on_disc: int = 1, + seed: int = 1, + memory_limit: int = 3072, + metadata_directory: Optional[str] = None, + include: Optional[dict[str, list[str]]] = None, + exclude: Optional[dict[str, list[str]]] = None, + resampling_strategy: str | Any = "holdout-iterative-fit", + resampling_strategy_arguments: Mapping[str, Any] = None, + n_jobs: Optional[int] = None, + dask_client: Optional[Client] = None, + precision: Literal[16, 32, 64] = 32, + disable_evaluator_output: bool | Iterable[str] = False, + get_smac_object_callback: Optional[Callable] = None, + smac_scenario_args: Optional[Mapping] = None, + logging_config: Optional[Mapping] = None, + metric: Optional[Scorer] = None, + scoring_functions: Optional[list[Scorer]] = None, + get_trials_callback: Optional[IncorporateRunResultCallback] = None, + dataset_compression: bool | Mapping[str, Any] = True, allow_string_features: bool = True, ): - super(AutoML, self).__init__() - self.configuration_space = None - self._backend: Optional[Backend] = None - self._temporary_directory = temporary_directory + super().__init__() + + if isinstance(disable_evaluator_output, Iterable): + disable_evaluator_output = list(disable_evaluator_output) # Incase iterator + allowed = set(["model", "cv_model", "y_optimization", "y_test", "y_valid"]) + unknown = allowed - set(disable_evaluator_output) + if any(unknown): + raise ValueError( + f"Unknown arg {unknown} for '_disable_evaluator_output'," + f" must be one of {allowed}" + ) + + # Validate dataset_compression and set its values + self._dataset_compression: Optional[DatasetCompressionSpec] + if isinstance(dataset_compression, bool): + if dataset_compression is True: + self._dataset_compression = default_dataset_compression_arg + else: + self._dataset_compression = None + else: + self._dataset_compression = validate_dataset_compression_arg( + dataset_compression, + memory_limit=memory_limit, + ) + self._delete_tmp_folder_after_terminate = delete_tmp_folder_after_terminate - # self._tmp_dir = tmp_dir self._time_for_task = time_left_for_this_task self._per_run_time_limit = per_run_time_limit - self._initial_configurations_via_metalearning = ( - initial_configurations_via_metalearning - ) + self._metric = metric self._ensemble_size = ensemble_size self._ensemble_nbest = ensemble_nbest self._max_models_on_disc = max_models_on_disc self._seed = seed self._memory_limit = memory_limit - self._data_memory_limit = None self._metadata_directory = metadata_directory self._include = include self._exclude = exclude self._resampling_strategy = resampling_strategy - self._scoring_functions = ( - scoring_functions if scoring_functions is not None else [] - ) - self._resampling_strategy_arguments = ( - resampling_strategy_arguments - if resampling_strategy_arguments is not None - else {} - ) - self._n_jobs = n_jobs - self._dask_client = dask_client - - self.precision = precision self._disable_evaluator_output = disable_evaluator_output - # Check arguments prior to doing anything! - if not isinstance(self._disable_evaluator_output, (bool, List)): - raise ValueError( - "disable_evaluator_output must be of type bool " "or list." - ) - if isinstance(self._disable_evaluator_output, List): - allowed_elements = [ - "model", - "cv_model", - "y_optimization", - "y_test", - "y_valid", - ] - for element in self._disable_evaluator_output: - if element not in allowed_elements: - raise ValueError( - "List member '%s' for argument " - "'disable_evaluator_output' must be one " - "of " + str(allowed_elements) - ) self._get_smac_object_callback = get_smac_object_callback self._get_trials_callback = get_trials_callback self._smac_scenario_args = smac_scenario_args self.logging_config = logging_config + self.precision = precision + self.allow_string_features = allow_string_features + self._initial_configurations_via_metalearning = ( + initial_configurations_via_metalearning + ) - # Validate dataset_compression and set its values - self._dataset_compression: Optional[DatasetCompressionSpec] - if isinstance(dataset_compression, bool): - if dataset_compression is True: - self._dataset_compression = default_dataset_compression_arg - else: - self._dataset_compression = None + self._scoring_functions = scoring_functions or {} + self._resampling_strategy_arguments = resampling_strategy_arguments or {} + + # Single core, local runs should use fork to prevent the __main__ requirements + # in examples. Nevertheless, multi-process runs have spawn as requirement to + # reduce the possibility of a deadlock + if n_jobs == 1 and dask_client is None: + self._multiprocessing_context = "fork" + self._dask_client = SingleThreadedClient() + self._n_jobs = 1 else: - self._dataset_compression = validate_dataset_compression_arg( - dataset_compression, memory_limit=self._memory_limit - ) - self.allow_string_features = allow_string_features + self._multiprocessing_context = "forkserver" + self._dask_client = dask_client + self._n_jobs = n_jobs + # Create the backend + self._backend: Backend = create( + temporary_directory=temporary_directory, + output_directory=None, + prefix="auto-sklearn", + delete_output_folder_after_terminate=delete_tmp_folder_after_terminate, + ) + + self._data_memory_limit = None # TODO: dead variable? Always None self._datamanager = None self._dataset_name = None self._feat_type = None - self._stopwatch = StopWatch() - self._logger = None + self._logger: Optional[PicklableClientLogger] = None self._task = None - - self._metric = metric - self._label_num = None self._parser = None - self.models_ = None - self.cv_models_ = None - self.ensemble_ = None self._can_predict = False - self._debug_mode = debug_mode - - self.InputValidator = None # type: Optional[InputValidator] + self.models_: Optional[dict] = None + self.cv_models_: Optional[dict] = None + self.ensemble_ = None + self.InputValidator: Optional[InputValidator] = None + self.configuration_space = None # The ensemble performance history through time - self.ensemble_performance_history = [] - self.fitted = False - - # Single core, local runs should use fork - # to prevent the __main__ requirements in - # examples. Nevertheless, multi-process runs - # have spawn as requirement to reduce the - # possibility of a deadlock - self._multiprocessing_context = "forkserver" - if self._n_jobs == 1 and self._dask_client is None: - self._multiprocessing_context = "fork" - self._dask_client = SingleThreadedClient() - - if not isinstance(self._time_for_task, int): - raise ValueError( - "time_left_for_this_task not of type integer, " - "but %s" % str(type(self._time_for_task)) - ) - if not isinstance(self._per_run_time_limit, int): - raise ValueError( - "per_run_time_limit not of type integer, but %s" - % str(type(self._per_run_time_limit)) - ) - - # By default try to use the TCP logging port or get a new port + self._stopwatch = StopWatch() self._logger_port = logging.handlers.DEFAULT_TCP_LOGGING_PORT + self.ensemble_performance_history = [] - # Num_run tell us how many runs have been launched - # It can be seen as an identifier for each configuration - # saved to disk + # Num_run tell us how many runs have been launched. It can be seen as an + # identifier for each configuration saved to disk self.num_run = 0 + self.fitted = False - def _create_backend(self) -> Backend: - return create( - temporary_directory=self._temporary_directory, - output_directory=None, - prefix="auto-sklearn", - delete_tmp_folder_after_terminate=self._delete_tmp_folder_after_terminate, - ) - - def _create_dask_client(self): + def _create_dask_client(self) -> None: self._is_dask_client_internally_created = True - self._dask_client = dask.distributed.Client( - dask.distributed.LocalCluster( + self._dask_client = Client( + LocalCluster( n_workers=self._n_jobs, processes=False, threads_per_worker=1, @@ -365,11 +334,9 @@ def _create_dask_client(self): heartbeat_interval=10000, ) - def _close_dask_client(self): - if ( - hasattr(self, "_is_dask_client_internally_created") - and self._is_dask_client_internally_created - and self._dask_client + def _close_dask_client(self, force: bool = False) -> None: + if getattr(self, "_dask_client", None) is not None and ( + force or getattr(self, "_is_dask_client_internally_created", False) ): self._dask_client.shutdown() self._dask_client.close() @@ -378,7 +345,7 @@ def _close_dask_client(self): self._is_dask_client_internally_created = False del self._is_dask_client_internally_created - def _get_logger(self, name): + def _get_logger(self, name: str) -> PicklableClientLogger: logger_name = "AutoML(%d):%s" % (self._seed, name) # Setup the configuration for the logger @@ -432,7 +399,7 @@ def _get_logger(self, name): port=self._logger_port, ) - def _clean_logger(self): + def _clean_logger(self) -> None: if not hasattr(self, "stop_logging_server") or self.stop_logging_server is None: return @@ -451,28 +418,34 @@ def _clean_logger(self): del self.stop_logging_server @staticmethod - def _start_task(watcher, task_name): + def _start_task(watcher: StopWatch, task_name: str) -> None: watcher.start_task(task_name) @staticmethod - def _stop_task(watcher, task_name): + def _stop_task(watcher: StopWatch, task_name: str) -> None: watcher.stop_task(task_name) @staticmethod - def _print_load_time(basename, time_left_for_this_task, time_for_load_data, logger): - - time_left_after_reading = max(0, time_left_for_this_task - time_for_load_data) - logger.info( - "Remaining time after reading %s %5.2f sec" - % (basename, time_left_after_reading) - ) + def _print_load_time( + basename: str, + time_left_for_this_task: float, + time_for_load_data: float, + logger: PicklableClientLogger, + ) -> float: + time_left = max(0, time_left_for_this_task - time_for_load_data) + logger.info(f"Remaining time after reading {basename} {time_left:5.2f} sec") return time_for_load_data - def _do_dummy_prediction(self, datamanager: XYDataManager, num_run: int) -> int: - + def _do_dummy_prediction(self) -> None: # When using partial-cv it makes no sense to do dummy predictions if self._resampling_strategy in ["partial-cv", "partial-cv-iterative-fit"]: - return num_run + return + + if self._metric is None: + raise ValueError("Metric was not set") + + # Dummy prediction always have num_run set to 1 + dummy_run_num = 1 self._logger.info("Starting to create dummy predictions.") @@ -491,7 +464,7 @@ def _do_dummy_prediction(self, datamanager: XYDataManager, num_run: int) -> int: autosklearn_seed=self._seed, multi_objectives=["cost"], resampling_strategy=self._resampling_strategy, - initial_num_run=num_run, + initial_num_run=dummy_run_num, stats=stats, metric=self._metric, memory_limit=memory_limit, @@ -504,7 +477,8 @@ def _do_dummy_prediction(self, datamanager: XYDataManager, num_run: int) -> int: ) status, cost, runtime, additional_info = ta.run( - num_run, cutoff=self._time_for_task + config=dummy_run_num, + cutoff=self._time_for_task, ) if status == StatusType.SUCCESS: self._logger.info("Finished creating dummy predictions.") @@ -528,7 +502,7 @@ def _do_dummy_prediction(self, datamanager: XYDataManager, num_run: int) -> int: self._logger.error(msg) raise ValueError(msg) - return num_run + return @classmethod def _task_type_id(cls, task_type: str) -> int: @@ -545,7 +519,7 @@ def fit( task: Optional[int] = None, X_test: Optional[SUPPORTED_FEAT_TYPES] = None, y_test: Optional[SUPPORTED_TARGET_TYPES] = None, - feat_type: Optional[List[str]] = None, + feat_type: Optional[list[str]] = None, dataset_name: Optional[str] = None, only_return_configuration_space: bool = False, load_models: bool = True, @@ -590,7 +564,6 @@ def fit( Parameters ---------- - X : {array-like, sparse matrix}, shape (n_samples, n_features) The training input samples. @@ -610,7 +583,7 @@ def fit( of all models. This allows to evaluate the performance of Auto-sklearn over time. - feat_type : Optional[List], + feat_type : Optional[list], List of str of `len(X.shape[1])` describing the attribute type. Possible types are `Categorical` and `Numerical`. `Categorical` attributes will be automatically One-Hot encoded. The values @@ -637,7 +610,6 @@ def fit( Returns ------- self - """ if (X_test is not None) ^ (y_test is not None): raise ValueError("Must provide both X_test and y_test together") @@ -664,9 +636,6 @@ def fit( if dataset_name is None: dataset_name = str(uuid.uuid1(clock_seq=os.getpid())) - # Create the backend - self._backend = self._create_backend() - # By default try to use the TCP logging port or get a new port self._logger_port = logging.handlers.DEFAULT_TCP_LOGGING_PORT self._logger = self._get_logger(dataset_name) @@ -707,9 +676,7 @@ def fit( memory_allocation = self._dataset_compression["memory_allocation"] # Remove precision reduction if we can't perform it - if X.dtype not in supported_precision_reductions and "precision" in cast( - List[str], methods - ): # Removable with TypedDict + if "precision" in methods and X.dtype not in supported_precision_reductions: methods = [method for method in methods if method != "precision"] with warnings_to(self._logger): @@ -764,96 +731,11 @@ def fit( # Take the feature types from the validator self._feat_type = self.InputValidator.feature_validator.feat_type - # Produce debug information to the logfile - self._logger.debug("Starting to print environment information") - self._logger.debug(" Python version: %s", sys.version.split("\n")) - try: - self._logger.debug( - f"\tDistribution: {distro.id()}-{distro.version()}-{distro.name()}" - ) - except AttributeError: - pass - - self._logger.debug(" System: %s", platform.system()) - self._logger.debug(" Machine: %s", platform.machine()) - self._logger.debug(" Platform: %s", platform.platform()) - # UNAME appears to leak sensible information - # self._logger.debug(' uname: %s', platform.uname()) - self._logger.debug(" Version: %s", platform.version()) - self._logger.debug(" Mac version: %s", platform.mac_ver()) - requirements = pkg_resources.resource_string("autosklearn", "requirements.txt") - requirements = requirements.decode("utf-8") - requirements = [requirement for requirement in requirements.split("\n")] - for requirement in requirements: - if not requirement: - continue - match = RE_PATTERN.match(requirement) - if match: - name = match.group("name") - module_dist = pkg_resources.get_distribution(name) - self._logger.debug(" %s", module_dist) - else: - raise ValueError("Unable to read requirement: %s" % requirement) - self._logger.debug("Done printing environment information") - self._logger.debug("Starting to print arguments to auto-sklearn") - self._logger.debug( - " tmp_folder: %s", self._backend.context._temporary_directory - ) - self._logger.debug(" time_left_for_this_task: %f", self._time_for_task) - self._logger.debug(" per_run_time_limit: %f", self._per_run_time_limit) - self._logger.debug( - " initial_configurations_via_metalearning: %d", - self._initial_configurations_via_metalearning, - ) - self._logger.debug(" ensemble_size: %d", self._ensemble_size) - self._logger.debug(" ensemble_nbest: %f", self._ensemble_nbest) - self._logger.debug(" max_models_on_disc: %s", str(self._max_models_on_disc)) - self._logger.debug(" seed: %d", self._seed) - self._logger.debug(" memory_limit: %s", str(self._memory_limit)) - self._logger.debug(" metadata_directory: %s", self._metadata_directory) - self._logger.debug(" debug_mode: %s", self._debug_mode) - self._logger.debug(" include: %s", str(self._include)) - self._logger.debug(" exclude: %s", str(self._exclude)) - self._logger.debug(" resampling_strategy: %s", str(self._resampling_strategy)) - self._logger.debug( - " resampling_strategy_arguments: %s", - str(self._resampling_strategy_arguments), - ) - self._logger.debug(" n_jobs: %s", str(self._n_jobs)) - self._logger.debug( - " multiprocessing_context: %s", str(self._multiprocessing_context) - ) - self._logger.debug(" dask_client: %s", str(self._dask_client)) - self._logger.debug(" precision: %s", str(self.precision)) - self._logger.debug( - " disable_evaluator_output: %s", str(self._disable_evaluator_output) - ) - self._logger.debug( - " get_smac_objective_callback: %s", str(self._get_smac_object_callback) - ) - self._logger.debug(" smac_scenario_args: %s", str(self._smac_scenario_args)) - self._logger.debug(" logging_config: %s", str(self.logging_config)) - self._logger.debug(" metric: %s", str(self._metric)) - self._logger.debug("Done printing arguments to auto-sklearn") - self._logger.debug("Starting to print available components") - for choice in ( - ClassifierChoice, - RegressorChoice, - FeaturePreprocessorChoice, - OHEChoice, - RescalingChoice, - CoalescenseChoice, - ): - self._logger.debug( - "%s: %s", - choice.__name__, - choice.get_components(), - ) - self._logger.debug("Done printing available components") + self._log_fit_setup() datamanager = XYDataManager( - X, - y, + X=X, + y=y, X_test=X_test, y_test=y_test, task=self._task, @@ -867,16 +749,6 @@ def fit( # == Pickle the data manager to speed up loading self._backend.save_datamanager(datamanager) - time_for_load_data = self._stopwatch.wall_elapsed(self._dataset_name) - - if self._debug_mode: - self._print_load_time( - self._dataset_name, - self._time_for_task, - time_for_load_data, - self._logger, - ) - # = Create a searchspace # Do this before One Hot Encoding to make sure that it creates a # search space for a dense classifier even if one hot encoding would @@ -896,8 +768,8 @@ def fit( return self.configuration_space # == Perform dummy predictions - # Dummy prediction always have num_run set to 1 - self.num_run += self._do_dummy_prediction(datamanager, num_run=1) + self.num_run += 1 + self._do_dummy_prediction() # == RUN ensemble builder # Do this before calculating the meta-features to make sure that the @@ -1036,9 +908,12 @@ def fit( list(entry[:2]) + [entry[2].get_dictionary()] + list(entry[3:]) for entry in self.trajectory_ ] + with open(trajectory_filename, "w") as fh: json.dump(saveable_trajectory, fh) + except Exception as e: + self._fit_cleanup() self._logger.exception(e) raise @@ -1078,10 +953,96 @@ def fit( return self + def _log_fit_setup(self) -> None: + # Produce debug information to the logfile + self._logger.debug("Starting to print environment information") + self._logger.debug(" Python version: %s", sys.version.split("\n")) + try: + self._logger.debug( + f"\tDistribution: {distro.id()}-{distro.version()}-{distro.name()}" + ) + except AttributeError: + pass + + self._logger.debug(" System: %s", platform.system()) + self._logger.debug(" Machine: %s", platform.machine()) + self._logger.debug(" Platform: %s", platform.platform()) + # UNAME appears to leak sensible information + # self._logger.debug(' uname: %s', platform.uname()) + self._logger.debug(" Version: %s", platform.version()) + self._logger.debug(" Mac version: %s", platform.mac_ver()) + requirements = pkg_resources.resource_string("autosklearn", "requirements.txt") + requirements = requirements.decode("utf-8") + requirements = [requirement for requirement in requirements.split("\n")] + for requirement in requirements: + if not requirement: + continue + match = RE_PATTERN.match(requirement) + if match: + name = match.group("name") + module_dist = pkg_resources.get_distribution(name) + self._logger.debug(" %s", module_dist) + else: + raise ValueError("Unable to read requirement: %s" % requirement) + + self._logger.debug("Done printing environment information") + self._logger.debug("Starting to print arguments to auto-sklearn") + self._logger.debug(" tmp_folder: %s", self._backend.temporary_directory) + self._logger.debug(" time_left_for_this_task: %f", self._time_for_task) + self._logger.debug(" per_run_time_limit: %f", self._per_run_time_limit) + self._logger.debug( + " initial_configurations_via_metalearning: %d", + self._initial_configurations_via_metalearning, + ) + self._logger.debug(" ensemble_size: %d", self._ensemble_size) + self._logger.debug(" ensemble_nbest: %f", self._ensemble_nbest) + self._logger.debug(" max_models_on_disc: %s", str(self._max_models_on_disc)) + self._logger.debug(" seed: %d", self._seed) + self._logger.debug(" memory_limit: %s", str(self._memory_limit)) + self._logger.debug(" metadata_directory: %s", self._metadata_directory) + self._logger.debug(" include: %s", str(self._include)) + self._logger.debug(" exclude: %s", str(self._exclude)) + self._logger.debug(" resampling_strategy: %s", str(self._resampling_strategy)) + self._logger.debug( + " resampling_strategy_arguments: %s", + str(self._resampling_strategy_arguments), + ) + self._logger.debug(" n_jobs: %s", str(self._n_jobs)) + self._logger.debug( + " multiprocessing_context: %s", str(self._multiprocessing_context) + ) + self._logger.debug(" dask_client: %s", str(self._dask_client)) + self._logger.debug(" precision: %s", str(self.precision)) + self._logger.debug( + " disable_evaluator_output: %s", str(self._disable_evaluator_output) + ) + self._logger.debug( + " get_smac_objective_callback: %s", str(self._get_smac_object_callback) + ) + self._logger.debug(" smac_scenario_args: %s", str(self._smac_scenario_args)) + self._logger.debug(" logging_config: %s", str(self.logging_config)) + self._logger.debug(" metric: %s", str(self._metric)) + self._logger.debug("Done printing arguments to auto-sklearn") + self._logger.debug("Starting to print available components") + for choice in ( + ClassifierChoice, + RegressorChoice, + FeaturePreprocessorChoice, + OHEChoice, + RescalingChoice, + CoalescenseChoice, + ): + self._logger.debug( + "%s: %s", + choice.__name__, + choice.get_components(), + ) + self._logger.debug("Done printing available components") + def __sklearn_is_fitted__(self) -> bool: return self.fitted - def _fit_cleanup(self): + def _fit_cleanup(self) -> None: self._logger.info("Closing the dask infrastructure") self._close_dask_client() self._logger.info("Finished closing the dask infrastructure") @@ -1170,16 +1131,36 @@ def _check_resampling_strategy( return - def refit(self, X, y): - # AutoSklearn does not handle sparse y for now - y = convert_if_sparse(y) + def refit( + self, + X: SUPPORTED_FEAT_TYPES, + y: SUPPORTED_TARGET_TYPES, + max_reshuffles: int = 10, + ) -> AutoML: + """Refit the models to a new given set of data + + Parameters + ---------- + X : SUPPORTED_FEAT_TYPES + The data to dit to + + y : SUPPORTED_TARGET_TYPES + The targets to fit to + + max_reshuffles : int = 10 + How many times to try reshuffle the data. If fitting fails, shuffle the + data. This can alleviate the problem in algorithms that depend on the + ordering of the data. + + Returns + ------- + AutoML + Self + """ + check_is_fitted(self) + y = convert_if_sparse(y) # AutoSklearn does not handle sparse y for now # Make sure input data is valid - if self.InputValidator is None or not self.InputValidator._is_fitted: - raise ValueError( - "refit() is only supported after calling fit. Kindly call first " - "the estimator fit() method." - ) X, y = self.InputValidator.transform(X, y) if self.models_ is None or len(self.models_) == 0 or self.ensemble_ is None: @@ -1190,15 +1171,9 @@ def refit(self, X, y): raise ValueError("Refit can only be called if 'ensemble_size != 0'") random_state = check_random_state(self._seed) - for identifier in self.models_: - model = self.models_[identifier] - # this updates the model inplace, it can then later be used in - # predict method - - # try to fit the model. If it fails, shuffle the data. This - # could alleviate the problem in algorithms that depend on - # the ordering of the data. - for i in range(10): + + for identifier, model in self.models_.items(): + for i in range(max_reshuffles): try: if self._budget_type is None: _fit_and_suppress_warnings(self._logger, model, X, y) @@ -1220,7 +1195,7 @@ def refit(self, X, y): X = X[indices] y = y[indices] - if i == 9: + if i == (max_reshuffles - 1): raise e self._can_predict = True @@ -1229,15 +1204,15 @@ def refit(self, X, y): def fit_pipeline( self, X: SUPPORTED_FEAT_TYPES, - y: Union[SUPPORTED_TARGET_TYPES, spmatrix], + y: SUPPORTED_TARGET_TYPES | spmatrix, is_classification: bool, - config: Union[Configuration, Dict[str, Union[str, float, int]]], + config: Configuration | dict[str, str | float | int], task: Optional[int] = None, dataset_name: Optional[str] = None, X_test: Optional[SUPPORTED_FEAT_TYPES] = None, - y_test: Optional[Union[SUPPORTED_TARGET_TYPES, spmatrix]] = None, - feat_type: Optional[List[str]] = None, - **kwargs: Dict, + y_test: Optional[SUPPORTED_TARGET_TYPES | spmatrix] = None, + feat_type: Optional[list[str]] = None, + **kwargs: dict, ) -> Tuple[Optional[BasePipeline], RunInfo, RunValue]: """Fits and individual pipeline configuration and returns the result to the user. @@ -1258,7 +1233,7 @@ def fit_pipeline( If provided, the testing performance will be tracked on this features. y_test: array-like If provided, the testing performance will be tracked on this labels - config: Union[Configuration, Dict[str, Union[str, float, int]]] + config: Configuration | dict[str, str | float | int] A configuration object used to define the pipeline steps. If a dict is passed, a configuration is created based on this dict. dataset_name: Optional[str] @@ -1403,6 +1378,8 @@ def predict(self, X, batch_size=None, n_jobs=1): Parallelize the predictions across the models with n_jobs processes. """ + check_is_fitted(self) + if ( self._resampling_strategy not in ("holdout", "holdout-iterative-fit", "cv", "cv-iterative-fit") @@ -1410,7 +1387,7 @@ def predict(self, X, batch_size=None, n_jobs=1): ): raise NotImplementedError( "Predict is currently not implemented for resampling " - "strategy %s, please call refit()." % self._resampling_strategy + f"strategy {self._resampling_strategy}, please call refit()." ) if self.models_ is None or len(self.models_) == 0 or self.ensemble_ is None: @@ -1485,16 +1462,25 @@ def predict(self, X, batch_size=None, n_jobs=1): def fit_ensemble( self, - y, - task=None, - precision=32, - dataset_name=None, - ensemble_nbest=None, - ensemble_size=None, + y: SUPPORTED_TARGET_TYPES, + task: Optional[int] = None, + precision: Literal[16, 32, 64] = 32, + dataset_name: Optional[str] = None, + ensemble_nbest: Optional[int] = None, + ensemble_size: Optional[int] = None, ): + check_is_fitted(self) + # check for the case when ensemble_size is less than 0 - if not ensemble_size > 0: - raise ValueError("ensemble_size must be greater than 0 for fit_ensemble") + if ensemble_size is not None and ensemble_size <= 0: + raise ValueError("`ensemble_size` must be >= 0 for `fit_ensemble`") + + if ensemble_size is None and ( + self._ensemble_size is None or self._ensemble_size <= 0 + ): + raise ValueError( + "Please pass `ensemble_size` to `fit_ensemble` if not setting in init" + ) # AutoSklearn does not handle sparse y for now y = convert_if_sparse(y) @@ -1509,11 +1495,6 @@ def fit_ensemble( self._logger = self._get_logger(dataset_name) # Make sure that input is valid - if self.InputValidator is None or not self.InputValidator._is_fitted: - raise ValueError( - "fit_ensemble() can only be called after fit. Please call the " - "estimator fit() method prior to fit_ensemble()." - ) y = self.InputValidator.target_validator.transform(y) # Create a client if needed @@ -1553,13 +1534,17 @@ def fit_ensemble( "line output for error messages." ) self.ensemble_performance_history, _, _, _, _ = result + self._ensemble_size = ensemble_size self._load_models() self._close_dask_client() return self def _load_models(self): - self.ensemble_ = self._backend.load_ensemble(self._seed) + if self._ensemble_size > 0: + self.ensemble_ = self._backend.load_ensemble(self._seed) + else: + self.ensemble_ = None # If no ensemble is loaded, try to get the best performing model if not self.ensemble_: @@ -1568,17 +1553,20 @@ def _load_models(self): if self.ensemble_: identifiers = self.ensemble_.get_selected_model_identifiers() self.models_ = self._backend.load_models_by_identifiers(identifiers) + if self._resampling_strategy in ("cv", "cv-iterative-fit"): self.cv_models_ = self._backend.load_cv_models_by_identifiers( identifiers ) else: self.cv_models_ = None + if len(self.models_) == 0 and self._resampling_strategy not in [ "partial-cv", "partial-cv-iterative-fit", ]: raise ValueError("No models fitted!") + if ( self._resampling_strategy in ["cv", "cv-iterative-fit"] and len(self.cv_models_) == 0 @@ -1586,7 +1574,7 @@ def _load_models(self): raise ValueError("No models fitted!") elif self._disable_evaluator_output is False or ( - isinstance(self._disable_evaluator_output, List) + isinstance(self._disable_evaluator_output, list) and "model" not in self._disable_evaluator_output ): model_names = self._backend.list_all_models(self._seed) @@ -1597,10 +1585,14 @@ def _load_models(self): ]: raise ValueError("No models fitted!") + self.ensemble_ = None self.models_ = [] + self.cv_models_ = None else: + self.ensemble_ = None self.models_ = [] + self.cv_models_ = None def _load_best_individual_model(self): """ @@ -1610,7 +1602,6 @@ def _load_best_individual_model(self): This is a robust mechanism to be able to predict, even though no ensemble was found by ensemble builder. """ - # We also require that the model is fit and a task is defined # The ensemble size must also be greater than 1, else it means # that the user intentionally does not want an ensemble @@ -1641,14 +1632,8 @@ def score(self, X, y): # The reason is we do not want to trigger the # check for changing input types on successive # input validator calls + check_is_fitted(self) prediction = self.predict(X) - - # Make sure that input is valid - if self.InputValidator is None or not self.InputValidator._is_fitted: - raise ValueError( - "score() is only supported after calling fit. Kindly call first " - "the estimator fit() method." - ) y = self.InputValidator.target_validator.transform(y) # Encode the prediction using the input validator @@ -1701,6 +1686,7 @@ def _get_runhistory_models_performance(self): @property def performance_over_time_(self): + check_is_fitted(self) individual_performance_frame = self._get_runhistory_models_performance() best_values = pd.Series( { @@ -1747,6 +1733,7 @@ def performance_over_time_(self): @property def cv_results_(self): + check_is_fitted(self) results = dict() # Missing in contrast to scikit-learn @@ -1868,7 +1855,8 @@ def cv_results_(self): return results - def sprint_statistics(self): + def sprint_statistics(self) -> str: + check_is_fitted(self) cv_results = self.cv_results_ sio = io.StringIO() sio.write("auto-sklearn results:\n") @@ -1913,13 +1901,14 @@ def sprint_statistics(self): ) return sio.getvalue() - def get_models_with_weights(self): + def get_models_with_weights(self) -> list[Tuple[float, BasePipeline]]: + check_is_fitted(self) if self.models_ is None or len(self.models_) == 0 or self.ensemble_ is None: self._load_models() return self.ensemble_.get_models_with_weights(self.models_) - def show_models(self) -> Dict[int, Any]: + def show_models(self) -> dict[int, Any]: """Returns a dictionary containing dictionaries of ensemble models. Each model in the ensemble can be accessed by giving its ``model_id`` as key. @@ -1984,13 +1973,12 @@ def show_models(self) -> Dict[int, Any]: Returns ------- - Dict(int, Any) : dictionary of length = number of models in the ensemble + dict[int, Any] : dictionary of length = number of models in the ensemble A dictionary of models in the ensemble, where ``model_id`` is the key. """ # noqa: E501 + check_is_fitted(self) + ensemble_dict = {} - # check for condition whether autosklearn is fitted if not raise runtime error - if not self.__sklearn_is_fitted__(): - raise RuntimeError("AutoSklearn has not been fitted") # check for ensemble_size == 0 if self._ensemble_size == 0: @@ -2082,12 +2070,12 @@ def has_key(rv, key): def _create_search_space( self, - tmp_dir, - backend, - datamanager, - include: Optional[Dict[str, List[str]]] = None, - exclude: Optional[Dict[str, List[str]]] = None, - ): + tmp_dir: str, + backend: Backend, + datamanager: XYDataManager, + include: Optional[Mapping[str, list[str]]] = None, + exclude: Optional[Mapping[str, list[str]]] = None, + ) -> Tuple[ConfigurationSpace, str]: task_name = "CreateConfigSpace" self._stopwatch.start_task(task_name) @@ -2097,27 +2085,23 @@ def _create_search_space( include=include, exclude=exclude, ) - configuration_space = self.configuration_space_created_hook( - datamanager, configuration_space - ) backend.write_txt_file( - configspace_path, cs_json.write(configuration_space), "Configuration space" + configspace_path, + cs_json.write(configuration_space), + "Configuration space", ) self._stopwatch.stop_task(task_name) return configuration_space, configspace_path - def configuration_space_created_hook(self, datamanager, configuration_space): - return configuration_space - - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: # Cannot serialize a client! self._dask_client = None self.logging_server = None self.stop_logging_server = None return self.__dict__ - def __del__(self): + def __del__(self) -> None: # Clean up the logger self._clean_logger() @@ -2143,14 +2127,14 @@ def _supports_task_type(cls, task_type: str) -> bool: def fit( self, X: SUPPORTED_FEAT_TYPES, - y: Union[SUPPORTED_TARGET_TYPES, spmatrix], + y: SUPPORTED_TARGET_TYPES | spmatrix, X_test: Optional[SUPPORTED_FEAT_TYPES] = None, - y_test: Optional[Union[SUPPORTED_TARGET_TYPES, spmatrix]] = None, - feat_type: Optional[List[bool]] = None, + y_test: Optional[SUPPORTED_TARGET_TYPES | spmatrix] = None, + feat_type: Optional[list[bool]] = None, dataset_name: Optional[str] = None, only_return_configuration_space: bool = False, load_models: bool = True, - ): + ) -> AutoMLClassifier: return super().fit( X, y, @@ -2166,12 +2150,12 @@ def fit( def fit_pipeline( self, X: SUPPORTED_FEAT_TYPES, - y: Union[SUPPORTED_TARGET_TYPES, spmatrix], - config: Union[Configuration, Dict[str, Union[str, float, int]]], + y: SUPPORTED_TARGET_TYPES | spmatrix, + config: Configuration | dict[str, str | float | int], dataset_name: Optional[str] = None, X_test: Optional[SUPPORTED_FEAT_TYPES] = None, - y_test: Optional[Union[SUPPORTED_TARGET_TYPES, spmatrix]] = None, - feat_type: Optional[List[str]] = None, + y_test: Optional[SUPPORTED_TARGET_TYPES | spmatrix] = None, + feat_type: Optional[list[str]] = None, **kwargs, ) -> Tuple[Optional[BasePipeline], RunInfo, RunValue]: return super().fit_pipeline( @@ -2186,16 +2170,18 @@ def fit_pipeline( **kwargs, ) - def predict(self, X, batch_size=None, n_jobs=1): + def predict( + self, + X: SUPPORTED_FEAT_TYPES, + batch_size: Optional[int] = None, + n_jobs: int = 1, + ) -> np.ndarray: + check_is_fitted(self) + predicted_probabilities = super().predict( X, batch_size=batch_size, n_jobs=n_jobs ) - if self.InputValidator is None or not self.InputValidator._is_fitted: - raise ValueError( - "predict() is only supported after calling fit. Kindly call first " - "the estimator fit() method." - ) if self.InputValidator.target_validator.is_single_column_target(): predicted_indexes = np.argmax(predicted_probabilities, axis=1) else: @@ -2203,7 +2189,12 @@ def predict(self, X, batch_size=None, n_jobs=1): return self.InputValidator.target_validator.inverse_transform(predicted_indexes) - def predict_proba(self, X, batch_size=None, n_jobs=1): + def predict_proba( + self, + X: SUPPORTED_FEAT_TYPES, + batch_size: Optional[int] = None, + n_jobs: int = 1, + ) -> np.ndarray: return super().predict(X, batch_size=batch_size, n_jobs=n_jobs) @@ -2226,14 +2217,14 @@ def _supports_task_type(cls, task_type: str) -> bool: def fit( self, X: SUPPORTED_FEAT_TYPES, - y: Union[SUPPORTED_TARGET_TYPES, spmatrix], + y: SUPPORTED_TARGET_TYPES | spmatrix, X_test: Optional[SUPPORTED_FEAT_TYPES] = None, - y_test: Optional[Union[SUPPORTED_TARGET_TYPES, spmatrix]] = None, - feat_type: Optional[List[bool]] = None, + y_test: Optional[SUPPORTED_TARGET_TYPES | spmatrix] = None, + feat_type: Optional[list[bool]] = None, dataset_name: Optional[str] = None, only_return_configuration_space: bool = False, load_models: bool = True, - ): + ) -> AutoMLRegressor: return super().fit( X, y, @@ -2249,13 +2240,13 @@ def fit( def fit_pipeline( self, X: SUPPORTED_FEAT_TYPES, - y: Union[SUPPORTED_TARGET_TYPES, spmatrix], - config: Union[Configuration, Dict[str, Union[str, float, int]]], + y: SUPPORTED_TARGET_TYPES | spmatrix, + config: Configuration | dict[str, str | float | int], dataset_name: Optional[str] = None, X_test: Optional[SUPPORTED_FEAT_TYPES] = None, - y_test: Optional[Union[SUPPORTED_TARGET_TYPES, spmatrix]] = None, - feat_type: Optional[List[str]] = None, - **kwargs: Dict, + y_test: Optional[SUPPORTED_TARGET_TYPES | spmatrix] = None, + feat_type: Optional[list[str]] = None, + **kwargs: dict, ) -> Tuple[Optional[BasePipeline], RunInfo, RunValue]: return super().fit_pipeline( X=X, diff --git a/autosklearn/evaluation/__init__.py b/autosklearn/evaluation/__init__.py index 8eb997c571..89c61d144d 100644 --- a/autosklearn/evaluation/__init__.py +++ b/autosklearn/evaluation/__init__.py @@ -324,6 +324,7 @@ def run( if not (instance_specific is None or instance_specific == "0"): raise ValueError(instance_specific) + init_params = {"instance": instance} if self.init_params is not None: init_params.update(self.init_params) @@ -542,6 +543,7 @@ def run( else: origin = getattr(config, "origin", "UNKNOWN") config_id = config.config_id + additional_run_info["configuration_origin"] = origin runtime = float(obj.wall_clock_time) diff --git a/autosklearn/evaluation/abstract_evaluator.py b/autosklearn/evaluation/abstract_evaluator.py index bc0be0e8d8..7843de6a8a 100644 --- a/autosklearn/evaluation/abstract_evaluator.py +++ b/autosklearn/evaluation/abstract_evaluator.py @@ -30,9 +30,6 @@ ) from autosklearn.util.logging_ import PicklableClientLogger, get_named_client_logger -__all__ = ["AbstractEvaluator"] - - # General TYPE definitions for numpy TYPE_ADDITIONAL_INFO = Dict[str, Union[int, float, str, Dict, List, Tuple]] @@ -49,9 +46,10 @@ def __init__( ): self.config = config if config == 1: - super(MyDummyClassifier, self).__init__(strategy="uniform") + super().__init__(strategy="uniform") else: - super(MyDummyClassifier, self).__init__(strategy="most_frequent") + super().__init__(strategy="most_frequent") + self.random_state = random_state self.init_params = init_params self.dataset_properties = dataset_properties @@ -59,8 +57,11 @@ def __init__( self.exclude = exclude def pre_transform( - self, X: np.ndarray, y: np.ndarray, fit_params: Optional[Dict[str, Any]] = None - ) -> Tuple[np.ndarray, Dict[str, Any]]: # pylint: disable=R0201 + self, + X: np.ndarray, + y: np.ndarray, + fit_params: Optional[Dict[str, Any]] = None, + ) -> Tuple[np.ndarray, Dict[str, Any]]: if fit_params is None: fit_params = {} return X, fit_params @@ -76,22 +77,23 @@ def fit( ) def fit_estimator( - self, X: np.ndarray, y: np.ndarray, fit_params: Optional[Dict[str, Any]] = None + self, + X: np.ndarray, + y: np.ndarray, + fit_params: Optional[Dict[str, Any]] = None, ) -> DummyClassifier: return self.fit(X, y) def predict_proba(self, X: np.ndarray, batch_size: int = 1000) -> np.ndarray: new_X = np.ones((X.shape[0], 1)) - probas = super(MyDummyClassifier, self).predict_proba(new_X) + probas = super().predict_proba(new_X) probas = convert_multioutput_multiclass_to_multilabel(probas).astype(np.float32) return probas - def estimator_supports_iterative_fit(self) -> bool: # pylint: disable=R0201 + def estimator_supports_iterative_fit(self) -> bool: return False - def get_additional_run_info( - self, - ) -> Optional[TYPE_ADDITIONAL_INFO]: # pylint: disable=R0201 + def get_additional_run_info(self) -> Optional[TYPE_ADDITIONAL_INFO]: return None @@ -107,9 +109,9 @@ def __init__( ): self.config = config if config == 1: - super(MyDummyRegressor, self).__init__(strategy="mean") + super().__init__(strategy="mean") else: - super(MyDummyRegressor, self).__init__(strategy="median") + super().__init__(strategy="median") self.random_state = random_state self.init_params = init_params self.dataset_properties = dataset_properties @@ -117,8 +119,11 @@ def __init__( self.exclude = exclude def pre_transform( - self, X: np.ndarray, y: np.ndarray, fit_params: Optional[Dict[str, Any]] = None - ) -> Tuple[np.ndarray, Dict[str, Any]]: # pylint: disable=R0201 + self, + X: np.ndarray, + y: np.ndarray, + fit_params: Optional[Dict[str, Any]] = None, + ) -> Tuple[np.ndarray, Dict[str, Any]]: if fit_params is None: fit_params = {} return X, fit_params @@ -129,25 +134,24 @@ def fit( y: np.ndarray, sample_weight: Optional[Union[np.ndarray, List]] = None, ) -> DummyRegressor: - return super(MyDummyRegressor, self).fit( - np.ones((X.shape[0], 1)), y, sample_weight=sample_weight - ) + return super().fit(np.ones((X.shape[0], 1)), y, sample_weight=sample_weight) def fit_estimator( - self, X: np.ndarray, y: np.ndarray, fit_params: Optional[Dict[str, Any]] = None + self, + X: np.ndarray, + y: np.ndarray, + fit_params: Optional[Dict[str, Any]] = None, ) -> DummyRegressor: return self.fit(X, y) def predict(self, X: np.ndarray, batch_size: int = 1000) -> np.ndarray: new_X = np.ones((X.shape[0], 1)) - return super(MyDummyRegressor, self).predict(new_X).astype(np.float32) + return super().predict(new_X).astype(np.float32) - def estimator_supports_iterative_fit(self) -> bool: # pylint: disable=R0201 + def estimator_supports_iterative_fit(self) -> bool: return False - def get_additional_run_info( - self, - ) -> Optional[TYPE_ADDITIONAL_INFO]: # pylint: disable=R0201 + def get_additional_run_info(self) -> Optional[TYPE_ADDITIONAL_INFO]: return None diff --git a/autosklearn/pipeline/util.py b/autosklearn/pipeline/util.py index 228c31357d..38cf1aa344 100644 --- a/autosklearn/pipeline/util.py +++ b/autosklearn/pipeline/util.py @@ -2,7 +2,6 @@ import inspect import os import pkgutil -import unittest import numpy as np import scipy.sparse @@ -10,6 +9,8 @@ import sklearn.base import sklearn.datasets +import unittest + def find_sklearn_classes(class_): classifiers = set() diff --git a/autosklearn/util/functional.py b/autosklearn/util/functional.py new file mode 100644 index 0000000000..55f38ddf5d --- /dev/null +++ b/autosklearn/util/functional.py @@ -0,0 +1,54 @@ +from typing import Optional + +import numpy as np + + +def normalize(x: np.ndarray, axis: Optional[int] = None) -> np.ndarray: + """Normalizes an array along an axis + + Note + ---- + TODO: Only works for positive numbers + + ..code:: python + + x = np.ndarray([ + [1, 1, 1], + [2, 2, 2], + [7, 7, 7], + ]) + + print(normalize(x, axis=0)) + + np.ndarray([ + [.1, .1, .1] + [.2, .2, .2] + [.7, .7, .7] + ]) + + print(normalize(x, axis=1)) + + np.ndarray([ + [.333, .333, .333] + [.333, .333, .333] + [.333, .333, .333] + ]) + + Note + ---- + Does not account for 0 sums along an axis + + Parameters + ---------- + x : np.ndarray + The array to normalize + + axis : Optional[int] = None + The axis to normalize across + + Returns + ------- + np.ndarray + The normalized array + """ + return x / x.sum(axis=axis, keepdims=True) diff --git a/autosklearn/util/logging_.py b/autosklearn/util/logging_.py index a85e4a80d6..e9ee676b79 100644 --- a/autosklearn/util/logging_.py +++ b/autosklearn/util/logging_.py @@ -1,4 +1,7 @@ -# -*- encoding: utf-8 -*- +""" +For accessing a logger, please default to using +`get_named_client_logger(name, host, port)` +""" from typing import Any, Dict, Iterator, Optional, TextIO, Type, cast import logging @@ -78,7 +81,7 @@ def get_logger(name: str) -> "PickableLoggerAdapter": return logger -class PickableLoggerAdapter(object): +class PickableLoggerAdapter: def __init__(self, name: str): self.name = name self.logger = _create_logger(name) diff --git a/pyproject.toml b/pyproject.toml index 0e48e3fc5f..3e3aafb1f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,9 +28,19 @@ py_version = "37" profile = "black" # Play nicely with black src_paths = ["autosklearn", "test"] known_types = ["typing", "abc"] # We put these in their own section TYPES +known_testlibs = ["unittest", "pytest", "pytest_cases"] # Put test libs in their own section known_first_party = ["autosklearn"] # Say that autosklearn is FIRSTPARTY known_test = ["test"] # Say that test.* is TEST -sections = ["FUTURE", "TYPES", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "TEST", "LOCALFOLDER"] # section ordering +sections = [ + "FUTURE", + "TYPES", + "STDLIB", + "THIRDPARTY", + "FIRSTPARTY", + "TESTLIBS", + "TEST", + "LOCALFOLDER" +] # section ordering multi_line_output = 3 # https://pycqa.github.io/isort/docs/configuration/multi_line_output_modes.html [tool.pydocstyle] @@ -147,6 +157,10 @@ module = [ "setuptools.*", "pkg_resources.*", "yaml.*", + "psutil.*" ] ignore_missing_imports = true +[[tool.mypy.overrides]] +module = ["test.*"] +disallow_untyped_decorators = false # Test decorators are not properly typed diff --git a/setup.py b/setup.py index 003b573bd4..e182cd716b 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,7 @@ "pytest-cov", "pytest-xdist", "pytest-timeout", + "pytest-cases", "mypy", "isort", "black", diff --git a/test/conftest.py b/test/conftest.py index 16a285b9df..36cc77b3ad 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,167 +1,227 @@ -import os -import shutil -import time -import unittest.mock +""" +Testing +======= +The following are some features, guidelines and functionality for testing which makes +updating, adding and refactoring tests easier, especially as features and functionality +changes. + +**Marks** +* todo - ``pytest.mark.todo``` to mark a test which xfails as it's todo +* slow - ``pytest.mark.slow``` to mark a test which is skipped if `pytest --fast` + +**Documenting Tests** +To ease in understanding of tests, what is being tested and what's expected of the test, +each test should be documented with what it's parameters/fixtures are as well as what +the test expects to happen, regardless of the tests implemntation. + + Parameters + ---------- + param1: Type + ... + + param2: Type + ... + + Fixtures + -------- + make_something: Callable[..., Something] + Factory to make Something + + Expects + ------- + * Something should raise a ValueError when called with X as X is not handled by the + validator Y. + +**Test same module across files** +When a module has many complicated avenues to be tested, create a folder and split the +tests according to each avenue. See `test/test_automl` for example as the `automl.py` +module is quite complicated to test and all tests in a single file become difficult to +follow and change. + +**pytest_cases** +Using pytest_cases, we seperate a `case`, something that defines the state of the +object, from the actual `test`, which tests properties of these cases. + +A complicated example can be seen at `test/test_automl/cases.py` where we have +autoML instances that are classifier/regressor, fitted or not, with cv or holdout, +or fitted with no ensemble. TODO: Easier example. + +Docs: https://smarie.github.io/python-pytest-cases/ + +**Caching** +Uses pytest's cache functionality for long training models so they can be shared between +tests and between different test runs. This is primarly used with `cases` so that tests +requiring the same kind of expensive case and used cached values. + +Use `pytest --cache-clear` to clear the cahce + +See `test/test_automl/cases.py` for example of how the fixtures from +`test/fixtures/caching.py` can be used to cache objects between tests. + +**Fixtures** +All fixtures in "test/fixtures" are known in every test file. We try to make use +of fixture `factories` which can be used to construct objects in complicated ways, +removing these complications from the tests themselves, importantly, keeping tests +short. A convention we use is to prefix them with `make`, for example, +`make_something`. This is useful for making data, e.g. `test/fixtures/data::make_data` + +..code:: python + + # Example of fixture factory + @fixture + def make_something(): + def _make(...args): + # ... complicated setup + # ... more complications + # ... make some sub objects which are complicated + return something + + return _make + + @parametrize("arg1", ['a', 'b', 'c']) + def test_something_does_x(arg1, make_something): + something = make_something(arg1, ...) + result = something.run() + assert something == expected +""" +from typing import Any, Iterator, List, Optional + +import re +import signal +from pathlib import Path import psutil -import pytest -from dask.distributed import Client, get_client - -from autosklearn.automl import AutoML -from autosklearn.automl_common.common.utils.backend import Backend, create - -class AutoMLStub(AutoML): - def __init__(self): - self.__class__ = AutoML - self._task = None - self._dask_client = None - self._is_dask_client_internally_created = False +import pytest +from pytest import ExitCode, Item, Session - def __del__(self): - pass +DEFAULT_SEED = 0 -@pytest.fixture(scope="function") -def automl_stub(request): - automl = AutoMLStub() - automl._seed = 42 - automl._backend = unittest.mock.Mock(spec=Backend) - automl._backend.context = unittest.mock.Mock() - automl._delete_output_directories = lambda: 0 - return automl +HERE = Path(__file__) +AUTOSKLEARN_CACHE_NAME = "autosklearn" -@pytest.fixture(scope="function") -def backend(request): +def walk(path: Path, include: Optional[str] = None) -> Iterator[Path]: + """Yeilds all files, iterating over directory - test_dir = os.path.dirname(__file__) - tmp = os.path.join( - test_dir, ".tmp__%s__%s" % (request.module.__name__, request.node.name) - ) + Parameters + ---------- + path: Path + The root path to walk from - for dir in (tmp,): - for i in range(10): - if os.path.exists(dir): - try: - shutil.rmtree(dir) - break - except OSError: - time.sleep(1) - - # Make sure the folders we wanna create do not already exist. - backend = create( - temporary_directory=tmp, output_directory=None, prefix="auto-sklearn" - ) + include: Optional[str] = None + Include only directories which match this string - def get_finalizer(tmp_dir): - def session_run_at_end(): - for dir in (tmp_dir,): - for i in range(10): - if os.path.exists(dir): - try: - shutil.rmtree(dir) - break - except OSError: - time.sleep(1) - - return session_run_at_end - - request.addfinalizer(get_finalizer(tmp)) + Returns + ------- + Iterator[Path] + All file paths that could be found from this walk + """ + for p in path.iterdir(): + if p.is_dir(): + if include is None or re.match(include, p.name): + yield from walk(p, include) + else: + yield p.resolve() - return backend +def is_fixture(path: Path) -> bool: + """Whether a path is a fixture""" + return path.name.endswith("fixtures.py") -@pytest.fixture(scope="function") -def tmp_dir(request): - return _dir_fixture("tmp", request) +def as_module(path: Path) -> str: + """Convert a path to a module as seen from here""" + root = HERE.parent.parent + parts = path.relative_to(root).parts + return ".".join(parts).replace(".py", "") -def _dir_fixture(dir_type, request): - test_dir = os.path.dirname(__file__) - dirname = f".{dir_type}__{request.module.__name__}__{request.node.name}" - dir = os.path.join(test_dir, dirname) +def fixture_modules() -> List[str]: + """Get all fixture modules""" + fixtures_folder = HERE.parent / "fixtures" + return [ + as_module(path) for path in walk(fixtures_folder) if path.name.endswith(".py") + ] - for i in range(10): - if os.path.exists(dir): - try: - shutil.rmtree(dir) - break - except OSError: - pass - def get_finalizer(dir): - def session_run_at_end(): - for i in range(10): - if os.path.exists(dir): - try: - shutil.rmtree(dir) - break - except OSError: - time.sleep(1) +def pytest_runtest_setup(item: Item) -> None: + """Run before each test""" + todos = [marker for marker in item.iter_markers(name="todo")] + if todos: + pytest.xfail(f"Test needs to be implemented, {item.location}") - return session_run_at_end - request.addfinalizer(get_finalizer(dir)) +def pytest_sessionstart(session: Session) -> None: + """Called after the ``Session`` object has been created and before performing collection + and entering the run test loop. - return dir + Parameters + ---------- + session : Session + The pytest session object + """ + return -@pytest.fixture(scope="function") -def dask_client(request): - """ - Create a dask client with two workers. +def pytest_sessionfinish(session: Session, exitstatus: ExitCode) -> None: + """Clean up any child processes""" + proc = psutil.Process() + kill_signal = signal.SIGTERM + for child in proc.children(recursive=True): - Workers are in subprocesses to not create deadlocks with the pynisher and logging. - """ + # https://stackoverflow.com/questions/57336095/access-verbosity-level-in-a-pytest-helper-function + if session.config.getoption("verbose") > 0: + print(child, child.cmdline()) - client = Client(n_workers=2, threads_per_worker=1, processes=False) - print("Started Dask client={}\n".format(client)) + # https://psutil.readthedocs.io/en/latest/#kill-process-tree + try: + child.send_signal(kill_signal) + except psutil.NoSuchProcess: + pass - def get_finalizer(address): - def session_run_at_end(): - client = get_client(address) - print("Closed Dask client={}\n".format(client)) - client.shutdown() - client.close() - del client - return session_run_at_end +Config = Any # Can't find import? - request.addfinalizer(get_finalizer(client.scheduler_info()["address"])) - return client +def pytest_collection_modifyitems( + session: Session, + config: Config, + items: List[Item], +) -> None: + """Modifys the colelction of tests that are captured""" + if config.getoption("--fast"): + skip = pytest.mark.skip(reason="Test marked `slow` and `--fast` arg used") + slow_items = [item for item in items if "slow" in item.keywords] + for item in slow_items: + item.add_marker(skip) -@pytest.fixture(scope="function") -def dask_client_single_worker(request): - """ - Same as above, but only with a single worker. - Using this might cause deadlocks with the pynisher and the logging module. However, - it is used very rarely to avoid this issue as much as possible. - """ +def pytest_configure(config: Config) -> None: + """Used to register marks""" + config.addinivalue_line("markers", "todo: Mark test as todo") + config.addinivalue_line("markers", "slow: Mark test as slow") - client = Client(n_workers=1, threads_per_worker=1, processes=False) - print("Started Dask client={}\n".format(client)) - def get_finalizer(address): - def session_run_at_end(): - client = get_client(address) - print("Closed Dask client={}\n".format(client)) - client.shutdown() - client.close() - del client +pytest_plugins = fixture_modules() - return session_run_at_end - request.addfinalizer(get_finalizer(client.scheduler_info()["address"])) +Parser = Any # Can't find import? - return client +def pytest_addoption(parser: Parser) -> None: + """ -def pytest_sessionfinish(session, exitstatus): - proc = psutil.Process() - for child in proc.children(recursive=True): - print(child, child.cmdline()) + Parameters + ---------- + parser : Parser + The parser to add options to + """ + parser.addoption( + "--fast", + action="store_true", + default=False, + help="Disable tests marked as slow", + ) diff --git a/test/fixtures/__init__.py b/test/fixtures/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/fixtures/automl.py b/test/fixtures/automl.py new file mode 100644 index 0000000000..abf31d304d --- /dev/null +++ b/test/fixtures/automl.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from typing import Any, Callable, Dict, Tuple, Type + +from functools import partial + +from autosklearn.automl import AutoML, AutoMLClassifier, AutoMLRegressor +from autosklearn.automl_common.common.utils.backend import Backend + +from pytest import fixture +from unittest.mock import Mock + +from test.conftest import DEFAULT_SEED +from test.fixtures.dask import create_test_dask_client + + +def _create_automl( + automl_type: Type[AutoML] = AutoML, + **kwargs: Any, +) -> AutoML: + """ + + Parameters + ---------- + automl_type : Type[AutoML] = AutoML + The type of AutoML object to use + + **kwargs: Any + Options to pass on to the AutoML type for construction + + Returns + ------- + AutoML + The constructed class and a close method for dask, if it exists + """ + test_defaults = { + "n_jobs": 2, + "time_left_for_this_task": 30, + "per_run_time_limit": 5, + "seed": DEFAULT_SEED, + "n_jobs": 2, + } + + opts: Dict[str, Any] = {**test_defaults, **kwargs} + + if "dask_client" not in opts: + client = create_test_dask_client(n_workers=opts["n_jobs"]) + opts["dask_client"] = client + + auto = automl_type(**opts) + return auto + + +@fixture +def make_automl() -> Callable[..., Tuple[AutoML, Callable]]: + """See `_create_automl`""" + yield partial(_create_automl, automl_type=AutoML) + + +@fixture +def make_automl_classifier() -> Callable[..., AutoMLClassifier]: + """See `_create_automl`""" + yield partial(_create_automl, automl_type=AutoMLClassifier) + + +@fixture +def make_automl_regressor() -> Callable[..., AutoMLRegressor]: + """See `_create_automl`""" + yield partial(_create_automl, automl_type=AutoMLRegressor) + + +class AutoMLStub(AutoML): + def __init__(self) -> None: + self.__class__ = AutoML + self._task = None + self._dask_client = None # type: ignore + self._is_dask_client_internally_created = False + + def __del__(self) -> None: + pass + + +@fixture(scope="function") +def automl_stub() -> AutoMLStub: + """TODO remove""" + automl = AutoMLStub() + automl._seed = 42 + automl._backend = Mock(spec=Backend) + automl._backend.context = Mock() + automl._delete_output_directories = lambda: 0 + return automl diff --git a/test/fixtures/backend.py b/test/fixtures/backend.py new file mode 100644 index 0000000000..3ee4626199 --- /dev/null +++ b/test/fixtures/backend.py @@ -0,0 +1,76 @@ +from typing import Callable, Union + +import os +from pathlib import Path + +from autosklearn.automl_common.common.utils.backend import Backend, create + +from pytest import fixture + + +# TODO Update to return path once everything can use a path +@fixture +def tmp_dir(tmp_path: Path) -> str: + """ + Fixtures + -------- + tmp_path : Path + Built in pytest fixture + + Returns + ------- + str + The directory as a str + """ + return str(tmp_path) + + +@fixture +def make_backend() -> Callable[..., Backend]: + """Make a backend + + Parameters + ---------- + path: Union[str, Path] + The path to place the backend at + + Returns + ------- + Backend + The created backend object + """ + # TODO redo once things use paths + def _make(path: Union[str, Path]) -> Backend: + _path = Path(path) if not isinstance(path, Path) else path + assert not _path.exists() + + backend = create( + temporary_directory=str(_path), + output_directory=None, + prefix="auto-sklearn", + ) + + return backend + + return _make + + +@fixture(scope="function") +def backend(tmp_dir: str, make_backend: Callable) -> Backend: + """A backend object + + Fixtures + -------- + tmp_dir : str + A directory to place the backend at + + make_backend : Callable + Factory to make a backend + + Returns + ------- + Backend + A backend object + """ + backend_path = os.path.join(tmp_dir, "backend") + return make_backend(path=backend_path) diff --git a/test/fixtures/caching.py b/test/fixtures/caching.py new file mode 100644 index 0000000000..b6a4ed2fdc --- /dev/null +++ b/test/fixtures/caching.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +from typing import Any, Callable, Optional + +import pickle +import shutil +import traceback +from functools import partial +from pathlib import Path + +from autosklearn.automl import AutoML + +from pytest import FixtureRequest +from pytest_cases import fixture + + +class Cache: + """Used for the below fixtures. + + Mainly used with cases so they don' need to be retrained at every invocation. + The cache can be cleared using `pytest`'s built in mechanism: + + pytest --clear-cache + + To view cached items use: + + pytest --cache-show + + ..code:: python + + def case_fitted_model(cache, ...): + key = "some key unique to this test" + cache = cache(key) + if "model" in cache: + return cache.load("model") + + # ... fit model + + cache.save(model, "model") + return model + + If multiple items are required, they can be access in much the same ways + + ..code:: python + + def case_requires_multiple_things(cache, ...): + + cache1 = cache("key1") + cache2 = cache("key2") + + If multiple things need to be stored in one location, you can access a given path + for a given named thing inside a cache. + + ..code:: python + + def case_fitted_model_and_populated_backend(cache, ...): + cache = cache("some key") + + """ + + def __init__(self, key: str, cache_dir: Path, verbose: int = 0): + """ + Parameters + ---------- + key : str + The key of the item stored + + cache_dir : Path + The dir where the cache should be located + + verbose : int = 0 + Whether to be verbose or not. Currently only has one level (> 0) + """ + self.dir = cache_dir / key + self.verbose = verbose > 0 + + def items(self) -> list[Path]: + """Get any paths associated to items in this dir""" + return list(self.dir.iterdir()) + + def __contains__(self, name: str) -> bool: + return self.path(name).exists() + + def path(self, name: str) -> Path: + """Path to an item for this cache""" + return self.dir / name + + def _load(self, name: str) -> Any: + """Load an item from the cache with a given name""" + if self.verbose: + print(f"Loading cached item {self.path(name)}") + + with self.path(name).open("rb") as f: + return pickle.load(f) + + def _save(self, item: Any, name: str) -> None: + """Dump an item to cache with a name""" + if self.verbose: + print(f"Saving cached item {self.path(name)}") + + with self.path(name).open("wb") as f: + pickle.dump(item, f) + + def reset(self) -> None: + """Delete this caches items""" + shutil.rmtree(self.dir) + self.dir.mkdir() + + +class AutoMLCache(Cache): + def save(self, model: AutoML) -> None: + """Save the model""" + self._save(model, "model") + + def model(self) -> Optional[AutoML]: + """Returns the saved model if it can. + + In the case of an issue loading an existing model file, it will delete + this cache item. + """ + if "model" not in self: + return None + + # Try to load the model, if there was an issue, delete all cached items + # for the model and return None + try: + model = self._load("model") + except Exception: + model = None + print(traceback.format_exc()) + self.reset() + finally: + return model + + def backend_path(self) -> Path: + """The path for the backend of the automl model""" + return self.path("backend") + + +@fixture +def cache(request: FixtureRequest) -> Callable[[str], Cache]: + """Gives the access to a cache.""" + pytest_cache = request.config.cache + assert pytest_cache is not None + + cache_dir = pytest_cache.mkdir("autosklearn-cache") + return partial(Cache, cache_dir=cache_dir) + + +@fixture +def automl_cache(request: FixtureRequest) -> Callable[[str], AutoMLCache]: + """Gives access to an automl cache""" + pytest_cache = request.config.cache + assert pytest_cache is not None + + cache_dir = pytest_cache.mkdir("autosklearn-cache") + verbosity = request.config.getoption("verbose") + return partial(AutoMLCache, cache_dir=cache_dir, verbose=verbosity) diff --git a/test/fixtures/dask.py b/test/fixtures/dask.py new file mode 100644 index 0000000000..0c1f112800 --- /dev/null +++ b/test/fixtures/dask.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +from typing import Callable + +from dask.distributed import Client, get_client + +from pytest import FixtureRequest, fixture + +# Terrible practice but we need to close dask clients somehow +active_clients: dict[str, Callable] = {} + + +@fixture(autouse=True) +def clean_up_any_dask_clients(request: FixtureRequest) -> None: + """Auto injected fixture to close dask clients after each test""" + yield + if any(active_clients): + for adr in list(active_clients.keys()): + if request.config.getoption("verbose") > 1: + print(f"\nFixture closing dask_client at {adr}") + + close = active_clients[adr] + close() + del active_clients[adr] + + +def create_test_dask_client(n_workers: int = 2) -> Client: + """Factory to make a Dask client and a function to close it + them + + Parameters + ---------- + n_workers: int = 2 + inside asklea + inside AutoML. + + Returns + ------- + Client, Callable + The client and a function to call to close that client + """ + # Workers are in subprocesses to not create deadlocks with the pynisher + # and logging. + client = Client( + n_workers=n_workers, + threads_per_worker=1, + processes=False, + scheduler_port=0, # Set to 0 so it chooses a random one + dashboard_address=None, # Disable dashboarding + ) + adr = client.scheduler_info()["address"] + + def close() -> None: + try: + client = get_client(adr, timeout=1) + client.shutdown() + except Exception: + pass + + active_clients[adr] = close + + return client + + +@fixture +def make_dask_client() -> Callable[[int], Client]: + """Factory to make a Dask client and a function to close it + + Parameters + ---------- + n_workers: int = 1 + How many workers to have in the dask client + + Returns + ------- + Client, Callable + The client and a function to call to close that client + """ + return create_test_dask_client + + +# TODO remove in favour of make_dask_client +@fixture(scope="function") +def dask_client(make_dask_client: Callable) -> Client: + """Create a dask client with two workers.""" + client = make_dask_client(n_workers=2) + yield client + + +# TODO remove in favour of make_dask_client +@fixture(scope="function") +def dask_client_single_worker(make_dask_client: Callable) -> Client: + """Dask client with only 1 worker + + Note + ---- + May create deadlocks with logging and pynisher + """ + client = make_dask_client(n_workers=1) + yield client diff --git a/test/fixtures/datasets.py b/test/fixtures/datasets.py new file mode 100644 index 0000000000..39d948e5a9 --- /dev/null +++ b/test/fixtures/datasets.py @@ -0,0 +1,278 @@ +from __future__ import annotations + +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple + +import numpy as np +import pandas as pd +from scipy.sparse import csr_matrix +from sklearn.utils import check_random_state + +from autosklearn.constants import ( + BINARY_CLASSIFICATION, + MULTICLASS_CLASSIFICATION, + MULTILABEL_CLASSIFICATION, + MULTIOUTPUT_REGRESSION, + REGRESSION, +) +from autosklearn.data.validation import SUPPORTED_FEAT_TYPES, SUPPORTED_TARGET_TYPES +from autosklearn.data.xy_data_manager import XYDataManager +from autosklearn.pipeline.util import get_dataset +from autosklearn.util.functional import normalize + +from pytest import fixture + +from test.conftest import DEFAULT_SEED + +Data = Tuple[SUPPORTED_FEAT_TYPES, SUPPORTED_TARGET_TYPES] + + +def astype( + t: np.ndarray | list | csr_matrix | pd.DataFrame | pd.Series, + x: Any, +) -> Any: + """Convert data to allowed types""" + if t == np.ndarray: + return np.asarray(x) + else: + return t(x) # type: ignore + + +# TODO Remove the implementation in autosklearn.pipeline.util and just put here +@fixture +def make_sklearn_dataset() -> Callable: + """ + + Parameters + ---------- + name : str = "iris" + Name of the dataset to get + + make_sparse : bool = False + Wehther to make the data sparse + + add_NaNs : bool = False + Whether to add NaNs to the data + + train_size_maximum : int = 150 + THe maximum size of training data + + make_multilabel : bool = False + Whether to force the data into being multilabel + + make_binary : bool = False + Whether to force the data into being binary + + as_datamanager: bool = False + Wether to return the information as an XYDataManager + + Returns + ------- + (X_train, Y_train, X_test, Y_Test) | XYDataManager + """ + + def _make( + name: str = "iris", + make_sparse: bool = False, + add_NaNs: bool = False, + train_size_maximum: int = 150, + make_multilabel: bool = False, + make_binary: bool = False, + as_datamanager: bool = False, + task: Optional[int] = None, + feat_type: Optional[Dict | str] = None, + ) -> Any: + X, y, Xt, yt = get_dataset( + dataset=name, + make_sparse=make_sparse, + add_NaNs=add_NaNs, + train_size_maximum=train_size_maximum, + make_multilabel=make_multilabel, + make_binary=make_binary, + ) + + if not as_datamanager: + return (X, y, Xt, yt) + else: + assert task is not None and feat_type is not None + if isinstance(feat_type, str): + feat_type = {i: feat_type for i in range(X.shape[1])} + + return XYDataManager( + X, + y, + Xt, + yt, + task=task, + dataset_name=name, + feat_type=feat_type, + ) + + return _make + + +def _make_binary_data( + dims: Tuple[int, ...] = (100, 3), + weights: Optional[Sequence[float] | np.ndarray] = None, + types: Tuple[ + np.ndarray | csr_matrix | pd.DataFrame | list, + np.ndarray | csr_matrix | pd.DataFrame | list | pd.Series, + ] = (np.ndarray, np.ndarray), + random_state: int | np.random.RandomState = DEFAULT_SEED, +) -> Data: + X_type, y_type = types + rs = check_random_state(random_state) + + classes = [0, 1] + + if not weights: + weights = np.ones_like(classes) / len(classes) + + assert len(weights) == len(classes) + weights = normalize(np.asarray(weights)) + + X = rs.rand(*dims) + y = rs.choice([0, 1], dims[0], p=weights) + + return astype(X_type, X), astype(y_type, y) + + +def _make_multiclass_data( + dims: Tuple[int, ...] = (100, 3), + classes: int | np.ndarray | List = 3, + weights: Optional[np.ndarray | List[float]] = None, + types: Tuple[ + np.ndarray | csr_matrix | pd.DataFrame | list, + np.ndarray | csr_matrix | pd.DataFrame | list | pd.Series, + ] = (np.ndarray, np.ndarray), + random_state: int | np.random.RandomState = DEFAULT_SEED, +) -> Data: + X_type, y_type = types + + if isinstance(classes, int): + classes = np.asarray(list(range(classes))) + + rs = check_random_state(random_state) + + if not weights: + weights = np.ones_like(classes) / len(classes) + + assert len(weights) == len(classes) + weights = normalize(np.asarray(weights)) + + X = rs.rand(*dims) + y = rs.choice(classes, dims[0], p=weights) + + return astype(X_type, X), astype(y_type, y) + + +def _make_multilabel_data( + dims: Tuple[int, ...] = (100, 3), + classes: np.ndarray | List = [[0, 0], [0, 1], [1, 0], [1, 1]], + weights: Optional[np.ndarray | List[float]] = None, + types: Tuple[ + np.ndarray | csr_matrix | pd.DataFrame | list, + np.ndarray | csr_matrix | pd.DataFrame | list | pd.Series, + ] = (np.ndarray, np.ndarray), + random_state: int | np.random.RandomState = DEFAULT_SEED, +) -> Data: + X_type, y_type = types + + classes = np.asarray(classes) + assert classes.ndim > 1 and classes.shape[1] > 1 + + rs = check_random_state(random_state) + + # Weights indicate each label tuple, and not the weights of individual labels + # in that tuple + if not weights: + weights = np.ones(classes.shape[0]) / len(classes) + + assert len(weights) == len(classes) + weights = normalize(np.asarray(weights)) + + X = rs.rand(*dims) + + class_indices = rs.choice(len(classes), dims[0], p=weights) + y = classes[class_indices] + + return astype(X_type, X), astype(y_type, y) + + +def _make_regression_data( + dims: Tuple[int, ...] = (100, 3), + types: Tuple[ + np.ndarray | csr_matrix | pd.DataFrame | list, + np.ndarray | csr_matrix | pd.DataFrame | list | pd.Series, + ] = (np.ndarray, np.ndarray), + random_state: int | np.random.RandomState = DEFAULT_SEED, +) -> Data: + X_type, y_type = types + rs = check_random_state(random_state) + + if X_type == csr_matrix: + X = rs.choice([0, 1], dims) + else: + X = rs.rand(*dims) + + y = rs.rand(dims[0]) + + return astype(X_type, X), astype(y_type, y) + + +def _make_multioutput_regression_data( + dims: Tuple[int, ...] = (100, 3), + targets: int = 2, + types: Tuple[ + np.ndarray | csr_matrix | pd.DataFrame | list, + np.ndarray | csr_matrix | pd.DataFrame | list | pd.Series, + ] = (np.ndarray, np.ndarray), + random_state: int | np.random.RandomState = DEFAULT_SEED, +) -> Data: + X_type, y_type = types + + rs = check_random_state(random_state) + + if X_type == csr_matrix: + X = rs.choice([0, 1], dims) + else: + X = rs.rand(*dims) + + y = rs.rand(dims[0], targets) + + return astype(X_type, X), astype(y_type, y) + + +@fixture +def make_data() -> Callable[..., Data]: + """Generate some arbitrary x,y data + + Parameters + ---------- + kind: int = BINARY_CLASSIFICATION + The task type, one of BINARY_CLASSIFICATION, MULTICLASS_CLASSIFICATION, ... + + **kwargs: Any + See the corresponding `_make_` + + Returns + ------- + Tuple[np.ndarray, np.ndarray] + The generated data + """ + + def _make( + kind: int = BINARY_CLASSIFICATION, + **kwargs: Any, + ) -> Data: + dispatches = { + BINARY_CLASSIFICATION: _make_binary_data, + MULTICLASS_CLASSIFICATION: _make_multiclass_data, + MULTILABEL_CLASSIFICATION: _make_multilabel_data, + REGRESSION: _make_regression_data, + MULTIOUTPUT_REGRESSION: _make_multioutput_regression_data, + } + + f = dispatches[kind] + return f(**kwargs) # type: ignore + + return _make diff --git a/test/fixtures/ensembles.py b/test/fixtures/ensembles.py new file mode 100644 index 0000000000..467c53822f --- /dev/null +++ b/test/fixtures/ensembles.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +from typing import Callable, Collection, Optional, Union + +import numpy as np +from sklearn.ensemble import VotingClassifier, VotingRegressor + +from autosklearn.data.validation import SUPPORTED_FEAT_TYPES, SUPPORTED_TARGET_TYPES +from autosklearn.evaluation.abstract_evaluator import ( + MyDummyClassifier, + MyDummyRegressor, +) +from autosklearn.pipeline.components.base import ( + AutoSklearnClassificationAlgorithm, + AutoSklearnRegressionAlgorithm, +) + +from pytest_cases import fixture + +from test.conftest import DEFAULT_SEED + + +@fixture +def make_voting_classifier() -> Callable[..., VotingClassifier]: + """ + Parameters + ---------- + X: Optional[SUPPORTED_FEAT_TYPES] = None + The X data to fit models on, if None, no fitting occurs + + y: Optional[SUPPORTED_FEAT_TYPES] = None + The y data to fit models on, if None, no fitting occurs + + models: Optional[Collection[AutoSklearnClassificationAlgorithm]] = None + Any collection of algorithms to use, if None, DummyClassifiers are used + """ + + def _make( + X: Optional[SUPPORTED_FEAT_TYPES] = None, + y: Optional[SUPPORTED_TARGET_TYPES] = None, + models: Optional[Collection[AutoSklearnClassificationAlgorithm]] = None, + seed: Union[int, None, np.random.RandomState] = DEFAULT_SEED, + ) -> VotingClassifier: + assert not (X is None) ^ (y is None) + + if not models: + models = [MyDummyClassifier(config=1, random_state=seed) for _ in range(5)] + + if X is not None: + for model in models: + model.fit(X, y) + + voter = VotingClassifier(estimators=None, voting="soft") + voter.estimators_ = models + return voter + + return _make + + +@fixture +def make_voting_regressor() -> Callable[..., VotingRegressor]: + """ + Parameters + ---------- + X: Optional[SUPPORTED_FEAT_TYPES] = None + The X data to fit models on, if None, no fitting occurs + + y: Optional[SUPPORTED_FEAT_TYPES] = None + The y data to fit models on, if None, no fitting occurs + + models: Optional[Collection[AutoSklearnRegressionAlgorithm]] = None + Any collection of algorithms to use, if None, DummyRegressors are used + """ + + def _make( + X: Optional[SUPPORTED_FEAT_TYPES] = None, + y: Optional[SUPPORTED_TARGET_TYPES] = None, + models: Optional[Collection[AutoSklearnRegressionAlgorithm]] = None, + seed: Union[int, None, np.random.RandomState] = DEFAULT_SEED, + ) -> VotingRegressor: + assert not (X is None) ^ (y is None) + + if not models: + models = [MyDummyRegressor(config=1, random_state=seed) for _ in range(5)] + + if X is not None: + for model in models: + model.fit(X, y) + + voter = VotingRegressor(estimators=None) + voter.estimators_ = models + return voter + + return _make diff --git a/test/fixtures/logging.py b/test/fixtures/logging.py new file mode 100644 index 0000000000..778738087b --- /dev/null +++ b/test/fixtures/logging.py @@ -0,0 +1,9 @@ +from pytest_cases import fixture + +from test.mocks.logging import MockLogger + + +@fixture +def mock_logger() -> MockLogger: + """A mock logger with some mock defaults""" + return MockLogger() diff --git a/test/mocks/__init__.py b/test/mocks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/mocks/logging.py b/test/mocks/logging.py new file mode 100644 index 0000000000..e61ca2c870 --- /dev/null +++ b/test/mocks/logging.py @@ -0,0 +1,38 @@ +from typing import Optional + +from autosklearn.util.logging_ import PicklableClientLogger + +from unittest.mock import Mock + +MOCKNAME = "mock" +MOCKHOST = "mockhost" +MOCKPORT = 9020 + + +class MockLogger(PicklableClientLogger): + """Should not be used for testing the actual loggers functionality + + Overwrites all methods with mock objects that can be queries + * All logging methods do nothing + * isEnabledFor returns True for everything as it's part of the logging config we + don't have access to + * __setstate__ and __getstate__ remain the same and are not mocked + """ + + def __init__( + self, + name: Optional[str] = None, + host: Optional[str] = None, + port: Optional[int] = None, + ): + self.name = name or MOCKNAME + self.host = host or MOCKHOST + self.port = port or MOCKPORT + + # Overwrite the logging implementations with mocks + attrs = ["debug", "info", "warning", "error", "exception", "critical", "log"] + for attr in attrs: + setattr(self, attr, Mock(return_value=None)) + + # This mock logger is enabled for all levels + setattr(self, "isEnabledFor", Mock(return_value=True)) diff --git a/test/test_automl/automl_utils.py b/test/test_automl/automl_utils.py index 577ea97359..7246b26fe5 100644 --- a/test/test_automl/automl_utils.py +++ b/test/test_automl/automl_utils.py @@ -16,7 +16,6 @@ def print_debug_information(automl): - # In case it is called with estimator, # Get the automl object if hasattr(automl, "automl_"): @@ -63,12 +62,6 @@ def print_debug_information(automl): return os.linesep.join(content) -def _includes(scores, all_scores): - return all(score in all_scores for score in scores) and len(scores) == len( - all_scores - ) - - def count_succeses(cv_results): return np.sum( [ @@ -78,7 +71,7 @@ def count_succeses(cv_results): ) -def includes_all_scores(scores): +def includes_all_scores(scores) -> bool: all_scores = ( scores_dict["train_single"] + scores_dict["test_single"] @@ -86,24 +79,24 @@ def includes_all_scores(scores): + scores_dict["test_ensamble"] + ["Timestamp"] ) - return _includes(scores, all_scores) + return set(scores) == set(all_scores) -def include_single_scores(scores): +def include_single_scores(scores) -> bool: all_scores = ( scores_dict["train_single"] + scores_dict["test_single"] + ["Timestamp"] ) - return _includes(scores, all_scores) + return set(scores) == set(all_scores) -def includes_train_scores(scores): +def includes_train_scores(scores) -> bool: all_scores = ( scores_dict["train_single"] + scores_dict["train_ensamble"] + ["Timestamp"] ) - return _includes(scores, all_scores) + return set(scores) == set(all_scores) -def performance_over_time_is_plausible(poT): +def performance_over_time_is_plausible(poT) -> bool: if len(poT) < 1: return False if len(poT.drop(columns=["Timestamp"]).dropna()) < 1: diff --git a/test/test_automl/cases.py b/test/test_automl/cases.py new file mode 100644 index 0000000000..70d68c4b73 --- /dev/null +++ b/test/test_automl/cases.py @@ -0,0 +1,196 @@ +""" +Here we define the several different setups and cache them to allow for easier unit +testing. The caching mechanism is only per session and does not persist over sessions. +There's only really a point for caching fitted models that will be tested multiple times +for different properties. + +Anything using a cached model must not destroy any backend resources, although it can +mock if required. + +Tags: + {classifier, regressor} - The type of AutoML object + classifier - will be fit on "iris" + regressor - will be fit on "boston" + {fitted} - If the automl case has been fitted + {cv, holdout} - Whether explicitly cv or holdout was used + {no_ensemble} - Fit with no ensemble size +""" +from typing import Callable, Tuple + +from pathlib import Path + +import numpy as np + +from autosklearn.automl import AutoMLClassifier, AutoMLRegressor + +from pytest_cases import case, parametrize + +from test.fixtures.caching import AutoMLCache + + +@case(tags=["classifier"]) +def case_classifier( + tmp_dir: str, + make_automl_classifier: Callable[..., AutoMLClassifier], +) -> AutoMLClassifier: + """Case basic unfitted AutoMLClassifier""" + dir = Path(tmp_dir) / "backend" + model = make_automl_classifier(temporary_directory=str(dir)) + return model + + +@case(tags=["classifier"]) +def case_regressor( + tmp_dir: str, + make_automl_regressor: Callable[..., AutoMLRegressor], +) -> AutoMLRegressor: + """Case basic unfitted AutoMLClassifier""" + dir = Path(tmp_dir) / "backend" + model = make_automl_regressor(temporary_directory=str(dir)) + return model + + +# ################################### +# The following are fitted and cached +# ################################### +@case(tags=["classifier", "fitted", "holdout"]) +@parametrize("dataset", ["iris"]) +def case_classifier_fitted_holdout( + automl_cache: Callable[[str], AutoMLCache], + dataset: str, + make_automl_classifier: Callable[..., AutoMLClassifier], + make_sklearn_dataset: Callable[..., Tuple[np.ndarray, ...]], +) -> AutoMLClassifier: + """Case of a holdout fitted classifier""" + resampling_strategy = "holdout-iterative-fit" + + cache = automl_cache(f"case_classifier_{resampling_strategy}_{dataset}") + + model = cache.model() + if model is not None: + return model + + X, y, Xt, yt = make_sklearn_dataset(name=dataset) + + model = make_automl_classifier( + temporary_directory=cache.path("backend"), + delete_tmp_folder_after_terminate=False, + resampling_strategy=resampling_strategy, + ) + model.fit(X, y, dataset_name=dataset) + + cache.save(model) + return model + + +@case(tags=["classifier", "fitted", "cv"]) +@parametrize("dataset", ["iris"]) +def case_classifier_fitted_cv( + automl_cache: Callable[[str], AutoMLCache], + dataset: str, + make_automl_classifier: Callable[..., AutoMLClassifier], + make_sklearn_dataset: Callable[..., Tuple[np.ndarray, ...]], +) -> AutoMLClassifier: + """Case of a fitted cv AutoMLClassifier""" + resampling_strategy = "cv" + cache = automl_cache(f"case_classifier_{resampling_strategy}_{dataset}") + + model = cache.model() + if model is not None: + return model + + X, y, Xt, yt = make_sklearn_dataset(name=dataset) + model = make_automl_classifier( + resampling_strategy=resampling_strategy, + temporary_directory=cache.path("backend"), + delete_tmp_folder_after_terminate=False, + ) + model.fit(X, y, dataset_name=dataset) + + cache.save(model) + return model + + +@case(tags=["regressor", "fitted", "holdout"]) +@parametrize("dataset", ["boston"]) +def case_regressor_fitted_holdout( + automl_cache: Callable[[str], AutoMLCache], + dataset: str, + make_automl_regressor: Callable[..., AutoMLRegressor], + make_sklearn_dataset: Callable[..., Tuple[np.ndarray, ...]], +) -> AutoMLRegressor: + """Case of fitted regressor with cv resampling""" + resampling_strategy = "holdout" + cache = automl_cache(f"case_regressor_{resampling_strategy}_{dataset}") + + model = cache.model() + if model is not None: + return model + + X, y, Xt, yt = make_sklearn_dataset(name=dataset) + model = make_automl_regressor( + resampling_strategy=resampling_strategy, + temporary_directory=cache.path("backend"), + delete_tmp_folder_after_terminate=False, + ) + model.fit(X, y, dataset_name=dataset) + + cache.save(model) + return model + + +@case(tags=["regressor", "fitted", "cv"]) +@parametrize("dataset", ["boston"]) +def case_regressor_fitted_cv( + automl_cache: Callable[[str], AutoMLCache], + dataset: str, + make_automl_regressor: Callable[..., AutoMLRegressor], + make_sklearn_dataset: Callable[..., Tuple[np.ndarray, ...]], +) -> AutoMLRegressor: + """Case of fitted regressor with cv resampling""" + resampling_strategy = "cv" + + cache = automl_cache(f"case_regressor_{resampling_strategy}_{dataset}") + model = cache.model() + if model is not None: + return model + + X, y, Xt, yt = make_sklearn_dataset(name=dataset) + + model = make_automl_regressor( + temporary_directory=cache.path("backend"), + delete_tmp_folder_after_terminate=False, + resampling_strategy=resampling_strategy, + ) + model.fit(X, y, dataset_name=dataset) + + cache.save(model) + return model + + +@case(tags=["classifier", "fitted", "no_ensemble"]) +@parametrize("dataset", ["iris"]) +def case_classifier_fitted_no_ensemble( + automl_cache: Callable[[str], AutoMLCache], + dataset: str, + make_automl_classifier: Callable[..., AutoMLClassifier], + make_sklearn_dataset: Callable[..., Tuple[np.ndarray, ...]], +) -> AutoMLClassifier: + """Case of a fitted classifier but enemble_size was set to 0""" + cache = automl_cache(f"case_classifier_fitted_no_ensemble_{dataset}") + + model = cache.model() + if model is not None: + return model + + X, y, Xt, yt = make_sklearn_dataset(name=dataset) + + model = make_automl_classifier( + temporary_directory=cache.path("backend"), + delete_tmp_folder_after_terminate=False, + ensemble_size=0, + ) + model.fit(X, y, dataset_name=dataset) + + cache.save(model) + return model diff --git a/test/test_automl/test_automl.py b/test/test_automl/test_automl.py deleted file mode 100644 index 3c62f9a4da..0000000000 --- a/test/test_automl/test_automl.py +++ /dev/null @@ -1,1232 +0,0 @@ -# -*- encoding: utf-8 -*- -from typing import Dict, List, Union - -import glob -import itertools -import os -import pickle -import time -import unittest -import unittest.mock -import warnings - -import numpy as np -import pandas as pd -import pytest -import sklearn.datasets -from scipy.sparse import csr_matrix, spmatrix -from sklearn.ensemble import VotingClassifier, VotingRegressor -from smac.facade.roar_facade import ROAR -from smac.scenario.scenario import Scenario -from smac.tae import StatusType - -import autosklearn.automl -import autosklearn.pipeline.util as putil -from autosklearn.automl import AutoML, AutoMLClassifier, AutoMLRegressor, _model_predict -from autosklearn.constants import ( - BINARY_CLASSIFICATION, - CLASSIFICATION_TASKS, - MULTICLASS_CLASSIFICATION, - MULTILABEL_CLASSIFICATION, - MULTIOUTPUT_REGRESSION, - REGRESSION, -) -from autosklearn.data.validation import InputValidator -from autosklearn.data.xy_data_manager import XYDataManager -from autosklearn.evaluation.abstract_evaluator import ( - MyDummyClassifier, - MyDummyRegressor, -) -from autosklearn.metrics import ( - accuracy, - balanced_accuracy, - default_metric_for_task, - log_loss, -) -from autosklearn.util.data import default_dataset_compression_arg -from autosklearn.util.logging_ import PickableLoggerAdapter - -from test.test_automl.automl_utils import ( - AutoMLLogParser, - count_succeses, - includes_train_scores, - performance_over_time_is_plausible, - print_debug_information, -) - - -class AutoMLStub(AutoML): - def __init__(self, classifier: bool = False): - self._task = None - self._dask_client = None - self._is_dask_client_internally_created = False - self._classifier = classifier - - def __del__(self): - pass - - -def test_fit(dask_client): - X_train, Y_train, X_test, Y_test = putil.get_dataset("iris") - automl = autosklearn.automl.AutoML( - seed=0, - time_left_for_this_task=30, - per_run_time_limit=5, - metric=accuracy, - dask_client=dask_client, - ) - - automl.fit(X_train, Y_train, task=MULTICLASS_CLASSIFICATION) - - score = automl.score(X_test, Y_test) - assert score > 0.8 - assert count_succeses(automl.cv_results_) > 0 - assert includes_train_scores(automl.performance_over_time_.columns) is True - assert performance_over_time_is_plausible(automl.performance_over_time_) is True - assert automl._task == MULTICLASS_CLASSIFICATION - - del automl - - -def test_ensemble_size_zero(): - """Test if automl.fit_ensemble raises error when ensemble_size == 0""" - X_train, Y_train, X_test, Y_test = putil.get_dataset("iris") - automl = autosklearn.automl.AutoML( - seed=0, - time_left_for_this_task=30, - per_run_time_limit=5, - metric=accuracy, - ensemble_size=0, - ) - automl.fit(X_train, Y_train, task=MULTICLASS_CLASSIFICATION) - with pytest.raises(ValueError): - automl.fit_ensemble(Y_test, ensemble_size=0) - - -def test_empty_dict_in_show_models(): - """Test if show_models() returns empty dictionary when ensemble_size == 0""" - X_train, Y_train, X_test, Y_test = putil.get_dataset("iris") - automl = autosklearn.automl.AutoMLClassifier( - seed=0, - time_left_for_this_task=30, - per_run_time_limit=5, - metric=accuracy, - ensemble_size=0, - ) - automl.fit(X_train, Y_train) - assert automl.show_models() == {} - - -def test_fitted_models_in_show_models(): - X_train, Y_train, X_test, Y_test = putil.get_dataset("iris") - automl = autosklearn.automl.AutoMLClassifier( - seed=0, - time_left_for_this_task=30, - per_run_time_limit=5, - metric=accuracy, - ensemble_size=0, - ) - with pytest.raises(RuntimeError, match="AutoSklearn has not been fitted"): - automl.show_models() - - -def test_fit_roar(dask_client_single_worker): - def get_roar_object_callback( - scenario_dict, seed, ta, ta_kwargs, dask_client, n_jobs, **kwargs - ): - """Random online adaptive racing. - - http://ml.informatik.uni-freiburg.de/papers/11-LION5-SMAC.pdf""" - scenario = Scenario(scenario_dict) - return ROAR( - scenario=scenario, - rng=seed, - tae_runner=ta, - tae_runner_kwargs=ta_kwargs, - dask_client=dask_client, - n_jobs=n_jobs, - ) - - X_train, Y_train, X_test, Y_test = putil.get_dataset("iris") - automl = autosklearn.automl.AutoML( - time_left_for_this_task=30, - per_run_time_limit=5, - initial_configurations_via_metalearning=0, - get_smac_object_callback=get_roar_object_callback, - metric=accuracy, - dask_client=dask_client_single_worker, - ) - - automl.fit(X_train, Y_train, task=MULTICLASS_CLASSIFICATION) - - score = automl.score(X_test, Y_test) - assert score > 0.8 - assert count_succeses(automl.cv_results_) > 0 - assert includes_train_scores(automl.performance_over_time_.columns) is True - assert automl._task == MULTICLASS_CLASSIFICATION - - del automl - - -def test_refit_shuffle_on_fail(dask_client): - - failing_model = unittest.mock.Mock() - failing_model.fit.side_effect = [ValueError(), ValueError(), None] - failing_model.fit_transformer.side_effect = [ValueError(), ValueError(), (None, {})] - failing_model.get_max_iter.return_value = 100 - - auto = AutoML(30, 5, dask_client=dask_client) - ensemble_mock = unittest.mock.Mock() - ensemble_mock.get_selected_model_identifiers.return_value = [(1, 1, 50.0)] - auto.ensemble_ = ensemble_mock - auto.InputValidator = InputValidator() - for budget_type in [None, "iterations"]: - auto._budget_type = budget_type - - auto.models_ = {(1, 1, 50.0): failing_model} - - # Make sure a valid 2D array is given to automl - X = np.array([1, 2, 3]).reshape(-1, 1) - y = np.array([1, 2, 3]) - auto.InputValidator.fit(X, y) - auto.refit(X, y) - - assert failing_model.fit.call_count == 3 - assert failing_model.fit_transformer.call_count == 3 - - del auto - - -def test_only_loads_ensemble_models(automl_stub): - def side_effect(ids, *args, **kwargs): - return models if ids is identifiers else {} - - # Add a resampling strategy as this is required by load_models - automl_stub._resampling_strategy = "holdout" - identifiers = [(1, 2), (3, 4)] - - models = [42] - load_ensemble_mock = unittest.mock.Mock() - load_ensemble_mock.get_selected_model_identifiers.return_value = identifiers - automl_stub._backend.load_ensemble.return_value = load_ensemble_mock - automl_stub._backend.load_models_by_identifiers.side_effect = side_effect - - automl_stub._load_models() - assert models == automl_stub.models_ - assert automl_stub.cv_models_ is None - - automl_stub._resampling_strategy = "cv" - - models = [42] - automl_stub._backend.load_cv_models_by_identifiers.side_effect = side_effect - - automl_stub._load_models() - assert models == automl_stub.cv_models_ - - -def test_check_for_models_if_no_ensemble(automl_stub): - models = [42] - automl_stub._backend.load_ensemble.return_value = None - automl_stub._backend.list_all_models.return_value = models - automl_stub._disable_evaluator_output = False - - automl_stub._load_models() - - -def test_raises_if_no_models(automl_stub): - automl_stub._backend.load_ensemble.return_value = None - automl_stub._backend.list_all_models.return_value = [] - automl_stub._resampling_strategy = "holdout" - - automl_stub._disable_evaluator_output = False - with pytest.raises(ValueError): - automl_stub._load_models() - - automl_stub._disable_evaluator_output = True - automl_stub._load_models() - - -def test_delete_non_candidate_models(dask_client): - - seed = 555 - X, Y, _, _ = putil.get_dataset("iris") - automl = autosklearn.automl.AutoML( - delete_tmp_folder_after_terminate=False, - time_left_for_this_task=60, - per_run_time_limit=5, - ensemble_nbest=3, - seed=seed, - initial_configurations_via_metalearning=0, - resampling_strategy="holdout", - include={"classifier": ["sgd"], "feature_preprocessor": ["no_preprocessing"]}, - metric=accuracy, - dask_client=dask_client, - # Force model to be deleted. That is, from 50 which is the - # default to 3 to make sure we delete models. - max_models_on_disc=3, - ) - - automl.fit(X, Y, task=MULTICLASS_CLASSIFICATION, X_test=X, y_test=Y) - - # Assert at least one model file has been deleted and that there were no - # deletion errors - log_file_path = glob.glob( - os.path.join( - automl._backend.temporary_directory, "AutoML(" + str(seed) + "):*.log" - ) - ) - with open(log_file_path[0]) as log_file: - log_content = log_file.read() - assert "Deleted files of non-candidate model" in log_content, log_content - assert ( - "Failed to delete files of non-candidate model" not in log_content - ), log_content - assert "Failed to lock model" not in log_content, log_content - - # Assert that the files of the models used by the ensemble weren't deleted - model_files = automl._backend.list_all_models(seed=seed) - model_files_idx = set() - for m_file in model_files: - # Extract the model identifiers from the filename - m_file = os.path.split(m_file)[1].replace(".model", "").split(".", 2) - model_files_idx.add((int(m_file[0]), int(m_file[1]), float(m_file[2]))) - ensemble_members_idx = set(automl.ensemble_.identifiers_) - assert ensemble_members_idx.issubset(model_files_idx), ( - ensemble_members_idx, - model_files_idx, - ) - - del automl - - -def test_binary_score_and_include(dask_client): - """ - Test fix for binary classification prediction - taking the index 1 of second dimension in prediction matrix - """ - - data = sklearn.datasets.make_classification( - n_samples=400, - n_features=10, - n_redundant=1, - n_informative=3, - n_repeated=1, - n_clusters_per_class=2, - random_state=1, - ) - X_train = data[0][:200] - Y_train = data[1][:200] - X_test = data[0][200:] - Y_test = data[1][200:] - - automl = autosklearn.automl.AutoML( - 20, - 5, - include={"classifier": ["sgd"], "feature_preprocessor": ["no_preprocessing"]}, - metric=accuracy, - dask_client=dask_client, - ) - - automl.fit(X_train, Y_train, task=BINARY_CLASSIFICATION) - - assert automl._task == BINARY_CLASSIFICATION - - # TODO, the assumption from above is not really tested here - # Also, the score method should be removed, it only makes little sense - score = automl.score(X_test, Y_test) - assert score >= 0.4 - - del automl - - -def test_automl_outputs(dask_client): - - X_train, Y_train, X_test, Y_test = putil.get_dataset("iris") - name = "iris" - auto = autosklearn.automl.AutoML( - 30, - 5, - initial_configurations_via_metalearning=0, - seed=100, - metric=accuracy, - dask_client=dask_client, - delete_tmp_folder_after_terminate=False, - ) - - auto.fit( - X=X_train, - y=Y_train, - X_test=X_test, - y_test=Y_test, - dataset_name=name, - task=MULTICLASS_CLASSIFICATION, - ) - - data_manager_file = os.path.join( - auto._backend.temporary_directory, ".auto-sklearn", "datamanager.pkl" - ) - - # pickled data manager (without one hot encoding!) - with open(data_manager_file, "rb") as fh: - D = pickle.load(fh) - assert np.allclose(D.data["X_train"], X_train) - - # Check that all directories are there - fixture = [ - "true_targets_ensemble.npy", - "start_time_100", - "datamanager.pkl", - "ensemble_read_preds.pkl", - "ensemble_read_losses.pkl", - "runs", - "ensembles", - "ensemble_history.json", - ] - assert sorted( - os.listdir(os.path.join(auto._backend.temporary_directory, ".auto-sklearn")) - ) == sorted(fixture) - - # At least one ensemble, one validation, one test prediction and one - # model and one ensemble - fixture = glob.glob( - os.path.join( - auto._backend.temporary_directory, - ".auto-sklearn", - "runs", - "*", - "predictions_ensemble*npy", - ) - ) - assert len(fixture) > 0 - - fixture = glob.glob( - os.path.join( - auto._backend.temporary_directory, - ".auto-sklearn", - "runs", - "*", - "100.*.model", - ) - ) - assert len(fixture) > 0 - - fixture = os.listdir( - os.path.join(auto._backend.temporary_directory, ".auto-sklearn", "ensembles") - ) - assert "100.0000000000.ensemble" in fixture - - # Start time - start_time_file_path = os.path.join( - auto._backend.temporary_directory, ".auto-sklearn", "start_time_100" - ) - with open(start_time_file_path, "r") as fh: - start_time = float(fh.read()) - assert time.time() - start_time >= 10, print_debug_information(auto) - - # Then check that the logger matches the run expectation - logfile = glob.glob(os.path.join(auto._backend.temporary_directory, "AutoML*.log"))[ - 0 - ] - parser = AutoMLLogParser(logfile) - - # The number of ensemble trajectories properly in log file - success_ensemble_iters_auto = len(auto.ensemble_performance_history) - success_ensemble_iters_log = parser.count_ensembler_success_pynisher_calls() - assert success_ensemble_iters_auto == success_ensemble_iters_log, "{} != {}".format( - auto.ensemble_performance_history, - print_debug_information(auto), - ) - - # We also care that no iteration got lost - # This is important because it counts for pynisher calls - # and whether a pynisher call actually called the ensemble - total_ensemble_iterations = parser.count_ensembler_iterations() - assert len(total_ensemble_iterations) > 1 # At least 1 iteration - assert range(1, max(total_ensemble_iterations) + 1), total_ensemble_iterations - - # a point where pynisher is called before budget exhaustion - # Dummy not in run history - total_calls_to_pynisher_log = parser.count_tae_pynisher_calls() - 1 - total_returns_from_pynisher_log = parser.count_tae_pynisher_returns() - 1 - total_elements_rh = len( - [ - run_value - for run_value in auto.runhistory_.data.values() - if run_value.status == StatusType.RUNNING - ] - ) - - # Make sure we register all calls to pynisher - # The less than or equal here is added as a WA as - # https://github.com/automl/SMAC3/pull/712 is not yet integrated - assert total_elements_rh <= total_calls_to_pynisher_log, print_debug_information( - auto - ) - - # Make sure we register all returns from pynisher - assert ( - total_elements_rh <= total_returns_from_pynisher_log - ), print_debug_information(auto) - - # Lastly check that settings are print to logfile - ensemble_size = parser.get_automl_setting_from_log( - auto._dataset_name, "ensemble_size" - ) - assert auto._ensemble_size == int(ensemble_size) - - del auto - - -@pytest.mark.parametrize( - "datasets", - [ - ("breast_cancer", BINARY_CLASSIFICATION), - ("wine", MULTICLASS_CLASSIFICATION), - ("diabetes", REGRESSION), - ], -) -def test_do_dummy_prediction(dask_client, datasets): - - name, task = datasets - - X_train, Y_train, X_test, Y_test = putil.get_dataset(name) - datamanager = XYDataManager( - X_train, - Y_train, - X_test, - Y_test, - task=task, - dataset_name=name, - feat_type={i: "numerical" for i in range(X_train.shape[1])}, - ) - - auto = autosklearn.automl.AutoML( - 20, - 5, - initial_configurations_via_metalearning=25, - metric=accuracy, - dask_client=dask_client, - delete_tmp_folder_after_terminate=False, - ) - auto._backend = auto._create_backend() - - # Make a dummy logger - auto._logger_port = 9020 - auto._logger = unittest.mock.Mock() - auto._logger.info.return_value = None - - auto._backend.save_datamanager(datamanager) - D = auto._backend.load_datamanager() - - # Check if data manager is correcly loaded - assert D.info["task"] == datamanager.info["task"] - auto._do_dummy_prediction(D, 1) - - # Ensure that the dummy predictions are not in the current working - # directory, but in the temporary directory. - unexpected_directory = os.path.join(os.getcwd(), ".auto-sklearn") - expected_directory = os.path.join( - auto._backend.temporary_directory, - ".auto-sklearn", - "runs", - "1_1_0.0", - "predictions_ensemble_1_1_0.0.npy", - ) - assert not os.path.exists(unexpected_directory) - assert os.path.exists(expected_directory) - - auto._clean_logger() - - del auto - - -@unittest.mock.patch("autosklearn.evaluation.ExecuteTaFuncWithQueue.run") -def test_fail_if_dummy_prediction_fails(ta_run_mock, dask_client): - - X_train, Y_train, X_test, Y_test = putil.get_dataset("iris") - datamanager = XYDataManager( - X_train, - Y_train, - X_test, - Y_test, - task=2, - feat_type={i: "Numerical" for i in range(X_train.shape[1])}, - dataset_name="iris", - ) - - time_for_this_task = 30 - per_run_time = 10 - auto = autosklearn.automl.AutoML( - time_for_this_task, - per_run_time, - initial_configurations_via_metalearning=25, - metric=accuracy, - dask_client=dask_client, - delete_tmp_folder_after_terminate=False, - ) - auto._backend = auto._create_backend() - auto._backend._make_internals_directory() - auto._backend.save_datamanager(datamanager) - - # Make a dummy logger - auto._logger_port = 9020 - auto._logger = unittest.mock.Mock() - auto._logger.info.return_value = None - - # First of all, check that ta.run() is actually called. - ta_run_mock.return_value = StatusType.SUCCESS, None, None, {} - auto._do_dummy_prediction(datamanager, 1) - ta_run_mock.assert_called_once_with(1, cutoff=time_for_this_task) - - # Case 1. Check that function raises no error when statustype == success. - # ta.run() returns status, cost, runtime, and additional info. - ta_run_mock.return_value = StatusType.SUCCESS, None, None, {} - raised = False - try: - auto._do_dummy_prediction(datamanager, 1) - except ValueError: - raised = True - assert not raised, "Exception raised" - - # Case 2. Check that if statustype returned by ta.run() != success, - # the function raises error. - ta_run_mock.return_value = StatusType.CRASHED, None, None, {} - with pytest.raises( - ValueError, - match="Dummy prediction failed with run state StatusType.CRASHED and additional output: {}.", # noqa - ): - auto._do_dummy_prediction(datamanager, 1) - - ta_run_mock.return_value = StatusType.ABORT, None, None, {} - with pytest.raises( - ValueError, - match="Dummy prediction failed with run state StatusType.ABORT " - "and additional output: {}.", - ): - auto._do_dummy_prediction(datamanager, 1) - ta_run_mock.return_value = StatusType.TIMEOUT, None, None, {} - with pytest.raises( - ValueError, - match="Dummy prediction failed with run state StatusType.TIMEOUT " - "and additional output: {}.", - ): - auto._do_dummy_prediction(datamanager, 1) - ta_run_mock.return_value = StatusType.MEMOUT, None, None, {} - with pytest.raises( - ValueError, - match="Dummy prediction failed with run state StatusType.MEMOUT " - "and additional output: {}.", - ): - auto._do_dummy_prediction(datamanager, 1) - ta_run_mock.return_value = StatusType.CAPPED, None, None, {} - with pytest.raises( - ValueError, - match="Dummy prediction failed with run state StatusType.CAPPED " - "and additional output: {}.", - ): - auto._do_dummy_prediction(datamanager, 1) - - ta_run_mock.return_value = StatusType.CRASHED, None, None, {"exitcode": -6} - with pytest.raises( - ValueError, - match="The error suggests that the provided memory limits are too tight.", - ): - auto._do_dummy_prediction(datamanager, 1) - - -@unittest.mock.patch("autosklearn.smbo.AutoMLSMBO.run_smbo") -def test_exceptions_inside_log_in_smbo(smbo_run_mock, dask_client): - - # Below importing and shutdown is a workaround, to make sure - # we reset the port to collect messages. Randomly, when running - # this test with multiple other test at the same time causes this - # test to fail. This resets the singletons of the logging class - import logging - - logging.shutdown() - - automl = autosklearn.automl.AutoML( - 20, - 5, - metric=accuracy, - dask_client=dask_client, - delete_tmp_folder_after_terminate=False, - ) - - dataset_name = "test_exceptions_inside_log" - - # Create a custom exception to prevent other errors to slip in - class MyException(Exception): - pass - - X_train, Y_train, X_test, Y_test = putil.get_dataset("iris") - # The first call is on dummy predictor failure - message = str(np.random.randint(100)) + "_run_smbo" - smbo_run_mock.side_effect = MyException(message) - - with pytest.raises(MyException): - automl.fit( - X_train, - Y_train, - task=MULTICLASS_CLASSIFICATION, - dataset_name=dataset_name, - ) - - # make sure that the logfile was created - logger_name = "AutoML(%d):%s" % (1, dataset_name) - logger = logging.getLogger(logger_name) - logfile = os.path.join(automl._backend.temporary_directory, logger_name + ".log") - assert os.path.exists(logfile), print_debug_information(automl) + str( - automl._clean_logger() - ) - - # Give some time for the error message to be printed in the - # log file - found_message = False - for incr_tolerance in range(5): - with open(logfile) as f: - lines = f.readlines() - if any(message in line for line in lines): - found_message = True - break - else: - time.sleep(incr_tolerance) - - # Speed up the closing after forced crash - automl._clean_logger() - - if not found_message: - pytest.fail( - "Did not find {} in the log file {} for logger {}/{}/{}".format( - message, - print_debug_information(automl), - vars(automl._logger.logger), - vars(logger), - vars(logging.getLogger()), - ) - ) - - -@pytest.mark.parametrize("metric", [log_loss, balanced_accuracy]) -def test_load_best_individual_model(metric, dask_client): - - X_train, Y_train, X_test, Y_test = putil.get_dataset("iris") - automl = autosklearn.automl.AutoML( - time_left_for_this_task=30, - per_run_time_limit=5, - metric=metric, - dask_client=dask_client, - delete_tmp_folder_after_terminate=False, - ) - - # We cannot easily mock a function sent to dask - # so for this test we create the whole set of models/ensembles - # but prevent it to be loaded - automl.fit(X_train, Y_train, task=MULTICLASS_CLASSIFICATION) - - automl._backend.load_ensemble = unittest.mock.MagicMock(return_value=None) - - # A memory error occurs in the ensemble construction - assert automl._backend.load_ensemble(automl._seed) is None - - # The load model is robust to this and loads the best model - automl._load_models() - assert automl.ensemble_ is not None - - # Just 1 model is there for ensemble and all weight must be on it - get_models_with_weights = automl.get_models_with_weights() - assert len(get_models_with_weights) == 1 - assert get_models_with_weights[0][0] == 1.0 - - # Match a toy dataset - if metric.name == "balanced_accuracy": - assert automl.score(X_test, Y_test) > 0.9 - elif metric.name == "log_loss": - # Seen values in github actions of 0.6978304740364537 - assert automl.score(X_test, Y_test) < 0.7 - else: - raise ValueError(metric.name) - - del automl - - -def test_fail_if_feat_type_on_pandas_input(dask_client): - """We do not support feat type when pandas - is provided as an input - """ - automl = autosklearn.automl.AutoML( - time_left_for_this_task=30, - per_run_time_limit=5, - metric=accuracy, - dask_client=dask_client, - ) - - X_train = pd.DataFrame({"a": [1, 1], "c": [1, 2]}) - y_train = [1, 0] - msg = ( - "providing the option feat_type to the fit method is not supported" - " when using a Dataframe." - ) - with pytest.raises(ValueError, match=msg): - automl.fit( - X_train, - y_train, - task=BINARY_CLASSIFICATION, - feat_type={1: "Categorical", 2: "Numerical"}, - ) - - -def data_input_and_target_types(): - n_rows = 100 - - # Create valid inputs - X_ndarray = np.random.random(size=(n_rows, 5)) - X_ndarray[X_ndarray < 0.9] = 0 - - # Binary Classificaiton - y_binary_ndarray = np.random.random(size=n_rows) - y_binary_ndarray[y_binary_ndarray >= 0.5] = 1 - y_binary_ndarray[y_binary_ndarray < 0.5] = 0 - - # Multiclass classification - y_multiclass_ndarray = np.random.random(size=n_rows) - y_multiclass_ndarray[y_multiclass_ndarray > 0.66] = 2 - y_multiclass_ndarray[ - (y_multiclass_ndarray <= 0.66) & (y_multiclass_ndarray >= 0.33) - ] = 1 - y_multiclass_ndarray[y_multiclass_ndarray < 0.33] = 0 - - # Multilabel classificaiton - y_multilabel_ndarray = np.random.random(size=(n_rows, 3)) - y_multilabel_ndarray[y_multilabel_ndarray > 0.5] = 1 - y_multilabel_ndarray[y_multilabel_ndarray <= 0.5] = 0 - - # Regression - y_regression_ndarray = np.random.random(size=n_rows) - - # Multioutput Regression - y_multioutput_regression_ndarray = np.random.random(size=(n_rows, 3)) - - xs = [ - X_ndarray, - list(X_ndarray), - csr_matrix(X_ndarray), - pd.DataFrame(data=X_ndarray), - ] - - ys_binary = [ - y_binary_ndarray, - list(y_binary_ndarray), - csr_matrix(y_binary_ndarray), - pd.Series(y_binary_ndarray), - pd.DataFrame(data=y_binary_ndarray), - ] - - ys_multiclass = [ - y_multiclass_ndarray, - list(y_multiclass_ndarray), - csr_matrix(y_multiclass_ndarray), - pd.Series(y_multiclass_ndarray), - pd.DataFrame(data=y_multiclass_ndarray), - ] - - ys_multilabel = [ - y_multilabel_ndarray, - list(y_multilabel_ndarray), - csr_matrix(y_multilabel_ndarray), - # pd.Series(y_multilabel_ndarray) - pd.DataFrame(data=y_multilabel_ndarray), - ] - - ys_regression = [ - y_regression_ndarray, - list(y_regression_ndarray), - csr_matrix(y_regression_ndarray), - pd.Series(y_regression_ndarray), - pd.DataFrame(data=y_regression_ndarray), - ] - - ys_multioutput_regression = [ - y_multioutput_regression_ndarray, - list(y_multioutput_regression_ndarray), - csr_matrix(y_multioutput_regression_ndarray), - # pd.Series(y_multioutput_regression_ndarray), - pd.DataFrame(data=y_multioutput_regression_ndarray), - ] - - # [ (X, y, X_test, y_test, task), ... ] - return ( - (X, y, X, y, task) - for X in xs - for y, task in itertools.chain( - itertools.product(ys_binary, [BINARY_CLASSIFICATION]), - itertools.product(ys_multiclass, [MULTICLASS_CLASSIFICATION]), - itertools.product(ys_multilabel, [MULTILABEL_CLASSIFICATION]), - itertools.product(ys_regression, [REGRESSION]), - itertools.product(ys_multioutput_regression, [MULTIOUTPUT_REGRESSION]), - ) - ) - - -@pytest.mark.parametrize("X, y, X_test, y_test, task", data_input_and_target_types()) -def test_input_and_target_types(dask_client, X, y, X_test, y_test, task): - - if task in CLASSIFICATION_TASKS: - automl = AutoMLClassifier( - time_left_for_this_task=15, - per_run_time_limit=5, - dask_client=dask_client, - ) - else: - automl = AutoMLRegressor( - time_left_for_this_task=15, - per_run_time_limit=5, - dask_client=dask_client, - ) - # To save time fitting and only validate the inputs we only return - # the configuration space - automl.fit( - X=X, y=y, X_test=X_test, y_test=y_test, only_return_configuration_space=True - ) - assert automl._task == task - assert automl._metric.name == default_metric_for_task[task].name - - -def data_test_model_predict_outsputs_correct_shapes(): - datasets = sklearn.datasets - binary = datasets.make_classification(n_samples=5, n_classes=2, random_state=0) - multiclass = datasets.make_classification( - n_samples=5, n_informative=3, n_classes=3, random_state=0 - ) - multilabel = datasets.make_multilabel_classification( - n_samples=5, n_classes=3, random_state=0 - ) - regression = datasets.make_regression(n_samples=5, random_state=0) - multioutput = datasets.make_regression(n_samples=5, n_targets=3, random_state=0) - - # TODO issue 1169 - # While testing output shapes, realised all models are wrapped to provide - # a special predict_proba that outputs a different shape than usual. - # This includes DummyClassifier and DummyRegressor which are wrapped as - # `MyDummyClassifier/Regressor` and require a config object. - # config == 1 : Classifier uses 'uniform', Regressor uses 'mean' - # else : Classifier uses 'most_frequent', Regressor uses 'median' - # - # This wrapping of probabilities with - # `convert_multioutput_multiclass_to_multilabel` - # can probably be just put into a base class which queries subclasses - # as to whether it's needed. - # - # tldr; thats why we use MyDummyX here instead of the default models - # from sklearn - def classifier(X, y): - return MyDummyClassifier(config=1, random_state=0).fit(X, y) - - def regressor(X, y): - return MyDummyRegressor(config=1, random_state=0).fit(X, y) - - # How cross validation models are currently grouped together - def voting_classifier(X, y): - classifiers = [ - MyDummyClassifier(config=1, random_state=0).fit(X, y) for _ in range(5) - ] - vc = VotingClassifier(estimators=None, voting="soft") - vc.estimators_ = classifiers - return vc - - def voting_regressor(X, y): - regressors = [ - MyDummyRegressor(config=1, random_state=0).fit(X, y) for _ in range(5) - ] - vr = VotingRegressor(estimators=None) - vr.estimators_ = regressors - return vr - - test_data = { - BINARY_CLASSIFICATION: { - "models": [classifier(*binary), voting_classifier(*binary)], - "data": binary, - # prob of false/true for the one class - "expected_output_shape": (len(binary[0]), 2), - }, - MULTICLASS_CLASSIFICATION: { - "models": [classifier(*multiclass), voting_classifier(*multiclass)], - "data": multiclass, - # prob of true for each possible class - "expected_output_shape": (len(multiclass[0]), 3), - }, - MULTILABEL_CLASSIFICATION: { - "models": [classifier(*multilabel), voting_classifier(*multilabel)], - "data": multilabel, - # probability of true for each binary label - "expected_output_shape": (len(multilabel[0]), 3), # type: ignore - }, - REGRESSION: { - "models": [regressor(*regression), voting_regressor(*regression)], - "data": regression, - # array of single outputs - "expected_output_shape": (len(regression[0]),), - }, - MULTIOUTPUT_REGRESSION: { - "models": [regressor(*multioutput), voting_regressor(*multioutput)], - "data": multioutput, - # array of vector otuputs - "expected_output_shape": (len(multioutput[0]), 3), - }, - } - - return itertools.chain.from_iterable( - [ - (model, cfg["data"], task, cfg["expected_output_shape"]) - for model in cfg["models"] - ] - for task, cfg in test_data.items() - ) - - -@pytest.mark.parametrize( - "model, data, task, expected_output_shape", - data_test_model_predict_outsputs_correct_shapes(), -) -def test_model_predict_outputs_correct_shapes(model, data, task, expected_output_shape): - X, y = data - prediction = _model_predict(model=model, X=X, task=task) - assert prediction.shape == expected_output_shape - - -def test_model_predict_outputs_warnings_to_logs(): - X = list(range(20)) - task = REGRESSION - logger = PickableLoggerAdapter("test_model_predict_correctly_outputs_warnings") - logger.warning = unittest.mock.Mock() - - class DummyModel: - def predict(self, x): - warnings.warn("test warning", Warning) - return x - - model = DummyModel() - - _model_predict(model, X, task, logger=logger) - - assert logger.warning.call_count == 1, "Logger should have had warning called" - - -def test_model_predict_outputs_to_stdout_if_no_logger(): - X = list(range(20)) - task = REGRESSION - - class DummyModel: - def predict(self, x): - warnings.warn("test warning", Warning) - return x - - model = DummyModel() - - with warnings.catch_warnings(record=True) as w: - _model_predict(model, X, task, logger=None) - - assert len(w) == 1, "One warning sould have been emmited" - - -@pytest.mark.parametrize("dataset_compression", [False]) -def test_param_dataset_compression_false(dataset_compression: bool) -> None: - """ - Parameters - ---------- - dataset_compression: bool - The dataset_compression arg set as False - - Expects - ------- - * Should set the private attribute to None - """ - auto = AutoMLRegressor( - time_left_for_this_task=30, - per_run_time_limit=5, - dataset_compression=dataset_compression, - ) - - assert auto._dataset_compression is None - - -@pytest.mark.parametrize("dataset_compression", [True]) -def test_construction_param_dataset_compression_true(dataset_compression: bool) -> None: - """ - Parameters - ---------- - dataset_compression: bool - The dataset_compression arg set as True - - Expects - ------- - * Should set the private attribute to the defaults - """ - auto = AutoMLRegressor( - time_left_for_this_task=30, - per_run_time_limit=5, - dataset_compression=dataset_compression, - ) - - assert auto._dataset_compression == default_dataset_compression_arg - - -@pytest.mark.parametrize("dataset_compression", [{"memory_allocation": 0.2}]) -def test_construction_param_dataset_compression_valid_dict( - dataset_compression: Dict, -) -> None: - """ - Parameters - ---------- - dataset_compression: Dict - The dataset_compression arg set partially specified - - Expects - ------- - * Should set the private attribute to the passed dataset_compression arg + defaults - """ - auto = AutoMLRegressor( - time_left_for_this_task=30, - per_run_time_limit=5, - dataset_compression=dataset_compression, - ) - - expected_memory_allocation = dataset_compression["memory_allocation"] - expected_methods = default_dataset_compression_arg["methods"] - - assert auto._dataset_compression is not None - assert auto._dataset_compression["memory_allocation"] == expected_memory_allocation - assert auto._dataset_compression["methods"] == expected_methods - - -@pytest.mark.parametrize( - "dataset_compression", [{"methods": ["precision", "subsample"]}] -) -@pytest.mark.parametrize("X", [np.ones((100, 10), dtype=int)]) -@pytest.mark.parametrize("y", [np.random.random((100,))]) -@unittest.mock.patch("autosklearn.automl.reduce_dataset_size_if_too_large") -def test_fit_performs_dataset_compression_without_precision_with_int( - mock_reduce_dataset: unittest.mock.MagicMock, - dataset_compression: Dict, - X: np.ndarray, - y: np.ndarray, -) -> None: - """We can't reduce the precision of ints as we do with floats. Suppose someone - was to pass a column with `max_int64` and `min_int64`, any reduction of bits will - cause this information to be lost and not simply reduce precision as it does with - floats. - - Parameters - ---------- - mock_reduce_dataset: MagicMock - A mock function to check the parameters that were passed in - - dataset_compression: Dict - The dataset_compression arg with "precision" set in it - - X: np.ndarray - An array of ints which we can't reduce precision of - - y: np.ndarray - An array of floats as the regression target - - Expects - ------- - * Should call reduce_dataset_size_if_too_large - * "precision" should have been removed from the "methods" passed to the keyword - argument "operations" of `reduce_dataset_size_if_too_large` - """ - # We just return the data - mock_reduce_dataset.return_value = X, y - - auto = AutoMLRegressor( - time_left_for_this_task=30, # not used but required - per_run_time_limit=5, # not used but required - dataset_compression=dataset_compression, - ) - - # To prevent fitting anything we use `only_return_configuration_space` - auto.fit(X, y, only_return_configuration_space=True) - - assert mock_reduce_dataset.call_count == 1 - - args, kwargs = mock_reduce_dataset.call_args - assert kwargs["operations"] == ["subsample"] - - -@pytest.mark.parametrize("dataset_compression", [True]) -@pytest.mark.parametrize( - "X", - [ - np.empty((10, 10)), - csr_matrix(np.identity(10)), - pytest.param( - np.empty((10, 10)).tolist(), - marks=pytest.mark.xfail(reason="Converted to dataframe by InputValidator"), - ), - pytest.param( - pd.DataFrame(np.empty((10, 10))), - marks=pytest.mark.xfail( - reason="No pandas support yet for dataset compression" - ), - ), - ], -) -@pytest.mark.parametrize( - "y", - [ - np.random.random((10, 1)), - np.random.random((10, 1)).tolist(), - pytest.param( - pd.Series(np.random.random((10,))), - marks=pytest.mark.xfail( - reason="No pandas support yet for dataset compression" - ), - ), - pytest.param( - pd.DataFrame(np.random.random((10, 10))), - marks=pytest.mark.xfail( - reason="No pandas support yet for dataset compression" - ), - ), - ], -) -@unittest.mock.patch("autosklearn.automl.reduce_dataset_size_if_too_large") -def test_fit_performs_dataset_compression( - mock_reduce_dataset: unittest.mock.MagicMock, - dataset_compression: bool, - X: Union[np.ndarray, spmatrix, List, pd.DataFrame], - y: Union[np.ndarray, List, pd.Series, pd.DataFrame], -) -> None: - """ - Parameters - ---------- - mock_reduce_dataset: MagicMock - A mock function to view call count - - dataset_compression: bool - Dataset compression set to True - - X: Union[np.ndarray, spmatrix, List, pd.Dataframe] - Feature to reduce - - y: Union[np.ndarray, List, pd.Series, pd.Dataframe] - Target to reduce (regression values) - - Expects - ------- - * Should call reduce_dataset_size_if_too_large - """ - # We just return the data - mock_reduce_dataset.return_value = X, y - - auto = AutoMLRegressor( - time_left_for_this_task=30, # not used but required - per_run_time_limit=5, # not used but required - dataset_compression=dataset_compression, - ) - - # To prevent fitting anything we use `only_return_configuration_space` - auto.fit(X, y, only_return_configuration_space=True) - - assert mock_reduce_dataset.called diff --git a/test/test_automl/test_construction.py b/test/test_automl/test_construction.py new file mode 100644 index 0000000000..5b15812acd --- /dev/null +++ b/test/test_automl/test_construction.py @@ -0,0 +1,95 @@ +"""Property based Tests + +These test are for checking properties of already fitted models. Any test that does +tests using cases should not modify the state as these models are cached between tests +to reduce training time. +""" +from typing import Any, Dict, Optional, Union + +from autosklearn.automl import AutoML +from autosklearn.util.data import default_dataset_compression_arg +from autosklearn.util.single_thread_client import SingleThreadedClient + +import pytest +from pytest_cases import parametrize + + +@parametrize("disable_evaluator_output", [("hello", "world"), ("model", "other")]) +def test_invalid_disable_eval_output_options(disable_evaluator_output: Any) -> None: + """ + Parameters + ---------- + disable_evaluator_output : Iterable[str] + An iterator of invalid options + + Expects + ------- + * Should raise an error about invalid options + """ + with pytest.raises(ValueError, match="Unknown arg"): + AutoML( + time_left_for_this_task=30, + per_run_time_limit=5, + disable_evaluator_output=disable_evaluator_output, + ) + + +@parametrize( + "dataset_compression, expected", + [ + (False, None), + (True, default_dataset_compression_arg), + ( + {"memory_allocation": 0.2}, + {**default_dataset_compression_arg, **{"memory_allocation": 0.2}}, + ), + ], +) +def test_param_dataset_compression_args( + dataset_compression: Union[bool, Dict], + expected: Optional[Dict], +) -> None: + """ + Parameters + ---------- + dataset_compression: Union[bool, Dict] + The dataset_compression arg used + + expected: Optional[Dict] + The expected internal variable setting + + Expects + ------- + * Setting the compression arg should result in the expected value + * False -> None, No dataset compression + * True -> default, The default settings + * dict -> default updated, The default should be updated with the args used + """ + auto = AutoML( + time_left_for_this_task=30, + per_run_time_limit=5, + dataset_compression=dataset_compression, + ) + assert auto._dataset_compression == expected + + +def test_single_job_and_no_dask_client_sets_correct_multiprocessing_context() -> None: + """ + Expects + ------- + * With n_jobs set to 1 and no dask client, we default to a SingleThreadedClient + with a "fork" _multiprocessing_context + """ + n_jobs = 1 + dask_client = None + + automl = AutoML( + time_left_for_this_task=30, + per_run_time_limit=5, + n_jobs=n_jobs, + dask_client=dask_client, + ) + + assert automl._multiprocessing_context == "fork" + assert automl._n_jobs == 1 + assert isinstance(automl._dask_client, SingleThreadedClient) diff --git a/test/test_automl/test_dataset_compression.py b/test/test_automl/test_dataset_compression.py new file mode 100644 index 0000000000..d50869ebbf --- /dev/null +++ b/test/test_automl/test_dataset_compression.py @@ -0,0 +1,135 @@ +from typing import Any, Callable, Dict + +import numpy as np +import pandas as pd +from scipy.sparse import csr_matrix + +from autosklearn.automl import AutoML +from autosklearn.constants import BINARY_CLASSIFICATION + +from pytest_cases import parametrize +from unittest.mock import patch + +from test.util import skip + + +@parametrize("dataset_compression", [{"methods": ["precision", "subsample"]}]) +def test_fit_performs_dataset_compression_without_precision_when_int( + dataset_compression: Dict, + make_automl: Callable[..., AutoML], +) -> None: + """ + Parameters + ---------- + dataset_compression: Dict + The dataset_compression arg with "precision" set in it + + Fixtures + -------- + make_automl: Callable[..., AutoML] + Makes an automl instance + + + Expects + ------- + * Should call reduce_dataset_size_if_too_large + * "precision" should have been removed from the "methods" passed to the keyword + argument "operations" of `reduce_dataset_size_if_too_large`. + + Note + ---- + * Only done with int's as we can't reduce precision of ints in a meaningful way + """ + X = np.ones((100, 10), dtype=int) + y = np.random.random((100,)) + + auto = make_automl(dataset_compression=dataset_compression) + + with patch( + "autosklearn.automl.reduce_dataset_size_if_too_large", return_value=(X, y) + ) as mck: + # To prevent fitting anything we use `only_return_configuration_space` + auto.fit(X, y, only_return_configuration_space=True, task=BINARY_CLASSIFICATION) + + assert mck.call_count == 1 + + args, kwargs = mck.call_args + assert kwargs["operations"] == ["subsample"] + + +@parametrize( + "X_type", + [ + np.ndarray, + csr_matrix, + skip( + list, + "dataset_compression does not support pandas types yet and list gets" + " converted in InputValidator", + ), + skip(pd.DataFrame, "dataset_compression does not support pandas types yet"), + ], +) +@parametrize( + "y_type", + [ + np.ndarray, + skip(csr_matrix, "See TODO note in `test_fit_performs_dataset_compression`"), + list, + skip(pd.DataFrame, "dataset_compression does not support pandas types yet"), + skip(pd.Series, "dataset_compression does not support pandas types yet"), + ], +) +def test_fit_performs_dataset_compression( + X_type: Any, + y_type: Any, + make_automl: Callable[..., AutoML], + make_data: Callable[..., Any], +) -> None: + """ + Parameters + ---------- + mock_reduce_dataset: MagicMock + A mock function to view call + + X_type: Union[np.ndarray, csr_matrix, list, pd.Dataframe] + Feature to reduce + + y_type: Union[np.ndarray, csr_matrix, list, pd.Series, pd.Dataframe] + Target to reduce (regression values) + + Fixtures + -------- + make_automl: Callable[..., AutoML] + Factory to make automl instance + + make_data: Callable + Factory to make data + + Expects + ------- + * Should call reduce_dataset_size_if_too_large + + # TODO not sure how to keep function behaviour and just use the mock object so + # that we can assert it was called. + # + # * `fit` will convert sparse `y` + # * This gets passed to `reduce_dataset_size_if_too_large` + # * The de-sparsified `y` is required for the datamanager later on + # + # Mocking away the functionality and just returning the X, y we see here will means + # that the datamanager will get the sparse y and crash, hence we manually convert + # here + """ + X, y = make_data(types=(X_type, y_type)) + + auto = make_automl(dataset_compression=True) + + with patch( + "autosklearn.automl.reduce_dataset_size_if_too_large", return_value=(X, y) + ) as mck: + # To prevent fitting anything we use `only_return_configuration_space` + auto.fit(X, y, only_return_configuration_space=True, task=BINARY_CLASSIFICATION) + + assert mck.called + del auto diff --git a/test/test_automl/test_dummy_predictions.py b/test/test_automl/test_dummy_predictions.py new file mode 100644 index 0000000000..9a268d1a2c --- /dev/null +++ b/test/test_automl/test_dummy_predictions.py @@ -0,0 +1,185 @@ +from typing import Callable, Tuple + +from pathlib import Path + +import numpy as np +from smac.tae import StatusType + +from autosklearn.automl import AutoML +from autosklearn.constants import ( + BINARY_CLASSIFICATION, + MULTICLASS_CLASSIFICATION, + REGRESSION, +) +from autosklearn.data.xy_data_manager import XYDataManager +from autosklearn.metrics import Scorer, accuracy, r2 +from autosklearn.util.logging_ import PicklableClientLogger + +import pytest +from pytest_cases import parametrize +from unittest.mock import patch + + +@parametrize( + "dataset, metric, task", + [ + ("breast_cancer", accuracy, BINARY_CLASSIFICATION), + ("wine", accuracy, MULTICLASS_CLASSIFICATION), + ("diabetes", r2, REGRESSION), + ], +) +def test_produces_correct_output( + dataset: str, + task: int, + metric: Scorer, + mock_logger: PicklableClientLogger, + make_automl: Callable[..., AutoML], + make_sklearn_dataset: Callable[..., XYDataManager], +) -> None: + """ + Parameters + ---------- + dataset: str + The name of the dataset + + task : int + The task type of the dataset + + metric: Scorer + Metric to use, required as fit usually determines the metric to use + + Fixtures + -------- + mock_logger: PickleableClientLogger + A mock logger to use + + make_automl : Callable[..., AutoML] + Factory to make an AutoML object + + make_sklearn_dataset : Callable[..., XYDataManager] + Factory to get an sklearn dataset + + Expects + ------- + * There should only be one output created with one dummy predictions + * It should be named "1337_1_0.0" with {seed}_{num_run}_{budget} + * It should produce predictions "predictions_ensemble_1337_1_0.0.npy" + """ + seed = 1337 + automl = make_automl(metric=metric, seed=seed) + automl._logger = mock_logger + + datamanager = make_sklearn_dataset( + dataset, + as_datamanager=True, + task=task, + feat_type="numerical", + ) + automl._backend.save_datamanager(datamanager) + automl._do_dummy_prediction() + + path = Path(automl._backend.get_runs_directory()) + run_paths = list(path.iterdir()) + assert len(run_paths) == 1 + + dummy_run_path = run_paths[0] + assert dummy_run_path.name == f"{seed}_1_0.0" + + predictions_path = dummy_run_path / f"predictions_ensemble_{seed}_1_0.0.npy" + assert predictions_path.exists() + + +def test_runs_with_correct_args( + mock_logger: PicklableClientLogger, + make_sklearn_dataset: Callable[..., Tuple[np.ndarray, ...]], + make_automl: Callable[..., AutoML], +) -> None: + """ + Fixtures + -------- + mock_logger: PickleableClientLogger + A mock logger to use + + make_sklearn_dataset : Callable[..., Tuple[np.ndarray, ...]] + Factory to make dataset + + make_automl : Callable[..., AutoML] + Factory to make automl + + Expects + ------- + * The mock run should be called once with: + * config = 1 (The always given number for the dummy) + * cutoff = `automl._time_for_task` (the fulll time for the task) + """ + dataset = "iris" + task = MULTICLASS_CLASSIFICATION + + automl = make_automl(metric=accuracy) + automl._logger = mock_logger + + datamanager = make_sklearn_dataset( + dataset, + as_datamanager=True, + task=task, + feat_type="numerical", + ) + automl._backend.save_datamanager(datamanager) + + with patch("autosklearn.evaluation.ExecuteTaFuncWithQueue.run") as mck: + mck.return_value = (StatusType.SUCCESS, None, None, {}) + automl._do_dummy_prediction() + + mck.assert_called_once_with(config=1, cutoff=automl._time_for_task) + + +def test_crash_due_to_memory_exception( + mock_logger: PicklableClientLogger, + make_sklearn_dataset: Callable[..., Tuple[np.ndarray, ...]], + make_automl: Callable[..., AutoML], +) -> None: + """ + Fixtures + -------- + mock_logger: PickleableClientLogger + A mock logger to use + + make_sklearn_dataset : Callable[..., Tuple[np.ndarray, ...]] + Factory to make dataset + + make_automl : Callable[..., AutoML] + Factory to make automl + + Expects + ------- + * The dummy prediction should raise when encoutering with StatusType.CRASHED + * The error message should indicate it's a memory issue with `{'exitcode' -6}` + encountered + """ + dataset = "iris" + task = MULTICLASS_CLASSIFICATION + + automl = make_automl(metric=accuracy) + automl._logger = mock_logger + + datamanager = make_sklearn_dataset( + dataset, + as_datamanager=True, + task=task, + feat_type="numerical", + ) + + automl._backend.save_datamanager(datamanager) + + with patch("autosklearn.evaluation.ExecuteTaFuncWithQueue.run") as mck: + mck.return_value = (StatusType.CRASHED, None, None, {"exitcode": -6}) + msg = "The error suggests that the provided memory limits are too tight." + + with pytest.raises(ValueError, match=msg): + automl._do_dummy_prediction() + + +def test_raises_if_no_metric_set(make_automl: Callable[..., AutoML]) -> None: + automl = make_automl() + with pytest.raises(ValueError, match="Metric was not set"): + automl._do_dummy_prediction() diff --git a/test/test_automl/test_fit.py b/test/test_automl/test_fit.py new file mode 100644 index 0000000000..2defa2518b --- /dev/null +++ b/test/test_automl/test_fit.py @@ -0,0 +1,82 @@ +from typing import Any, Callable, Dict, Optional, Tuple, Union + +import numpy as np +from dask.distributed import Client +from smac.facade.roar_facade import ROAR +from smac.scenario.scenario import Scenario + +from autosklearn.automl import AutoML +from autosklearn.constants import MULTICLASS_CLASSIFICATION + +from pytest_cases import parametrize + + +@parametrize("dataset, task, bounds", [("iris", MULTICLASS_CLASSIFICATION, (0.8, 1.0))]) +def test_fit_roar( + dataset: str, + task: int, + bounds: Tuple[float, float], + dask_client_single_worker: Client, + make_automl: Callable[..., AutoML], + make_sklearn_dataset: Callable[..., Tuple[np.ndarray, ...]], +) -> None: + """ + Parameters + ---------- + dataset : str + The name of the dataset + + task : int + The task type of the dataset + + bounds : Tuple[float, float] + The bounds the final score should be in, (lowest, upper) + + Fixtures + -------- + make_automl : Callable[..., AutoML] + Factory for making an AutoML instance + + make_sklearn_dataset : Callable[..., Tuple[np.ndarray, ...]] + Factory for getting a dataset + + Expects + ------- + * Should fit without a problem using a different smac object + """ + + def get_roar_object_callback( + scenario_dict: Dict, + seed: Optional[Union[int, np.random.RandomState]], + ta: Callable, + ta_kwargs: Dict, + dask_client: Client, + n_jobs: int, + **kwargs: Any, + ) -> ROAR: + """Random online adaptive racing. + + http://ml.informatik.uni-freiburg.de/papers/11-LION5-SMAC.pdf + """ + scenario = Scenario(scenario_dict) + return ROAR( + run_id=seed, + scenario=scenario, + rng=seed, + tae_runner=ta, + tae_runner_kwargs=ta_kwargs, + dask_client=dask_client, + n_jobs=n_jobs, + ) + + X_train, Y_train, X_test, Y_test = make_sklearn_dataset(dataset) + automl = make_automl( + initial_configurations_via_metalearning=0, + get_smac_object_callback=get_roar_object_callback, + dask_client=dask_client_single_worker, + ) + + automl.fit(X_train, Y_train, task=task) + + score = automl.score(X_test, Y_test) + assert score > 0.8 diff --git a/test/test_automl/test_fit_ensemble.py b/test/test_automl/test_fit_ensemble.py new file mode 100644 index 0000000000..14bc36192a --- /dev/null +++ b/test/test_automl/test_fit_ensemble.py @@ -0,0 +1,35 @@ +import numpy as np + +from autosklearn.automl import AutoML + +import pytest +from pytest_cases import filters as ft +from pytest_cases import parametrize, parametrize_with_cases + +import test.test_automl.cases as cases + + +@parametrize("ensemble_size", [-10, -1, 0]) +@parametrize_with_cases("automl", cases=cases, filter=~ft.has_tag("fitted")) +def test_non_positive_ensemble_size_raises( + tmp_dir: str, + automl: AutoML, + ensemble_size: int, +) -> None: + """ + Parameters + ---------- + automl: AutoML + The AutoML object to test + + ensemble_size : int + The ensemble size to use + + Expects + ------- + * Can't fit ensemble with non-positive ensemble size + """ + dummy_data = np.array([1, 1, 1]) + + with pytest.raises(ValueError): + automl.fit_ensemble(dummy_data, ensemble_size=ensemble_size) diff --git a/test/test_automl/test_fit_pipeline.py b/test/test_automl/test_fit_pipeline.py new file mode 100644 index 0000000000..137b57a5c3 --- /dev/null +++ b/test/test_automl/test_fit_pipeline.py @@ -0,0 +1 @@ +"""TODO""" diff --git a/test/test_automl/test_model_predict.py b/test/test_automl/test_model_predict.py new file mode 100644 index 0000000000..a301d1a9a5 --- /dev/null +++ b/test/test_automl/test_model_predict.py @@ -0,0 +1,151 @@ +from typing import Callable, Dict, Tuple + +import warnings + +import numpy as np +from sklearn.ensemble import VotingClassifier, VotingRegressor + +from autosklearn.automl import _model_predict +from autosklearn.constants import ( + BINARY_CLASSIFICATION, + MULTICLASS_CLASSIFICATION, + MULTILABEL_CLASSIFICATION, + MULTIOUTPUT_REGRESSION, + REGRESSION, +) +from autosklearn.util.logging_ import PicklableClientLogger + +from pytest_cases import parametrize + + +class WarningModel: + def predict(self, X: np.ndarray) -> np.ndarray: + warnings.warn("shout") + return X + + +@parametrize( + "dataspec, expected_shape", + [ + ({"kind": BINARY_CLASSIFICATION, "dims": (100, 5)}, (100, 2)), + ({"kind": MULTICLASS_CLASSIFICATION, "dims": (100, 5), "classes": 3}, (100, 3)), + ( + { + "kind": MULTILABEL_CLASSIFICATION, + "dims": (100, 5), + "classes": [[0, 0], [0, 1], [1, 0], [1, 1]], + }, + (100, 2), # TODO seems wrong + ), + ], +) +def test_classifier_output_shape( + dataspec: Dict, + expected_shape: Tuple[int, ...], + make_voting_classifier: Callable[..., VotingClassifier], + make_data: Callable[..., Tuple[np.ndarray, np.ndarray]], +) -> None: + """ + Parameters + ---------- + dataspec : Dict + The spec to make data of + + expected_shape : Tuple[int, ...] + The expected shape of the output of _model_predict + + Fixtures + -------- + make_voting_classifier : Callable[..., VotingClassifier] + Factory to make a voting classifier which _model_predict expects + + make_data : Callable[..., Tuple[np.ndarray, np.ndarray]] + Factory to make data according to a spec + + Expects + ------- + * The output shape after predicting should be the expected shape + + Note + ---- + * The output shape for MULTILABEL_CLASSIFICATION seems wrong according to + + """ + task = dataspec["kind"] + X, y = make_data(**dataspec) + + voter = make_voting_classifier(X=X, y=y) + + output = _model_predict(voter, X, task=task) + assert output.shape == expected_shape + + +@parametrize( + "dataspec, expected_shape", + [ + ({"kind": REGRESSION, "dims": (100, 5)}, (100,)), + ({"kind": MULTIOUTPUT_REGRESSION, "dims": (100, 5), "targets": 3}, (100, 3)), + ], +) +def test_regressor_output_shape( + dataspec: Dict, + expected_shape: Tuple[int, ...], + make_voting_regressor: Callable[..., VotingRegressor], + make_data: Callable[..., Tuple[np.ndarray, np.ndarray]], +) -> None: + """ + Parameters + ---------- + dataspec : Dict + The spec to make data of + + expected_shape : Tuple[int, ...] + The expected shape of the output of _model_predict + + Fixtures + -------- + make_voting_regressor: Callable[..., VotingRegressor] + Factory to make a voting classifier which _model_predict expects + + make_data : Callable[..., Tuple[np.ndarray, np.ndarray]] + Factory to make data according to a spec + """ + task = dataspec["kind"] + X, y = make_data(**dataspec) + + voter = make_voting_regressor(X=X, y=y) + + output = _model_predict(voter, X, task=task) + assert output.shape == expected_shape + + +def test_outputs_warnings_to_logs( + mock_logger: PicklableClientLogger, +) -> None: + """ + Fixtures + -------- + mock_logger : PicklableClientLogger + A mock logger that can be queried for call counts + + Expects + ------- + * Any warning emitted by a model should be redirected to the logger + """ + _model_predict( + model=WarningModel(), X=np.eye(5), task=REGRESSION, logger=mock_logger + ) + + assert mock_logger.warning.call_count == 1 # type: ignore + + +def test_outputs_to_stdout_if_no_logger() -> None: + """ + Expects + ------- + * With no logger, any warning emitted by a model goes to standard out + """ + with warnings.catch_warnings(record=True) as w: + _model_predict(model=WarningModel(), X=np.eye(5), task=REGRESSION, logger=None) + + assert len(w) == 1, "One warning sould have been emmited" diff --git a/test/test_automl/test_outputs.py b/test/test_automl/test_outputs.py new file mode 100644 index 0000000000..458347c145 --- /dev/null +++ b/test/test_automl/test_outputs.py @@ -0,0 +1,107 @@ +from pathlib import Path + +from autosklearn.automl import AutoML + +from pytest import mark +from pytest_cases import parametrize_with_cases +from pytest_cases.filters import has_tag + +import test.test_automl.cases as cases +from test.conftest import DEFAULT_SEED + +# Some filters +has_ensemble = has_tag("fitted") & ~has_tag("no_ensemble") +no_ensemble = has_tag("fitted") & has_tag("no_ensemble") + + +@mark.todo +def test_datamanager_stored_contents() -> None: + ... + + +@parametrize_with_cases("automl", cases=cases, has_tag="fitted") +def test_paths_created(automl: AutoML) -> None: + """ + Parameters + ---------- + automl : AutoML + A previously fitted automl + + Expects + ------- + * The given paths should exist after the automl has been run and fitted + """ + assert automl._backend is not None + + partial = Path(automl._backend.internals_directory) + expected = [ + partial / fixture + for fixture in ( + "true_targets_ensemble.npy", + f"start_time_{DEFAULT_SEED}", + "datamanager.pkl", + "runs", + ) + ] + + for path in expected: + assert path.exists() + + +@parametrize_with_cases("automl", cases=cases, filter=has_ensemble) +def test_paths_created_with_ensemble(automl: AutoML) -> None: + """ + Parameters + ---------- + automl : AutoML + A previously fitted automl + + Expects + ------- + * The given paths for an automl with an ensemble should include paths + specific to ensemble building + """ + assert automl._backend is not None + + partial = Path(automl._backend.internals_directory) + expected = [ + partial / fixture + for fixture in ( + "ensemble_read_preds.pkl", + "ensemble_read_losses.pkl", + "ensembles", + "ensemble_history.json", + ) + ] + + for path in expected: + assert path.exists() + + +@parametrize_with_cases("automl", cases=cases, has_tag="fitted") +def test_at_least_one_model_and_predictions(automl: AutoML) -> None: + assert automl._backend is not None + runs_dir = Path(automl._backend.get_runs_directory()) + + runs = list(runs_dir.iterdir()) + assert len(runs) > 0 + + at_least_one = False + for run in runs: + prediction_files = run.glob("predictions_ensemble*.npy") + model_files = run.glob("*.*.model") + + if any(prediction_files): + at_least_one = True + assert any(model_files), "Run produced prediction but no model" + + assert at_least_one, "No runs produced predictions" + + +@parametrize_with_cases("automl", cases=cases, filter=has_ensemble) +def test_at_least_one_ensemble(automl: AutoML) -> None: + assert automl._backend is not None + ens_dir = Path(automl._backend.get_ensemble_dir()) + + # TODO make more generic + assert len(list(ens_dir.glob("*.ensemble"))) > 0 diff --git a/test/test_automl/test_performance.py b/test/test_automl/test_performance.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/test_automl/test_performance_over_time.py b/test/test_automl/test_performance_over_time.py new file mode 100644 index 0000000000..d5cc327a41 --- /dev/null +++ b/test/test_automl/test_performance_over_time.py @@ -0,0 +1,38 @@ +from autosklearn.automl import AutoML + +from pytest_cases import parametrize_with_cases +from pytest_cases.filters import has_tag + +import test.test_automl.cases as cases + + +@parametrize_with_cases( + "automl", + cases=cases, + filter=has_tag("fitted") & ~has_tag("no_ensemble"), +) +def test_performance_over_time_with_ensemble(automl: AutoML) -> None: + """ + Parameters + ---------- + automl: AutoMLClassifier + The fitted automl instance with an ensemble + + Expects + ------- + * Performance over time should include only the given columns + * The performance over time should have at least one entry that isn't NaN + * The timestamps should be monotonic + """ + expected_performance_columns = { + "single_best_train_score", + "single_best_optimization_score", + "ensemble_optimization_score", + "Timestamp", + } + columns = automl.performance_over_time_.columns + assert set(columns) == set(expected_performance_columns) + + perf_over_time = automl.performance_over_time_ + assert len(perf_over_time.drop(columns="Timestamp").dropna()) != 0 + assert perf_over_time["Timestamp"].is_monotonic diff --git a/test/test_automl/test_post_fit.py b/test/test_automl/test_post_fit.py new file mode 100644 index 0000000000..674a452d02 --- /dev/null +++ b/test/test_automl/test_post_fit.py @@ -0,0 +1,62 @@ +from autosklearn.automl import AutoML + +from pytest_cases import parametrize_with_cases + +import test.test_automl.cases as cases + + +@parametrize_with_cases("automl", cases=cases, has_tag=["fitted", "holdout"]) +def test_holdout_loaded_models(automl: AutoML) -> None: + """ + Parameters + ---------- + automl : AutoML + The fitted automl object to test + + Expects + ------- + * The ensemble should not be empty + * The models_ should contain the identifiers of what's in the ensemble + * The cv_models_ attr should remain None + """ + assert automl.ensemble_ is not None + assert set(automl.models_.keys()) == set(automl.ensemble_.identifiers_) + assert automl.cv_models_ is None + + +@parametrize_with_cases("automl", cases=cases, has_tag=["fitted", "cv"]) +def test_cv_loaded_models(automl: AutoML) -> None: + """ + Parameters + ---------- + automl : AutoML + The fitted automl object to test + + Expects + ------- + * The ensemble should not be empty + * The models_ should contain the identifiers of what's in the ensemble + * The cv_models_ should contain the identifiers of what's in the ensemble + """ + assert automl.ensemble_ is not None + assert set(automl.models_.keys()) == set(automl.ensemble_.identifiers_) + assert set(automl.cv_models_.keys()) == set(automl.ensemble_.identifiers_) + + +@parametrize_with_cases("automl", cases=cases, has_tag=["fitted", "no_ensemble"]) +def test_no_ensemble(automl: AutoML) -> None: + """ + Parameters + ---------- + automl : AutoML + A fitted automl object with ensemble size specified as 0 + + Expects + ------- + * The ensemble should remain None + * The models_ should be empty + * The cv_models_ should remain None + """ + assert automl.ensemble_ is None + assert automl.models_ == [] + assert automl.cv_models_ is None diff --git a/test/test_automl/test_predict.py b/test/test_automl/test_predict.py new file mode 100644 index 0000000000..137b57a5c3 --- /dev/null +++ b/test/test_automl/test_predict.py @@ -0,0 +1 @@ +"""TODO""" diff --git a/test/test_automl/test_refit.py b/test/test_automl/test_refit.py new file mode 100644 index 0000000000..341486ab13 --- /dev/null +++ b/test/test_automl/test_refit.py @@ -0,0 +1,61 @@ +from typing import Callable, Union + +from itertools import repeat + +import numpy as np + +from autosklearn.automl import AutoML +from autosklearn.data.validation import InputValidator + +from pytest_cases import parametrize +from unittest.mock import Mock + + +@parametrize("budget_type", [None, "iterations"]) +def test_shuffle_on_fail( + budget_type: Union[None, str], + make_automl: Callable[..., AutoML], +) -> None: + """ + Parameters + ---------- + budget_type : Union[None, str] + The budget type to use + + Fixtures + -------- + make_automl : Callable[..., AutoML] + Factory to make an AutoML instance + + Expects + ------- + * The automl should not be able to fit before `refit` + * The model should be attempted to fitted `n_fails` times before successing once + after + * The automl should be able to fit after `refit` + """ + n_fails = 3 + failing_model = Mock() + failing_model.fit.side_effect = [ValueError()] * n_fails + [None] # type: ignore + failing_model.estimator_supports_iterative_fit.side_effect = repeat(False) + + ensemble_mock = Mock() + ensemble_mock.get_selected_model_identifiers.return_value = [(1, 1, 50.0)] + + X = np.ones((3, 2)) + y = np.ones((3,)) + + input_validator = InputValidator() + input_validator.fit(X, y) + + auto = make_automl() + auto.ensemble_ = ensemble_mock # type: ignore + auto.models_ = {(1, 1, 50.0): failing_model} + auto._budget_type = budget_type + auto.InputValidator = input_validator + + assert not auto._can_predict + auto.refit(X, y) + + assert failing_model.fit.call_count == n_fails + 1 + assert auto._can_predict diff --git a/test/test_automl/test_show_models.py b/test/test_automl/test_show_models.py new file mode 100644 index 0000000000..72b2e4f8d6 --- /dev/null +++ b/test/test_automl/test_show_models.py @@ -0,0 +1,20 @@ +from autosklearn.automl import AutoML + +from pytest_cases import parametrize_with_cases + +import test.test_automl.cases as cases + + +@parametrize_with_cases("automl", cases=cases, has_tag=["fitted", "no_ensemble"]) +def test_no_ensemble_produces_empty_show_models(automl: AutoML) -> None: + """ + Parameters + ---------- + automl : AutoML + The automl object with no ensemble size to test + + Expects + ------- + * Show models should return an empty dict + """ + assert automl.show_models() == {} diff --git a/test/test_automl/test_sklearn_compliance.py b/test/test_automl/test_sklearn_compliance.py new file mode 100644 index 0000000000..ce747e1bb8 --- /dev/null +++ b/test/test_automl/test_sklearn_compliance.py @@ -0,0 +1,76 @@ +""" +Note +---- +This is far from complete at the moment +""" +from typing import List, Union + +from sklearn.exceptions import NotFittedError + +from autosklearn.automl import AutoML + +import pytest +from pytest_cases import parametrize, parametrize_with_cases +from pytest_cases.filters import has_tag + +import test.test_automl.cases as cases + + +@pytest.mark.xfail( + reason="__sklearn_is_fitted__ only supported from sklearn 1.0 onwards" +) +@parametrize_with_cases("automl", cases=cases, filter=~has_tag("fitted")) +@parametrize( + "attr, argnames", + [ + ("refit", ["X", "y"]), + ("predict", ["X"]), + ("fit_ensemble", ["y"]), + ("score", ["X", "y"]), + ("performance_over_time_", None), + ("cv_results_", None), + ("sprint_statistics", []), + ("get_models_with_weights", []), + ("show_models", []), + ], +) +def test_attrs_raise_if_not_fitted( + automl: AutoML, + attr: str, + argnames: Union[List[str], None], +) -> None: + """ + Parameters + ---------- + automl : AutoML + An unfitted automl instance + + attr: str + The attribute to test + + argnames: Union[List[str], None] + The arguments of the the method + * ["arg1", "arg2"] for method with args + * [] for method with no args + * None for property + + Expects + ------- + * Should raise a NotFittedError + + Note + ---- + * This also ensures any validation should be done after the fit check as + NotFittedError should be raised + """ + with pytest.raises(NotFittedError): + + if argnames is None: + property = getattr(automl, attr) # noqa + else: + method = getattr(automl, attr) + args = {name: None for name in argnames} + if len(args) > 0: + method(args) + else: + method() diff --git a/test/test_data/test_feature_validator.py b/test/test_data/test_feature_validator.py index a06b91b0f3..3637b91e62 100644 --- a/test/test_data/test_feature_validator.py +++ b/test/test_data/test_feature_validator.py @@ -1,6 +1,5 @@ import numpy as np import pandas as pd -import pytest import sklearn.datasets import sklearn.model_selection from pandas.api.types import is_categorical_dtype, is_numeric_dtype, is_string_dtype @@ -8,6 +7,8 @@ from autosklearn.data.feature_validator import FeatureValidator +import pytest + # Fixtures to be used in this class. By default all elements have 100 datapoints @pytest.fixture diff --git a/test/test_data/test_target_validator.py b/test/test_data/test_target_validator.py index e57f464c72..7c08dba20a 100644 --- a/test/test_data/test_target_validator.py +++ b/test/test_data/test_target_validator.py @@ -1,6 +1,5 @@ import numpy as np import pandas as pd -import pytest import sklearn.datasets import sklearn.model_selection from pandas.api.types import is_bool_dtype, is_numeric_dtype @@ -9,6 +8,8 @@ from autosklearn.data.target_validator import TargetValidator +import pytest + # Fixtures to be used in this class. By default all elements have 100 datapoints @pytest.fixture diff --git a/test/test_data/test_validation.py b/test/test_data/test_validation.py index 4d09c65075..251a8405f2 100644 --- a/test/test_data/test_validation.py +++ b/test/test_data/test_validation.py @@ -1,12 +1,13 @@ import numpy as np import pandas as pd -import pytest import sklearn.datasets import sklearn.model_selection from scipy import sparse from autosklearn.data.validation import InputValidator +import pytest + @pytest.mark.parametrize("openmlid", [2, 40975, 40984]) @pytest.mark.parametrize("as_frame", [True, False]) diff --git a/test/test_ensemble_builder/ensemble_utils.py b/test/test_ensemble_builder/ensemble_utils.py index fa0f22e9e7..7a3cd7f252 100644 --- a/test/test_ensemble_builder/ensemble_utils.py +++ b/test/test_ensemble_builder/ensemble_utils.py @@ -1,7 +1,5 @@ import os import shutil -import unittest -import unittest.mock import numpy as np @@ -11,6 +9,9 @@ from autosklearn.ensemble_builder import EnsembleBuilder from autosklearn.metrics import make_scorer +import unittest +import unittest.mock + def scorer_function(a, b): return 0.9 diff --git a/test/test_ensemble_builder/test_ensemble.py b/test/test_ensemble_builder/test_ensemble.py index 3533da37cd..469f617fb0 100644 --- a/test/test_ensemble_builder/test_ensemble.py +++ b/test/test_ensemble_builder/test_ensemble.py @@ -3,12 +3,10 @@ import shutil import sys import time -import unittest.mock import dask.distributed import numpy as np import pandas as pd -import pytest from smac.runhistory.runhistory import RunHistory, RunKey, RunValue from autosklearn.constants import BINARY_CLASSIFICATION, MULTILABEL_CLASSIFICATION @@ -22,6 +20,9 @@ from autosklearn.ensembles.singlebest_ensemble import SingleBest from autosklearn.metrics import accuracy, log_loss, roc_auc +import pytest +import unittest.mock + this_directory = os.path.dirname(__file__) sys.path.append(this_directory) from ensemble_utils import ( # noqa (E402: module level import not at top of file) diff --git a/test/test_ensemble_builder/test_ensemble_selection.py b/test/test_ensemble_builder/test_ensemble_selection.py index 44e00229fb..07c972c59f 100644 --- a/test/test_ensemble_builder/test_ensemble_selection.py +++ b/test/test_ensemble_builder/test_ensemble_selection.py @@ -1,10 +1,11 @@ import numpy as np -import pytest from autosklearn.constants import BINARY_CLASSIFICATION, REGRESSION from autosklearn.ensembles.ensemble_selection import EnsembleSelection from autosklearn.metrics import accuracy, root_mean_squared_error +import pytest + def testEnsembleSelection(): """ diff --git a/test/test_estimators/__init__.py b/test/test_estimators/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/test_estimators/cases.py b/test/test_estimators/cases.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/test_automl/test_estimators.py b/test/test_estimators/test_estimators.py similarity index 97% rename from test/test_automl/test_estimators.py rename to test/test_estimators/test_estimators.py index ac60e51472..cd4b0922de 100644 --- a/test/test_automl/test_estimators.py +++ b/test/test_estimators/test_estimators.py @@ -8,16 +8,12 @@ import os import pickle import re -import sys import tempfile -import unittest -import unittest.mock import joblib import numpy as np import numpy.ma as npma import pandas as pd -import pytest import sklearn import sklearn.datasets import sklearn.dummy @@ -42,8 +38,11 @@ from autosklearn.metrics import accuracy, f1_macro, mean_squared_error, r2 from autosklearn.smbo import get_smac_object -sys.path.append(os.path.dirname(__file__)) -from automl_utils import ( # noqa (E402: module level import not at top of file) +import pytest +import unittest +import unittest.mock + +from test.test_automl.automl_utils import ( count_succeses, include_single_scores, includes_all_scores, @@ -54,7 +53,6 @@ def test_fit_n_jobs(tmp_dir): - X_train, Y_train, X_test, Y_test = putil.get_dataset("breast_cancer") # test parallel Classifier to predict classes, not only indices @@ -77,7 +75,7 @@ def __call__(self, *args, **kwargs): delete_tmp_folder_after_terminate=False, time_left_for_this_task=30, per_run_time_limit=5, - tmp_folder=tmp_dir, + tmp_folder=os.path.join(tmp_dir, "backend"), seed=1, initial_configurations_via_metalearning=0, ensemble_size=5, @@ -147,18 +145,18 @@ def test_feat_type_wrong_arguments(): X = np.zeros((100, 100)) y = np.zeros((100,)) - cls = AutoSklearnClassifier(ensemble_size=0) + cls = AutoSklearnClassifier() expected_msg = r".*feat_type does not have same number of " "variables as X has features. 1 vs 100.*" with pytest.raises(ValueError, match=expected_msg): cls.fit(X=X, y=y, feat_type=[True]) - cls = AutoSklearnClassifier(ensemble_size=0) + cls = AutoSklearnClassifier() expected_msg = r".*feat_type must only contain strings.*" with pytest.raises(ValueError, match=expected_msg): cls.fit(X=X, y=y, feat_type=[True] * 100) - cls = AutoSklearnClassifier(ensemble_size=0) + cls = AutoSklearnClassifier() expected_msg = r".*Only `Categorical`, `Numerical` and `String` are" "valid feature types, you passed `Car`.*" with pytest.raises(ValueError, match=expected_msg): @@ -206,7 +204,7 @@ def test_type_of_target(mock_estimator): ] ) - cls = AutoSklearnClassifier(ensemble_size=0) + cls = AutoSklearnClassifier() cls.automl_ = unittest.mock.Mock() cls.automl_.InputValidator = unittest.mock.Mock() cls.automl_.InputValidator.target_validator = unittest.mock.Mock() @@ -248,7 +246,7 @@ def test_type_of_target(mock_estimator): ) # Test that regressor raises error for illegal target types. - reg = AutoSklearnRegressor(ensemble_size=0) + reg = AutoSklearnRegressor() # Illegal target types for regression: multilabel-indicator # multiclass-multioutput expected_msg = r".*Regression with data of type" @@ -300,10 +298,10 @@ def test_performance_over_time_no_ensemble(tmp_dir): cls = AutoSklearnClassifier( time_left_for_this_task=30, per_run_time_limit=5, - tmp_folder=tmp_dir, + tmp_folder=os.path.join(tmp_dir, "backend"), seed=1, - initial_configurations_via_metalearning=0, ensemble_size=0, + initial_configurations_via_metalearning=0, ) cls.fit(X_train, Y_train, X_test, Y_test) @@ -321,10 +319,9 @@ def test_cv_results(tmp_dir): cls = AutoSklearnClassifier( time_left_for_this_task=30, per_run_time_limit=5, - tmp_folder=tmp_dir, + tmp_folder=os.path.join(tmp_dir, "backend"), seed=1, initial_configurations_via_metalearning=0, - ensemble_size=0, scoring_functions=[autosklearn.metrics.precision, autosklearn.metrics.roc_auc], ) @@ -416,7 +413,10 @@ def test_leaderboard( X_train, Y_train, _, _ = putil.get_dataset(dataset_name) model = estimator_type( - time_left_for_this_task=30, per_run_time_limit=5, tmp_folder=tmp_dir, seed=1 + time_left_for_this_task=30, + per_run_time_limit=5, + tmp_folder=os.path.join(tmp_dir, "backend"), + seed=1, ) model.fit(X_train, Y_train) @@ -552,7 +552,7 @@ def test_show_models_with_holdout( automl = estimator( time_left_for_this_task=60, per_run_time_limit=5, - tmp_folder=tmp_dir, + tmp_folder=os.path.join(tmp_dir, "backend"), resampling_strategy=resampling_strategy, dask_client=dask_client, ) @@ -627,7 +627,7 @@ def test_show_models_with_cv( automl = estimator( time_left_for_this_task=120, per_run_time_limit=5, - tmp_folder=tmp_dir, + tmp_folder=os.path.join(tmp_dir, "backend"), resampling_strategy=resampling_strategy, dask_client=dask_client, ) @@ -684,7 +684,7 @@ def test_show_models_with_cv( @unittest.mock.patch("autosklearn.estimators.AutoSklearnEstimator.build_automl") def test_fit_n_jobs_negative(build_automl_patch): n_cores = cpu_count() - cls = AutoSklearnEstimator(n_jobs=-1, ensemble_size=0) + cls = AutoSklearnEstimator(n_jobs=-1) cls.fit() assert cls._n_jobs == n_cores @@ -764,7 +764,7 @@ def test_can_pickle_classifier(tmp_dir, dask_client): time_left_for_this_task=30, delete_tmp_folder_after_terminate=False, per_run_time_limit=5, - tmp_folder=tmp_dir, + tmp_folder=os.path.join(tmp_dir, "backend"), dask_client=dask_client, ) @@ -810,7 +810,7 @@ def test_multilabel(tmp_dir, dask_client): automl = AutoSklearnClassifier( time_left_for_this_task=30, per_run_time_limit=5, - tmp_folder=tmp_dir, + tmp_folder=os.path.join(tmp_dir, "backend"), dask_client=dask_client, ) @@ -836,7 +836,7 @@ def test_binary(tmp_dir, dask_client): time_left_for_this_task=40, delete_tmp_folder_after_terminate=False, per_run_time_limit=10, - tmp_folder=tmp_dir, + tmp_folder=os.path.join(tmp_dir, "backend"), dask_client=dask_client, ) @@ -878,7 +878,7 @@ def test_classification_pandas_support(tmp_dir, dask_client): exclude={"classifier": ["libsvm_svc"]}, dask_client=dask_client, seed=5, - tmp_folder=tmp_dir, + tmp_folder=os.path.join(tmp_dir, "backend"), ) automl.fit(X, y) @@ -905,7 +905,7 @@ def test_regression(tmp_dir, dask_client): automl = AutoSklearnRegressor( time_left_for_this_task=30, per_run_time_limit=5, - tmp_folder=tmp_dir, + tmp_folder=os.path.join(tmp_dir, "backend"), dask_client=dask_client, ) @@ -938,7 +938,7 @@ def test_cv_regression(tmp_dir, dask_client): time_left_for_this_task=60, per_run_time_limit=10, resampling_strategy="cv", - tmp_folder=tmp_dir, + tmp_folder=os.path.join(tmp_dir, "backend"), dask_client=dask_client, ) @@ -967,7 +967,7 @@ def test_regression_pandas_support(tmp_dir, dask_client): time_left_for_this_task=40, per_run_time_limit=5, dask_client=dask_client, - tmp_folder=tmp_dir, + tmp_folder=os.path.join(tmp_dir, "backend"), ) # Make sure we error out because y is not encoded @@ -987,8 +987,7 @@ def test_regression_pandas_support(tmp_dir, dask_client): def test_autosklearn_classification_methods_returns_self(dask_client): - """ - Currently this method only tests that the methods of AutoSklearnClassifier + """Currently this method only tests that the methods of AutoSklearnClassifier is able to fit using fit(), fit_ensemble() and refit() """ X_train, y_train, X_test, y_test = putil.get_dataset("iris") @@ -996,7 +995,6 @@ def test_autosklearn_classification_methods_returns_self(dask_client): time_left_for_this_task=60, delete_tmp_folder_after_terminate=False, per_run_time_limit=10, - ensemble_size=0, dask_client=dask_client, exclude={"feature_preprocessor": ["fast_ica"]}, ) @@ -1022,7 +1020,6 @@ def test_autosklearn_regression_methods_returns_self(dask_client): delete_tmp_folder_after_terminate=False, per_run_time_limit=5, dask_client=dask_client, - ensemble_size=0, ) automl_fitted = automl.fit(X_train, y_train) @@ -1039,7 +1036,6 @@ def test_autosklearn2_classification_methods_returns_self(dask_client): X_train, y_train, X_test, y_test = putil.get_dataset("iris") automl = AutoSklearn2Classifier( time_left_for_this_task=60, - ensemble_size=0, delete_tmp_folder_after_terminate=False, dask_client=dask_client, ) @@ -1069,7 +1065,6 @@ def test_autosklearn2_classification_methods_returns_self_sparse(dask_client): ) automl = AutoSklearn2Classifier( time_left_for_this_task=60, - ensemble_size=0, delete_tmp_folder_after_terminate=False, dask_client=dask_client, ) @@ -1201,7 +1196,6 @@ def test_fit_pipeline(dask_client, task_type, resampling_strategy, disable_file_ # Time left for task plays no role # only per run time limit per_run_time_limit=30, - ensemble_size=0, dask_client=dask_client, include=include, seed=seed, @@ -1283,7 +1277,6 @@ def test_fit_pipeline(dask_client, task_type, resampling_strategy, disable_file_ def test_pass_categorical_and_numeric_columns_to_pipeline( dask_client, data_type, include_categorical ): - # Prepare the training data X, y = sklearn.datasets.make_classification(random_state=0) X = cast(np.ndarray, X) @@ -1319,7 +1312,6 @@ def test_pass_categorical_and_numeric_columns_to_pipeline( delete_tmp_folder_after_terminate=False, time_left_for_this_task=120, per_run_time_limit=30, - ensemble_size=0, seed=0, dask_client=dask_client, include={"classifier": ["random_forest"]}, @@ -1384,7 +1376,6 @@ def test_autosklearn_anneal(as_frame): X, y = sklearn.datasets.fetch_openml(data_id=2, return_X_y=True, as_frame=as_frame) automl = AutoSklearnClassifier( time_left_for_this_task=60, - ensemble_size=0, delete_tmp_folder_after_terminate=False, initial_configurations_via_metalearning=0, smac_scenario_args={"runcount_limit": 6}, diff --git a/test/test_evaluation/evaluation_util.py b/test/test_evaluation/evaluation_util.py index d8bf017c35..62623a50ba 100644 --- a/test/test_evaluation/evaluation_util.py +++ b/test/test_evaluation/evaluation_util.py @@ -1,7 +1,6 @@ import functools import tempfile import traceback -import unittest import numpy as np import sklearn.datasets @@ -34,6 +33,8 @@ from autosklearn.pipeline.util import get_dataset from autosklearn.util.data import convert_to_bin +import unittest + SCORER_LIST = [ accuracy, balanced_accuracy, diff --git a/test/test_evaluation/test_abstract_evaluator.py b/test/test_evaluation/test_abstract_evaluator.py index c668a82ffd..7f88383bcd 100644 --- a/test/test_evaluation/test_abstract_evaluator.py +++ b/test/test_evaluation/test_abstract_evaluator.py @@ -4,8 +4,6 @@ import shutil import sys import tempfile -import unittest -import unittest.mock import numpy as np import sklearn.dummy @@ -16,6 +14,9 @@ from autosklearn.metrics import accuracy from autosklearn.pipeline.components.base import _addons +import unittest +import unittest.mock + this_directory = os.path.dirname(__file__) sys.path.append(this_directory) from evaluation_util import get_multiclass_classification_datamanager # noqa E402 diff --git a/test/test_evaluation/test_custom_splitters.py b/test/test_evaluation/test_custom_splitters.py index 64f9dc2f18..96670923dd 100644 --- a/test/test_evaluation/test_custom_splitters.py +++ b/test/test_evaluation/test_custom_splitters.py @@ -1,5 +1,4 @@ import numpy as np -import pytest from autosklearn.constants import ( BINARY_CLASSIFICATION, @@ -8,6 +7,8 @@ ) from autosklearn.evaluation.splitter import CustomStratifiedShuffleSplit +import pytest + @pytest.mark.parametrize( "task, X, y", diff --git a/test/test_evaluation/test_dummy_pipelines.py b/test/test_evaluation/test_dummy_pipelines.py index 3d5f1d0f59..8d1005e178 100644 --- a/test/test_evaluation/test_dummy_pipelines.py +++ b/test/test_evaluation/test_dummy_pipelines.py @@ -1,5 +1,4 @@ import numpy as np -import pytest from sklearn.base import clone from sklearn.datasets import make_classification, make_regression from sklearn.utils.validation import check_is_fitted @@ -9,9 +8,11 @@ MyDummyRegressor, ) +import pytest + @pytest.mark.parametrize("task_type", ["classification", "regression"]) -def test_dummy_pipeline(task_type): +def test_dummy_pipeline(task_type: str) -> None: if task_type == "classification": estimator_class = MyDummyClassifier data_maker = make_classification diff --git a/test/test_evaluation/test_evaluation.py b/test/test_evaluation/test_evaluation.py index 5df1f5fe50..1723b208f2 100644 --- a/test/test_evaluation/test_evaluation.py +++ b/test/test_evaluation/test_evaluation.py @@ -4,8 +4,6 @@ import shutil import sys import time -import unittest -import unittest.mock import numpy as np import pynisher @@ -17,6 +15,9 @@ from autosklearn.evaluation import ExecuteTaFuncWithQueue, get_cost_of_crash from autosklearn.metrics import accuracy, log_loss +import unittest +import unittest.mock + this_directory = os.path.dirname(__file__) sys.path.append(this_directory) from evaluation_util import ( # noqa E402 diff --git a/test/test_evaluation/test_test_evaluator.py b/test/test_evaluation/test_test_evaluator.py index 0a1b67faa9..8615682ce7 100644 --- a/test/test_evaluation/test_test_evaluator.py +++ b/test/test_evaluation/test_test_evaluator.py @@ -7,8 +7,6 @@ import shutil import sys import tempfile -import unittest -import unittest.mock import numpy as np from smac.tae import StatusType @@ -25,6 +23,9 @@ from autosklearn.metrics import accuracy, f1_macro, r2 from autosklearn.util.pipeline import get_configuration_space +import unittest +import unittest.mock + this_directory = os.path.dirname(__file__) sys.path.append(this_directory) from evaluation_util import ( # noqa (E402: module level import not at top of file) diff --git a/test/test_evaluation/test_train_evaluator.py b/test/test_evaluation/test_train_evaluator.py index 92e3cfcc10..afed8b5ce1 100644 --- a/test/test_evaluation/test_train_evaluator.py +++ b/test/test_evaluation/test_train_evaluator.py @@ -7,8 +7,6 @@ import shutil import sys import tempfile -import unittest -import unittest.mock import numpy as np import sklearn.model_selection @@ -53,6 +51,9 @@ from autosklearn.metrics import accuracy, f1_macro, r2 from autosklearn.util.pipeline import get_configuration_space +import unittest +import unittest.mock + this_directory = os.path.dirname(__file__) sys.path.append(this_directory) from evaluation_util import ( # noqa (E402: module level import not at top of file) diff --git a/test/test_metalearning/pyMetaLearn/metalearning/test_kND.py b/test/test_metalearning/pyMetaLearn/metalearning/test_kND.py index 4877379440..cb9dc80ab0 100644 --- a/test/test_metalearning/pyMetaLearn/metalearning/test_kND.py +++ b/test/test_metalearning/pyMetaLearn/metalearning/test_kND.py @@ -1,5 +1,4 @@ import logging -import unittest import numpy as np import pandas as pd @@ -7,6 +6,8 @@ from autosklearn.metalearning.metalearning.kNearestDatasets.kND import KNearestDatasets from autosklearn.metalearning.metalearning.metrics.misc import get_random_metric +import unittest + class kNDTest(unittest.TestCase): _multiprocess_can_split_ = True diff --git a/test/test_metalearning/pyMetaLearn/test_meta_base.py b/test/test_metalearning/pyMetaLearn/test_meta_base.py index 1c6788e816..3f06ad07be 100644 --- a/test/test_metalearning/pyMetaLearn/test_meta_base.py +++ b/test/test_metalearning/pyMetaLearn/test_meta_base.py @@ -1,12 +1,13 @@ import logging import os -import unittest import pandas as pd import autosklearn.pipeline.classification from autosklearn.metalearning.metalearning.meta_base import MetaBase +import unittest + class MetaBaseTest(unittest.TestCase): _multiprocess_can_split_ = True diff --git a/test/test_metalearning/pyMetaLearn/test_meta_features.py b/test/test_metalearning/pyMetaLearn/test_meta_features.py index 6a9bec4dcf..40048c708a 100644 --- a/test/test_metalearning/pyMetaLearn/test_meta_features.py +++ b/test/test_metalearning/pyMetaLearn/test_meta_features.py @@ -1,12 +1,10 @@ import logging import os import tempfile -import unittest import arff import numpy as np import pandas as pd -import pytest from joblib import Memory from sklearn.datasets import fetch_openml, make_multilabel_classification @@ -16,6 +14,9 @@ FeatTypeSplit, ) +import pytest +import unittest + @pytest.fixture(scope="class", params=("pandas", "numpy")) def multilabel_train_data(request): diff --git a/test/test_metalearning/pyMetaLearn/test_meta_features_sparse.py b/test/test_metalearning/pyMetaLearn/test_meta_features_sparse.py index 34a2c8e8d1..992032a349 100644 --- a/test/test_metalearning/pyMetaLearn/test_meta_features_sparse.py +++ b/test/test_metalearning/pyMetaLearn/test_meta_features_sparse.py @@ -3,7 +3,6 @@ import arff import numpy as np -import pytest from scipy import sparse from sklearn.impute import SimpleImputer from sklearn.preprocessing import StandardScaler @@ -13,6 +12,8 @@ FeatTypeSplit, ) +import pytest + @pytest.fixture def sparse_data(): diff --git a/test/test_metalearning/pyMetaLearn/test_metalearner.py b/test/test_metalearning/pyMetaLearn/test_metalearner.py index a8b7d604cb..42d27d49da 100644 --- a/test/test_metalearning/pyMetaLearn/test_metalearner.py +++ b/test/test_metalearning/pyMetaLearn/test_metalearner.py @@ -1,6 +1,5 @@ import logging import os -import unittest import numpy as np import pandas as pd @@ -10,6 +9,8 @@ import autosklearn.pipeline.classification from autosklearn.metalearning.metalearning.meta_base import MetaBase +import unittest + logging.basicConfig() diff --git a/test/test_metalearning/pyMetaLearn/test_optimizer_base.py b/test/test_metalearning/pyMetaLearn/test_optimizer_base.py index 63dc2184da..654f0026e5 100644 --- a/test/test_metalearning/pyMetaLearn/test_optimizer_base.py +++ b/test/test_metalearning/pyMetaLearn/test_optimizer_base.py @@ -1,8 +1,9 @@ -import unittest from collections import OrderedDict from autosklearn.metalearning.optimizers import optimizer_base +import unittest + class OptimizerBaseTest(unittest.TestCase): _multiprocess_can_split_ = True diff --git a/test/test_metalearning/test_metalearning.py b/test/test_metalearning/test_metalearning.py index 3ec847a8f5..e9e6e4ca1a 100644 --- a/test/test_metalearning/test_metalearning.py +++ b/test/test_metalearning/test_metalearning.py @@ -1,6 +1,4 @@ # -*- encoding: utf-8 -*- -import unittest - from sklearn.datasets import load_breast_cancer from autosklearn.classification import AutoSklearnClassifier @@ -10,6 +8,8 @@ from autosklearn.smbo import _calculate_metafeatures, _calculate_metafeatures_encoded from autosklearn.util.pipeline import get_configuration_space +import unittest + class MetafeatureValueDummy(object): def __init__(self, name, value): diff --git a/test/test_metric/test_metrics.py b/test/test_metric/test_metrics.py index 334a485fe3..541b2d6783 100644 --- a/test/test_metric/test_metrics.py +++ b/test/test_metric/test_metrics.py @@ -1,8 +1,6 @@ -import unittest import warnings import numpy as np -import pytest import sklearn.metrics from smac.utils.constants import MAXINT @@ -10,6 +8,9 @@ from autosklearn.constants import BINARY_CLASSIFICATION, REGRESSION from autosklearn.metrics import calculate_loss, calculate_metric, calculate_score +import pytest +import unittest + class TestScorer(unittest.TestCase): def test_predict_scorer_binary(self): diff --git a/test/test_optimizer/test_smbo.py b/test/test_optimizer/test_smbo.py index fafd7b5a42..eb23d3a932 100644 --- a/test/test_optimizer/test_smbo.py +++ b/test/test_optimizer/test_smbo.py @@ -1,6 +1,5 @@ import logging.handlers -import pytest from ConfigSpace.configuration_space import Configuration import autosklearn.metrics @@ -11,6 +10,8 @@ from autosklearn.smbo import AutoMLSMBO from autosklearn.util.stopwatch import StopWatch +import pytest + @pytest.mark.parametrize("context", ["fork", "forkserver"]) def test_smbo_metalearning_configurations(backend, context, dask_client): diff --git a/test/test_pipeline/components/classification/test_base.py b/test/test_pipeline/components/classification/test_base.py index a524759bc5..9fc54f4dba 100644 --- a/test/test_pipeline/components/classification/test_base.py +++ b/test/test_pipeline/components/classification/test_base.py @@ -1,7 +1,5 @@ from typing import Dict, Optional -import unittest - import numpy as np import sklearn.metrics @@ -12,6 +10,8 @@ _test_classifier_predict_proba, ) +import unittest + from test.test_pipeline.ignored_warnings import classifier_warnings, ignore_warnings diff --git a/test/test_pipeline/components/data_preprocessing/test_balancing.py b/test/test_pipeline/components/data_preprocessing/test_balancing.py index cf8dc103b8..6a76ce419c 100644 --- a/test/test_pipeline/components/data_preprocessing/test_balancing.py +++ b/test/test_pipeline/components/data_preprocessing/test_balancing.py @@ -1,7 +1,6 @@ __author__ = "feurerm" import copy -import unittest import numpy as np import sklearn.datasets @@ -33,6 +32,8 @@ LibLinear_Preprocessor, ) +import unittest + class BalancingComponentTest(unittest.TestCase): def test_balancing_get_weights_treed_single_label(self): diff --git a/test/test_pipeline/components/data_preprocessing/test_categorical_imputation.py b/test/test_pipeline/components/data_preprocessing/test_categorical_imputation.py index d50e8cf842..41c383bc09 100644 --- a/test/test_pipeline/components/data_preprocessing/test_categorical_imputation.py +++ b/test/test_pipeline/components/data_preprocessing/test_categorical_imputation.py @@ -1,12 +1,13 @@ import numpy as np import pandas as pd -import pytest from scipy import sparse from autosklearn.pipeline.components.data_preprocessing.imputation.categorical_imputation import ( # noqa: E501 CategoricalImputation, ) +import pytest + @pytest.fixture def input_data_imputation(request): diff --git a/test/test_pipeline/components/data_preprocessing/test_category_shift.py b/test/test_pipeline/components/data_preprocessing/test_category_shift.py index ce637f50d4..d97b510fc4 100644 --- a/test/test_pipeline/components/data_preprocessing/test_category_shift.py +++ b/test/test_pipeline/components/data_preprocessing/test_category_shift.py @@ -1,5 +1,3 @@ -import unittest - import numpy as np import scipy.sparse @@ -7,6 +5,8 @@ CategoryShift, ) +import unittest + class CategoryShiftTest(unittest.TestCase): def test_data_type_consistency(self): diff --git a/test/test_pipeline/components/data_preprocessing/test_data_preprocessing.py b/test/test_pipeline/components/data_preprocessing/test_data_preprocessing.py index ac8e9abbe2..4c7ad8383c 100644 --- a/test/test_pipeline/components/data_preprocessing/test_data_preprocessing.py +++ b/test/test_pipeline/components/data_preprocessing/test_data_preprocessing.py @@ -1,5 +1,3 @@ -import unittest - import numpy as np from scipy import sparse @@ -7,6 +5,8 @@ FeatTypeSplit, ) +import unittest + class PreprocessingPipelineTest(unittest.TestCase): def do_a_fit_transform(self, sparse_input): diff --git a/test/test_pipeline/components/data_preprocessing/test_data_preprocessing_categorical.py b/test/test_pipeline/components/data_preprocessing/test_data_preprocessing_categorical.py index 1d693eb150..029b72d183 100644 --- a/test/test_pipeline/components/data_preprocessing/test_data_preprocessing_categorical.py +++ b/test/test_pipeline/components/data_preprocessing/test_data_preprocessing_categorical.py @@ -1,13 +1,13 @@ -import unittest - import numpy as np -import pytest from scipy import sparse from autosklearn.pipeline.components.data_preprocessing.feature_type_categorical import ( # noqa: E501 CategoricalPreprocessingPipeline, ) +import pytest +import unittest + class CategoricalPreprocessingPipelineTest(unittest.TestCase): def test_data_type_consistency(self): diff --git a/test/test_pipeline/components/data_preprocessing/test_data_preprocessing_numerical.py b/test/test_pipeline/components/data_preprocessing/test_data_preprocessing_numerical.py index 5a0a840501..d25cef2a2b 100644 --- a/test/test_pipeline/components/data_preprocessing/test_data_preprocessing_numerical.py +++ b/test/test_pipeline/components/data_preprocessing/test_data_preprocessing_numerical.py @@ -1,5 +1,3 @@ -import unittest - import numpy as np from scipy import sparse @@ -7,6 +5,8 @@ NumericalPreprocessingPipeline, ) +import unittest + class NumericalPreprocessingPipelineTest(unittest.TestCase): def test_data_type_consistency(self): diff --git a/test/test_pipeline/components/data_preprocessing/test_data_preprocessing_text.py b/test/test_pipeline/components/data_preprocessing/test_data_preprocessing_text.py index c0729b1dfc..eed5b01bea 100644 --- a/test/test_pipeline/components/data_preprocessing/test_data_preprocessing_text.py +++ b/test/test_pipeline/components/data_preprocessing/test_data_preprocessing_text.py @@ -1,5 +1,3 @@ -import unittest - import numpy as np import pandas as pd @@ -10,6 +8,8 @@ BagOfWordEncoder as BOW_distinct, ) +import unittest + class TextPreprocessingPipelineTest(unittest.TestCase): def test_fit_transform(self): diff --git a/test/test_pipeline/components/data_preprocessing/test_minority_coalescence.py b/test/test_pipeline/components/data_preprocessing/test_minority_coalescence.py index 8e73e963ab..7fa4c24720 100644 --- a/test/test_pipeline/components/data_preprocessing/test_minority_coalescence.py +++ b/test/test_pipeline/components/data_preprocessing/test_minority_coalescence.py @@ -1,5 +1,3 @@ -import unittest - import numpy as np import scipy.sparse @@ -10,6 +8,8 @@ NoCoalescence, ) +import unittest + class MinorityCoalescerTest(unittest.TestCase): def test_data_type_consistency(self): diff --git a/test/test_pipeline/components/data_preprocessing/test_one_hot_encoding.py b/test/test_pipeline/components/data_preprocessing/test_one_hot_encoding.py index 08d2cadd9e..989ed6784e 100644 --- a/test/test_pipeline/components/data_preprocessing/test_one_hot_encoding.py +++ b/test/test_pipeline/components/data_preprocessing/test_one_hot_encoding.py @@ -1,5 +1,3 @@ -import unittest - import numpy as np from scipy import sparse @@ -11,6 +9,8 @@ ) from autosklearn.pipeline.util import _test_preprocessing +import unittest + def create_X(instances=1000, n_feats=10, categs_per_feat=5, seed=0): rs = np.random.RandomState(seed) diff --git a/test/test_pipeline/components/data_preprocessing/test_scaling.py b/test/test_pipeline/components/data_preprocessing/test_scaling.py index 7f8249e3f1..b87223d14d 100644 --- a/test/test_pipeline/components/data_preprocessing/test_scaling.py +++ b/test/test_pipeline/components/data_preprocessing/test_scaling.py @@ -1,11 +1,11 @@ -import unittest - import numpy as np import sklearn.datasets from autosklearn.pipeline.components.data_preprocessing.rescaling import RescalingChoice from autosklearn.pipeline.util import get_dataset +import unittest + class ScalingComponentTest(unittest.TestCase): def _test_helper(self, Preprocessor, dataset=None, make_sparse=False): diff --git a/test/test_pipeline/components/feature_preprocessing/test_choice.py b/test/test_pipeline/components/feature_preprocessing/test_choice.py index 516cf318bf..89272aafcc 100644 --- a/test/test_pipeline/components/feature_preprocessing/test_choice.py +++ b/test/test_pipeline/components/feature_preprocessing/test_choice.py @@ -1,7 +1,7 @@ -import unittest - import autosklearn.pipeline.components.feature_preprocessing as fp +import unittest + class FeatureProcessingTest(unittest.TestCase): def test_get_available_components(self): diff --git a/test/test_pipeline/components/feature_preprocessing/test_fast_ica.py b/test/test_pipeline/components/feature_preprocessing/test_fast_ica.py index a38097a60e..717c2cce36 100644 --- a/test/test_pipeline/components/feature_preprocessing/test_fast_ica.py +++ b/test/test_pipeline/components/feature_preprocessing/test_fast_ica.py @@ -1,5 +1,3 @@ -import unittest - import sklearn.metrics from sklearn.linear_model import Ridge @@ -10,6 +8,8 @@ get_dataset, ) +import unittest + class FastICAComponentTest(PreprocessingTestCase): def test_default_configuration(self): diff --git a/test/test_pipeline/components/feature_preprocessing/test_kernel_pca.py b/test/test_pipeline/components/feature_preprocessing/test_kernel_pca.py index 2c5a8c865b..6af5cb88d8 100644 --- a/test/test_pipeline/components/feature_preprocessing/test_kernel_pca.py +++ b/test/test_pipeline/components/feature_preprocessing/test_kernel_pca.py @@ -1,5 +1,3 @@ -import unittest - import sklearn.metrics from sklearn.linear_model import RidgeClassifier @@ -10,6 +8,8 @@ get_dataset, ) +import unittest + class KernelPCAComponentTest(PreprocessingTestCase): def test_default_configuration(self): diff --git a/test/test_pipeline/components/feature_preprocessing/test_kitchen_sinks.py b/test/test_pipeline/components/feature_preprocessing/test_kitchen_sinks.py index 16ef41198d..52b2d0acee 100644 --- a/test/test_pipeline/components/feature_preprocessing/test_kitchen_sinks.py +++ b/test/test_pipeline/components/feature_preprocessing/test_kitchen_sinks.py @@ -1,10 +1,10 @@ -import unittest - from autosklearn.pipeline.components.feature_preprocessing.kitchen_sinks import ( RandomKitchenSinks, ) from autosklearn.pipeline.util import PreprocessingTestCase, _test_preprocessing +import unittest + class KitchenSinkComponent(PreprocessingTestCase): def test_default_configuration(self): diff --git a/test/test_pipeline/components/feature_preprocessing/test_nystroem_sampler.py b/test/test_pipeline/components/feature_preprocessing/test_nystroem_sampler.py index d6244c362f..5fa1269cf5 100644 --- a/test/test_pipeline/components/feature_preprocessing/test_nystroem_sampler.py +++ b/test/test_pipeline/components/feature_preprocessing/test_nystroem_sampler.py @@ -1,5 +1,3 @@ -import unittest - import numpy as np import sklearn.preprocessing @@ -8,6 +6,8 @@ ) from autosklearn.pipeline.util import _test_preprocessing, get_dataset +import unittest + class NystroemComponentTest(unittest.TestCase): def test_default_configuration(self): diff --git a/test/test_pipeline/components/feature_preprocessing/test_random_trees_embedding.py b/test/test_pipeline/components/feature_preprocessing/test_random_trees_embedding.py index f84675dc1a..82feaca5ec 100644 --- a/test/test_pipeline/components/feature_preprocessing/test_random_trees_embedding.py +++ b/test/test_pipeline/components/feature_preprocessing/test_random_trees_embedding.py @@ -1,5 +1,3 @@ -import unittest - import numpy as np import scipy.sparse @@ -8,6 +6,8 @@ ) from autosklearn.pipeline.util import _test_preprocessing, get_dataset +import unittest + class RandomTreesEmbeddingComponentTest(unittest.TestCase): def test_default_configuration(self): diff --git a/test/test_pipeline/components/feature_preprocessing/test_select_percentile_classification.py b/test/test_pipeline/components/feature_preprocessing/test_select_percentile_classification.py index b177e4f4ba..f0dbb3e947 100644 --- a/test/test_pipeline/components/feature_preprocessing/test_select_percentile_classification.py +++ b/test/test_pipeline/components/feature_preprocessing/test_select_percentile_classification.py @@ -1,5 +1,3 @@ -import unittest - import numpy as np import scipy.sparse import sklearn.preprocessing @@ -9,6 +7,8 @@ ) from autosklearn.pipeline.util import _test_preprocessing, get_dataset +import unittest + class SelectPercentileClassificationTest(unittest.TestCase): def test_default_configuration(self): diff --git a/test/test_pipeline/components/feature_preprocessing/test_select_percentile_regression.py b/test/test_pipeline/components/feature_preprocessing/test_select_percentile_regression.py index 0fd335fd83..98bc50a690 100644 --- a/test/test_pipeline/components/feature_preprocessing/test_select_percentile_regression.py +++ b/test/test_pipeline/components/feature_preprocessing/test_select_percentile_regression.py @@ -1,5 +1,3 @@ -import unittest - import numpy as np from autosklearn.pipeline.components.feature_preprocessing.select_percentile_regression import ( # noqa: E501 @@ -7,6 +5,8 @@ ) from autosklearn.pipeline.util import _test_preprocessing, get_dataset +import unittest + class SelectPercentileRegressionTest(unittest.TestCase): def test_default_configuration(self): diff --git a/test/test_pipeline/components/feature_preprocessing/test_select_rates_classification.py b/test/test_pipeline/components/feature_preprocessing/test_select_rates_classification.py index 2d1c2aaf78..03c6a45983 100644 --- a/test/test_pipeline/components/feature_preprocessing/test_select_rates_classification.py +++ b/test/test_pipeline/components/feature_preprocessing/test_select_rates_classification.py @@ -1,5 +1,3 @@ -import unittest - import numpy as np import scipy.sparse import sklearn.preprocessing @@ -9,6 +7,8 @@ ) from autosklearn.pipeline.util import _test_preprocessing, get_dataset +import unittest + class SelectClassificationRatesComponentTest(unittest.TestCase): def test_default_configuration(self): diff --git a/test/test_pipeline/components/feature_preprocessing/test_select_rates_regression.py b/test/test_pipeline/components/feature_preprocessing/test_select_rates_regression.py index 869d7fbee2..826d05e53e 100644 --- a/test/test_pipeline/components/feature_preprocessing/test_select_rates_regression.py +++ b/test/test_pipeline/components/feature_preprocessing/test_select_rates_regression.py @@ -1,5 +1,3 @@ -import unittest - import numpy as np import scipy.sparse import sklearn.preprocessing @@ -9,6 +7,8 @@ ) from autosklearn.pipeline.util import _test_preprocessing, get_dataset +import unittest + class SelectRegressionRatesComponentTest(unittest.TestCase): def test_default_configuration(self): diff --git a/test/test_pipeline/components/feature_preprocessing/test_truncatedSVD.py b/test/test_pipeline/components/feature_preprocessing/test_truncatedSVD.py index 7e16fa7fa5..6f09368b6c 100644 --- a/test/test_pipeline/components/feature_preprocessing/test_truncatedSVD.py +++ b/test/test_pipeline/components/feature_preprocessing/test_truncatedSVD.py @@ -1,5 +1,3 @@ -import unittest - import sklearn.metrics from sklearn.linear_model import RidgeClassifier @@ -12,6 +10,8 @@ get_dataset, ) +import unittest + class TruncatedSVDComponentTest(PreprocessingTestCase): def test_default_configuration(self): diff --git a/test/test_pipeline/components/regression/test_base.py b/test/test_pipeline/components/regression/test_base.py index dcf7770332..7ed4afe79a 100644 --- a/test/test_pipeline/components/regression/test_base.py +++ b/test/test_pipeline/components/regression/test_base.py @@ -1,9 +1,6 @@ from typing import Container, Type -import unittest - import numpy as np -import pytest import sklearn.metrics from autosklearn.pipeline.components.regression import RegressorChoice, _regressors @@ -11,6 +8,9 @@ from autosklearn.pipeline.constants import SPARSE from autosklearn.pipeline.util import _test_regressor, _test_regressor_iterative_fit +import pytest +import unittest + from test.test_pipeline.ignored_warnings import ignore_warnings, regressor_warnings diff --git a/test/test_pipeline/components/test_base.py b/test/test_pipeline/components/test_base.py index 1e6ddbbd14..f8fcf6b398 100644 --- a/test/test_pipeline/components/test_base.py +++ b/test/test_pipeline/components/test_base.py @@ -1,12 +1,13 @@ import os import sys -import unittest from autosklearn.pipeline.components.base import ( AutoSklearnClassificationAlgorithm, find_components, ) +import unittest + this_dir = os.path.dirname(os.path.abspath(__file__)) sys.path.append(this_dir) diff --git a/test/test_pipeline/implementations/test_CategoryShift.py b/test/test_pipeline/implementations/test_CategoryShift.py index 1b5e1451e6..e6cbced71e 100644 --- a/test/test_pipeline/implementations/test_CategoryShift.py +++ b/test/test_pipeline/implementations/test_CategoryShift.py @@ -1,10 +1,10 @@ -import unittest - import numpy as np import scipy.sparse from autosklearn.pipeline.implementations.CategoryShift import CategoryShift +import unittest + class CategoryShiftTest(unittest.TestCase): def test_dense(self): diff --git a/test/test_pipeline/implementations/test_MinorityCoalescer.py b/test/test_pipeline/implementations/test_MinorityCoalescer.py index 7bdca8f1aa..d7058453e5 100644 --- a/test/test_pipeline/implementations/test_MinorityCoalescer.py +++ b/test/test_pipeline/implementations/test_MinorityCoalescer.py @@ -1,10 +1,10 @@ -import unittest - import numpy as np import scipy.sparse from autosklearn.pipeline.implementations.MinorityCoalescer import MinorityCoalescer +import unittest + class MinorityCoalescerTest(unittest.TestCase): @property diff --git a/test/test_pipeline/implementations/test_SparseOneHotEncoder.py b/test/test_pipeline/implementations/test_SparseOneHotEncoder.py index 91f1827c06..ed9bc07a89 100644 --- a/test/test_pipeline/implementations/test_SparseOneHotEncoder.py +++ b/test/test_pipeline/implementations/test_SparseOneHotEncoder.py @@ -1,5 +1,3 @@ -import unittest - import numpy as np import scipy.sparse import sklearn.datasets @@ -12,6 +10,8 @@ from autosklearn.pipeline.implementations.CategoryShift import CategoryShift from autosklearn.pipeline.implementations.SparseOneHotEncoder import SparseOneHotEncoder +import unittest + sparse1 = scipy.sparse.csc_matrix( ([3, 2, 1, 1, 2, 3], ((1, 4, 5, 2, 3, 5), (0, 0, 0, 1, 1, 1))), shape=(6, 2) ) diff --git a/test/test_pipeline/implementations/test_util.py b/test/test_pipeline/implementations/test_util.py index 58412e0b0c..d6b6530569 100644 --- a/test/test_pipeline/implementations/test_util.py +++ b/test/test_pipeline/implementations/test_util.py @@ -1,9 +1,9 @@ -import unittest - import numpy as np from autosklearn.pipeline.implementations.util import softmax +import unittest + class UtilTest(unittest.TestCase): def test_softmax_binary(self): diff --git a/test/test_pipeline/test_base.py b/test/test_pipeline/test_base.py index f1efed23b4..f8cfe26912 100644 --- a/test/test_pipeline/test_base.py +++ b/test/test_pipeline/test_base.py @@ -1,6 +1,3 @@ -import unittest -import unittest.mock - import ConfigSpace.configuration_space import autosklearn.pipeline.base @@ -8,6 +5,9 @@ import autosklearn.pipeline.components.classification as classification import autosklearn.pipeline.components.feature_preprocessing as feature_preprocessing +import unittest +import unittest.mock + class BasePipelineMock(autosklearn.pipeline.base.BasePipeline): def __init__(self): diff --git a/test/test_pipeline/test_classification.py b/test/test_pipeline/test_classification.py index 1ce93ebb0d..7be8038119 100644 --- a/test/test_pipeline/test_classification.py +++ b/test/test_pipeline/test_classification.py @@ -5,8 +5,6 @@ import os import resource import tempfile -import unittest -import unittest.mock import numpy as np import sklearn.datasets @@ -40,6 +38,9 @@ ) from autosklearn.pipeline.util import get_dataset +import unittest +import unittest.mock + from test.test_pipeline.ignored_warnings import classifier_warnings, ignore_warnings diff --git a/test/test_pipeline/test_create_searchspace_util_classification.py b/test/test_pipeline/test_create_searchspace_util_classification.py index a830430097..1b09de2bb7 100644 --- a/test/test_pipeline/test_create_searchspace_util_classification.py +++ b/test/test_pipeline/test_create_searchspace_util_classification.py @@ -1,4 +1,3 @@ -import unittest from collections import OrderedDict import numpy @@ -19,6 +18,8 @@ TruncatedSVD, ) +import unittest + class TestCreateClassificationSearchspace(unittest.TestCase): _multiprocess_can_split_ = True diff --git a/test/test_pipeline/test_regression.py b/test/test_pipeline/test_regression.py index 501b73ec5d..3a50decb8c 100644 --- a/test/test_pipeline/test_regression.py +++ b/test/test_pipeline/test_regression.py @@ -2,8 +2,6 @@ import itertools import resource import tempfile -import unittest -import unittest.mock import numpy as np import sklearn.datasets @@ -34,6 +32,9 @@ from autosklearn.pipeline.regression import SimpleRegressionPipeline from autosklearn.pipeline.util import get_dataset +import unittest +import unittest.mock + from test.test_pipeline.ignored_warnings import ignore_warnings, regressor_warnings diff --git a/test/test_scripts/test_metadata_generation.py b/test/test_scripts/test_metadata_generation.py index 6c6ba70ef5..929b90e029 100644 --- a/test/test_scripts/test_metadata_generation.py +++ b/test/test_scripts/test_metadata_generation.py @@ -4,13 +4,14 @@ import shutil import socket import subprocess -import unittest import arff import numpy as np from autosklearn.metrics import CLASSIFICATION_METRICS, REGRESSION_METRICS +import unittest + class TestMetadataGeneration(unittest.TestCase): def setUp(self): diff --git a/test/test_util/test_StopWatch.py b/test/test_util/test_StopWatch.py index d45ecbf55d..a59940d0c4 100644 --- a/test/test_util/test_StopWatch.py +++ b/test/test_util/test_StopWatch.py @@ -6,11 +6,12 @@ """ import time -import unittest -import unittest.mock from autosklearn.util.stopwatch import StopWatch +import unittest +import unittest.mock + class Test(unittest.TestCase): _multiprocess_can_split_ = True diff --git a/test/test_util/test_backend.py b/test/test_util/test_backend.py index 0673370b97..719b356e89 100644 --- a/test/test_util/test_backend.py +++ b/test/test_util/test_backend.py @@ -1,10 +1,11 @@ # -*- encoding: utf-8 -*- import builtins -import unittest -import unittest.mock from autosklearn.automl_common.common.utils.backend import Backend +import unittest +import unittest.mock + class BackendModelsTest(unittest.TestCase): class BackendStub(Backend): diff --git a/test/test_util/test_common.py b/test/test_util/test_common.py index 33fa4cee31..af17a5c259 100644 --- a/test/test_util/test_common.py +++ b/test/test_util/test_common.py @@ -1,9 +1,10 @@ # -*- encoding: utf-8 -*- import os -import unittest from autosklearn.util.common import check_pid +import unittest + class TestUtilsCommon(unittest.TestCase): _multiprocess_can_split_ = True diff --git a/test/test_util/test_data.py b/test/test_util/test_data.py index 2bceac804a..14a8ec44e6 100644 --- a/test/test_util/test_data.py +++ b/test/test_util/test_data.py @@ -5,7 +5,6 @@ import numpy as np import pandas as pd -import pytest import sklearn.datasets from scipy.sparse import csr_matrix, spmatrix @@ -28,6 +27,8 @@ validate_dataset_compression_arg, ) +import pytest + parametrize = pytest.mark.parametrize diff --git a/test/test_util/test_dependencies.py b/test/test_util/test_dependencies.py index 1c59dad51b..d51e849e8b 100644 --- a/test/test_util/test_dependencies.py +++ b/test/test_util/test_dependencies.py @@ -1,6 +1,4 @@ import re -import unittest -from unittest.mock import Mock, patch import numpy as np import pkg_resources @@ -11,6 +9,9 @@ verify_packages, ) +import unittest +from unittest.mock import Mock, patch + @patch("pkg_resources.get_distribution") class VerifyPackagesTests(unittest.TestCase): diff --git a/test/test_util/test_logging.py b/test/test_util/test_logging.py index d824aecc02..c046df12f3 100644 --- a/test/test_util/test_logging.py +++ b/test/test_util/test_logging.py @@ -2,12 +2,13 @@ import logging.config import os import tempfile -import unittest import yaml from autosklearn.util import logging_ +import unittest + class LoggingTest(unittest.TestCase): def test_setup_logger(self): diff --git a/test/test_util/test_single_thread_client.py b/test/test_util/test_single_thread_client.py index 770ff9f04a..e7163a36b1 100644 --- a/test/test_util/test_single_thread_client.py +++ b/test/test_util/test_single_thread_client.py @@ -1,9 +1,10 @@ import dask.distributed -import pytest from distributed.utils_test import inc from autosklearn.util.single_thread_client import SingleThreadedClient +import pytest + def test_single_thread_client_like_dask_client(): single_thread_client = SingleThreadedClient() diff --git a/test/test_util/test_trials_callback.py b/test/test_util/test_trials_callback.py index d1bfe6b748..b1328b9489 100644 --- a/test/test_util/test_trials_callback.py +++ b/test/test_util/test_trials_callback.py @@ -1,6 +1,5 @@ import os import tempfile -import unittest import pandas as pd from smac.callbacks import IncorporateRunResultCallback @@ -11,6 +10,8 @@ import autosklearn.pipeline.util as putil from autosklearn.classification import AutoSklearnClassifier +import unittest + class AutoMLTrialsCallBack(IncorporateRunResultCallback): def __init__(self, fname): diff --git a/test/util.py b/test/util.py new file mode 100644 index 0000000000..eaf1a34d19 --- /dev/null +++ b/test/util.py @@ -0,0 +1,49 @@ +from typing import Any + +from pytest import mark, param + + +def fails(arg: Any, reason: str = "No reason given") -> Any: + """Mark a parameter for pytest parametrize that it should fail + + ..code:: python + + @parametrize("number", [2, 3, fails(5, "some reason")]) + + Parameters + ---------- + arg : Any + The arg that should fail + + reason : str = "No reason given" + The reason for the expected fail + + Returns + ------- + Any + The param object + """ + return param(arg, marks=mark.xfail(reason=reason)) + + +def skip(arg: Any, reason: str = "No reason given") -> Any: + """Mark a parameter for pytest parametrize that should be skipped + + ..code:: python + + @parametrize("number", [2, 3, skip(5, "some reason")]) + + Parameters + ---------- + arg : Any + The arg that should be skipped + + reason : str = "No Reason given" + The reason for skipping it + + Returns + ------- + Any + The param object + """ + return param(arg, marks=mark.skip(reason=reason))