Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion .all-contributorsrc
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,14 @@
"contributions": [
"code",
]
}
},
{
"login": "Khushmagrawal",
"name": "Khush Agrawal",
"profile": "https://github.com/Khushmagrawal",
"contributions": [
"code",
]
},
]
}
1 change: 1 addition & 0 deletions docs/source/api_reference/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ Integer support
Hurdle
NegativeBinomial
Poisson
ZeroInflated

Non-parametric and empirical distributions
------------------------------------------
Expand Down
2 changes: 2 additions & 0 deletions skpro/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"TruncatedNormal",
"Uniform",
"Weibull",
"ZeroInflated",
]

from skpro.distributions.alpha import Alpha
Expand Down Expand Up @@ -83,3 +84,4 @@
from skpro.distributions.truncated_normal import TruncatedNormal
from skpro.distributions.uniform import Uniform
from skpro.distributions.weibull import Weibull
from skpro.distributions.zeroinflated import ZeroInflated
40 changes: 33 additions & 7 deletions skpro/distributions/truncated.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class TruncatedDistribution(BaseDistribution):
r"""A truncated distribution _not_ including the lower bound.
r"""A truncated distribution.

Given a univariate distribution, this distribution samples from the base
distribution but truncates the values to lie between a specified lower and
Expand All @@ -26,11 +26,15 @@ class TruncatedDistribution(BaseDistribution):
The distribution to truncate.

lower : Union[float, int], optional
The lower bound below which values are truncated, _not_ including it.
The lower bound below which values are truncated.
By default, this bound is exclusive (see ``inclusive_lower``).

upper : Union[float, int], optional
The upper bound above which values are truncated.

inclusive_lower : bool, optional, default = False
If True, the lower bound is inclusive (x >= lower).

Examples
--------
>>> from skpro.distributions import Normal, TruncatedDistribution
Expand Down Expand Up @@ -61,12 +65,14 @@ def __init__(
*,
lower: Union[float, int] = None,
upper: Union[float, int] = None,
inclusive_lower: bool = False,
index=None,
columns=None,
):
self.distribution = distribution
self.lower = lower
self.upper = upper
self.inclusive_lower = inclusive_lower

super().__init__(
index=index if index is not None else distribution.index,
Expand Down Expand Up @@ -98,9 +104,19 @@ def __init__(
self.set_tags(**{"distr:paramtype": inner_paramtype})

def _get_low_high_prob(self) -> Tuple[float, float]:
prob_at_lower = (
self.distribution.cdf(self.lower) if self.lower is not None else 0.0
)
if self.lower is not None:
prob_at_lower = self.distribution.cdf(self.lower)
if self.inclusive_lower:
measure_type = self.get_tag("distr:measuretype")

# If continuous, P(X=lower) is 0, so CDF(lower) is already correct
# If discrete, then P(X < lower) = CDF(lower) - P(X=lower)
if measure_type == "discrete":
prob_mass_at_lower = self.distribution.pmf(self.lower)
prob_at_lower = prob_at_lower - prob_mass_at_lower
else:
prob_at_lower = 0.0

prob_at_upper = (
self.distribution.cdf(self.upper) if self.upper is not None else 1.0
)
Expand All @@ -110,7 +126,13 @@ def _get_low_high_prob(self) -> Tuple[float, float]:
def _calculate_density(self, x: np.ndarray, fun, as_log: bool):
inf = float("inf")

is_valid = (x > (self.lower or -inf)) & (x <= (self.upper or inf))
lower_bound = self.lower if self.lower is not None else -inf
upper_bound = self.upper if self.upper is not None else inf

if self.inclusive_lower:
is_valid = (x >= lower_bound) & (x <= upper_bound)
else:
is_valid = (x > lower_bound) & (x <= upper_bound)

prob_base = fun(x)
cdf_lower, cdf_upper = self._get_low_high_prob()
Expand Down Expand Up @@ -176,6 +198,7 @@ def _iloc(self, rowidx=None, colidx=None):
distribution=distr,
lower=self.lower,
upper=self.upper,
inclusive_lower=self.inclusive_lower,
index=new_index,
columns=new_columns,
)
Expand All @@ -189,6 +212,7 @@ def _iat(self, rowidx=None, colidx=None):
distribution=self_subset.distribution.iat[0, 0],
lower=self.lower,
upper=self.upper,
inclusive_lower=self.inclusive_lower,
)

@classmethod
Expand Down Expand Up @@ -236,5 +260,7 @@ def get_test_params(cls, parameter_set="default"): # noqa: D102
"index": idx,
"columns": cols,
}
# inclusive lower test parameter
params5 = {"distribution": dist, "lower": 0.0, "inclusive_lower": True}

return [params1, params2, params3, params4]
return [params1, params2, params3, params4, params5]
Loading