-
Notifications
You must be signed in to change notification settings - Fork 121
Feature mrmr #622
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Feature mrmr #622
Changes from 22 commits
Commits
Show all changes
41 commits
Select commit
Hold shift + click to select a range
0e2e45f
Wip 1
14ff603
new ideas
4d8e4ae
new ideas
e324389
commit
fabioscantamburlo 2a78dc1
commit
fabioscantamburlo 55fcd63
exclude venv
fabioscantamburlo 82d9719
Wip2
fabioscantamburlo bfe8ba3
Wip3
fabioscantamburlo 5397d34
Wip 3.5
fabioscantamburlo 7ae7d10
Pushing some optim
fabioscantamburlo 365a78d
Doc string and examples
fabioscantamburlo acf356a
Docstring WIP
fabioscantamburlo b5160af
Docstring WIP2
fabioscantamburlo 5d12ada
Adding something
fabioscantamburlo 899f315
Bugfix
fabioscantamburlo a0f0470
Mkdocs and small fixes
fabioscantamburlo 7591f85
Added tests
fabioscantamburlo 0053735
Added scripts
fabioscantamburlo 8fecf4a
Wip4
fabioscantamburlo b08b04b
removing tests
fabioscantamburlo 9d8c20b
Added tests and some bugifx
fabioscantamburlo 077b0d1
revert pandastransformer
fabioscantamburlo 5d55a2e
Update sklego/feature_selection/mrmr.py
fabioscantamburlo 2e107ef
Resolving comments on PR
fabioscantamburlo 1d4340c
features
fabioscantamburlo 8f4481e
venv
fabioscantamburlo a6713d6
Add missing file
fabioscantamburlo 25f5613
Wip userguide
fabioscantamburlo 774f170
Merge branch 'FEATURE-MRMR-UserGuide' into FEATURE-MRMR
fabioscantamburlo 4454266
Merge branch 'main' into FEATURE-MRMR
fabioscantamburlo 3f21b0a
typing
fabioscantamburlo 249d17f
Update sklego/feature_selection/mrmr.py
fabioscantamburlo 6c772d8
Update sklego/feature_selection/mrmr.py
fabioscantamburlo b9b5bfc
Update docs/user-guide/feature-selection.md
fabioscantamburlo 86e0dc6
Update sklego/feature_selection/mrmr.py
fabioscantamburlo f788026
resolve comments
fabioscantamburlo 87da50c
clean
fabioscantamburlo 287977e
suggestions + general rephrase
fabioscantamburlo 8b4c40a
Typo
fabioscantamburlo a1df5fe
Update docs/user-guide/feature-selection.md
fabioscantamburlo 3b2bc7a
Merge branch 'main' into FEATURE-MRMR
fabioscantamburlo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,138 @@ | ||
| from pathlib import Path | ||
|
|
||
| _file = Path(__file__) | ||
| print(f"Executing {_file}") | ||
|
|
||
| _static_path = Path("_static") / _file.stem | ||
| _static_path.mkdir(parents=True, exist_ok=True) | ||
|
|
||
| import matplotlib.pyplot as plt | ||
fabioscantamburlo marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| import numpy as np | ||
| from sklearn.datasets import fetch_openml | ||
| from sklearn.ensemble import HistGradientBoostingClassifier | ||
| from sklearn.feature_selection import f_classif, mutual_info_classif | ||
| from sklearn.metrics import f1_score | ||
| from sklearn.model_selection import train_test_split | ||
|
|
||
| from sklego.feature_selection.mrmr import MaximumRelevanceMinimumRedundancy | ||
|
|
||
| # --8<-- [start:mrmr] | ||
|
|
||
| # Download MNIST dataset using scikit-learn | ||
| mnist = fetch_openml("mnist_784", cache=True) | ||
|
|
||
| # Assign features and labels | ||
| X_pd, y_pd = mnist["data"], mnist["target"] | ||
|
|
||
| X, y = X_pd.to_numpy(), y_pd.to_numpy() | ||
| X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=10000, random_state=42) | ||
| X_train = X_train.reshape(60000, 28 * 28) | ||
| X_test = X_test.reshape(10000, 28 * 28) | ||
|
|
||
|
|
||
| from scipy.spatial.distance import cosine | ||
|
|
||
|
|
||
| def _redundancy_cosine_scipy(X, selected, left): | ||
fabioscantamburlo marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if len(selected) == 0: | ||
| return np.ones(len(left)) | ||
|
|
||
| score = np.array([np.sum([cosine(X[:, _s], X[:, _l]) for _s in selected]) for _l in left]) | ||
| score[np.isclose(score, 0.0, atol=np.finfo(float).eps, rtol=0)] = np.finfo(float).eps | ||
| return np.array(score) | ||
|
|
||
|
|
||
| def smile_relevance(X, y): | ||
| rows = 28 | ||
| cols = 28 | ||
| smiling_face = np.zeros((rows, cols), dtype=int) | ||
|
|
||
| # Set the values for the eyes, nose, and mouth with adjusted positions and sizes | ||
| smiling_face[10:13, 8:10] = 1 # Left eye | ||
| smiling_face[10:13, 18:20] = 1 # Right eye | ||
| smiling_face[16:18, 10:18] = 1 # Upper part of the mouth | ||
| smiling_face[18:20, 8:10] = 1 # Left edge of the open mouth | ||
| smiling_face[18:20, 18:20] = 1 # Right edge of the open mouth | ||
|
|
||
| # Add the nose as four pixels one pixel higher | ||
| smiling_face[14, 13:15] = 1 | ||
| smiling_face[27, :] = 1 | ||
| return smiling_face.reshape( | ||
| rows * cols, | ||
| ) | ||
|
|
||
|
|
||
| def smile_redundancy(X, selected, left): | ||
| return np.ones(len(left)) | ||
|
|
||
|
|
||
| K = 35 | ||
| mrmr = MaximumRelevanceMinimumRedundancy(k=K, kind="auto", redundancy_func="p", relevance_func="f") | ||
| mrmr_cosine = MaximumRelevanceMinimumRedundancy( | ||
| k=K, kind="auto", redundancy_func=_redundancy_cosine_scipy, relevance_func="f" | ||
| ) | ||
| mrmr_smile = MaximumRelevanceMinimumRedundancy(k=K, redundancy_func=smile_redundancy, relevance_func=smile_relevance) | ||
|
|
||
| f = f_classif( | ||
| X_train, | ||
| y_train.reshape( | ||
| 60000, | ||
| ), | ||
| )[0] | ||
| f_features = np.argsort(np.nan_to_num(f, nan=np.finfo(float).eps))[-K:] | ||
| mi = mutual_info_classif( | ||
| X_train, | ||
| y_train.reshape( | ||
| 60000, | ||
| ), | ||
| ) | ||
| mi_features = np.argsort(np.nan_to_num(mi, nan=np.finfo(float).eps))[-K:] | ||
| mrmr_features = mrmr.fit(X_train, y_train).selected_features_ | ||
| mrmr_cos_features = mrmr_cosine.fit(X_train, y_train).selected_features_ | ||
| mrmr_smile_features = mrmr_smile.fit(X_train, y_train).selected_features_ | ||
|
|
||
|
|
||
| features = { | ||
| "f_classif": f_features, | ||
| "mutual_info": mi_features, | ||
| "mrmr": mrmr_features, | ||
| "mrmr_cosine": mrmr_cos_features, | ||
| "mrmr_smile": mrmr_smile_features, | ||
| } | ||
| for name, s_f in features.items(): | ||
| model = HistGradientBoostingClassifier() | ||
| model.fit(X_train[:, s_f], y_train.squeeze()) | ||
| y_pred = model.predict(X_test[:, s_f]) | ||
| print(f1_score(y_test, y_pred, average="weighted")) | ||
|
|
||
| import matplotlib.pyplot as plt | ||
| import numpy as np | ||
|
|
||
| # Create figure and axes for the plots | ||
| fig, axes = plt.subplots(2, 3, figsize=(12, 8)) | ||
|
|
||
| # Define features dictionary | ||
| features = { | ||
| "f_classif": f_features, | ||
| "mutual_info": mi_features, | ||
| "mrmr": mrmr_features, | ||
| "mrmr_cos": mrmr_cos_features, | ||
| "mrmr_smile": mrmr_smile_features, | ||
| } | ||
|
|
||
| # Iterate through the features dictionary and plot the images | ||
| for idx, (name, s_f) in enumerate(features.items()): | ||
| row = idx // 3 # Calculate the row index | ||
| col = idx % 3 # Calculate the column index | ||
|
|
||
| a = np.zeros(28 * 28) | ||
| a[s_f] = 1 | ||
| ax = axes[row, col] | ||
| ax.imshow(a.reshape(28, 28), cmap="binary", vmin=0, vmax=1) | ||
| ax.set_title(name) | ||
|
|
||
| # --8<-- [end:mrmr] | ||
|
|
||
| plt.tight_layout() | ||
| plt.savefig(_static_path / "mrmr-feature-selection-mnist.png") | ||
| plt.clf() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| # Features Selection | ||
|
|
||
| :::sklego.feature_selection.mrmr.MaximumRelevanceMinimumRedundancy | ||
| options: | ||
| show_root_full_path: true | ||
| show_root_heading: true |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,216 @@ | ||
| import warnings | ||
|
|
||
| import numpy as np | ||
| from sklearn.base import BaseEstimator | ||
| from sklearn.feature_selection import f_classif, f_regression | ||
| from sklearn.feature_selection._base import SelectorMixin | ||
| from sklearn.utils.validation import check_is_fitted, check_X_y | ||
|
|
||
|
|
||
| def _redundancy_pearson(X, selected, left): | ||
| """Redundancy function for the MRMR feature selector algorithm | ||
|
|
||
| Parameters | ||
| ---------- | ||
| X : array-like, shape=(n_samples, n_features,) | ||
| Training data. Used to compute redundancy of the training features. | ||
| selected : array-like. | ||
| List of indexes of the selected features at iteration i-th. | ||
| left : array-like. | ||
| List of indexes of the left features at iteration i-th. Mrmr will select a feature | ||
| from this list. | ||
|
|
||
| Returns | ||
| ------- | ||
| np.ndarray, shape = (len(left), ) | ||
| The array containing the redundancy score using pearson correlation. | ||
| """ | ||
| # if len(selected) == 0: | ||
| # return np.ones(len(left)) | ||
fabioscantamburlo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| X_norm = X - np.mean(X, axis=0, keepdims=True) | ||
| Xs = X_norm[:, selected] | ||
| Xl = X_norm[:, left] | ||
|
|
||
| num = (Xl[:, None, :] * Xs[:, :, None]).sum(axis=0) | ||
| den = np.sqrt((Xl[:, None, :] ** 2).sum(axis=0)) * np.sqrt((Xs[:, :, None] ** 2).sum(axis=0)) | ||
|
|
||
| return np.sum(np.abs(np.nan_to_num(num / den, nan=np.finfo(float).eps)), axis=0) | ||
|
|
||
|
|
||
| class MaximumRelevanceMinimumRedundancy(SelectorMixin, BaseEstimator): | ||
| """Maximum Relevance Minimum Redundancy (MRMR) is an iterative feature selection method commonly used in data | ||
fabioscantamburlo marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
fabioscantamburlo marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| science to select a subset of features from a larger feature set. The goal of MRMR is to choose features that | ||
| have high relevance to the target variable while minimizing redundancy among the already selected features. | ||
|
|
||
| How MRMR works: | ||
|
|
||
| 1. Compute the relevance of each feature to the target variable: The relevance of a feature is typically | ||
| measured using a metric such as mutual information, correlation coefficient, or another appropriate measure of | ||
| dependence between the feature and the target variable. | ||
|
|
||
| 2. Compute the redundancy between each pair of features: Redundancy is the degree of similarity or overlap between | ||
| features. It can be measured using metrics such as mutual information, correlation coefficient, or other similarity | ||
| measures. | ||
|
|
||
| 3. Select features based on the maximum relevance and minimum redundancy criteria: MRMR aims to maximize the | ||
| relevance of selected features to the target variable while minimizing redundancy among them. | ||
|
|
||
| 4. Construct the final subset of features: MRMR iteratively adds features to the selected subset until a predefined | ||
| number of features is reached. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| k : int | ||
| Number of feature the model should use. | ||
| relevance_func : str | Callable,, default= "f"(f_classif or f_regression from sklearn.feature_selection) | ||
| [f_classif](https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.f_classif.html) | ||
| [f_regression](https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.f_regression.html) | ||
| The relevance function to use. | ||
fabioscantamburlo marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| redundancy_func : str | Callable, default="p" (Pearson correlation) | ||
| The redundancy function to use. | ||
| kind : Literal["auto", "classficiation", "regression"], default="auto". | ||
| 'classification' or 'regression' or 'auto' if auto the model | ||
| will try to infer the type of problem looking at the y data type, by default "auto". | ||
|
|
||
| !! warning: | ||
| If a custom relevance_func is provided it must have this signature: | ||
| Callable[[np.ndarray, np.ndarray], np.ndarray] | ||
fabioscantamburlo marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| It should accept X, y as arguments and it should compute the score for each feature of X | ||
| and return an array of shape (n_features_in_,). | ||
| !! warning: | ||
| If a custom redundancy_func is provided it must have the same signature as the method _redundancy_pearson | ||
|
|
||
| Attributes | ||
| ---------- | ||
| _y_dtype : data type of y | ||
fabioscantamburlo marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| selected_features_ : array-like of shape (k,) | ||
| Indexes of the selected features. | ||
| scores_ : array-like of shape (k,) | ||
| Scores of the selected features. | ||
|
|
||
| Examples | ||
| -------- | ||
| ```py | ||
| from sklego.feature_selection.mrmr import MaximumRelevanceMinimumRedundancy | ||
fabioscantamburlo marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
fabioscantamburlo marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| mrmr = MaximumRelevanceMinimumRedundancy(k=4, | ||
| kind='auto', | ||
| redundancy_func='p', | ||
| relevance_func='f') | ||
fabioscantamburlo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| X, y = ... | ||
fabioscantamburlo marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| # Fit mrmr model | ||
| mrmr = mrmr.fit(X, y) | ||
|
|
||
| # Selected features | ||
| selected_features = mrmr.selected_features_ | ||
|
|
||
| # Get the scores of the selected features | ||
| feature_scores = mrmr.scores_ | ||
| ``` | ||
| """ | ||
|
|
||
| def __init__(self, k, *, relevance_func="f", redundancy_func="p", kind="auto"): | ||
| self.k = k | ||
| self.relevance_func = relevance_func | ||
| self.redundancy_func = redundancy_func | ||
| self.kind = kind | ||
|
|
||
| def _get_support_mask(self): | ||
| """SelectorMixin base function to get the selected features mask | ||
|
|
||
| Returns | ||
| ------- | ||
| np.ndarray | ||
| Array of boolean, mask indicating if feature n is selected by mrmr or not. | ||
| """ | ||
| check_is_fitted(self, ["selected_features_"]) | ||
| all_features = np.arange(0, self.n_features_in_) | ||
| return np.isin(all_features, self.selected_features_) | ||
|
|
||
| @property | ||
| def _get_relevance(self): | ||
| """get relevance function from init values.""" | ||
| if self.relevance_func == "f": | ||
| if (self.kind == "auto" and np.issubdtype(self._y_dtype, np.integer)) | (self.kind == "classification"): | ||
| return lambda X, y: np.nan_to_num(f_classif(X, y)[0]) | ||
| elif (self.kind == "auto" and np.issubdtype(self._y_dtype, np.floating)) | (self.kind == "regression"): | ||
| return lambda X, y: np.nan_to_num(f_regression(X, y)[0]) | ||
| else: | ||
| raise ValueError( | ||
| "`kind` parameter must be 'auto', 'classification' or 'regression' and y dtype must be numeric" | ||
| ) | ||
| elif callable(self.relevance_func): | ||
| return self.relevance_func | ||
| else: | ||
| raise ValueError(f"Relevance function supported are 'f' or Callable, got {self.relevance_func}") | ||
|
|
||
| @property | ||
| def _get_redundancy(self): | ||
| """get redundancy function from init values.""" | ||
| if self.redundancy_func == "p": | ||
| return _redundancy_pearson | ||
| elif callable(self.redundancy_func): | ||
| return self.redundancy_func | ||
| else: | ||
| raise ValueError(f"Redundancy function supported are 'p' or Callable, got {self.redundancy_func}") | ||
|
|
||
| def fit(self, X, y): | ||
| """Fit the underlying feature selection algorithm on the training data `X` and `y` | ||
| using the provided redundancy and relevance functions. | ||
|
|
||
|
|
||
fabioscantamburlo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| Parameters | ||
| ---------- | ||
| X : array-like of shape (n_samples, n_features) | ||
| Training data. | ||
| y : array-like of shape (n_samples,) | ||
| Target values. | ||
|
|
||
| Returns | ||
| ------- | ||
| self : MaximumRelevanceMinimumRedundancy | ||
| The fitted estimator. | ||
|
|
||
| Raises | ||
| ------ | ||
| ValueError | ||
| if: | ||
|
|
||
| k parameter is not integer type or is < n_features_in (X.shape[1]) or < 1 | ||
fabioscantamburlo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| X, y = check_X_y(X, y) | ||
| self._y_dtype = y.dtype | ||
|
|
||
| relevance = self._get_relevance | ||
| redundancy = self._get_redundancy | ||
|
|
||
| self.n_features_in_ = X.shape[1] | ||
| left_features = list(range(self.n_features_in_)) | ||
| selected_features = [] | ||
| selected_scores = [] | ||
|
|
||
| if not isinstance(self.k, int): | ||
| raise ValueError("k parameter must be integer type") | ||
| if self.k > self.n_features_in_: | ||
| raise ValueError(f"k ({self.k}) parameter must be less than n_features_in_ ({self.n_features_in_})") | ||
| elif self.k == self.n_features_in_: | ||
| warnings.warn("k parameter is equal to n_features_in, no feature selection is applied") | ||
| return np.asarray(left_features) | ||
| elif self.k < 1: | ||
| raise ValueError(f"k ({self.k}) parameter must be greater than or equal to 1") | ||
|
|
||
| # computed one time for all features | ||
| rel_score = relevance(X, y) | ||
|
|
||
| for i in range(self.k): | ||
| red_i = redundancy(X, selected_features, left_features) / i if i > 0 else 1 | ||
| mrmr_score_i = rel_score[left_features] / red_i | ||
| selected_index = np.argmax(mrmr_score_i) | ||
| selected_features += [left_features.pop(selected_index)] | ||
| selected_scores += [mrmr_score_i[selected_index]] | ||
| self.selected_features_ = np.asarray(selected_features) | ||
| self.scores_ = np.asarray(selected_scores) | ||
| return self | ||
Empty file.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.