-
Notifications
You must be signed in to change notification settings - Fork 76
[ENH] interface ondil OnlineGamlss wrapper #637
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 10 commits
60308dd
069f9de
116ed41
4fd29bf
60e429b
348e6a6
20f534f
287eca0
ad635c2
5fe54ba
d12212d
d697695
8ea9ad2
6659c7b
4329f82
39c3134
d40e7a9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,212 @@ | ||
| """Interface for ondil OnlineGamlss probabilistic regressor. | ||
|
|
||
| This module provides a lightweight wrapper around ``ondil``'s | ||
| ``OnlineGamlss`` estimator to expose it as an skpro ``BaseProbaRegressor``. | ||
|
|
||
| The wrapper is intentionally lightweight: imports of the optional | ||
| ``ondil`` dependency are performed inside methods so the package is | ||
| optional for users who do not need this estimator. | ||
|
|
||
| The wrapper attempts to be tolerant to different method names used by | ||
| the upstream estimator: it will use ``fit`` when available, otherwise | ||
| fall back to ``partial_fit`` or ``update`` where appropriate. Prediction | ||
| is best-effort: if the upstream ``predict`` method returns distribution | ||
| parameters (e.g., columns for location/scale) these are converted to a | ||
| ``skpro.distributions`` object; otherwise an informative error is raised. | ||
| """ | ||
|
|
||
| from skpro.regression.base import BaseProbaRegressor | ||
|
|
||
|
|
||
| class OndilOnlineGamlss(BaseProbaRegressor): | ||
| """Wrapper for ondil.online_gamlss.OnlineGamlss. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| distribution : str, default="Normal" | ||
| Name of distribution to expose via skpro. This is used to map | ||
| parameter names returned by the upstream estimator to skpro's | ||
| distribution constructors. Common value is "Normal". | ||
|
|
||
| Notes | ||
| ----- | ||
| * The ondil dependency is optional and imported inside methods. | ||
| * The wrapper uses a best-effort strategy to call the appropriate | ||
| fit/update/predict methods of the upstream estimator. If ondil's | ||
| API changes in incompatible ways, this wrapper may need updates. | ||
| """ | ||
|
|
||
| _tags = { | ||
| "authors": ["arnavk23"], | ||
| "maintainers": ["fkiraly"], | ||
| "python_dependencies": ["ondil"], | ||
| "capability:multioutput": False, | ||
| "capability:missing": True, | ||
| "tests:vm": True, | ||
| "capability:update": True, | ||
| "X_inner_mtype": "pd_DataFrame_Table", | ||
| "y_inner_mtype": "pd_DataFrame_Table", | ||
arnavk23 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| def __init__(self, distribution="Normal", ondil_init_params=None): | ||
| self.distribution = distribution | ||
| # explicit dict of kwargs forwarded to the ondil constructor. | ||
| self._ondil_kwargs = dict(ondil_init_params or {}) | ||
|
|
||
| super().__init__() | ||
|
|
||
| def _fit(self, X, y): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually,
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good to know. If including
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah, I see. FYI, unfortunately this is not possible since originally censoring was not supported (only ordinary probabilistic supervised regression), so a lot of third party code would use Hence the more lenient extension contract - which also means, the same problem would be caused when starting to require
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From a power user perspective I think this is also a good compromise solution for the long term - since users not interested in censoring do not need to worry about what that
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just meant just passing |
||
| """Fit the underlying ondil OnlineGamlss estimator. | ||
|
|
||
| The method tries several common fitting/update method names in the | ||
| upstream estimator to support different ondil versions. | ||
| """ | ||
| # defer import to keep ondil optional | ||
| import importlib | ||
|
|
||
| module_str = "ondil.estimators.online_gamlss" | ||
| ondil_mod = importlib.import_module(module_str) | ||
| try: | ||
| OnlineGamlss = ondil_mod.OnlineGamlss | ||
| except AttributeError: | ||
| try: | ||
| OnlineGamlss = ondil_mod.OnlineDistributionalRegression | ||
| except AttributeError as exc: | ||
| raise ImportError( | ||
| "ondil.estimators.online_gamlss does not expose '" | ||
| "OnlineDistributionalRegression' or 'OnlineGamlss' - " | ||
| "please install a compatible ondil version" | ||
| ) from exc | ||
|
|
||
| # store y columns for later | ||
| self._y_cols = y.columns | ||
|
|
||
| # instantiate upstream estimator | ||
| self._ondil = OnlineGamlss(**self._ondil_kwargs) | ||
|
|
||
| # Prefer `fit`, then `partial_fit`, then `update`. | ||
| for method_name in ("fit", "partial_fit", "update"): | ||
| if hasattr(self._ondil, method_name): | ||
| method = getattr(self._ondil, method_name) | ||
| # Call the upstream method with the provided X, y. | ||
| method(X, y) | ||
| break | ||
| else: | ||
| raise AttributeError( | ||
| "ondil OnlineGamlss instance has no fit/partial_fit/update method" | ||
| ) | ||
|
|
||
| return self | ||
|
|
||
| def _update(self, X, y): | ||
| """Update the fitted ondil estimator in online fashion. | ||
|
|
||
| Tries common update method names on the upstream estimator. | ||
| """ | ||
| if not hasattr(self, "_ondil"): | ||
| raise RuntimeError("Estimator not fitted yet; call fit before update") | ||
|
|
||
| if hasattr(self._ondil, "update"): | ||
| self._ondil.update(X, y) | ||
| return self | ||
|
|
||
| if hasattr(self._ondil, "partial_fit"): | ||
| self._ondil.partial_fit(X, y) | ||
| return self | ||
|
|
||
| raise AttributeError( | ||
| "Upstream ondil estimator has no update/partial_fit method" | ||
| ) | ||
|
|
||
| def _predict_proba(self, X): | ||
| """Predict distribution parameters and convert to skpro distribution. | ||
|
|
||
| The method is best-effort: it tries to call ``predict`` on the | ||
| underlying ondil estimator and expects a pandas DataFrame (or array) | ||
| of parameters. For the common case of a Normal prediction the | ||
| columns should contain location and scale (names tolerated below). | ||
| """ | ||
| import importlib | ||
|
|
||
| import pandas as pd | ||
|
|
||
| if not hasattr(self, "_ondil"): | ||
| raise RuntimeError("Estimator not fitted yet; call fit before predict") | ||
|
|
||
| # call predict on upstream estimator | ||
| if hasattr(self._ondil, "predict"): | ||
| params = self._ondil.predict(X) | ||
| elif hasattr(self._ondil, "predict_params"): | ||
| params = self._ondil.predict_params(X) | ||
| else: | ||
| raise AttributeError("Upstream ondil estimator has no predict method") | ||
|
|
||
| # normalize to pandas DataFrame | ||
| if isinstance(params, pd.DataFrame): | ||
| df = params | ||
| else: | ||
| try: | ||
| df = pd.DataFrame(params) | ||
| except Exception as e: | ||
| raise TypeError("Unrecognized predict output from ondil: %s" % e) | ||
|
|
||
| # decide mapping based on requested distribution | ||
| dist = self.distribution | ||
| # import skpro distributions lazily | ||
| distr_mod = importlib.import_module("skpro.distributions") | ||
|
|
||
| if dist == "Normal": | ||
| # accept common column names for loc/scale | ||
| col_candidates = { | ||
| "loc": ["loc", "mu", "mean"], | ||
| "scale": ["scale", "sigma", "sd", "std"], | ||
| } | ||
|
|
||
| def _find(col_names): | ||
| for c in col_names: | ||
| if c in df.columns: | ||
| return c | ||
| return None | ||
|
|
||
| loc_col = _find(col_candidates["loc"]) or df.columns[0] | ||
|
|
||
| if df.shape[1] > 1: | ||
| scale_col = _find(col_candidates["scale"]) or df.columns[1] | ||
| else: | ||
| scale_col = None | ||
|
|
||
| if scale_col is None: | ||
| raise ValueError( | ||
| "Could not infer scale column from ondil predict output" | ||
| ) | ||
|
|
||
| loc = df.loc[:, [loc_col]].values | ||
| scale = df.loc[:, [scale_col]].values | ||
|
|
||
| Normal = distr_mod.Normal | ||
| return Normal(mu=loc, sigma=scale, index=X.index, columns=self._y_cols) | ||
|
|
||
| # fallback: try to call distribution class with all columns as kwargs | ||
| if hasattr(distr_mod, dist): | ||
| Distr = getattr(distr_mod, dist) | ||
| # construct args dict using column names | ||
| vals = {c: df.loc[:, [c]].values for c in df.columns} | ||
| return Distr(**vals, index=X.index, columns=self._y_cols) | ||
arnavk23 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| raise NotImplementedError( | ||
| "Mapping to skpro distribution '" + str(dist) + "' not implemented" | ||
| ) | ||
|
|
||
| @classmethod | ||
| def get_test_params(cls, parameter_set="default"): | ||
| """Return testing parameter settings for the estimator. | ||
|
|
||
| The ondil dependency is optional: the test harness will skip tests | ||
| requiring ondil if the package is not available on the test runner. | ||
| """ | ||
| # minimal constructor params; provide two small parameter sets so | ||
| # the package-level tests exercise different constructor paths. | ||
| return [ | ||
| {"distribution": "Normal"}, | ||
| {"distribution": "Normal", "ondil_init_params": {"verbose": 0}}, | ||
| ] | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| import numpy as np | ||
| import pandas as pd | ||
| import pytest | ||
|
|
||
| try: | ||
arnavk23 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| import ondil # noqa: F401 | ||
| except Exception: | ||
| ondil = None | ||
|
|
||
| from skpro.regression.ondil import OndilOnlineGamlss | ||
|
|
||
| if ondil is None: | ||
| pytest.skip("ondil not installed", allow_module_level=True) | ||
|
|
||
|
|
||
| def test_ondil_instantiation_and_get_test_params(): | ||
| """Basic smoke test for the Ondil wrapper. | ||
|
|
||
| The test is skipped if the optional dependency ``ondil`` is not | ||
| installed. It verifies that ``get_test_params`` returns at least one | ||
| parameter set and that the estimator can be instantiated with it. | ||
| """ | ||
| params = OndilOnlineGamlss.get_test_params() | ||
| if isinstance(params, dict): | ||
| params = [params] | ||
| assert len(params) >= 1 | ||
|
|
||
| p = params[0] | ||
| est = OndilOnlineGamlss(**p) | ||
| assert isinstance(est, OndilOnlineGamlss) | ||
|
|
||
|
|
||
| def test_ondil_fit_smoke(): | ||
| """Try a light-weight fit call on tiny data to validate wiring. | ||
|
|
||
| This is a smoke test only; if the upstream API requires more complex | ||
| constructor args or data handling, the test will be adjusted later. | ||
| """ | ||
| # create tiny dataset | ||
| X = pd.DataFrame({"a": [0.0, 1.0, 2.0]}) | ||
| y = pd.DataFrame(np.array([[0.1], [1.1], [1.9]])) | ||
|
|
||
| est = OndilOnlineGamlss() | ||
|
|
||
| # fit should run without raising (best-effort); if upstream raises, | ||
| # surface the error so developers can adapt the wrapper. | ||
| est.fit(X, y) | ||
| assert est.is_fitted | ||
Uh oh!
There was an error while loading. Please reload this page.