diff --git a/dask_planner/src/sql/logical.rs b/dask_planner/src/sql/logical.rs index 7598df864..67aea8cb9 100644 --- a/dask_planner/src/sql/logical.rs +++ b/dask_planner/src/sql/logical.rs @@ -3,6 +3,7 @@ use crate::sql::types::rel_data_type::RelDataType; use crate::sql::types::rel_data_type_field::RelDataTypeField; mod aggregate; +mod create_memory_table; mod cross_join; mod empty_relation; mod explain; @@ -123,6 +124,11 @@ impl PyLogicalPlan { to_py_plan(self.current_node.as_ref()) } + /// LogicalPlan::CreateMemoryTable as PyCreateMemoryTable + pub fn create_memory_table(&self) -> PyResult { + to_py_plan(self.current_node.as_ref()) + } + /// Gets the "input" for the current LogicalPlan pub fn get_inputs(&mut self) -> PyResult> { let mut py_inputs: Vec = Vec::new(); diff --git a/dask_planner/src/sql/logical/create_memory_table.rs b/dask_planner/src/sql/logical/create_memory_table.rs new file mode 100644 index 000000000..2dabde073 --- /dev/null +++ b/dask_planner/src/sql/logical/create_memory_table.rs @@ -0,0 +1,99 @@ +use crate::sql::exceptions::py_type_err; +use crate::sql::logical::PyLogicalPlan; +use datafusion_expr::{logical_plan::CreateMemoryTable, logical_plan::CreateView, LogicalPlan}; +use pyo3::prelude::*; + +#[pyclass(name = "CreateMemoryTable", module = "dask_planner", subclass)] +#[derive(Clone)] +pub struct PyCreateMemoryTable { + create_memory_table: Option, + create_view: Option, +} + +#[pymethods] +impl PyCreateMemoryTable { + #[pyo3(name = "getName")] + pub fn get_name(&self) -> PyResult { + Ok(match &self.create_memory_table { + Some(create_memory_table) => create_memory_table.name.clone(), + None => match &self.create_view { + Some(create_view) => create_view.name.clone(), + None => { + return Err(py_type_err( + "Encountered a non CreateMemoryTable/CreateView type in get_input", + )) + } + }, + }) + } + + #[pyo3(name = "getInput")] + pub fn get_input(&self) -> PyResult { + Ok(match &self.create_memory_table { + Some(create_memory_table) => PyLogicalPlan { + original_plan: (*create_memory_table.input).clone(), + current_node: None, + }, + None => match &self.create_view { + Some(create_view) => PyLogicalPlan { + original_plan: (*create_view.input).clone(), + current_node: None, + }, + None => { + return Err(py_type_err( + "Encountered a non CreateMemoryTable/CreateView type in get_input", + )) + } + }, + }) + } + + #[pyo3(name = "getIfNotExists")] + pub fn get_if_not_exists(&self) -> PyResult { + Ok(match &self.create_memory_table { + Some(create_memory_table) => create_memory_table.if_not_exists, + None => false, // TODO: in the future we may want to set this based on dialect + }) + } + + #[pyo3(name = "getOrReplace")] + pub fn get_or_replace(&self) -> PyResult { + Ok(match &self.create_memory_table { + Some(create_memory_table) => create_memory_table.or_replace, + None => match &self.create_view { + Some(create_view) => create_view.or_replace, + None => { + return Err(py_type_err( + "Encountered a non CreateMemoryTable/CreateView type in get_input", + )) + } + }, + }) + } + + #[pyo3(name = "isTable")] + pub fn is_table(&self) -> PyResult { + Ok(match &self.create_memory_table { + Some(_) => true, + None => false, + }) + } +} + +impl TryFrom for PyCreateMemoryTable { + type Error = PyErr; + + fn try_from(logical_plan: LogicalPlan) -> Result { + Ok(match logical_plan { + LogicalPlan::CreateMemoryTable(create_memory_table) => PyCreateMemoryTable { + create_memory_table: Some(create_memory_table), + create_view: None, + }, + LogicalPlan::CreateView(create_view) => PyCreateMemoryTable { + create_memory_table: None, + create_view: Some(create_view), + }, + _ => return Err(py_type_err("unexpected plan")), + }) + } +} diff --git a/dask_sql/context.py b/dask_sql/context.py index 4c48b13d0..7d2903637 100644 --- a/dask_sql/context.py +++ b/dask_sql/context.py @@ -41,7 +41,7 @@ from dask_sql.utils import OptimizationException, ParsingException if TYPE_CHECKING: - from dask_planner.rust import Expression + from dask_planner.rust import Expression, LogicalPlan logger = logging.getLogger(__name__) @@ -490,40 +490,9 @@ def sql( for df_name, df in dataframes.items(): self.create_table(df_name, df, gpu=gpu) - rel, select_fields, _ = self._get_ral(sql) - - dc = RelConverter.convert(rel, context=self) - - if rel.get_current_node_type() == "Explain": - return dc - if dc is None: - return - - if select_fields: - # Use FQ name if not unique and simple name if it is unique. If a join contains the same column - # names the output col is prepended with the fully qualified column name - field_counts = Counter([field.getName() for field in select_fields]) - select_names = [ - field.getQualifiedName() - if field_counts[field.getName()] > 1 - else field.getName() - for field in select_fields - ] - - cc = dc.column_container - cc = cc.rename( - { - df_col: select_name - for df_col, select_name in zip(cc.columns, select_names) - } - ) - dc = DataContainer(dc.df, cc) - - df = dc.assign() - if not return_futures: - df = df.compute() + rel, _ = self._get_ral(sql) - return df + return self._compute_table_from_rel(rel, return_futures) def explain( self, @@ -555,7 +524,7 @@ def explain( for df_name, df in dataframes.items(): self.create_table(df_name, df, gpu=gpu) - _, _, rel_string = self._get_ral(sql) + _, rel_string = self._get_ral(sql) return rel_string def visualize(self, sql: str, filename="mydask.png") -> None: # pragma: no cover @@ -817,7 +786,6 @@ def _get_ral(self, sql): sqlTree = self.context.parse_sql(sql) logger.debug(f"_get_ral -> sqlTree: {sqlTree}") - select_names = None rel = sqlTree # TODO: Need to understand if this list here is actually needed? For now just use the first entry. @@ -845,10 +813,44 @@ def _get_ral(self, sql): logger.debug(f"_get_ral -> LogicalPlan: {rel}") logger.debug(f"Extracted relational algebra:\n {rel_string}") + return rel, rel_string + + def _compute_table_from_rel(self, rel: "LogicalPlan", return_futures: bool = True): + dc = RelConverter.convert(rel, context=self) + # Optimization might remove some alias projects. Make sure to keep them here. select_names = [field for field in rel.getRowType().getFieldList()] - return rel, select_names, rel_string + if rel.get_current_node_type() == "Explain": + return dc + if dc is None: + return + + if select_names: + # Use FQ name if not unique and simple name if it is unique. If a join contains the same column + # names the output col is prepended with the fully qualified column name + field_counts = Counter([field.getName() for field in select_names]) + select_names = [ + field.getQualifiedName() + if field_counts[field.getName()] > 1 + else field.getName() + for field in select_names + ] + + cc = dc.column_container + cc = cc.rename( + { + df_col: select_name + for df_col, select_name in zip(cc.columns, select_names) + } + ) + dc = DataContainer(dc.df, cc) + + df = dc.assign() + if not return_futures: + df = df.compute() + + return df def _get_tables_from_stack(self): """Helper function to return all dask/pandas dataframes from the calling stack""" diff --git a/dask_sql/physical/rel/custom/create_table_as.py b/dask_sql/physical/rel/custom/create_table_as.py index 7a0c04044..2819072f5 100644 --- a/dask_sql/physical/rel/custom/create_table_as.py +++ b/dask_sql/physical/rel/custom/create_table_as.py @@ -6,7 +6,7 @@ if TYPE_CHECKING: import dask_sql - from dask_sql.java import org + from dask_planner import LogicalPlan logger = logging.getLogger(__name__) @@ -32,29 +32,36 @@ class CreateTableAsPlugin(BaseRelPlugin): Nothing is returned. """ - class_name = "com.dask.sql.parser.SqlCreateTableAs" + class_name = ["CreateMemoryTable", "CreateView"] - def convert( - self, sql: "org.apache.calcite.sql.SqlNode", context: "dask_sql.Context" - ) -> DataContainer: - schema_name, table_name = context.fqn(sql.getTableName()) + def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer: + # Rust create_memory_table instance handle + create_memory_table = rel.create_memory_table() + + # can we avoid hardcoding the schema name? + schema_name, table_name = context.schema_name, create_memory_table.getName() if table_name in context.schema[schema_name].tables: - if sql.getIfNotExists(): + if create_memory_table.getIfNotExists(): return - elif not sql.getReplace(): + elif not create_memory_table.getOrReplace(): raise RuntimeError( f"A table with the name {table_name} is already present." ) - sql_select = sql.getSelect() - persist = bool(sql.isPersist()) + input_rel = create_memory_table.getInput() + + # TODO: we currently always persist for CREATE TABLE AS and never persist for CREATE VIEW AS; + # should this be configured by the user? https://github.com/dask-contrib/dask-sql/issues/269 + persist = create_memory_table.isTable() logger.debug( - f"Creating new table with name {table_name} and query {sql_select}" + f"Creating new table with name {table_name} and logical plan {input_rel}" ) - sql_select_query = context._to_sql_string(sql_select) - df = context.sql(sql_select_query) - - context.create_table(table_name, df, persist=persist, schema_name=schema_name) + context.create_table( + table_name, + context._compute_table_from_rel(input_rel), + persist=persist, + schema_name=schema_name, + ) diff --git a/tests/integration/test_create.py b/tests/integration/test_create.py index 1e5281e79..d1968fb98 100644 --- a/tests/integration/test_create.py +++ b/tests/integration/test_create.py @@ -119,7 +119,6 @@ def test_wrong_create(c): ) -@pytest.mark.skip(reason="WIP DataFusion") def test_create_from_query(c, df): c.sql( """