diff --git a/dask_sql/context.py b/dask_sql/context.py index c6030814c..13432dc51 100644 --- a/dask_sql/context.py +++ b/dask_sql/context.py @@ -125,6 +125,7 @@ def create_table( format: str = None, persist: bool = True, schema_name: str = None, + gpu: bool = False, **kwargs, ): """ @@ -199,6 +200,7 @@ def create_table( table_name=table_name, format=format, persist=persist, + gpu=gpu, **kwargs, ) self.schema[schema_name].tables[table_name.lower()] = dc diff --git a/dask_sql/input_utils/convert.py b/dask_sql/input_utils/convert.py index 2bf258835..5b2e8a797 100644 --- a/dask_sql/input_utils/convert.py +++ b/dask_sql/input_utils/convert.py @@ -37,6 +37,7 @@ def to_dc( table_name: str, format: str = None, persist: bool = True, + gpu: bool = False, **kwargs, ) -> DataContainer: """ @@ -45,7 +46,7 @@ def to_dc( maybe persist them to cluster memory before. """ filled_get_dask_dataframe = lambda *args: cls._get_dask_dataframe( - *args, table_name=table_name, format=format, **kwargs, + *args, table_name=table_name, format=format, gpu=gpu, **kwargs, ) if isinstance(input_item, list): @@ -60,7 +61,12 @@ def to_dc( @classmethod def _get_dask_dataframe( - cls, input_item: InputType, table_name: str, format: str = None, **kwargs, + cls, + input_item: InputType, + table_name: str, + format: str = None, + gpu: bool = False, + **kwargs, ): plugin_list = cls.get_plugins() @@ -69,7 +75,7 @@ def _get_dask_dataframe( input_item, table_name=table_name, format=format, **kwargs ): return plugin.to_dc( - input_item, table_name=table_name, format=format, **kwargs + input_item, table_name=table_name, format=format, gpu=gpu, **kwargs ) raise ValueError(f"Do not understand the input type {type(input_item)}") diff --git a/dask_sql/input_utils/intake.py b/dask_sql/input_utils/intake.py index 241f1de33..2204ae855 100644 --- a/dask_sql/input_utils/intake.py +++ b/dask_sql/input_utils/intake.py @@ -18,11 +18,21 @@ def is_correct_input( isinstance(input_item, intake.catalog.Catalog) or format == "intake" ) - def to_dc(self, input_item: Any, table_name: str, format: str = None, **kwargs): + def to_dc( + self, + input_item: Any, + table_name: str, + format: str = None, + gpu: bool = False, + **kwargs + ): table_name = kwargs.pop("intake_table_name", table_name) catalog_kwargs = kwargs.pop("catalog_kwargs", {}) if isinstance(input_item, str): input_item = intake.open_catalog(input_item, **catalog_kwargs) - return input_item[table_name].to_dask(**kwargs) + if gpu: + raise Exception("Intake does not support gpu") + else: + return input_item[table_name].to_dask(**kwargs) diff --git a/dask_sql/input_utils/location.py b/dask_sql/input_utils/location.py index 7d1ff1067..03e606176 100644 --- a/dask_sql/input_utils/location.py +++ b/dask_sql/input_utils/location.py @@ -15,7 +15,14 @@ def is_correct_input( ): return isinstance(input_item, str) - def to_dc(self, input_item: Any, table_name: str, format: str = None, **kwargs): + def to_dc( + self, + input_item: Any, + table_name: str, + format: str = None, + gpu: bool = False, + **kwargs, + ): if format == "memory": client = default_client() @@ -27,7 +34,12 @@ def to_dc(self, input_item: Any, table_name: str, format: str = None, **kwargs): format = extension.lstrip(".") try: - read_function = getattr(dd, f"read_{format}") + if gpu: + import dask_cudf + + read_function = getattr(dask_cudf, f"read_{format}") + else: + read_function = getattr(dd, f"read_{format}") except AttributeError: raise AttributeError(f"Can not read files of format {format}") diff --git a/dask_sql/input_utils/pandas.py b/dask_sql/input_utils/pandas.py index bcef06c51..14767357b 100644 --- a/dask_sql/input_utils/pandas.py +++ b/dask_sql/input_utils/pandas.py @@ -12,6 +12,21 @@ def is_correct_input( ): return isinstance(input_item, pd.DataFrame) or format == "dask" - def to_dc(self, input_item, table_name: str, format: str = None, **kwargs): + def to_dc( + self, + input_item, + table_name: str, + format: str = None, + gpu: bool = False, + **kwargs, + ): npartitions = kwargs.pop("npartitions", 1) - return dd.from_pandas(input_item, npartitions=npartitions, **kwargs) + if gpu: + import cudf + import dask_cudf + + return dask_cudf.from_cudf( + cudf.from_pandas(input_item), npartitions=npartitions, **kwargs, + ) + else: + return dd.from_pandas(input_item, npartitions=npartitions, **kwargs) diff --git a/dask_sql/physical/rel/custom/create_table.py b/dask_sql/physical/rel/custom/create_table.py index 6151b3e17..d459b4849 100644 --- a/dask_sql/physical/rel/custom/create_table.py +++ b/dask_sql/physical/rel/custom/create_table.py @@ -62,11 +62,13 @@ def convert( except KeyError: raise AttributeError("Parameters must include a 'location' parameter.") + gpu = kwargs.pop("gpu", False) context.create_table( table_name, location, format=format, persist=persist, schema_name=schema_name, + gpu=gpu, **kwargs, )