Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions docs/sources/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ The CHANGELOG for the current development version is available at

##### New Features

- Adds multiprocessing support to `StackingCVClassifier`. ([#512](https://github.com/rasbt/mlxtend/pull/512) via [Qiang Gu](https://github.com/qiaguhttps://github.com/qiagu))
- Adds multiprocessing support to `StackingCVRegressor`. ([#512](https://github.com/rasbt/mlxtend/pull/512) via [Qiang Gu](https://github.com/qiaguhttps://github.com/qiagu))
- Now, the `StackingCVRegressor` also enables grid search over the `regressors` and even a single base regressor. When there are level-mixed parameters, `GridSearchCV` will try to replace hyperparameters in a top-down order (see the [documentation](http://rasbt.github.io/mlxtend/user_guide/regressor/StackingCVRegressor/) for examples details). ([#515](https://github.com/rasbt/mlxtend/pull/512) via [Qiang Gu](https://github.com/qiaguhttps://github.com/qiagu))
- Adds a `verbose` parameter to `apriori` to show the current iteration number as well as the itemset size currently being sampled. ([#519](https://github.com/rasbt/mlxtend/pull/519)
- Adds an optional `class_name` parameter to the confusion matrix function to display class names on the axis as tick marks. ([#487](https://github.com/rasbt/mlxtend/pull/487) via [sandpiturtle](https://github.com/qiaguhttps://github.com/sandpiturtle))

##### Changes

-
- Due to new features, restructuring, and better scikit-learn support (for `GridSearchCV`, etc.) the `StackingCVRegressor`'s meta regressor is now being accessed via `'meta_regressor__*` in the parameter grid. E.g., if a `RandomForestRegressor` as meta- egressor was previously tuned via `'randomforestregressor__n_estimators'`, this has now changed to `'meta_regressor__n_estimators'`. ([#515](https://github.com/rasbt/mlxtend/pull/512) via [Qiang Gu](https://github.com/qiaguhttps://github.com/qiagu))


##### Bug Fixes

Expand Down
62 changes: 35 additions & 27 deletions docs/sources/user_guide/regressor/StackingCVRegressor.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@
"text": [
"5-fold cross validation scores:\n",
"\n",
"R^2 Score: 0.45 (+/- 0.29) [SVM]\n",
"R^2 Score: 0.46 (+/- 0.29) [SVM]\n",
"R^2 Score: 0.43 (+/- 0.14) [Lasso]\n",
"R^2 Score: 0.52 (+/- 0.28) [Random Forest]\n",
"R^2 Score: 0.58 (+/- 0.24) [StackingCVRegressor]\n"
"R^2 Score: 0.53 (+/- 0.28) [Random Forest]\n",
"R^2 Score: 0.58 (+/- 0.23) [StackingCVRegressor]\n"
]
}
],
Expand Down Expand Up @@ -138,10 +138,10 @@
"text": [
"5-fold cross validation scores:\n",
"\n",
"Neg. MSE Score: -33.69 (+/- 22.36) [SVM]\n",
"Neg. MSE Score: -33.34 (+/- 22.36) [SVM]\n",
"Neg. MSE Score: -35.53 (+/- 16.99) [Lasso]\n",
"Neg. MSE Score: -27.32 (+/- 16.62) [Random Forest]\n",
"Neg. MSE Score: -25.64 (+/- 18.11) [StackingCVRegressor]\n"
"Neg. MSE Score: -27.25 (+/- 16.76) [Random Forest]\n",
"Neg. MSE Score: -25.56 (+/- 18.22) [StackingCVRegressor]\n"
]
}
],
Expand Down Expand Up @@ -177,19 +177,27 @@
"source": [
"In this second example we demonstrate how `StackingCVRegressor` works in combination with `GridSearchCV`. The stack still allows tuning hyper parameters of the base and meta models!\n",
"\n",
"To set up a parameter grid for scikit-learn's `GridSearch`, we simply provide the estimator's names in the parameter grid -- in the special case of the meta-regressor, we append the `'meta-'` prefix.\n"
"For instance, we can use `estimator.get_params().keys()` to get a full list of tunable parameters.\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/guq/miniconda3/envs/python3/lib/python3.7/site-packages/sklearn/model_selection/_search.py:841: DeprecationWarning: The default of the `iid` parameter will change from True to False in version 0.22 and will be removed in 0.24. This will change numeric results when test-set sizes are unequal.\n",
" DeprecationWarning)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Best: 0.673590 using {'lasso__alpha': 0.4, 'meta-randomforestregressor__n_estimators': 10, 'ridge__alpha': 0.3}\n"
"Best: 0.674237 using {'lasso__alpha': 1.6, 'meta_regressor__n_estimators': 100, 'ridge__alpha': 0.2}\n"
]
}
],
Expand All @@ -203,8 +211,8 @@
"\n",
"X, y = load_boston(return_X_y=True)\n",
"\n",
"ridge = Ridge()\n",
"lasso = Lasso()\n",
"ridge = Ridge(random_state=RANDOM_SEED)\n",
"lasso = Lasso(random_state=RANDOM_SEED)\n",
"rf = RandomForestRegressor(random_state=RANDOM_SEED)\n",
"\n",
"# The StackingCVRegressor uses scikit-learn's check_cv\n",
Expand All @@ -224,7 +232,7 @@
" param_grid={\n",
" 'lasso__alpha': [x/5.0 for x in range(1, 10)],\n",
" 'ridge__alpha': [x/20.0 for x in range(1, 10)],\n",
" 'meta-randomforestregressor__n_estimators': [10, 100]\n",
" 'meta_regressor__n_estimators': [10, 100]\n",
" }, \n",
" cv=5,\n",
" refit=True\n",
Expand All @@ -244,20 +252,20 @@
"name": "stdout",
"output_type": "stream",
"text": [
"0.622 +/- 0.10 {'lasso__alpha': 0.2, 'meta-randomforestregressor__n_estimators': 10, 'ridge__alpha': 0.05}\n",
"0.649 +/- 0.09 {'lasso__alpha': 0.2, 'meta-randomforestregressor__n_estimators': 10, 'ridge__alpha': 0.1}\n",
"0.650 +/- 0.09 {'lasso__alpha': 0.2, 'meta-randomforestregressor__n_estimators': 10, 'ridge__alpha': 0.15}\n",
"0.667 +/- 0.09 {'lasso__alpha': 0.2, 'meta-randomforestregressor__n_estimators': 10, 'ridge__alpha': 0.2}\n",
"0.629 +/- 0.09 {'lasso__alpha': 0.2, 'meta-randomforestregressor__n_estimators': 10, 'ridge__alpha': 0.25}\n",
"0.663 +/- 0.08 {'lasso__alpha': 0.2, 'meta-randomforestregressor__n_estimators': 10, 'ridge__alpha': 0.3}\n",
"0.633 +/- 0.08 {'lasso__alpha': 0.2, 'meta-randomforestregressor__n_estimators': 10, 'ridge__alpha': 0.35}\n",
"0.637 +/- 0.08 {'lasso__alpha': 0.2, 'meta-randomforestregressor__n_estimators': 10, 'ridge__alpha': 0.4}\n",
"0.649 +/- 0.09 {'lasso__alpha': 0.2, 'meta-randomforestregressor__n_estimators': 10, 'ridge__alpha': 0.45}\n",
"0.653 +/- 0.09 {'lasso__alpha': 0.2, 'meta-randomforestregressor__n_estimators': 100, 'ridge__alpha': 0.05}\n",
"0.648 +/- 0.09 {'lasso__alpha': 0.2, 'meta-randomforestregressor__n_estimators': 100, 'ridge__alpha': 0.1}\n",
"0.645 +/- 0.09 {'lasso__alpha': 0.2, 'meta-randomforestregressor__n_estimators': 100, 'ridge__alpha': 0.15}\n",
"0.616 +/- 0.09 {'lasso__alpha': 0.2, 'meta_regressor__n_estimators': 10, 'ridge__alpha': 0.05}\n",
"0.656 +/- 0.08 {'lasso__alpha': 0.2, 'meta_regressor__n_estimators': 10, 'ridge__alpha': 0.1}\n",
"0.653 +/- 0.09 {'lasso__alpha': 0.2, 'meta_regressor__n_estimators': 10, 'ridge__alpha': 0.15}\n",
"0.669 +/- 0.09 {'lasso__alpha': 0.2, 'meta_regressor__n_estimators': 10, 'ridge__alpha': 0.2}\n",
"0.632 +/- 0.08 {'lasso__alpha': 0.2, 'meta_regressor__n_estimators': 10, 'ridge__alpha': 0.25}\n",
"0.664 +/- 0.08 {'lasso__alpha': 0.2, 'meta_regressor__n_estimators': 10, 'ridge__alpha': 0.3}\n",
"0.632 +/- 0.08 {'lasso__alpha': 0.2, 'meta_regressor__n_estimators': 10, 'ridge__alpha': 0.35}\n",
"0.642 +/- 0.08 {'lasso__alpha': 0.2, 'meta_regressor__n_estimators': 10, 'ridge__alpha': 0.4}\n",
"0.653 +/- 0.09 {'lasso__alpha': 0.2, 'meta_regressor__n_estimators': 10, 'ridge__alpha': 0.45}\n",
"0.657 +/- 0.09 {'lasso__alpha': 0.2, 'meta_regressor__n_estimators': 100, 'ridge__alpha': 0.05}\n",
"0.650 +/- 0.09 {'lasso__alpha': 0.2, 'meta_regressor__n_estimators': 100, 'ridge__alpha': 0.1}\n",
"0.648 +/- 0.09 {'lasso__alpha': 0.2, 'meta_regressor__n_estimators': 100, 'ridge__alpha': 0.15}\n",
"...\n",
"Best parameters: {'lasso__alpha': 0.4, 'meta-randomforestregressor__n_estimators': 10, 'ridge__alpha': 0.3}\n",
"Best parameters: {'lasso__alpha': 1.6, 'meta_regressor__n_estimators': 100, 'ridge__alpha': 0.2}\n",
"Accuracy: 0.67\n"
]
}
Expand All @@ -284,12 +292,12 @@
"source": [
"**Note**\n",
"\n",
"The `StackingCVRegressor` also enables grid search over the `regressors` argument. However, due to the current implementation of `GridSearchCV` in scikit-learn, it is not possible to search over both, different regressors and regressor parameters at the same time. For instance, while the following parameter dictionary works\n",
"The `StackingCVRegressor` also enables grid search over the `regressors` and even a single base regressor. When there are level-mixed hyperparameters, `GridSearchCV` will try to replace hyperparameters in a top-down order, i.e., `regressors` -> single base regressor -> regressor hyperparameter. For instance, given a hyperparameter grid such as\n",
"\n",
" params = {'randomforestregressor__n_estimators': [1, 100],\n",
" 'regressors': [(regr1, regr1, regr1), (regr2, regr3)]}\n",
" \n",
"it will use the instance settings of `regr1`, `regr2`, and `regr3` and not overwrite it with the `'n_estimators'` settings from `'randomforestregressor__n_estimators': [1, 100]`."
"it will first use the instance settings of either `(regr1, regr2, regr3)` or `(regr2, regr3)` . Then it will replace the `'n_estimators'` settings for a matching regressor based on `'randomforestregressor__n_estimators': [1, 100]`."
]
},
{
Expand Down Expand Up @@ -605,7 +613,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.6"
"version": "3.7.1"
},
"toc": {
"nav_menu": {},
Expand Down
45 changes: 21 additions & 24 deletions mlxtend/regressor/stacking_cv_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
# License: BSD 3 clause

from ..externals.estimator_checks import check_is_fitted
from ..externals import six
from ..externals.name_estimators import _name_estimators
from ..utils.base_compostion import _BaseXComposition
from scipy import sparse
from sklearn.base import BaseEstimator
from sklearn.base import RegressorMixin
from sklearn.base import TransformerMixin
from sklearn.base import clone
Expand All @@ -27,7 +26,7 @@
import numpy as np


class StackingCVRegressor(BaseEstimator, RegressorMixin, TransformerMixin):
class StackingCVRegressor(_BaseXComposition, RegressorMixin, TransformerMixin):
"""A 'Stacking Cross-Validation' regressor for scikit-learn estimators.

New in mlxtend v0.7.0
Expand Down Expand Up @@ -123,12 +122,6 @@ def __init__(self, regressors, meta_regressor, cv=5,

self.regressors = regressors
self.meta_regressor = meta_regressor
self.named_regressors = {key: value for
key, value in
_name_estimators(regressors)}
self.named_meta_regressor = {'meta-%s' % key: value for
key, value in
_name_estimators([meta_regressor])}
self.cv = cv
self.shuffle = shuffle
self.n_jobs = n_jobs
Expand Down Expand Up @@ -273,25 +266,29 @@ def predict_meta_features(self, X):
check_is_fitted(self, 'regr_')
return np.column_stack([regr.predict(X) for regr in self.regr_])

@property
def named_regressors(self):
"""
Returns
-------
List of named estimator tuples, like [('svc', SVC(...))]
"""
return _name_estimators(self.regressors)

def get_params(self, deep=True):
#
# Return estimator parameter names for GridSearch support.
#
if not deep:
return super(StackingCVRegressor, self).get_params(deep=False)
else:
out = self.named_regressors.copy()
for name, step in six.iteritems(self.named_regressors):
for key, value in six.iteritems(step.get_params(deep=True)):
out['%s__%s' % (name, key)] = value
return self._get_params('named_regressors', deep=deep)

out.update(self.named_meta_regressor.copy())
for name, step in six.iteritems(self.named_meta_regressor):
for key, value in six.iteritems(step.get_params(deep=True)):
out['%s__%s' % (name, key)] = value
def set_params(self, **params):
"""Set the parameters of this estimator.

for key, value in six.iteritems(super(StackingCVRegressor,
self).get_params(deep=False)):
out['%s' % key] = value
Valid parameter keys can be listed with ``get_params()``.

return out
Returns
-------
self
"""
self._set_params('regressors', 'named_regressors', **params)
return self
34 changes: 32 additions & 2 deletions mlxtend/regressor/tests/test_stacking_cv_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_gridsearch_numerate_regr():
params = {'ridge-1__alpha': [0.01, 1.0],
'ridge-2__alpha': [0.01, 1.0],
'svr__C': [0.01, 1.0],
'meta-svr__C': [0.01, 1.0],
'meta_regressor__C': [0.01, 1.0],
'use_features_in_secondary': [True, False]}

grid = GridSearchCV(estimator=stack,
Expand All @@ -122,7 +122,6 @@ def test_get_params():
got = sorted(list({s.split('__')[0] for s in stregr.get_params().keys()}))
expect = ['cv',
'linearregression',
'meta-svr',
'meta_regressor',
'n_jobs',
'pre_dispatch',
Expand Down Expand Up @@ -332,3 +331,34 @@ def test_weight_unsupported_with_no_weight():
stack = StackingCVRegressor(regressors=[svr_lin, lr, ridge],
meta_regressor=lasso)
stack.fit(X1, y).predict(X1)


def test_gridsearch_replace_mix():
svr_lin = SVR(kernel='linear', gamma='auto')
ridge = Ridge(random_state=1)
svr_rbf = SVR(kernel='rbf', gamma='auto')
lr = LinearRegression()
lasso = Lasso(random_state=1)
stack = StackingCVRegressor(regressors=[svr_lin, lasso, ridge],
meta_regressor=svr_rbf,
shuffle=False)

params = {'regressors': [[svr_lin, lr]],
'linearregression': [None, lasso, ridge],
'svr__kernel': ['poly']}

grid = GridSearchCV(estimator=stack,
param_grid=params,
cv=KFold(5, shuffle=True, random_state=42),
iid=False,
refit=True,
verbose=0)
grid = grid.fit(X1, y)

got1 = round(grid.best_score_, 2)
got2 = len(grid.best_params_['regressors'])
got3 = grid.best_params_['regressors'][0].kernel

assert got1 == 0.73, got1
assert got2 == 2, got2
assert got3 == 'poly', got3
40 changes: 40 additions & 0 deletions mlxtend/utils/base_compostion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Utilties to handle estimator list"""

from ..externals import six
from sklearn.utils.metaestimators import _BaseComposition


class _BaseXComposition(_BaseComposition):
"""
parameter handler for list of estimators
"""
def _set_params(self, attr, named_attr, **params):
# Ordered parameter replacement
# 1. root parameter
if attr in params:
setattr(self, attr, params.pop(attr))

# 2. single estimator replacement
items = getattr(self, named_attr)
names = []
if items:
names, estimators = zip(*items)
estimators = list(estimators)
for name in list(six.iterkeys(params)):
if '__' not in name and name in names:
# replace single estimator and re-build the
# root estimators list
for i, est_name in enumerate(names):
if est_name == name:
new_val = params.pop(name)
if new_val is None:
del estimators[i]
else:
estimators[i] = new_val
break
# replace the root estimators
setattr(self, attr, estimators)

# 3. estimator parameters and other initialisation arguments
super(_BaseXComposition, self).set_params(**params)
return self