Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions dask_sql/physical/rel/custom/predict.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import logging
import uuid
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -88,8 +87,7 @@ def convert(
else: # pragma: no cover
continue

tmp_context = copy.deepcopy(context)
tmp_context.create_table(temporary_table, predicted_df)
context.create_table(temporary_table, predicted_df)

sql_ns = org.apache.calcite.sql
pos = sql.getParserPosition()
Expand All @@ -112,8 +110,9 @@ def convert(
None, # hints
)

sql_outer_query = tmp_context._to_sql_string(outer_select)
df = tmp_context.sql(sql_outer_query)
Comment on lines -115 to -116
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable to me but can we should a test to ensure we have not broken anything. Like are we sure that sql_outer_query does not add anything .

Maybe we add a test to verify that the registered tables still line up and nothing temporary is left around.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, I will add that in now

sql_outer_query = context._to_sql_string(outer_select)
df = context.sql(sql_outer_query)
context.drop_table(temporary_table)

cc = ColumnContainer(df.columns)
dc = DataContainer(df, cc)
Expand Down
3 changes: 3 additions & 0 deletions tests/integration/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,11 @@ def check_trained_model(c, model_name=None):
)
"""

tables_before = c.schema["root"].tables.keys()
result_df = c.sql(sql).compute()

# assert that there are no additional tables in context from prediction
assert tables_before == c.schema["root"].tables.keys()
assert "target" in result_df.columns
assert len(result_df["target"]) > 0

Expand Down