Skip to content
6 changes: 6 additions & 0 deletions dask_planner/src/sql/logical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::sql::types::rel_data_type::RelDataType;
use crate::sql::types::rel_data_type_field::RelDataTypeField;

mod aggregate;
mod create_memory_table;
mod cross_join;
mod empty_relation;
mod explain;
Expand Down Expand Up @@ -123,6 +124,11 @@ impl PyLogicalPlan {
to_py_plan(self.current_node.as_ref())
}

/// LogicalPlan::CreateMemoryTable as PyCreateMemoryTable
pub fn create_memory_table(&self) -> PyResult<create_memory_table::PyCreateMemoryTable> {
to_py_plan(self.current_node.as_ref())
}

/// Gets the "input" for the current LogicalPlan
pub fn get_inputs(&mut self) -> PyResult<Vec<PyLogicalPlan>> {
let mut py_inputs: Vec<PyLogicalPlan> = Vec::new();
Expand Down
99 changes: 99 additions & 0 deletions dask_planner/src/sql/logical/create_memory_table.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
use crate::sql::exceptions::py_type_err;
use crate::sql::logical::PyLogicalPlan;
use datafusion_expr::{logical_plan::CreateMemoryTable, logical_plan::CreateView, LogicalPlan};
use pyo3::prelude::*;

#[pyclass(name = "CreateMemoryTable", module = "dask_planner", subclass)]
#[derive(Clone)]
pub struct PyCreateMemoryTable {
create_memory_table: Option<CreateMemoryTable>,
create_view: Option<CreateView>,
}

#[pymethods]
impl PyCreateMemoryTable {
#[pyo3(name = "getName")]
pub fn get_name(&self) -> PyResult<String> {
Ok(match &self.create_memory_table {
Some(create_memory_table) => create_memory_table.name.clone(),
None => match &self.create_view {
Some(create_view) => create_view.name.clone(),
None => {
return Err(py_type_err(
"Encountered a non CreateMemoryTable/CreateView type in get_input",
))
}
},
})
}

#[pyo3(name = "getInput")]
pub fn get_input(&self) -> PyResult<PyLogicalPlan> {
Ok(match &self.create_memory_table {
Some(create_memory_table) => PyLogicalPlan {
original_plan: (*create_memory_table.input).clone(),
current_node: None,
},
None => match &self.create_view {
Some(create_view) => PyLogicalPlan {
original_plan: (*create_view.input).clone(),
current_node: None,
},
None => {
return Err(py_type_err(
"Encountered a non CreateMemoryTable/CreateView type in get_input",
))
}
},
})
}

#[pyo3(name = "getIfNotExists")]
pub fn get_if_not_exists(&self) -> PyResult<bool> {
Ok(match &self.create_memory_table {
Some(create_memory_table) => create_memory_table.if_not_exists,
None => false, // TODO: in the future we may want to set this based on dialect
})
}

#[pyo3(name = "getOrReplace")]
pub fn get_or_replace(&self) -> PyResult<bool> {
Ok(match &self.create_memory_table {
Some(create_memory_table) => create_memory_table.or_replace,
None => match &self.create_view {
Some(create_view) => create_view.or_replace,
None => {
return Err(py_type_err(
"Encountered a non CreateMemoryTable/CreateView type in get_input",
))
}
},
})
}

#[pyo3(name = "isTable")]
pub fn is_table(&self) -> PyResult<bool> {
Ok(match &self.create_memory_table {
Some(_) => true,
None => false,
})
}
}

impl TryFrom<LogicalPlan> for PyCreateMemoryTable {
type Error = PyErr;

fn try_from(logical_plan: LogicalPlan) -> Result<Self, Self::Error> {
Ok(match logical_plan {
LogicalPlan::CreateMemoryTable(create_memory_table) => PyCreateMemoryTable {
create_memory_table: Some(create_memory_table),
create_view: None,
},
LogicalPlan::CreateView(create_view) => PyCreateMemoryTable {
create_memory_table: None,
create_view: Some(create_view),
},
_ => return Err(py_type_err("unexpected plan")),
})
}
}
76 changes: 39 additions & 37 deletions dask_sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from dask_sql.utils import OptimizationException, ParsingException

if TYPE_CHECKING:
from dask_planner.rust import Expression
from dask_planner.rust import Expression, LogicalPlan

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -490,40 +490,9 @@ def sql(
for df_name, df in dataframes.items():
self.create_table(df_name, df, gpu=gpu)

rel, select_fields, _ = self._get_ral(sql)

dc = RelConverter.convert(rel, context=self)

if rel.get_current_node_type() == "Explain":
return dc
if dc is None:
return

if select_fields:
# Use FQ name if not unique and simple name if it is unique. If a join contains the same column
# names the output col is prepended with the fully qualified column name
field_counts = Counter([field.getName() for field in select_fields])
select_names = [
field.getQualifiedName()
if field_counts[field.getName()] > 1
else field.getName()
for field in select_fields
]

cc = dc.column_container
cc = cc.rename(
{
df_col: select_name
for df_col, select_name in zip(cc.columns, select_names)
}
)
dc = DataContainer(dc.df, cc)

df = dc.assign()
if not return_futures:
df = df.compute()
rel, _ = self._get_ral(sql)

return df
return self._compute_table_from_rel(rel, return_futures)

def explain(
self,
Expand Down Expand Up @@ -555,7 +524,7 @@ def explain(
for df_name, df in dataframes.items():
self.create_table(df_name, df, gpu=gpu)

_, _, rel_string = self._get_ral(sql)
_, rel_string = self._get_ral(sql)
return rel_string

def visualize(self, sql: str, filename="mydask.png") -> None: # pragma: no cover
Expand Down Expand Up @@ -817,7 +786,6 @@ def _get_ral(self, sql):
sqlTree = self.context.parse_sql(sql)
logger.debug(f"_get_ral -> sqlTree: {sqlTree}")

select_names = None
rel = sqlTree

# TODO: Need to understand if this list here is actually needed? For now just use the first entry.
Expand Down Expand Up @@ -845,10 +813,44 @@ def _get_ral(self, sql):
logger.debug(f"_get_ral -> LogicalPlan: {rel}")
logger.debug(f"Extracted relational algebra:\n {rel_string}")

return rel, rel_string

def _compute_table_from_rel(self, rel: "LogicalPlan", return_futures: bool = True):
dc = RelConverter.convert(rel, context=self)

# Optimization might remove some alias projects. Make sure to keep them here.
select_names = [field for field in rel.getRowType().getFieldList()]

return rel, select_names, rel_string
if rel.get_current_node_type() == "Explain":
return dc
if dc is None:
return

if select_names:
# Use FQ name if not unique and simple name if it is unique. If a join contains the same column
# names the output col is prepended with the fully qualified column name
field_counts = Counter([field.getName() for field in select_names])
select_names = [
field.getQualifiedName()
if field_counts[field.getName()] > 1
else field.getName()
for field in select_names
]

cc = dc.column_container
cc = cc.rename(
{
df_col: select_name
for df_col, select_name in zip(cc.columns, select_names)
}
)
dc = DataContainer(dc.df, cc)

df = dc.assign()
if not return_futures:
df = df.compute()

return df

def _get_tables_from_stack(self):
"""Helper function to return all dask/pandas dataframes from the calling stack"""
Expand Down
37 changes: 22 additions & 15 deletions dask_sql/physical/rel/custom/create_table_as.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

if TYPE_CHECKING:
import dask_sql
from dask_sql.java import org
from dask_planner import LogicalPlan

logger = logging.getLogger(__name__)

Expand All @@ -32,29 +32,36 @@ class CreateTableAsPlugin(BaseRelPlugin):
Nothing is returned.
"""

class_name = "com.dask.sql.parser.SqlCreateTableAs"
class_name = ["CreateMemoryTable", "CreateView"]

def convert(
self, sql: "org.apache.calcite.sql.SqlNode", context: "dask_sql.Context"
) -> DataContainer:
schema_name, table_name = context.fqn(sql.getTableName())
def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer:
# Rust create_memory_table instance handle
create_memory_table = rel.create_memory_table()

# can we avoid hardcoding the schema name?
schema_name, table_name = context.schema_name, create_memory_table.getName()

if table_name in context.schema[schema_name].tables:
if sql.getIfNotExists():
if create_memory_table.getIfNotExists():
return
elif not sql.getReplace():
elif not create_memory_table.getOrReplace():
raise RuntimeError(
f"A table with the name {table_name} is already present."
)

sql_select = sql.getSelect()
persist = bool(sql.isPersist())
input_rel = create_memory_table.getInput()

# TODO: we currently always persist for CREATE TABLE AS and never persist for CREATE VIEW AS;
# should this be configured by the user? https://github.com/dask-contrib/dask-sql/issues/269
persist = create_memory_table.isTable()

logger.debug(
f"Creating new table with name {table_name} and query {sql_select}"
f"Creating new table with name {table_name} and logical plan {input_rel}"
)

sql_select_query = context._to_sql_string(sql_select)
df = context.sql(sql_select_query)

context.create_table(table_name, df, persist=persist, schema_name=schema_name)
context.create_table(
table_name,
context._compute_table_from_rel(input_rel),
persist=persist,
schema_name=schema_name,
)
1 change: 0 additions & 1 deletion tests/integration/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ def test_wrong_create(c):
)


@pytest.mark.skip(reason="WIP DataFusion")
def test_create_from_query(c, df):
c.sql(
"""
Expand Down