Skip to content
21 changes: 20 additions & 1 deletion dask_sql/input_utils/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@

from dask_sql.input_utils.base import BaseInputPlugin

try:
import dask_cudf
except ImportError:
dask_cudf = None


class DaskInputPlugin(BaseInputPlugin):
"""Input Plugin for Dask DataFrames, just keeping them"""
Expand All @@ -13,5 +18,19 @@ def is_correct_input(
):
return isinstance(input_item, dd.DataFrame) or format == "dask"

def to_dc(self, input_item: Any, table_name: str, format: str = None, **kwargs):
def to_dc(
self,
input_item: Any,
table_name: str,
format: str = None,
gpu: bool = False,
**kwargs
):
if gpu: # pragma: no cover
if not dask_cudf:
raise ModuleNotFoundError(
"Setting `gpu=True` for table creation requires dask_cudf"
)
if not isinstance(input_item, dask_cudf.DataFrame):
input_item = dask_cudf.from_dask_dataframe(input_item, **kwargs)
return input_item
8 changes: 4 additions & 4 deletions dask_sql/input_utils/intake.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ def to_dc(
gpu: bool = False,
**kwargs,
):
if gpu: # pragma: no cover
raise NotImplementedError("Intake does not support gpu")

table_name = kwargs.pop("intake_table_name", table_name)
catalog_kwargs = kwargs.pop("catalog_kwargs", {})

if isinstance(input_item, str):
input_item = intake.open_catalog(input_item, **catalog_kwargs)

if gpu: # pragma: no cover
raise Exception("Intake does not support gpu")
else:
return input_item[table_name].to_dask(**kwargs)
return input_item[table_name].to_dask(**kwargs)
21 changes: 16 additions & 5 deletions dask_sql/input_utils/location.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
from distributed.client import default_client

from dask_sql.input_utils.base import BaseInputPlugin
from dask_sql.input_utils.convert import InputUtil

try:
import dask_cudf
except ImportError:
dask_cudf = None


class LocationInputPlugin(BaseInputPlugin):
Expand All @@ -23,20 +29,25 @@ def to_dc(
gpu: bool = False,
**kwargs,
):

if format == "memory":
client = default_client()
return client.get_dataset(input_item, **kwargs)
df = client.get_dataset(input_item, **kwargs)

plugin_list = InputUtil.get_plugins()

for plugin in plugin_list:
if plugin.is_correct_input(df, table_name, format, **kwargs):
return plugin.to_dc(df, table_name, format, gpu, **kwargs)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like the most sensible way to handle published datasets without duplicating code, but I am not sure if there's a cleaner way this could be done.

if not format:
_, extension = os.path.splitext(input_item)

format = extension.lstrip(".")

try:
if gpu: # pragma: no cover
import dask_cudf

if not dask_cudf:
raise ModuleNotFoundError(
"Setting `gpu=True` for table creation requires dask-cudf"
)
read_function = getattr(dask_cudf, f"read_{format}")
else:
read_function = getattr(dd, f"read_{format}")
Expand Down
29 changes: 14 additions & 15 deletions dask_sql/input_utils/pandaslike.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
import dask.dataframe as dd
import pandas as pd

from dask_sql.input_utils.base import BaseInputPlugin

try:
import cudf
except ImportError: # pragma: no cover
except ImportError:
cudf = None

from dask_sql.input_utils.base import BaseInputPlugin


class PandasLikeInputPlugin(BaseInputPlugin):
"""Input Plugin for Pandas Like DataFrames, which get converted to dask DataFrames"""

def is_correct_input(
self, input_item, table_name: str, format: str = None, **kwargs
):
is_cudf_type = cudf and isinstance(input_item, cudf.DataFrame)
return is_cudf_type or isinstance(input_item, pd.DataFrame) or format == "dask"
return (
dd.utils.is_dataframe_like(input_item)
and not isinstance(input_item, dd.DataFrame)
) or format == "dask"

def to_dc(
self,
Expand All @@ -28,15 +30,12 @@ def to_dc(
):
npartitions = kwargs.pop("npartitions", 1)
if gpu: # pragma: no cover
import dask_cudf
if not cudf:
raise ModuleNotFoundError(
"Setting `gpu=True` for table creation requires cudf"
)

if isinstance(input_item, pd.DataFrame):
return dask_cudf.from_cudf(
cudf.from_pandas(input_item), npartitions=npartitions, **kwargs,
)
else:
return dask_cudf.from_cudf(
input_item, npartitions=npartitions, **kwargs,
)
else:
return dd.from_pandas(input_item, npartitions=npartitions, **kwargs)
input_item = cudf.from_pandas(input_item)

return dd.from_pandas(input_item, npartitions=npartitions, **kwargs)
10 changes: 9 additions & 1 deletion dask_sql/input_utils/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,16 @@ def is_correct_input(
return correct_prefix

def to_dc(
self, input_item: Any, table_name: str, format: str = None, **kwargs
self,
input_item: Any,
table_name: str,
format: str = None,
gpu: bool = False,
**kwargs
): # pragma: no cover
if gpu:
raise NotImplementedError("Hive does not support gpu")

import sqlalchemy

engine_kwargs = {}
Expand Down
14 changes: 1 addition & 13 deletions tests/integration/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,7 @@ def test_create_from_csv(c, df, temporary_data_file, gpu):


@pytest.mark.parametrize(
"gpu",
[
False,
pytest.param(
True,
marks=[
pytest.mark.gpu,
pytest.mark.xfail(
reason="dataframes on memory currently aren't being converted to dask-cudf"
),
],
),
],
"gpu", [False, pytest.param(True, marks=pytest.mark.gpu),],
)
def test_cluster_memory(client, c, df, gpu):
client.publish_dataset(df=dd.from_pandas(df, npartitions=1))
Expand Down
24 changes: 2 additions & 22 deletions tests/unit/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,7 @@ def test_explain(gpu):


@pytest.mark.parametrize(
"gpu",
[
False,
pytest.param(
True,
marks=(
pytest.mark.gpu,
pytest.mark.xfail(reason="create_table(gpu=True) doesn't work"),
),
),
],
"gpu", [False, pytest.param(True, marks=pytest.mark.gpu,),],
)
def test_sql(gpu):
c = Context()
Expand All @@ -119,17 +109,7 @@ def test_sql(gpu):


@pytest.mark.parametrize(
"gpu",
[
False,
pytest.param(
True,
marks=(
pytest.mark.gpu,
pytest.mark.xfail(reason="create_table(gpu=True) doesn't work"),
),
),
],
"gpu", [False, pytest.param(True, marks=pytest.mark.gpu,),],
)
def test_input_types(temporary_data_file, gpu):
c = Context()
Expand Down