Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions dask_sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pandas as pd
from dask import config as dask_config
from dask.base import optimize
from dask.utils_test import hlg_layer

from dask_planner.rust import (
DaskSchema,
Expand Down Expand Up @@ -247,6 +248,15 @@ def create_table(
if type(input_table) == str:
dc.filepath = input_table
self.schema[schema_name].filepaths[table_name.lower()] = input_table
elif "dask" in str(type(input_table)):
try:
dask_filepath = hlg_layer(
input_table.dask, "read-parquet"
).creation_info["args"][0]
dc.filepath = dask_filepath
self.schema[schema_name].filepaths[table_name.lower()] = dask_filepath
except KeyError:
logger.debug("Expected 'read-parquet' layer")

if parquet_statistics and not statistics:
statistics = parquet_statistics(dc.df)
Expand Down
28 changes: 28 additions & 0 deletions tests/unit/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ def test_filepath(tmpdir):
},
)
dd.from_pandas(parquet_df, npartitions=3).to_parquet(parquet_path)
# Create table with string (Parquet filepath)
c.create_table("parquet_df", parquet_path, format="parquet")

assert c.schema["root"].tables["parquet_df"].filepath == parquet_path
Expand All @@ -342,3 +343,30 @@ def test_filepath(tmpdir):
assert c.schema["root"].tables["df"].filepath is None
with pytest.raises(KeyError):
c.schema["root"].filepaths["df"]


def test_ddf_filepath(tmpdir):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the parquet_ddf fixture work for the test here, or does it explicitly need to be redefined here?

c = Context()

parquet_path = os.path.join(tmpdir, "parquet")
parquet_df = pd.DataFrame(
{
"a": [1, 2, 3] * 5,
"b": range(15),
"c": ["A"] * 15,
"d": [
pd.Timestamp("2013-08-01 23:00:00"),
pd.Timestamp("2014-09-01 23:00:00"),
pd.Timestamp("2015-10-01 23:00:00"),
]
* 5,
"index": range(15),
},
)
dd.from_pandas(parquet_df, npartitions=3).to_parquet(parquet_path)
# Create table with Dask DataFrame (created from read_parquet)
ddf = dd.read_parquet(parquet_path)
c.create_table("parquet_df", ddf)

assert c.schema["root"].tables["parquet_df"].filepath == parquet_path
assert c.schema["root"].filepaths["parquet_df"] == parquet_path