diff --git a/dask_sql/context.py b/dask_sql/context.py index a18c5daf6..446f1bda1 100644 --- a/dask_sql/context.py +++ b/dask_sql/context.py @@ -882,7 +882,7 @@ def _get_tables_from_stack(self): for var_name, variable in frame_info.frame.f_locals.items(): if var_name.startswith("_"): continue - if not isinstance(variable, (pd.DataFrame, dd.DataFrame)): + if not dd.utils.is_dataframe_like(variable): continue # only set them if not defined in an inner context diff --git a/tests/unit/test_context.py b/tests/unit/test_context.py index 34f25db4e..f331eff08 100644 --- a/tests/unit/test_context.py +++ b/tests/unit/test_context.py @@ -168,19 +168,7 @@ def assert_correct_output(gpu): @pytest.mark.parametrize( - "gpu", - [ - False, - pytest.param( - True, - marks=( - pytest.mark.gpu, - pytest.mark.xfail( - reason="GPU tables aren't picked up by _get_tables_from_stack" - ), - ), - ), - ], + "gpu", [False, pytest.param(True, marks=pytest.mark.gpu),], ) def test_tables_from_stack(gpu): c = Context()