Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions entropy_estimators/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,7 @@ def det(array_or_scalar):


@convert_vectors_to_2d_arrays_if_any
def get_h_mvn(x):

def get_h_mvn(x, normalized=False):
"""
Computes the entropy of a multivariate Gaussian distribution:

Expand All @@ -83,14 +82,44 @@ def get_h_mvn(x):
x: (n, d) ndarray
n samples from a d-dimensional multivariate normal distribution

normalized: bool
normalize distribution, if `True` each component is normalized such that
its standard deviation is `1` and the covariance matrix becomes equal to
the Pearson correlation coefficients and the entropy becomes invariant
under scalar multiplication

Returns:
--------
h: float
entropy H(X)
"""

d = x.shape[1]
h = 0.5 * log((2 * np.pi * np.e)**d * det(np.cov(x.T)))
if d == 1:
if normalized:
det_cov = 1
else:
det_cov = np.var(x, ddof=1)
else:
cov = np.cov(x.T)
# corrcoef calculation simplified from source code of np.corrcoef
stddev = np.sqrt(np.diagonal(cov))
corrcoef = cov / stddev[:, None] / stddev[None, :]
det_corrcoef = np.linalg.det(corrcoef)
# check if elements are exactly zero, then variables 100% correlated
if det_corrcoef == 0:
return -np.inf
# check if elements of corrcoef determinent is close to zero considering
# float precision
if np.isclose(det_corrcoef, 0, rtol=0,
atol=100*np.finfo(det_corrcoef).resolution):
return np.nan
if normalized:
det_cov = det_corrcoef
else:
det_cov = np.linalg.det(cov)
h = 0.5 * np.log((2 * np.pi * np.e)**d * det_cov)

return h


Expand Down