Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions dask_sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ def sql(
sql: str,
return_futures: bool = True,
dataframes: Dict[str, Union[dd.DataFrame, pd.DataFrame]] = None,
gpu: bool = False,
) -> Union[dd.DataFrame, pd.DataFrame]:
"""
Query the registered tables with the given SQL.
Expand All @@ -443,14 +444,16 @@ def sql(
Defaults to returning the dask dataframe.
dataframes (:obj:`Dict[str, dask.dataframe.DataFrame]`): additional Dask or pandas dataframes
to register before executing this query
gpu (:obj:`bool`): Whether or not to load the additional Dask or pandas dataframes (if any) on GPU;
requires cuDF / dask-cuDF if enabled. Defaults to False.

Returns:
:obj:`dask.dataframe.DataFrame`: the created data frame of this query.

"""
if dataframes is not None:
for df_name, df in dataframes.items():
self.create_table(df_name, df)
self.create_table(df_name, df, gpu=gpu)

rel, select_names, _ = self._get_ral(sql)

Expand All @@ -477,7 +480,10 @@ def sql(
return df

def explain(
self, sql: str, dataframes: Dict[str, Union[dd.DataFrame, pd.DataFrame]] = None
self,
sql: str,
dataframes: Dict[str, Union[dd.DataFrame, pd.DataFrame]] = None,
gpu: bool = False,
) -> str:
"""
Return the stringified relational algebra that this query will produce
Expand All @@ -492,14 +498,16 @@ def explain(
sql (:obj:`str`): The query string to use
dataframes (:obj:`Dict[str, dask.dataframe.DataFrame]`): additional Dask or pandas dataframes
to register before executing this query
gpu (:obj:`bool`): Whether or not to load the additional Dask or pandas dataframes (if any) on GPU;
requires cuDF / dask-cuDF if enabled. Defaults to False.

Returns:
:obj:`str`: a description of the created relational algebra.

"""
if dataframes is not None:
for df_name, df in dataframes.items():
self.create_table(df_name, df)
self.create_table(df_name, df, gpu=gpu)

_, _, rel_string = self._get_ral(sql)
return rel_string
Expand Down
11 changes: 4 additions & 7 deletions tests/unit/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,8 @@ def test_explain(gpu):

data_frame = dd.from_pandas(pd.DataFrame({"a": [1, 2, 3]}), npartitions=1)

if gpu:
data_frame = dask_cudf.from_dask_dataframe(data_frame)

sql_string = c.explain(
"SELECT * FROM other_df", dataframes={"other_df": data_frame}
"SELECT * FROM other_df", dataframes={"other_df": data_frame}, gpu=gpu
)

assert sql_string.startswith(
Expand All @@ -107,9 +104,9 @@ def test_sql(gpu):
assert isinstance(result, pd.DataFrame if not gpu else cudf.DataFrame)
dd.assert_eq(result, data_frame)

if gpu:
data_frame = dask_cudf.from_dask_dataframe(data_frame)
result = c.sql("SELECT * FROM other_df", dataframes={"other_df": data_frame})
result = c.sql(
"SELECT * FROM other_df", dataframes={"other_df": data_frame}, gpu=gpu
)
assert isinstance(result, dd.DataFrame if not gpu else dask_cudf.DataFrame)
dd.assert_eq(result, data_frame)

Expand Down