Skip to content
6 changes: 6 additions & 0 deletions dask_planner/src/sql/logical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::sql::types::rel_data_type::RelDataType;
use crate::sql::types::rel_data_type_field::RelDataTypeField;

mod aggregate;
mod create_memory_table;
mod cross_join;
mod empty_relation;
mod explain;
Expand Down Expand Up @@ -123,6 +124,11 @@ impl PyLogicalPlan {
to_py_plan(self.current_node.as_ref())
}

/// LogicalPlan::CreateMemoryTable as PyCreateMemoryTable
pub fn create_memory_table(&self) -> PyResult<create_memory_table::PyCreateMemoryTable> {
to_py_plan(self.current_node.as_ref())
}

/// Gets the "input" for the current LogicalPlan
pub fn get_inputs(&mut self) -> PyResult<Vec<PyLogicalPlan>> {
let mut py_inputs: Vec<PyLogicalPlan> = Vec::new();
Expand Down
78 changes: 78 additions & 0 deletions dask_planner/src/sql/logical/create_memory_table.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use crate::sql::exceptions::py_type_err;
use crate::sql::logical::PyLogicalPlan;
use datafusion_expr::{logical_plan::CreateMemoryTable, logical_plan::CreateView, LogicalPlan};
use pyo3::prelude::*;

#[pyclass(name = "CreateMemoryTable", module = "dask_planner", subclass)]
#[derive(Clone)]
pub struct PyCreateMemoryTable {
create_memory_table: Option<CreateMemoryTable>,
create_view: Option<CreateView>,
}

#[pymethods]
impl PyCreateMemoryTable {
#[pyo3(name = "getName")]
pub fn get_name(&self) -> PyResult<String> {
match &self.create_memory_table {
Some(create_memory_table) => Ok(format!("{}", create_memory_table.name)),
None => match &self.create_view {
Some(create_view) => Ok(format!("{}", create_view.name)),
None => panic!("Encountered a non CreateMemoryTable/CreateView type in get_name"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should try replacing the panic with a python error

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point - might also be worth it to do a follow up PR applying this change to some of the other plan types, which also panic in similar situations

},
}
}

#[pyo3(name = "getInput")]
pub fn get_input(&self) -> PyResult<PyLogicalPlan> {
Ok(PyLogicalPlan {
original_plan: match &self.create_memory_table {
Some(create_memory_table) => (*create_memory_table.input).clone(),
None => match &self.create_view {
Some(create_view) => (*create_view.input).clone(),
None => {
panic!("Encountered a non CreateMemoryTable/CreateView type in get_input")
}
},
},
current_node: None,
})
}

#[pyo3(name = "getIfNotExists")]
pub fn get_if_not_exists(&self) -> PyResult<bool> {
match &self.create_memory_table {
Some(create_memory_table) => Ok(create_memory_table.if_not_exists),
None => Ok(false),
}
}

#[pyo3(name = "getOrReplace")]
pub fn get_or_replace(&self) -> PyResult<bool> {
match &self.create_memory_table {
Some(create_memory_table) => Ok(create_memory_table.or_replace),
None => match &self.create_view {
Some(create_view) => Ok(create_view.or_replace),
None => panic!("Encountered a non CreateMemoryTable/CreateView type in get_name"),
},
}
}
}

impl TryFrom<LogicalPlan> for PyCreateMemoryTable {
type Error = PyErr;

fn try_from(logical_plan: LogicalPlan) -> Result<Self, Self::Error> {
match logical_plan {
LogicalPlan::CreateMemoryTable(create_memory_table) => Ok(PyCreateMemoryTable {
create_memory_table: Some(create_memory_table),
create_view: None,
}),
LogicalPlan::CreateView(create_view) => Ok(PyCreateMemoryTable {
create_memory_table: None,
create_view: Some(create_view),
}),
_ => Err(py_type_err("unexpected plan")),
}
}
}
55 changes: 40 additions & 15 deletions dask_sql/physical/rel/custom/create_table_as.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import logging
from collections import Counter
from typing import TYPE_CHECKING

from dask_sql.datacontainer import DataContainer
from dask_sql.physical.rel.base import BaseRelPlugin
from dask_sql.physical.rel.convert import RelConverter

if TYPE_CHECKING:
import dask_sql
from dask_sql.java import org
from dask_planner import LogicalPlan

logger = logging.getLogger(__name__)

Expand All @@ -32,29 +34,52 @@ class CreateTableAsPlugin(BaseRelPlugin):
Nothing is returned.
"""

class_name = "com.dask.sql.parser.SqlCreateTableAs"
class_name = ["CreateMemoryTable", "CreateView"]

def convert(
self, sql: "org.apache.calcite.sql.SqlNode", context: "dask_sql.Context"
) -> DataContainer:
schema_name, table_name = context.fqn(sql.getTableName())
def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer:
# Rust create_memory_table instance handle
create_memory_table = rel.create_memory_table()

# can we avoid hardcoding the schema name?
schema_name, table_name = context.schema_name, create_memory_table.getName()

if table_name in context.schema[schema_name].tables:
if sql.getIfNotExists():
if create_memory_table.getIfNotExists():
return
elif not sql.getReplace():
elif not create_memory_table.getOrReplace():
raise RuntimeError(
f"A table with the name {table_name} is already present."
)

sql_select = sql.getSelect()
persist = bool(sql.isPersist())
input_rel = create_memory_table.getInput()

logger.debug(
f"Creating new table with name {table_name} and query {sql_select}"
f"Creating new table with name {table_name} and logical plan {input_rel}"
)

sql_select_query = context._to_sql_string(sql_select)
df = context.sql(sql_select_query)

context.create_table(table_name, df, persist=persist, schema_name=schema_name)
dc = RelConverter.convert(input_rel, context=context)
select_names = [field for field in input_rel.getRowType().getFieldList()]

if select_names:
# Use FQ name if not unique and simple name if it is unique. If a join contains the same column
# names the output col is prepended with the fully qualified column name
field_counts = Counter([field.getName() for field in select_names])
select_names = [
field.getQualifiedName()
if field_counts[field.getName()] > 1
else field.getName()
for field in select_names
]

cc = dc.column_container
cc = cc.rename(
{
df_col: select_name
for df_col, select_name in zip(cc.columns, select_names)
}
)
dc = DataContainer(dc.df, cc)

df = dc.assign()

context.create_table(table_name, df, schema_name=schema_name)
1 change: 0 additions & 1 deletion tests/integration/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ def test_wrong_create(c):
)


@pytest.mark.skip(reason="WIP DataFusion")
def test_create_from_query(c, df):
c.sql(
"""
Expand Down