Skip to content
Merged
42 changes: 29 additions & 13 deletions dask_sql/physical/rel/custom/create_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
from typing import TYPE_CHECKING

from dask import delayed

from dask_sql.datacontainer import DataContainer
from dask_sql.java import org
from dask_sql.physical.rel.base import BaseRelPlugin
Expand Down Expand Up @@ -134,6 +136,19 @@ def convert(
wrap_fit = kwargs.pop("wrap_fit", False)
fit_kwargs = kwargs.pop("fit_kwargs", {})

select_query = context._to_sql_string(select)
training_df = context.sql(select_query)

if target_column:
non_target_columns = [
col for col in training_df.columns if col != target_column
]
X = training_df[non_target_columns]
y = training_df[target_column]
else:
X = training_df
y = None

try:
ModelClass = import_class(model_class)
except ImportError:
Expand All @@ -156,20 +171,21 @@ def convert(
except ImportError: # pragma: no cover
raise ValueError("Wrapping requires dask-ml to be installed.")

model = ParallelPostFit(estimator=model)
# When `wrap_predict` is set to True we train on single partition frames
# because this is only useful for non dask distributed models
# Training via delayed fit ensures that we dont have to transfer
# data back to the client for training

select_query = context._to_sql_string(select)
training_df = context.sql(select_query)
X_d = X.repartition(npartitions=1).to_delayed()
if y is not None:
y_d = y.repartition(npartitions=1).to_delayed()
else:
y_d = None
Comment on lines +179 to +183
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For more context around this see issue: rapidsai/cuml#4406 .

We were previously training on client which is:

a. Very inefficient and possibly problematic in multi-node clusters and heterogeneous setup.
b. Training non distributed xgboost models on dask collections is not supported.


if target_column:
non_target_columns = [
col for col in training_df.columns if col != target_column
]
X = training_df[non_target_columns]
y = training_df[target_column]
else:
X = training_df
y = None
delayed_model = [delayed(model.fit)(x_p, y_p) for x_p, y_p in zip(X_d, y_d)]
model = delayed_model[0].compute()
model = ParallelPostFit(estimator=model)

model.fit(X, y, **fit_kwargs)
else:
model.fit(X, y, **fit_kwargs)
context.register_model(model_name, model, X.columns, schema_name=schema_name)
20 changes: 20 additions & 0 deletions tests/integration/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from dask_cuda import LocalCUDACluster # noqa: F401
except ImportError:
cudf = None
LocalCUDACluster = None


@pytest.fixture()
Expand Down Expand Up @@ -255,6 +256,25 @@ def _assert_query_gives_same_result(query, sort_columns=None, **kwargs):
return _assert_query_gives_same_result


@pytest.fixture()
def gpu_cluster():
if LocalCUDACluster is None:
pytest.skip("dask_cuda not installed")
return None

cluster = LocalCUDACluster(protocol="tcp")
yield cluster
cluster.close()


@pytest.fixture()
def gpu_client(gpu_cluster):
if gpu_cluster:
client = Client(gpu_cluster)
yield client
client.close()


@pytest.fixture(scope="session", autouse=True)
def setup_dask_client():
"""Setup a dask client if requested"""
Expand Down
90 changes: 89 additions & 1 deletion tests/integration/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@
import pytest
from dask.datasets import timeseries

from tests.integration.fixtures import skip_if_external_scheduler

try:
import cuml
import dask_cudf
import xgboost
except ImportError:
cuml = None
xgboost = None
dask_cudf = None

pytest.importorskip("dask_ml")


Expand Down Expand Up @@ -40,7 +51,16 @@ def training_df(c):
df = timeseries(freq="1d").reset_index(drop=True)
c.create_table("timeseries", df, persist=True)

return training_df
return None


@pytest.fixture()
def gpu_training_df(c):
if dask_cudf:
df = timeseries(freq="1d").reset_index(drop=True)
df = dask_cudf.from_dask_dataframe(df)
c.create_table("timeseries", input_table=df)
return None


def test_training_and_prediction(c, training_df):
Expand All @@ -61,6 +81,74 @@ def test_training_and_prediction(c, training_df):
check_trained_model(c)


@pytest.mark.gpu
def test_cuml_training_and_prediction(c, gpu_training_df):
model_query = """
CREATE OR REPLACE MODEL my_model WITH (
model_class = 'cuml.linear_model.LogisticRegression',
wrap_predict = True,
wrap_fit = False,
target_column = 'target'
) AS (
SELECT x, y, x*y > 0 AS target
FROM timeseries
)
"""
c.sql(model_query)
check_trained_model(c)


@pytest.mark.gpu
@skip_if_external_scheduler
def test_dask_cuml_training_and_prediction(c, gpu_training_df, gpu_client):

model_query = """
CREATE OR REPLACE MODEL my_model WITH (
model_class = 'cuml.dask.linear_model.LinearRegression',
target_column = 'target'
) AS (
SELECT x, y, x*y AS target
FROM timeseries
)
"""
c.sql(model_query)
check_trained_model(c)


@skip_if_external_scheduler
@pytest.mark.gpu
def test_dask_xgboost_training_prediction(c, gpu_training_df, gpu_client):
model_query = """
CREATE OR REPLACE MODEL my_model WITH (
model_class = 'xgboost.dask.DaskXGBRegressor',
target_column = 'target',
tree_method= 'gpu_hist'
) AS (
SELECT x, y, x*y AS target
FROM timeseries
)
"""
c.sql(model_query)
check_trained_model(c)


@pytest.mark.gpu
def test_xgboost_training_prediction(c, gpu_training_df):
model_query = """
CREATE OR REPLACE MODEL my_model WITH (
model_class = 'xgboost.XGBRegressor',
wrap_predict = True,
target_column = 'target',
tree_method= 'gpu_hist'
) AS (
SELECT x, y, x*y AS target
FROM timeseries
)
"""
c.sql(model_query)
check_trained_model(c)


def test_clustering_and_prediction(c, training_df):
c.sql(
"""
Expand Down