diff --git a/dask_sql/physical/rel/custom/predict.py b/dask_sql/physical/rel/custom/predict.py index 3a1650d19..e2e8c6e03 100644 --- a/dask_sql/physical/rel/custom/predict.py +++ b/dask_sql/physical/rel/custom/predict.py @@ -1,4 +1,3 @@ -import copy import logging import uuid from typing import TYPE_CHECKING @@ -88,8 +87,7 @@ def convert( else: # pragma: no cover continue - tmp_context = copy.deepcopy(context) - tmp_context.create_table(temporary_table, predicted_df) + context.create_table(temporary_table, predicted_df) sql_ns = org.apache.calcite.sql pos = sql.getParserPosition() @@ -112,8 +110,9 @@ def convert( None, # hints ) - sql_outer_query = tmp_context._to_sql_string(outer_select) - df = tmp_context.sql(sql_outer_query) + sql_outer_query = context._to_sql_string(outer_select) + df = context.sql(sql_outer_query) + context.drop_table(temporary_table) cc = ColumnContainer(df.columns) dc = DataContainer(df, cc) diff --git a/tests/integration/test_model.py b/tests/integration/test_model.py index 6a3a22dbf..7a7faf994 100644 --- a/tests/integration/test_model.py +++ b/tests/integration/test_model.py @@ -26,8 +26,11 @@ def check_trained_model(c, model_name=None): ) """ + tables_before = c.schema["root"].tables.keys() result_df = c.sql(sql).compute() + # assert that there are no additional tables in context from prediction + assert tables_before == c.schema["root"].tables.keys() assert "target" in result_df.columns assert len(result_df["target"]) > 0