Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions dask_planner/src/sql/logical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::sql::types::rel_data_type_field::RelDataTypeField;
mod aggregate;
mod create_memory_table;
mod cross_join;
mod drop_table;
mod empty_relation;
mod explain;
mod filter;
Expand Down Expand Up @@ -129,6 +130,11 @@ impl PyLogicalPlan {
to_py_plan(self.current_node.as_ref())
}

/// LogicalPlan::DropTable as DropTable
pub fn drop_table(&self) -> PyResult<drop_table::PyDropTable> {
to_py_plan(self.current_node.as_ref())
}

/// Gets the "input" for the current LogicalPlan
pub fn get_inputs(&mut self) -> PyResult<Vec<PyLogicalPlan>> {
let mut py_inputs: Vec<PyLogicalPlan> = Vec::new();
Expand Down
34 changes: 34 additions & 0 deletions dask_planner/src/sql/logical/drop_table.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use datafusion_expr::logical_plan::{DropTable, LogicalPlan};

use crate::sql::exceptions::py_type_err;
use pyo3::prelude::*;

#[pyclass(name = "DropTable", module = "dask_planner", subclass)]
#[derive(Clone)]
pub struct PyDropTable {
drop_table: DropTable,
}

#[pymethods]
impl PyDropTable {
#[pyo3(name = "getName")]
pub fn get_name(&self) -> PyResult<String> {
Ok(self.drop_table.name.clone())
}

#[pyo3(name = "getIfExists")]
pub fn get_if_exists(&self) -> PyResult<bool> {
Ok(self.drop_table.if_exists)
}
}

impl TryFrom<LogicalPlan> for PyDropTable {
type Error = PyErr;

fn try_from(logical_plan: LogicalPlan) -> Result<Self, Self::Error> {
match logical_plan {
LogicalPlan::DropTable(drop_table) => Ok(PyDropTable { drop_table }),
_ => Err(py_type_err("unexpected plan")),
}
}
}
4 changes: 0 additions & 4 deletions dask_planner/src/sql/types/rel_data_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ impl RelDataType {
/// * `case_sensitive` - True if column name matching should be case sensitive and false otherwise
#[pyo3(name = "getField")]
pub fn field(&self, field_name: String, case_sensitive: bool) -> PyResult<RelDataTypeField> {
assert!(!self.field_list.is_empty());
let field_map: HashMap<String, RelDataTypeField> = self.field_map();
if case_sensitive && !field_map.is_empty() {
Ok(field_map.get(&field_name).unwrap().clone())
Expand Down Expand Up @@ -73,14 +72,12 @@ impl RelDataType {
/// Gets the fields in a struct type. The field count is equal to the size of the returned list.
#[pyo3(name = "getFieldList")]
pub fn field_list(&self) -> Vec<RelDataTypeField> {
assert!(!self.field_list.is_empty());
self.field_list.clone()
}

/// Returns the names of all of the columns in a given DaskTable
#[pyo3(name = "getFieldNames")]
pub fn field_names(&self) -> Vec<String> {
assert!(!self.field_list.is_empty());
let mut field_names: Vec<String> = Vec::new();
for field in &self.field_list {
field_names.push(field.qualified_name());
Expand All @@ -91,7 +88,6 @@ impl RelDataType {
/// Returns the number of fields in a struct type.
#[pyo3(name = "getFieldCount")]
pub fn field_count(&self) -> usize {
assert!(!self.field_list.is_empty());
self.field_list.len()
}

Expand Down
16 changes: 9 additions & 7 deletions dask_sql/physical/rel/custom/drop_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

if TYPE_CHECKING:
import dask_sql
from dask_sql.java import org
from dask_sql.rust import LogicalPlan

logger = logging.getLogger(__name__)

Expand All @@ -19,15 +19,17 @@ class DropTablePlugin(BaseRelPlugin):
DROP TABLE <table-name>
"""

class_name = "com.dask.sql.parser.SqlDropTable"
class_name = "DropTable"

def convert(
self, sql: "org.apache.calcite.sql.SqlNode", context: "dask_sql.Context"
) -> DataContainer:
schema_name, table_name = context.fqn(sql.getTableName())
def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer:
# Rust create_memory_table instance handle
drop_table = rel.drop_table()

# can we avoid hardcoding the schema name?
schema_name, table_name = context.schema_name, drop_table.getName()

if table_name not in context.schema[schema_name].tables:
if not sql.getIfExists():
if not drop_table.getIfExists():
raise RuntimeError(
f"A table with the name {table_name} is not present."
)
Expand Down
1 change: 0 additions & 1 deletion tests/integration/test_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,6 @@ def test_replace_and_error(c, temporary_data_file, df):
assert_eq(result_df, df)


@pytest.mark.skip(reason="WIP DataFusion")
def test_drop(c):
with pytest.raises(RuntimeError):
c.sql("DROP TABLE new_table")
Expand Down