Skip to content

Commit 031c04c

Browse files
authored
Move / minimize number of cudf / dask-cudf imports (#480)
* Move / minimize number of cudf / dask-cudf imports * Add tests for GPU-related errors * Fix unbound local error * Fix ddf value error
1 parent 95b0dd0 commit 031c04c

File tree

6 files changed

+42
-33
lines changed

6 files changed

+42
-33
lines changed

dask_sql/input_utils/dask.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,6 @@
44

55
from dask_sql.input_utils.base import BaseInputPlugin
66

7-
try:
8-
import dask_cudf
9-
except ImportError:
10-
dask_cudf = None
11-
127

138
class DaskInputPlugin(BaseInputPlugin):
149
"""Input Plugin for Dask DataFrames, just keeping them"""
@@ -27,7 +22,9 @@ def to_dc(
2722
**kwargs
2823
):
2924
if gpu: # pragma: no cover
30-
if not dask_cudf:
25+
try:
26+
import dask_cudf
27+
except ImportError:
3128
raise ModuleNotFoundError(
3229
"Setting `gpu=True` for table creation requires dask_cudf"
3330
)

dask_sql/input_utils/location.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,6 @@
77
from dask_sql.input_utils.base import BaseInputPlugin
88
from dask_sql.input_utils.convert import InputUtil
99

10-
try:
11-
import dask_cudf
12-
except ImportError:
13-
dask_cudf = None
14-
1510

1611
class LocationInputPlugin(BaseInputPlugin):
1712
"""Input Plugin for everything, which can be read in from a file (on disk, remote etc.)"""
@@ -44,7 +39,9 @@ def to_dc(
4439
format = extension.lstrip(".")
4540
try:
4641
if gpu: # pragma: no cover
47-
if not dask_cudf:
42+
try:
43+
import dask_cudf
44+
except ImportError:
4845
raise ModuleNotFoundError(
4946
"Setting `gpu=True` for table creation requires dask-cudf"
5047
)

dask_sql/input_utils/pandaslike.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,6 @@
33

44
from dask_sql.input_utils.base import BaseInputPlugin
55

6-
try:
7-
import cudf
8-
except ImportError:
9-
cudf = None
10-
116

127
class PandasLikeInputPlugin(BaseInputPlugin):
138
"""Input Plugin for Pandas Like DataFrames, which get converted to dask DataFrames"""
@@ -30,7 +25,9 @@ def to_dc(
3025
):
3126
npartitions = kwargs.pop("npartitions", 1)
3227
if gpu: # pragma: no cover
33-
if not cudf:
28+
try:
29+
import cudf
30+
except ImportError:
3431
raise ModuleNotFoundError(
3532
"Setting `gpu=True` for table creation requires cudf"
3633
)

dask_sql/physical/rel/logical/aggregate.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,6 @@
88
import pandas as pd
99
from dask import config as dask_config
1010

11-
try:
12-
import dask_cudf
13-
except ImportError:
14-
dask_cudf = None
15-
1611
from dask_sql.datacontainer import ColumnContainer, DataContainer
1712
from dask_sql.physical.rel.base import BaseRelPlugin
1813
from dask_sql.physical.rex.core.call import IsNullOperation
@@ -83,7 +78,7 @@ def get_supported_aggregation(self, series):
8378

8479
if pd.api.types.is_string_dtype(series.dtype):
8580
# If dask_cudf strings dtype, return built-in aggregation
86-
if dask_cudf is not None and isinstance(series, dask_cudf.Series):
81+
if "cudf" in str(series._partition_type):
8782
return built_in_aggregation
8883

8984
# With pandas StringDtype built-in aggregations work

dask_sql/physical/utils/sort.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,6 @@
66

77
from dask_sql.utils import make_pickable_without_dask_sql
88

9-
try:
10-
import dask_cudf
11-
except ImportError:
12-
dask_cudf = None
13-
149

1510
def apply_sort(
1611
df: dd.DataFrame,
@@ -35,10 +30,7 @@ def apply_sort(
3530

3631
# dask / dask-cudf don't support lists of ascending / null positions
3732
if len(sort_columns) == 1 or (
38-
dask_cudf is not None
39-
and isinstance(df, dask_cudf.DataFrame)
40-
and single_ascending
41-
and single_null_first
33+
"cudf" in str(df._partition_type) and single_ascending and single_null_first
4234
):
4335
try:
4436
return df.sort_values(

tests/integration/test_create.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,3 +363,34 @@ def test_drop(c):
363363

364364
with pytest.raises(dask_sql.utils.ParsingException):
365365
c.sql("SELECT a FROM new_table")
366+
367+
368+
def test_create_gpu_error(c, df, temporary_data_file):
369+
try:
370+
import cudf
371+
except ImportError:
372+
cudf = None
373+
374+
if cudf is not None:
375+
pytest.skip("GPU-related import errors only need to be checked on CPU")
376+
377+
with pytest.raises(ModuleNotFoundError):
378+
c.create_table("new_table", df, gpu=True)
379+
380+
with pytest.raises(ModuleNotFoundError):
381+
c.create_table("new_table", dd.from_pandas(df, npartitions=2), gpu=True)
382+
383+
df.to_csv(temporary_data_file, index=False)
384+
385+
with pytest.raises(ModuleNotFoundError):
386+
c.sql(
387+
f"""
388+
CREATE TABLE
389+
new_table
390+
WITH (
391+
location = '{temporary_data_file}',
392+
format = 'csv',
393+
gpu = True
394+
)
395+
"""
396+
)

0 commit comments

Comments
 (0)