diff --git a/tests/unit/test_context.py b/tests/unit/test_context.py index e7f5e0ef9..e2bd0f471 100644 --- a/tests/unit/test_context.py +++ b/tests/unit/test_context.py @@ -40,13 +40,10 @@ def test_deprecation_warning(gpu): c = Context() data_frame = dd.from_pandas(pd.DataFrame(), npartitions=1) - if gpu: - data_frame = dask_cudf.from_dask_dataframe(data_frame) - with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - c.register_dask_table(data_frame, "table") + c.register_dask_table(data_frame, "table", gpu=gpu) assert len(w) == 1 assert issubclass(w[-1].category, DeprecationWarning)