diff --git a/dask_sql/input_utils/dask.py b/dask_sql/input_utils/dask.py index 3cfc33996..2da11e701 100644 --- a/dask_sql/input_utils/dask.py +++ b/dask_sql/input_utils/dask.py @@ -4,6 +4,11 @@ from dask_sql.input_utils.base import BaseInputPlugin +try: + import dask_cudf +except ImportError: + dask_cudf = None + class DaskInputPlugin(BaseInputPlugin): """Input Plugin for Dask DataFrames, just keeping them""" @@ -13,5 +18,19 @@ def is_correct_input( ): return isinstance(input_item, dd.DataFrame) or format == "dask" - def to_dc(self, input_item: Any, table_name: str, format: str = None, **kwargs): + def to_dc( + self, + input_item: Any, + table_name: str, + format: str = None, + gpu: bool = False, + **kwargs + ): + if gpu: # pragma: no cover + if not dask_cudf: + raise ModuleNotFoundError( + "Setting `gpu=True` for table creation requires dask_cudf" + ) + if not isinstance(input_item, dask_cudf.DataFrame): + input_item = dask_cudf.from_dask_dataframe(input_item, **kwargs) return input_item diff --git a/dask_sql/input_utils/intake.py b/dask_sql/input_utils/intake.py index 709072d67..cd88ff951 100644 --- a/dask_sql/input_utils/intake.py +++ b/dask_sql/input_utils/intake.py @@ -26,13 +26,13 @@ def to_dc( gpu: bool = False, **kwargs, ): + if gpu: # pragma: no cover + raise NotImplementedError("Intake does not support gpu") + table_name = kwargs.pop("intake_table_name", table_name) catalog_kwargs = kwargs.pop("catalog_kwargs", {}) if isinstance(input_item, str): input_item = intake.open_catalog(input_item, **catalog_kwargs) - if gpu: # pragma: no cover - raise Exception("Intake does not support gpu") - else: - return input_item[table_name].to_dask(**kwargs) + return input_item[table_name].to_dask(**kwargs) diff --git a/dask_sql/input_utils/location.py b/dask_sql/input_utils/location.py index a7f8a036f..8cb23a444 100644 --- a/dask_sql/input_utils/location.py +++ b/dask_sql/input_utils/location.py @@ -5,6 +5,12 @@ from distributed.client import default_client from dask_sql.input_utils.base import BaseInputPlugin +from dask_sql.input_utils.convert import InputUtil + +try: + import dask_cudf +except ImportError: + dask_cudf = None class LocationInputPlugin(BaseInputPlugin): @@ -23,20 +29,25 @@ def to_dc( gpu: bool = False, **kwargs, ): - if format == "memory": client = default_client() - return client.get_dataset(input_item, **kwargs) + df = client.get_dataset(input_item, **kwargs) + + plugin_list = InputUtil.get_plugins() + for plugin in plugin_list: + if plugin.is_correct_input(df, table_name, format, **kwargs): + return plugin.to_dc(df, table_name, format, gpu, **kwargs) if not format: _, extension = os.path.splitext(input_item) format = extension.lstrip(".") - try: if gpu: # pragma: no cover - import dask_cudf - + if not dask_cudf: + raise ModuleNotFoundError( + "Setting `gpu=True` for table creation requires dask-cudf" + ) read_function = getattr(dask_cudf, f"read_{format}") else: read_function = getattr(dd, f"read_{format}") diff --git a/dask_sql/input_utils/pandaslike.py b/dask_sql/input_utils/pandaslike.py index 681c2d459..32d7ff5ea 100644 --- a/dask_sql/input_utils/pandaslike.py +++ b/dask_sql/input_utils/pandaslike.py @@ -1,13 +1,13 @@ import dask.dataframe as dd import pandas as pd +from dask_sql.input_utils.base import BaseInputPlugin + try: import cudf -except ImportError: # pragma: no cover +except ImportError: cudf = None -from dask_sql.input_utils.base import BaseInputPlugin - class PandasLikeInputPlugin(BaseInputPlugin): """Input Plugin for Pandas Like DataFrames, which get converted to dask DataFrames""" @@ -15,8 +15,10 @@ class PandasLikeInputPlugin(BaseInputPlugin): def is_correct_input( self, input_item, table_name: str, format: str = None, **kwargs ): - is_cudf_type = cudf and isinstance(input_item, cudf.DataFrame) - return is_cudf_type or isinstance(input_item, pd.DataFrame) or format == "dask" + return ( + dd.utils.is_dataframe_like(input_item) + and not isinstance(input_item, dd.DataFrame) + ) or format == "dask" def to_dc( self, @@ -28,15 +30,12 @@ def to_dc( ): npartitions = kwargs.pop("npartitions", 1) if gpu: # pragma: no cover - import dask_cudf + if not cudf: + raise ModuleNotFoundError( + "Setting `gpu=True` for table creation requires cudf" + ) if isinstance(input_item, pd.DataFrame): - return dask_cudf.from_cudf( - cudf.from_pandas(input_item), npartitions=npartitions, **kwargs, - ) - else: - return dask_cudf.from_cudf( - input_item, npartitions=npartitions, **kwargs, - ) - else: - return dd.from_pandas(input_item, npartitions=npartitions, **kwargs) + input_item = cudf.from_pandas(input_item) + + return dd.from_pandas(input_item, npartitions=npartitions, **kwargs) diff --git a/dask_sql/input_utils/sqlalchemy.py b/dask_sql/input_utils/sqlalchemy.py index 9a8199bb8..c45120317 100644 --- a/dask_sql/input_utils/sqlalchemy.py +++ b/dask_sql/input_utils/sqlalchemy.py @@ -16,8 +16,16 @@ def is_correct_input( return correct_prefix def to_dc( - self, input_item: Any, table_name: str, format: str = None, **kwargs + self, + input_item: Any, + table_name: str, + format: str = None, + gpu: bool = False, + **kwargs ): # pragma: no cover + if gpu: + raise NotImplementedError("Hive does not support gpu") + import sqlalchemy engine_kwargs = {} diff --git a/tests/integration/test_create.py b/tests/integration/test_create.py index 0e118e6e6..3a893cafb 100644 --- a/tests/integration/test_create.py +++ b/tests/integration/test_create.py @@ -37,19 +37,7 @@ def test_create_from_csv(c, df, temporary_data_file, gpu): @pytest.mark.parametrize( - "gpu", - [ - False, - pytest.param( - True, - marks=[ - pytest.mark.gpu, - pytest.mark.xfail( - reason="dataframes on memory currently aren't being converted to dask-cudf" - ), - ], - ), - ], + "gpu", [False, pytest.param(True, marks=pytest.mark.gpu),], ) def test_cluster_memory(client, c, df, gpu): client.publish_dataset(df=dd.from_pandas(df, npartitions=1)) diff --git a/tests/unit/test_context.py b/tests/unit/test_context.py index 34f25db4e..15d91657e 100644 --- a/tests/unit/test_context.py +++ b/tests/unit/test_context.py @@ -85,17 +85,7 @@ def test_explain(gpu): @pytest.mark.parametrize( - "gpu", - [ - False, - pytest.param( - True, - marks=( - pytest.mark.gpu, - pytest.mark.xfail(reason="create_table(gpu=True) doesn't work"), - ), - ), - ], + "gpu", [False, pytest.param(True, marks=pytest.mark.gpu,),], ) def test_sql(gpu): c = Context() @@ -119,17 +109,7 @@ def test_sql(gpu): @pytest.mark.parametrize( - "gpu", - [ - False, - pytest.param( - True, - marks=( - pytest.mark.gpu, - pytest.mark.xfail(reason="create_table(gpu=True) doesn't work"), - ), - ), - ], + "gpu", [False, pytest.param(True, marks=pytest.mark.gpu,),], ) def test_input_types(temporary_data_file, gpu): c = Context()