diff --git a/dask_sql/context.py b/dask_sql/context.py index f277c4da0..837f7cd1c 100644 --- a/dask_sql/context.py +++ b/dask_sql/context.py @@ -8,6 +8,7 @@ import pandas as pd from dask import config as dask_config from dask.base import optimize +from dask.utils_test import hlg_layer from dask_planner.rust import ( DaskSchema, @@ -247,6 +248,15 @@ def create_table( if type(input_table) == str: dc.filepath = input_table self.schema[schema_name].filepaths[table_name.lower()] = input_table + elif hasattr(input_table, "dask") and dd.utils.is_dataframe_like(input_table): + try: + dask_filepath = hlg_layer( + input_table.dask, "read-parquet" + ).creation_info["args"][0] + dc.filepath = dask_filepath + self.schema[schema_name].filepaths[table_name.lower()] = dask_filepath + except KeyError: + logger.debug("Expected 'read-parquet' layer") if parquet_statistics and not statistics: statistics = parquet_statistics(dc.df) diff --git a/tests/unit/test_context.py b/tests/unit/test_context.py index 86d78086c..81bd5f23b 100644 --- a/tests/unit/test_context.py +++ b/tests/unit/test_context.py @@ -312,29 +312,15 @@ def test_alter_table(c, df_simple): del c.schema[c.schema_name].tables["physics"] -def test_filepath(tmpdir): +def test_filepath(tmpdir, parquet_ddf): c = Context() - parquet_path = os.path.join(tmpdir, "parquet") - parquet_df = pd.DataFrame( - { - "a": [1, 2, 3] * 5, - "b": range(15), - "c": ["A"] * 15, - "d": [ - pd.Timestamp("2013-08-01 23:00:00"), - pd.Timestamp("2014-09-01 23:00:00"), - pd.Timestamp("2015-10-01 23:00:00"), - ] - * 5, - "index": range(15), - }, - ) - dd.from_pandas(parquet_df, npartitions=3).to_parquet(parquet_path) - c.create_table("parquet_df", parquet_path, format="parquet") - assert c.schema["root"].tables["parquet_df"].filepath == parquet_path - assert c.schema["root"].filepaths["parquet_df"] == parquet_path + # Create table with string (Parquet filepath) + c.create_table("parquet_ddf", parquet_path, format="parquet") + + assert c.schema["root"].tables["parquet_ddf"].filepath == parquet_path + assert c.schema["root"].filepaths["parquet_ddf"] == parquet_path df = pd.DataFrame({"a": [2, 1, 2, 3], "b": [3, 3, 1, 3]}) c.create_table("df", df) @@ -342,3 +328,14 @@ def test_filepath(tmpdir): assert c.schema["root"].tables["df"].filepath is None with pytest.raises(KeyError): c.schema["root"].filepaths["df"] + + +def test_ddf_filepath(tmpdir, parquet_ddf): + c = Context() + parquet_path = os.path.join(tmpdir, "parquet") + + # Create table with Dask DataFrame (created from read_parquet) + c.create_table("parquet_ddf", parquet_ddf) + + assert c.schema["root"].tables["parquet_ddf"].filepath == parquet_path + assert c.schema["root"].filepaths["parquet_ddf"] == parquet_path