Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
383 changes: 188 additions & 195 deletions dask_planner/src/expression.rs

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions dask_planner/src/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub mod types;

use crate::{
dialect::DaskSqlDialect,
sql::exceptions::{OptimizationException, ParsingException},
sql::exceptions::{py_optimization_exp, py_parsing_exp, py_runtime_err},
};

use arrow::datatypes::{DataType, Field, Schema};
Expand Down Expand Up @@ -171,7 +171,7 @@ impl DaskSQLContext {
schema.add_function(function);
Ok(true)
}
None => Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
None => Err(py_runtime_err(format!(
"Schema: {} not found in DaskSQLContext",
schema_name
))),
Expand All @@ -189,7 +189,7 @@ impl DaskSQLContext {
schema.add_table(table);
Ok(true)
}
None => Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
None => Err(py_runtime_err(format!(
"Schema: {} not found in DaskSQLContext",
schema_name
))),
Expand All @@ -211,7 +211,7 @@ impl DaskSQLContext {
);
Ok(statements)
}
Err(e) => Err(PyErr::new::<ParsingException, _>(format!("{}", e))),
Err(e) => Err(py_parsing_exp(e)),
}
}

Expand All @@ -227,7 +227,7 @@ impl DaskSQLContext {
original_plan: k,
current_node: None,
})
.map_err(|e| PyErr::new::<ParsingException, _>(format!("{}", e)))
.map_err(|e| py_parsing_exp(e))
}

/// Accepts an existing relational plan, `LogicalPlan`, and optimizes it
Expand All @@ -249,13 +249,13 @@ impl DaskSQLContext {
original_plan: k,
current_node: None,
})
.map_err(|e| PyErr::new::<OptimizationException, _>(format!("{}", e)))
.map_err(|e| py_optimization_exp(e))
} else {
// This LogicalPlan does not support Optimization. Return original
Ok(existing_plan)
}
}
Err(e) => Err(PyErr::new::<OptimizationException, _>(format!("{}", e))),
Err(e) => Err(py_optimization_exp(e)),
}
}
}
Expand Down
11 changes: 9 additions & 2 deletions dask_planner/src/sql/exceptions.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use datafusion_common::DataFusionError;
use pyo3::{create_exception, PyErr};
use std::fmt::Debug;

Expand All @@ -12,6 +11,14 @@ pub fn py_type_err(e: impl Debug) -> PyErr {
PyErr::new::<pyo3::exceptions::PyTypeError, _>(format!("{:?}", e))
}

pub fn py_runtime_err(e: DataFusionError) -> PyErr {
pub fn py_runtime_err(e: impl Debug) -> PyErr {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("{:?}", e))
}

pub fn py_parsing_exp(e: impl Debug) -> PyErr {
PyErr::new::<ParsingException, _>(format!("{:?}", e))
}

pub fn py_optimization_exp(e: impl Debug) -> PyErr {
PyErr::new::<OptimizationException, _>(format!("{:?}", e))
}
6 changes: 2 additions & 4 deletions dask_planner/src/sql/logical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ impl PyLogicalPlan {
pub fn table(&mut self) -> PyResult<table::DaskTable> {
match table::table_from_logical_plan(&self.current_node()) {
Some(table) => Ok(table),
None => Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
None => Err(py_type_err(
"Unable to compute DaskTable from DataFusion LogicalPlan",
)),
}
Expand All @@ -163,9 +163,7 @@ impl PyLogicalPlan {
pub fn get_current_node_table_name(&mut self) -> PyResult<String> {
match self.table() {
Ok(dask_table) => Ok(dask_table.name),
Err(_e) => Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(
"Unable to determine current node table name",
)),
Err(_e) => Err(py_type_err("Unable to determine current node table name")),
}
}

Expand Down
18 changes: 12 additions & 6 deletions dask_planner/src/sql/logical/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ impl PyAggregate {
pub fn distinct_columns(&self) -> PyResult<Vec<String>> {
match &self.distinct {
Some(e) => Ok(e.input.schema().field_names()),
None => panic!("distinct_columns invoked for non distinct instance"),
None => Err(py_type_err(
"distinct_columns invoked for non distinct instance",
)),
}
}

Expand All @@ -42,10 +44,12 @@ impl PyAggregate {

#[pyo3(name = "getAggregationFuncName")]
pub fn agg_func_name(&self, expr: PyExpr) -> PyResult<String> {
Ok(match expr.expr {
Expr::AggregateFunction { fun, .. } => fun.to_string(),
_ => panic!("Encountered a non Aggregate type in agg_func_name"),
})
match expr.expr {
Expr::AggregateFunction { fun, .. } => Ok(fun.to_string()),
_ => Err(py_type_err(
"Encountered a non Aggregate type in agg_func_name",
)),
}
}

#[pyo3(name = "getArgs")]
Expand All @@ -55,7 +59,9 @@ impl PyAggregate {
Some(e) => py_expr_list(&e.input, &args),
None => Ok(vec![]),
},
_ => panic!("Encountered a non Aggregate type in agg_func_name"),
_ => Err(py_type_err(
"Encountered a non Aggregate type in agg_func_name",
)),
}
}

Expand Down
12 changes: 10 additions & 2 deletions dask_planner/src/sql/logical/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,20 @@ impl PyJoin {
pub fn join_conditions(&mut self) -> PyResult<Vec<(column::PyColumn, column::PyColumn)>> {
let lhs_table_name: String = match &*self.join.left {
LogicalPlan::TableScan(scan) => scan.table_name.clone(),
_ => panic!("lhs Expected TableScan but something else was received!"),
_ => {
return Err(py_type_err(
"lhs Expected TableScan but something else was received!",
))
}
};

let rhs_table_name: String = match &*self.join.right {
LogicalPlan::TableScan(scan) => scan.table_name.clone(),
_ => panic!("rhs Expected TableScan but something else was received!"),
_ => {
return Err(py_type_err(
"rhs Expected TableScan but something else was received!",
))
}
};

let mut join_conditions: Vec<(column::PyColumn, column::PyColumn)> = Vec::new();
Expand Down
5 changes: 1 addition & 4 deletions dask_planner/src/sql/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,7 @@ impl DaskTypeMap {
};
DataType::Timestamp(unit, tz)
}
_ => {
// panic!("stop here");
sql_type.to_arrow()
}
_ => sql_type.to_arrow(),
};

DaskTypeMap {
Expand Down
13 changes: 7 additions & 6 deletions dask_planner/src/sql/types/rel_data_type.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::sql::exceptions::py_runtime_err;
use crate::sql::types::rel_data_type_field::RelDataTypeField;

use std::collections::HashMap;
Expand Down Expand Up @@ -33,25 +34,25 @@ impl RelDataType {
/// * `field_name` - A String containing the name of the field to find
/// * `case_sensitive` - True if column name matching should be case sensitive and false otherwise
#[pyo3(name = "getField")]
pub fn field(&self, field_name: String, case_sensitive: bool) -> RelDataTypeField {
pub fn field(&self, field_name: String, case_sensitive: bool) -> PyResult<RelDataTypeField> {
assert!(!self.field_list.is_empty());
let field_map: HashMap<String, RelDataTypeField> = self.field_map();
if case_sensitive && !field_map.is_empty() {
field_map.get(&field_name).unwrap().clone()
Ok(field_map.get(&field_name).unwrap().clone())
} else {
for field in &self.field_list {
if (case_sensitive && field.name().eq(&field_name))
|| (!case_sensitive && field.name().eq_ignore_ascii_case(&field_name))
{
return field.clone();
return Ok(field.clone());
}
}

// TODO: Throw a proper error here
panic!(
Err(py_runtime_err(format!(
"Unable to find RelDataTypeField with name {:?} in the RelDataType field_list",
field_name
);
field_name,
)))
}
}

Expand Down