Skip to content

Commit 0a4949d

Browse files
sagar-kaushikeddiebergman
authored andcommitted
Changes show_models() function to return a dictionary of models in ensemble (#1321)
* Changed show_models() function to return a dictionary of models in the ensemble instead of a string
1 parent ee664fb commit 0a4949d

File tree

1 file changed

+139
-0
lines changed

1 file changed

+139
-0
lines changed

autosklearn/automl.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1843,6 +1843,145 @@ def show_models(self) -> Dict[int, Any]:
18431843
18441844
A model dictionary contains the following:
18451845
1846+
* ``"model_id"`` - The id given to a model by ``autosklearn``.
1847+
* ``"rank"`` - The rank of the model based on it's ``"cost"``.
1848+
* ``"cost"`` - The loss of the model on the validation set.
1849+
* ``"ensemble_weight"`` - The weight given to the model in the ensemble.
1850+
* ``"voting_model"`` - The ``cv_voting_ensemble`` model (for 'cv' resampling).
1851+
* ``"estimators"`` - List of models (dicts) in ``cv_voting_ensemble`` (for 'cv' resampling).
1852+
* ``"data_preprocessor"`` - The preprocessor used on the data.
1853+
* ``"balancing"`` - The balancing used on the data (for classification).
1854+
* ``"feature_preprocessor"`` - The preprocessor for features types.
1855+
* ``"classifier"`` or ``"regressor"`` - The autosklearn wrapped classifier or regressor.
1856+
* ``"sklearn_classifier"`` or ``"sklearn_regressor"`` - The sklearn classifier or regressor.
1857+
1858+
**Example**
1859+
1860+
.. code-block:: python
1861+
1862+
import sklearn.datasets
1863+
import sklearn.metrics
1864+
import autosklearn.regression
1865+
1866+
X, y = sklearn.datasets.load_diabetes(return_X_y=True)
1867+
1868+
automl = autosklearn.regression.AutoSklearnRegressor(
1869+
time_left_for_this_task=120
1870+
)
1871+
automl.fit(X_train, y_train, dataset_name='diabetes')
1872+
1873+
ensemble_dict = automl.show_models()
1874+
print(ensemble_dict)
1875+
1876+
Output:
1877+
1878+
.. code-block:: text
1879+
1880+
{
1881+
25: {'model_id': 25.0,
1882+
'rank': 1,
1883+
'cost': 0.43667876507897496,
1884+
'ensemble_weight': 0.38,
1885+
'data_preprocessor': <autosklearn.pipeline.components.data_preprocessing....>,
1886+
'feature_preprocessor': <autosklearn.pipeline.components....>,
1887+
'regressor': <autosklearn.pipeline.components.regression....>,
1888+
'sklearn_regressor': SGDRegressor(alpha=0.0006517033225329654,...)
1889+
},
1890+
6: {'model_id': 6.0,
1891+
'rank': 2,
1892+
'cost': 0.4550418898836528,
1893+
'ensemble_weight': 0.3,
1894+
'data_preprocessor': <autosklearn.pipeline.components.data_preprocessing....>,
1895+
'feature_preprocessor': <autosklearn.pipeline.components....>,
1896+
'regressor': <autosklearn.pipeline.components.regression....>,
1897+
'sklearn_regressor': ARDRegression(alpha_1=0.0003701926442639788,...)
1898+
}...
1899+
}
1900+
1901+
Returns
1902+
-------
1903+
Dict(int, Any) : dictionary of length = number of models in the ensemble
1904+
A dictionary of models in the ensemble, where ``model_id`` is the key.
1905+
1906+
"""
1907+
1908+
ensemble_dict = {}
1909+
1910+
def has_key(rv, key):
1911+
return rv.additional_info and key in rv.additional_info
1912+
1913+
table_dict = {}
1914+
for rkey, rval in self.runhistory_.data.items():
1915+
if has_key(rval, 'num_run'):
1916+
model_id = rval.additional_info['num_run']
1917+
table_dict[model_id] = {
1918+
'model_id': model_id,
1919+
'cost': rval.cost
1920+
}
1921+
1922+
# Checking if the dictionary is empty
1923+
if not table_dict:
1924+
raise RuntimeError('No model found. Try increasing \'time_left_for_this_task\'.')
1925+
1926+
for i, weight in enumerate(self.ensemble_.weights_):
1927+
(_, model_id, _) = self.ensemble_.identifiers_[i]
1928+
table_dict[model_id]['ensemble_weight'] = weight
1929+
1930+
table = pd.DataFrame.from_dict(table_dict, orient='index')
1931+
table.sort_values(by='cost', inplace=True)
1932+
1933+
# Checking which resampling strategy is chosen and selecting the appropriate models
1934+
is_cv = (self._resampling_strategy == "cv")
1935+
models = self.cv_models_ if is_cv else self.models_
1936+
1937+
rank = 1 # Initializing rank for the first model
1938+
for (_, model_id, _), model in models.items():
1939+
model_dict = {} # Declaring model dictionary
1940+
1941+
# Inserting model_id, rank, cost and ensemble weight
1942+
model_dict['model_id'] = table.loc[model_id]['model_id'].astype(int)
1943+
model_dict['rank'] = rank
1944+
model_dict['cost'] = table.loc[model_id]['cost']
1945+
model_dict['ensemble_weight'] = table.loc[model_id]['ensemble_weight']
1946+
rank += 1 # Incrementing rank by 1 for the next model
1947+
1948+
# The steps in the models pipeline are as follows:
1949+
# 'data_preprocessor': DataPreprocessor,
1950+
# 'balancing': Balancing,
1951+
# 'feature_preprocessor': FeaturePreprocessorChoice,
1952+
# 'classifier'/'regressor': ClassifierChoice/RegressorChoice (autosklearn wrapped model)
1953+
1954+
# For 'cv' (cross validation) strategy
1955+
if is_cv:
1956+
# Voting model created by cross validation
1957+
cv_voting_ensemble = model
1958+
model_dict['voting_model'] = cv_voting_ensemble
1959+
1960+
# List of models, each trained on one cv fold
1961+
cv_models = []
1962+
for cv_model in cv_voting_ensemble.estimators_:
1963+
estimator = dict(cv_model.steps)
1964+
1965+
# Adding sklearn model to the model dictionary
1966+
model_type, autosklearn_wrapped_model = cv_model.steps[-1]
1967+
estimator[f'sklearn_{model_type}'] = autosklearn_wrapped_model.choice.estimator
1968+
cv_models.append(estimator)
1969+
model_dict['estimators'] = cv_models
1970+
1971+
# For any other strategy
1972+
else:
1973+
steps = dict(model.steps)
1974+
model_dict.update(steps)
1975+
1976+
# Adding sklearn model to the model dictionary
1977+
model_type, autosklearn_wrapped_model = model.steps[-1]
1978+
model_dict[f'sklearn_{model_type}'] = autosklearn_wrapped_model.choice.estimator
1979+
1980+
# Insterting model_dict in the ensemble dictionary
1981+
ensemble_dict[model_id] = model_dict
1982+
1983+
return ensemble_dict
1984+
18461985
def _create_search_space(
18471986
self,
18481987
tmp_dir,

0 commit comments

Comments
 (0)