From f774264e261919e712dc376721c1769b0f449768 Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Mon, 29 Nov 2021 14:32:25 -0500 Subject: [PATCH 1/2] Circumvent deep copy of context in PredictModelPlugin --- dask_sql/physical/rel/custom/predict.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/dask_sql/physical/rel/custom/predict.py b/dask_sql/physical/rel/custom/predict.py index 3a1650d19..e2e8c6e03 100644 --- a/dask_sql/physical/rel/custom/predict.py +++ b/dask_sql/physical/rel/custom/predict.py @@ -1,4 +1,3 @@ -import copy import logging import uuid from typing import TYPE_CHECKING @@ -88,8 +87,7 @@ def convert( else: # pragma: no cover continue - tmp_context = copy.deepcopy(context) - tmp_context.create_table(temporary_table, predicted_df) + context.create_table(temporary_table, predicted_df) sql_ns = org.apache.calcite.sql pos = sql.getParserPosition() @@ -112,8 +110,9 @@ def convert( None, # hints ) - sql_outer_query = tmp_context._to_sql_string(outer_select) - df = tmp_context.sql(sql_outer_query) + sql_outer_query = context._to_sql_string(outer_select) + df = context.sql(sql_outer_query) + context.drop_table(temporary_table) cc = ColumnContainer(df.columns) dc = DataContainer(df, cc) From 25c121d4e803978e93dc1b6514416a69c413e9d0 Mon Sep 17 00:00:00 2001 From: Charles Blackmon-Luca <20627856+charlesbluca@users.noreply.github.com> Date: Mon, 29 Nov 2021 15:35:14 -0500 Subject: [PATCH 2/2] Check that context is not altered by predict temporary table --- tests/integration/test_model.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/integration/test_model.py b/tests/integration/test_model.py index 6a3a22dbf..7a7faf994 100644 --- a/tests/integration/test_model.py +++ b/tests/integration/test_model.py @@ -26,8 +26,11 @@ def check_trained_model(c, model_name=None): ) """ + tables_before = c.schema["root"].tables.keys() result_df = c.sql(sql).compute() + # assert that there are no additional tables in context from prediction + assert tables_before == c.schema["root"].tables.keys() assert "target" in result_df.columns assert len(result_df["target"]) > 0