diff --git a/dask_sql/context.py b/dask_sql/context.py index 8224fda42..6d84b2436 100644 --- a/dask_sql/context.py +++ b/dask_sql/context.py @@ -111,7 +111,7 @@ def __init__(self): RexConverter.add_plugin_class(core.RexLiteralPlugin, replace=False) InputUtil.add_plugin_class(input_utils.DaskInputPlugin, replace=False) - InputUtil.add_plugin_class(input_utils.PandasInputPlugin, replace=False) + InputUtil.add_plugin_class(input_utils.PandasLikeInputPlugin, replace=False) InputUtil.add_plugin_class(input_utils.HiveInputPlugin, replace=False) InputUtil.add_plugin_class(input_utils.IntakeCatalogInputPlugin, replace=False) InputUtil.add_plugin_class(input_utils.SqlalchemyHiveInputPlugin, replace=False) diff --git a/dask_sql/input_utils/__init__.py b/dask_sql/input_utils/__init__.py index 4f8399280..5aa588c51 100644 --- a/dask_sql/input_utils/__init__.py +++ b/dask_sql/input_utils/__init__.py @@ -3,7 +3,7 @@ from .hive import HiveInputPlugin from .intake import IntakeCatalogInputPlugin from .location import LocationInputPlugin -from .pandas import PandasInputPlugin +from .pandaslike import PandasLikeInputPlugin from .sqlalchemy import SqlalchemyHiveInputPlugin __all__ = [ @@ -13,6 +13,6 @@ HiveInputPlugin, IntakeCatalogInputPlugin, LocationInputPlugin, - PandasInputPlugin, + PandasLikeInputPlugin, SqlalchemyHiveInputPlugin, ] diff --git a/dask_sql/input_utils/convert.py b/dask_sql/input_utils/convert.py index 2bf258835..7a3365c9f 100644 --- a/dask_sql/input_utils/convert.py +++ b/dask_sql/input_utils/convert.py @@ -14,7 +14,11 @@ dd.DataFrame, pd.DataFrame, str, - Union["sqlalchemy.engine.base.Connection", "hive.Cursor"], + Union[ + "sqlalchemy.engine.base.Connection", + "hive.Cursor", + "cudf.core.dataframe.DataFrame", + ], ] diff --git a/dask_sql/input_utils/pandas.py b/dask_sql/input_utils/pandaslike.py similarity index 52% rename from dask_sql/input_utils/pandas.py rename to dask_sql/input_utils/pandaslike.py index bcef06c51..664cfc19c 100644 --- a/dask_sql/input_utils/pandas.py +++ b/dask_sql/input_utils/pandaslike.py @@ -1,16 +1,22 @@ import dask.dataframe as dd import pandas as pd +try: + import cudf +except ImportError: # pragma: no cover + cudf = None + from dask_sql.input_utils.base import BaseInputPlugin -class PandasInputPlugin(BaseInputPlugin): - """Input Plugin for Pandas DataFrames, which get converted to dask DataFrames""" +class PandasLikeInputPlugin(BaseInputPlugin): + """Input Plugin for Pandas Like DataFrames, which get converted to dask DataFrames""" def is_correct_input( self, input_item, table_name: str, format: str = None, **kwargs ): - return isinstance(input_item, pd.DataFrame) or format == "dask" + is_cudf_type = cudf and isinstance(input_item, cudf.DataFrame) + return is_cudf_type or isinstance(input_item, pd.DataFrame) or format == "dask" def to_dc(self, input_item, table_name: str, format: str = None, **kwargs): npartitions = kwargs.pop("npartitions", 1)