Skip to content
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
0e2e45f
Wip 1
Jan 28, 2024
14ff603
new ideas
Jan 29, 2024
4d8e4ae
new ideas
Jan 29, 2024
e324389
commit
fabioscantamburlo Jan 30, 2024
2a78dc1
commit
fabioscantamburlo Jan 30, 2024
55fcd63
exclude venv
fabioscantamburlo Jan 30, 2024
82d9719
Wip2
fabioscantamburlo Jan 31, 2024
bfe8ba3
Wip3
fabioscantamburlo Feb 1, 2024
5397d34
Wip 3.5
fabioscantamburlo Feb 1, 2024
7ae7d10
Pushing some optim
fabioscantamburlo Feb 5, 2024
365a78d
Doc string and examples
fabioscantamburlo Feb 6, 2024
acf356a
Docstring WIP
fabioscantamburlo Feb 6, 2024
b5160af
Docstring WIP2
fabioscantamburlo Feb 6, 2024
5d12ada
Adding something
fabioscantamburlo Feb 7, 2024
899f315
Bugfix
fabioscantamburlo Feb 12, 2024
a0f0470
Mkdocs and small fixes
fabioscantamburlo Feb 12, 2024
7591f85
Added tests
fabioscantamburlo Feb 13, 2024
0053735
Added scripts
fabioscantamburlo Feb 13, 2024
8fecf4a
Wip4
fabioscantamburlo Feb 19, 2024
b08b04b
removing tests
fabioscantamburlo Feb 19, 2024
9d8c20b
Added tests and some bugifx
fabioscantamburlo Feb 20, 2024
077b0d1
revert pandastransformer
fabioscantamburlo Feb 20, 2024
5d55a2e
Update sklego/feature_selection/mrmr.py
fabioscantamburlo Feb 22, 2024
2e107ef
Resolving comments on PR
fabioscantamburlo Feb 22, 2024
1d4340c
features
fabioscantamburlo Feb 28, 2024
8f4481e
venv
fabioscantamburlo Feb 28, 2024
a6713d6
Add missing file
fabioscantamburlo Feb 28, 2024
25f5613
Wip userguide
fabioscantamburlo Feb 28, 2024
774f170
Merge branch 'FEATURE-MRMR-UserGuide' into FEATURE-MRMR
fabioscantamburlo Mar 1, 2024
4454266
Merge branch 'main' into FEATURE-MRMR
fabioscantamburlo Mar 1, 2024
3f21b0a
typing
fabioscantamburlo Mar 1, 2024
249d17f
Update sklego/feature_selection/mrmr.py
fabioscantamburlo Mar 3, 2024
6c772d8
Update sklego/feature_selection/mrmr.py
fabioscantamburlo Mar 3, 2024
b9b5bfc
Update docs/user-guide/feature-selection.md
fabioscantamburlo Mar 3, 2024
86e0dc6
Update sklego/feature_selection/mrmr.py
fabioscantamburlo Mar 3, 2024
f788026
resolve comments
fabioscantamburlo Mar 3, 2024
87da50c
clean
fabioscantamburlo Mar 3, 2024
287977e
suggestions + general rephrase
fabioscantamburlo Mar 3, 2024
8b4c40a
Typo
fabioscantamburlo Mar 4, 2024
a1df5fe
Update docs/user-guide/feature-selection.md
fabioscantamburlo Mar 9, 2024
3b2bc7a
Merge branch 'main' into FEATURE-MRMR
fabioscantamburlo Mar 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ venv/
ENV/
env.bak/
venv.bak/
venv*/

# Spyder project settings
.spyderproject
Expand All @@ -120,4 +121,4 @@ dmypy.json
.DS_Store

# Local Netlify folder
.netlify
.netlify
138 changes: 138 additions & 0 deletions docs/_scripts/feature-selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from pathlib import Path

_file = Path(__file__)
print(f"Executing {_file}")

_static_path = Path("_static") / _file.stem
_static_path.mkdir(parents=True, exist_ok=True)

import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.feature_selection import f_classif, mutual_info_classif
from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split

from sklego.feature_selection.mrmr import MaximumRelevanceMinimumRedundancy

# --8<-- [start:mrmr]

# Download MNIST dataset using scikit-learn
mnist = fetch_openml("mnist_784", cache=True)

# Assign features and labels
X_pd, y_pd = mnist["data"], mnist["target"]

X, y = X_pd.to_numpy(), y_pd.to_numpy()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=10000, random_state=42)
X_train = X_train.reshape(60000, 28 * 28)
X_test = X_test.reshape(10000, 28 * 28)


from scipy.spatial.distance import cosine


def _redundancy_cosine_scipy(X, selected, left):
if len(selected) == 0:
return np.ones(len(left))

score = np.array([np.sum([cosine(X[:, _s], X[:, _l]) for _s in selected]) for _l in left])
score[np.isclose(score, 0.0, atol=np.finfo(float).eps, rtol=0)] = np.finfo(float).eps
return np.array(score)


def smile_relevance(X, y):
rows = 28
cols = 28
smiling_face = np.zeros((rows, cols), dtype=int)

# Set the values for the eyes, nose, and mouth with adjusted positions and sizes
smiling_face[10:13, 8:10] = 1 # Left eye
smiling_face[10:13, 18:20] = 1 # Right eye
smiling_face[16:18, 10:18] = 1 # Upper part of the mouth
smiling_face[18:20, 8:10] = 1 # Left edge of the open mouth
smiling_face[18:20, 18:20] = 1 # Right edge of the open mouth

# Add the nose as four pixels one pixel higher
smiling_face[14, 13:15] = 1
smiling_face[27, :] = 1
return smiling_face.reshape(
rows * cols,
)


def smile_redundancy(X, selected, left):
return np.ones(len(left))


K = 35
mrmr = MaximumRelevanceMinimumRedundancy(k=K, kind="auto", redundancy_func="p", relevance_func="f")
mrmr_cosine = MaximumRelevanceMinimumRedundancy(
k=K, kind="auto", redundancy_func=_redundancy_cosine_scipy, relevance_func="f"
)
mrmr_smile = MaximumRelevanceMinimumRedundancy(k=K, redundancy_func=smile_redundancy, relevance_func=smile_relevance)

f = f_classif(
X_train,
y_train.reshape(
60000,
),
)[0]
f_features = np.argsort(np.nan_to_num(f, nan=np.finfo(float).eps))[-K:]
mi = mutual_info_classif(
X_train,
y_train.reshape(
60000,
),
)
mi_features = np.argsort(np.nan_to_num(mi, nan=np.finfo(float).eps))[-K:]
mrmr_features = mrmr.fit(X_train, y_train).selected_features_
mrmr_cos_features = mrmr_cosine.fit(X_train, y_train).selected_features_
mrmr_smile_features = mrmr_smile.fit(X_train, y_train).selected_features_


features = {
"f_classif": f_features,
"mutual_info": mi_features,
"mrmr": mrmr_features,
"mrmr_cosine": mrmr_cos_features,
"mrmr_smile": mrmr_smile_features,
}
for name, s_f in features.items():
model = HistGradientBoostingClassifier()
model.fit(X_train[:, s_f], y_train.squeeze())
y_pred = model.predict(X_test[:, s_f])
print(f1_score(y_test, y_pred, average="weighted"))

import matplotlib.pyplot as plt
import numpy as np

# Create figure and axes for the plots
fig, axes = plt.subplots(2, 3, figsize=(12, 8))

# Define features dictionary
features = {
"f_classif": f_features,
"mutual_info": mi_features,
"mrmr": mrmr_features,
"mrmr_cos": mrmr_cos_features,
"mrmr_smile": mrmr_smile_features,
}

# Iterate through the features dictionary and plot the images
for idx, (name, s_f) in enumerate(features.items()):
row = idx // 3 # Calculate the row index
col = idx % 3 # Calculate the column index

a = np.zeros(28 * 28)
a[s_f] = 1
ax = axes[row, col]
ax.imshow(a.reshape(28, 28), cmap="binary", vmin=0, vmax=1)
ax.set_title(name)

# --8<-- [end:mrmr]

plt.tight_layout()
plt.savefig(_static_path / "mrmr-feature-selection-mnist.png")
plt.clf()
6 changes: 6 additions & 0 deletions docs/api/features-selection.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Features Selection

:::sklego.feature_selection.mrmr.MaximumRelevanceMinimumRedundancy
options:
show_root_full_path: true
show_root_heading: true
1 change: 1 addition & 0 deletions mkdocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ nav:
- Metrics: api/metrics.md
- Mixture: api/mixture.md
- Model Selection: api/model-selection.md
- Features Selection: api/features-selection.md
- Naive Bayes: api/naive-bayes.md
- Neighbors: api/neighbors.md
- Pandas Utils: api/pandas-utils.md
Expand Down
Empty file.
216 changes: 216 additions & 0 deletions sklego/feature_selection/mrmr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
import warnings

import numpy as np
from sklearn.base import BaseEstimator
from sklearn.feature_selection import f_classif, f_regression
from sklearn.feature_selection._base import SelectorMixin
from sklearn.utils.validation import check_is_fitted, check_X_y


def _redundancy_pearson(X, selected, left):
"""Redundancy function for the MRMR feature selector algorithm

Parameters
----------
X : array-like, shape=(n_samples, n_features,)
Training data. Used to compute redundancy of the training features.
selected : array-like.
List of indexes of the selected features at iteration i-th.
left : array-like.
List of indexes of the left features at iteration i-th. Mrmr will select a feature
from this list.

Returns
-------
np.ndarray, shape = (len(left), )
The array containing the redundancy score using pearson correlation.
"""
# if len(selected) == 0:
# return np.ones(len(left))

X_norm = X - np.mean(X, axis=0, keepdims=True)
Xs = X_norm[:, selected]
Xl = X_norm[:, left]

num = (Xl[:, None, :] * Xs[:, :, None]).sum(axis=0)
den = np.sqrt((Xl[:, None, :] ** 2).sum(axis=0)) * np.sqrt((Xs[:, :, None] ** 2).sum(axis=0))

return np.sum(np.abs(np.nan_to_num(num / den, nan=np.finfo(float).eps)), axis=0)


class MaximumRelevanceMinimumRedundancy(SelectorMixin, BaseEstimator):
"""Maximum Relevance Minimum Redundancy (MRMR) is an iterative feature selection method commonly used in data
science to select a subset of features from a larger feature set. The goal of MRMR is to choose features that
have high relevance to the target variable while minimizing redundancy among the already selected features.

How MRMR works:

1. Compute the relevance of each feature to the target variable: The relevance of a feature is typically
measured using a metric such as mutual information, correlation coefficient, or another appropriate measure of
dependence between the feature and the target variable.

2. Compute the redundancy between each pair of features: Redundancy is the degree of similarity or overlap between
features. It can be measured using metrics such as mutual information, correlation coefficient, or other similarity
measures.

3. Select features based on the maximum relevance and minimum redundancy criteria: MRMR aims to maximize the
relevance of selected features to the target variable while minimizing redundancy among them.

4. Construct the final subset of features: MRMR iteratively adds features to the selected subset until a predefined
number of features is reached.

Parameters
----------
k : int
Number of feature the model should use.
relevance_func : str | Callable,, default= "f"(f_classif or f_regression from sklearn.feature_selection)
[f_classif](https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.f_classif.html)
[f_regression](https://scikit-learn.org/stable/modules/generated/sklearn.feature_selection.f_regression.html)
The relevance function to use.
redundancy_func : str | Callable, default="p" (Pearson correlation)
The redundancy function to use.
kind : Literal["auto", "classficiation", "regression"], default="auto".
'classification' or 'regression' or 'auto' if auto the model
will try to infer the type of problem looking at the y data type, by default "auto".

!! warning:
If a custom relevance_func is provided it must have this signature:
Callable[[np.ndarray, np.ndarray], np.ndarray]
It should accept X, y as arguments and it should compute the score for each feature of X
and return an array of shape (n_features_in_,).
!! warning:
If a custom redundancy_func is provided it must have the same signature as the method _redundancy_pearson

Attributes
----------
_y_dtype : data type of y
selected_features_ : array-like of shape (k,)
Indexes of the selected features.
scores_ : array-like of shape (k,)
Scores of the selected features.

Examples
--------
```py
from sklego.feature_selection.mrmr import MaximumRelevanceMinimumRedundancy

mrmr = MaximumRelevanceMinimumRedundancy(k=4,
kind='auto',
redundancy_func='p',
relevance_func='f')

X, y = ...

# Fit mrmr model
mrmr = mrmr.fit(X, y)

# Selected features
selected_features = mrmr.selected_features_

# Get the scores of the selected features
feature_scores = mrmr.scores_
```
"""

def __init__(self, k, *, relevance_func="f", redundancy_func="p", kind="auto"):
self.k = k
self.relevance_func = relevance_func
self.redundancy_func = redundancy_func
self.kind = kind

def _get_support_mask(self):
"""SelectorMixin base function to get the selected features mask

Returns
-------
np.ndarray
Array of boolean, mask indicating if feature n is selected by mrmr or not.
"""
check_is_fitted(self, ["selected_features_"])
all_features = np.arange(0, self.n_features_in_)
return np.isin(all_features, self.selected_features_)

@property
def _get_relevance(self):
"""get relevance function from init values."""
if self.relevance_func == "f":
if (self.kind == "auto" and np.issubdtype(self._y_dtype, np.integer)) | (self.kind == "classification"):
return lambda X, y: np.nan_to_num(f_classif(X, y)[0])
elif (self.kind == "auto" and np.issubdtype(self._y_dtype, np.floating)) | (self.kind == "regression"):
return lambda X, y: np.nan_to_num(f_regression(X, y)[0])
else:
raise ValueError(
"`kind` parameter must be 'auto', 'classification' or 'regression' and y dtype must be numeric"
)
elif callable(self.relevance_func):
return self.relevance_func
else:
raise ValueError(f"Relevance function supported are 'f' or Callable, got {self.relevance_func}")

@property
def _get_redundancy(self):
"""get redundancy function from init values."""
if self.redundancy_func == "p":
return _redundancy_pearson
elif callable(self.redundancy_func):
return self.redundancy_func
else:
raise ValueError(f"Redundancy function supported are 'p' or Callable, got {self.redundancy_func}")

def fit(self, X, y):
"""Fit the underlying feature selection algorithm on the training data `X` and `y`
using the provided redundancy and relevance functions.


Parameters
----------
X : array-like of shape (n_samples, n_features)
Training data.
y : array-like of shape (n_samples,)
Target values.

Returns
-------
self : MaximumRelevanceMinimumRedundancy
The fitted estimator.

Raises
------
ValueError
if:

k parameter is not integer type or is < n_features_in (X.shape[1]) or < 1
"""
X, y = check_X_y(X, y)
self._y_dtype = y.dtype

relevance = self._get_relevance
redundancy = self._get_redundancy

self.n_features_in_ = X.shape[1]
left_features = list(range(self.n_features_in_))
selected_features = []
selected_scores = []

if not isinstance(self.k, int):
raise ValueError("k parameter must be integer type")
if self.k > self.n_features_in_:
raise ValueError(f"k ({self.k}) parameter must be less than n_features_in_ ({self.n_features_in_})")
elif self.k == self.n_features_in_:
warnings.warn("k parameter is equal to n_features_in, no feature selection is applied")
return np.asarray(left_features)
elif self.k < 1:
raise ValueError(f"k ({self.k}) parameter must be greater than or equal to 1")

# computed one time for all features
rel_score = relevance(X, y)

for i in range(self.k):
red_i = redundancy(X, selected_features, left_features) / i if i > 0 else 1
mrmr_score_i = rel_score[left_features] / red_i
selected_index = np.argmax(mrmr_score_i)
selected_features += [left_features.pop(selected_index)]
selected_scores += [mrmr_score_i[selected_index]]
self.selected_features_ = np.asarray(selected_features)
self.scores_ = np.asarray(selected_scores)
return self
Empty file.
Loading