diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index 966611338..627b5efed 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -114,7 +114,7 @@ impl PyExpr { pub fn subquery_plan(&self) -> PyResult { match &self.expr { Expr::ScalarSubquery(subquery) => Ok((&*subquery.subquery).clone().into()), - _ => Err(PyErr::new::(format!( + _ => Err(py_type_err(format!( "Attempted to extract a LogicalPlan instance from invalid Expr {:?}. Only Subquery and related variants are supported for this operation.", &self.expr @@ -126,10 +126,10 @@ impl PyExpr { /// Column in the SQL parse tree or not #[pyo3(name = "isInputReference")] pub fn is_input_reference(&self) -> PyResult { - match &self.expr { - Expr::Column(_col) => Ok(true), - _ => Ok(false), - } + Ok(match &self.expr { + Expr::Column(_col) => true, + _ => false, + }) } #[pyo3(name = "toString")] @@ -160,7 +160,7 @@ impl PyExpr { } Ok(idx) } - Err(e) => panic!("{:?}", e), + Err(e) => Err(py_runtime_err(e)), } } else if input_plans.len() >= 2 { let mut base_schema: DFSchema = (**input_plans[0].schema()).clone(); @@ -192,19 +192,21 @@ impl PyExpr { return Ok(index); } } - panic!("Unable to find match for column with name: '{}' in DFSchema", &fq_name); + Err(py_runtime_err(format!("Unable to find match for column with name: '{}' in DFSchema", &fq_name))) } } } - Err(e) => panic!("{:?}", e), + Err(e) => Err(py_runtime_err(e)), } } else { - panic!("Not really sure what we should do right here???"); + Err(py_runtime_err( + "Not really sure what we should do right here???", + )) } } - None => { - panic!("We need a valid LogicalPlan instance to get the Expr's index in the schema") - } + None => Err(py_runtime_err( + "We need a valid LogicalPlan instance to get the Expr's index in the schema", + )), } } @@ -213,36 +215,41 @@ impl PyExpr { /// RexConverter plugin instance should be invoked to handle /// the Rex conversion #[pyo3(name = "getExprType")] - pub fn get_expr_type(&self) -> String { - String::from(match &self.expr { - Expr::Alias(..) => "Alias", - Expr::Column(..) => "Column", - Expr::ScalarVariable(..) => panic!("ScalarVariable!!!"), - Expr::Literal(..) => "Literal", - Expr::BinaryExpr { .. } => "BinaryExpr", - Expr::Not(..) => panic!("Not!!!"), - Expr::IsNotNull(..) => panic!("IsNotNull!!!"), - Expr::Negative(..) => panic!("Negative!!!"), - Expr::GetIndexedField { .. } => panic!("GetIndexedField!!!"), - Expr::IsNull(..) => panic!("IsNull!!!"), - Expr::Between { .. } => "Between", - Expr::Case { .. } => panic!("Case!!!"), - Expr::Cast { .. } => "Cast", - Expr::TryCast { .. } => panic!("TryCast!!!"), - Expr::Sort { .. } => "Sort", - Expr::ScalarFunction { .. } => "ScalarFunction", - Expr::AggregateFunction { .. } => "AggregateFunction", - Expr::WindowFunction { .. } => panic!("WindowFunction!!!"), - Expr::AggregateUDF { .. } => panic!("AggregateUDF!!!"), - Expr::InList { .. } => "InList", - Expr::Wildcard => panic!("Wildcard!!!"), - Expr::InSubquery { .. } => "Subquery", - Expr::ScalarUDF { .. } => "ScalarUDF", - Expr::Exists { .. } => "Exists", - Expr::ScalarSubquery(..) => "ScalarSubquery", - Expr::QualifiedWildcard { .. } => "Wildcard", - Expr::GroupingSet(..) => "GroupingSet", - }) + pub fn get_expr_type(&self) -> PyResult { + Ok(String::from(match &self.expr { + Expr::Alias(..) + | Expr::Column(..) + | Expr::Literal(..) + | Expr::BinaryExpr { .. } + | Expr::Between { .. } + | Expr::Cast { .. } + | Expr::Sort { .. } + | Expr::ScalarFunction { .. } + | Expr::AggregateFunction { .. } + | Expr::InList { .. } + | Expr::InSubquery { .. } + | Expr::ScalarUDF { .. } + | Expr::Exists { .. } + | Expr::ScalarSubquery(..) + | Expr::QualifiedWildcard { .. } + | Expr::GroupingSet(..) => self.expr.variant_name(), + Expr::ScalarVariable(..) + | Expr::Not(..) + | Expr::IsNotNull(..) + | Expr::Negative(..) + | Expr::GetIndexedField { .. } + | Expr::IsNull(..) + | Expr::Case { .. } + | Expr::TryCast { .. } + | Expr::WindowFunction { .. } + | Expr::AggregateUDF { .. } + | Expr::Wildcard => { + return Err(py_type_err(format!( + "Encountered unsupported expression type: {}", + &self.expr.variant_name() + ))) + } + })) } /// Determines the type of this Expr based on its variant @@ -354,32 +361,34 @@ impl PyExpr { #[pyo3(name = "getOperatorName")] pub fn get_operator_name(&self) -> PyResult { - match &self.expr { + Ok(match &self.expr { Expr::BinaryExpr { left: _, op, right: _, - } => Ok(format!("{}", op)), - Expr::ScalarFunction { fun, args: _ } => Ok(format!("{}", fun)), - Expr::Cast { .. } => Ok("cast".to_string()), - Expr::Between { .. } => Ok("between".to_string()), - Expr::Case { .. } => Ok("case".to_string()), - Expr::IsNull(..) => Ok("is null".to_string()), - Expr::IsNotNull(..) => Ok("is not null".to_string()), - Expr::ScalarUDF { fun, .. } => Ok(fun.name.clone()), - Expr::InList { .. } => Ok("in list".to_string()), - Expr::Negative(..) => Ok("negative".to_string()), - _ => Err(PyErr::new::(format!( - "Catch all triggered for get_operator_name: {:?}", - &self.expr - ))), - } + } => format!("{}", op), + Expr::ScalarFunction { fun, args: _ } => format!("{}", fun), + Expr::ScalarUDF { fun, .. } => fun.name.clone(), + Expr::Cast { .. } => "cast".to_string(), + Expr::Between { .. } => "between".to_string(), + Expr::Case { .. } => "case".to_string(), + Expr::IsNull(..) => "is null".to_string(), + Expr::IsNotNull(..) => "is not null".to_string(), + Expr::InList { .. } => "in list".to_string(), + Expr::Negative(..) => "negative".to_string(), + _ => { + return Err(py_type_err(format!( + "Catch all triggered in get_operator_name: {:?}", + &self.expr + ))) + } + }) } /// Gets the ScalarValue represented by the Expression #[pyo3(name = "getType")] pub fn get_type(&self) -> PyResult { - match &self.expr { + Ok(String::from(match &self.expr { Expr::BinaryExpr { left: _, op, @@ -402,241 +411,225 @@ impl PyExpr { | Operator::RegexNotMatch | Operator::RegexNotIMatch | Operator::BitwiseAnd - | Operator::BitwiseOr => Ok(String::from("BOOLEAN")), + | Operator::BitwiseOr => "BOOLEAN", Operator::Plus | Operator::Minus | Operator::Multiply | Operator::Modulo => { - Ok(String::from("BIGINT")) + "BIGINT" } - Operator::Divide => Ok(String::from("FLOAT")), - Operator::StringConcat => Ok(String::from("VARCHAR")), + Operator::Divide => "FLOAT", + Operator::StringConcat => "VARCHAR", }, - Expr::ScalarVariable(..) => panic!("ScalarVariable!!!"), Expr::Literal(scalar_value) => match scalar_value { - ScalarValue::Boolean(_value) => Ok(String::from("Boolean")), - ScalarValue::Float32(_value) => Ok(String::from("Float32")), - ScalarValue::Float64(_value) => Ok(String::from("Float64")), - ScalarValue::Decimal128(_value, ..) => Ok(String::from("Decimal128")), - ScalarValue::Int8(_value) => Ok(String::from("Int8")), - ScalarValue::Int16(_value) => Ok(String::from("Int16")), - ScalarValue::Int32(_value) => Ok(String::from("Int32")), - ScalarValue::Int64(_value) => Ok(String::from("Int64")), - ScalarValue::UInt8(_value) => Ok(String::from("UInt8")), - ScalarValue::UInt16(_value) => Ok(String::from("UInt16")), - ScalarValue::UInt32(_value) => Ok(String::from("UInt32")), - ScalarValue::UInt64(_value) => Ok(String::from("UInt64")), - ScalarValue::Utf8(_value) => Ok(String::from("Utf8")), - ScalarValue::LargeUtf8(_value) => Ok(String::from("LargeUtf8")), - ScalarValue::Binary(_value) => Ok(String::from("Binary")), - ScalarValue::LargeBinary(_value) => Ok(String::from("LargeBinary")), - ScalarValue::Date32(_value) => Ok(String::from("Date32")), - ScalarValue::Date64(_value) => Ok(String::from("Date64")), - ScalarValue::Null => Ok(String::from("Null")), + ScalarValue::Null => "Null", + ScalarValue::Boolean(_value) => "Boolean", + ScalarValue::Float32(_value) => "Float32", + ScalarValue::Float64(_value) => "Float64", + ScalarValue::Decimal128(_value, ..) => "Decimal128", + ScalarValue::Int8(_value) => "Int8", + ScalarValue::Int16(_value) => "Int16", + ScalarValue::Int32(_value) => "Int32", + ScalarValue::Int64(_value) => "Int64", + ScalarValue::UInt8(_value) => "UInt8", + ScalarValue::UInt16(_value) => "UInt16", + ScalarValue::UInt32(_value) => "UInt32", + ScalarValue::UInt64(_value) => "UInt64", + ScalarValue::Utf8(_value) => "Utf8", + ScalarValue::LargeUtf8(_value) => "LargeUtf8", + ScalarValue::Binary(_value) => "Binary", + ScalarValue::LargeBinary(_value) => "LargeBinary", + ScalarValue::Date32(_value) => "Date32", + ScalarValue::Date64(_value) => "Date64", _ => { - panic!("CatchAll") + return Err(py_type_err(format!( + "Catch all triggered for Literal in get_type; {:?}", + scalar_value + ))) } }, Expr::ScalarFunction { fun, args: _ } => match fun { - BuiltinScalarFunction::Abs => Ok(String::from("Abs")), - BuiltinScalarFunction::DatePart => Ok(String::from("DatePart")), + BuiltinScalarFunction::Abs => "Abs", + BuiltinScalarFunction::DatePart => "DatePart", _ => { - panic!("fire here for scalar function") + return Err(py_type_err(format!( + "Catch all triggered for ScalarFunction in get_type; {:?}", + fun + ))) } }, Expr::Cast { expr: _, data_type } => match data_type { - DataType::Null => Ok(String::from("NULL")), - DataType::Boolean => Ok(String::from("BOOLEAN")), - DataType::Int8 => Ok(String::from("TINYINT")), - DataType::UInt8 => Ok(String::from("TINYINT")), - DataType::Int16 => Ok(String::from("SMALLINT")), - DataType::UInt16 => Ok(String::from("SMALLINT")), - DataType::Int32 => Ok(String::from("INTEGER")), - DataType::UInt32 => Ok(String::from("INTEGER")), - DataType::Int64 => Ok(String::from("BIGINT")), - DataType::UInt64 => Ok(String::from("BIGINT")), - DataType::Float32 => Ok(String::from("FLOAT")), - DataType::Float64 => Ok(String::from("DOUBLE")), - DataType::Timestamp { .. } => Ok(String::from("TIMESTAMP")), - DataType::Date32 => Ok(String::from("DATE")), - DataType::Date64 => Ok(String::from("DATE")), - DataType::Time32(..) => Ok(String::from("TIME32")), - DataType::Time64(..) => Ok(String::from("TIME64")), - DataType::Duration(..) => Ok(String::from("DURATION")), - DataType::Interval(..) => Ok(String::from("INTERVAL")), - DataType::Binary => Ok(String::from("BINARY")), - DataType::FixedSizeBinary(..) => Ok(String::from("FIXEDSIZEBINARY")), - DataType::LargeBinary => Ok(String::from("LARGEBINARY")), - DataType::Utf8 => Ok(String::from("VARCHAR")), - DataType::LargeUtf8 => Ok(String::from("BIGVARCHAR")), - DataType::List(..) => Ok(String::from("LIST")), - DataType::FixedSizeList(..) => Ok(String::from("FIXEDSIZELIST")), - DataType::LargeList(..) => Ok(String::from("LARGELIST")), - DataType::Struct(..) => Ok(String::from("STRUCT")), - DataType::Union(..) => Ok(String::from("UNION")), - DataType::Dictionary(..) => Ok(String::from("DICTIONARY")), - DataType::Decimal(..) => Ok(String::from("DECIMAL")), - DataType::Map(..) => Ok(String::from("MAP")), + DataType::Null => "NULL", + DataType::Boolean => "BOOLEAN", + DataType::Int8 | DataType::UInt8 => "TINYINT", + DataType::Int16 | DataType::UInt16 => "SMALLINT", + DataType::Int32 | DataType::UInt32 => "INTEGER", + DataType::Int64 | DataType::UInt64 => "BIGINT", + DataType::Float32 => "FLOAT", + DataType::Float64 => "DOUBLE", + DataType::Timestamp { .. } => "TIMESTAMP", + DataType::Date32 | DataType::Date64 => "DATE", + DataType::Time32(..) => "TIME32", + DataType::Time64(..) => "TIME64", + DataType::Duration(..) => "DURATION", + DataType::Interval(..) => "INTERVAL", + DataType::Binary => "BINARY", + DataType::FixedSizeBinary(..) => "FIXEDSIZEBINARY", + DataType::LargeBinary => "LARGEBINARY", + DataType::Utf8 => "VARCHAR", + DataType::LargeUtf8 => "BIGVARCHAR", + DataType::List(..) => "LIST", + DataType::FixedSizeList(..) => "FIXEDSIZELIST", + DataType::LargeList(..) => "LARGELIST", + DataType::Struct(..) => "STRUCT", + DataType::Union(..) => "UNION", + DataType::Dictionary(..) => "DICTIONARY", + DataType::Decimal(..) => "DECIMAL", + DataType::Map(..) => "MAP", _ => { - panic!("This is not yet implemented!!!") + return Err(py_type_err(format!( + "Catch all triggered for Cast in get_type; {:?}", + data_type + ))) } }, - _ => panic!("OTHER"), - } + _ => { + return Err(py_type_err(format!( + "Catch all triggered in get_type; {:?}", + &self.expr + ))) + } + })) } /// TODO: I can't express how much I dislike explicity listing all of these methods out /// but PyO3 makes it necessary since its annotations cannot be used in trait impl blocks #[pyo3(name = "getFloat32Value")] - pub fn float_32_value(&mut self) -> f32 { + pub fn float_32_value(&mut self) -> PyResult { match &self.expr { Expr::Literal(scalar_value) => match scalar_value { - ScalarValue::Float32(iv) => iv.unwrap(), - _ => { - panic!("getValue() - Unexpected value") - } + ScalarValue::Float32(iv) => Ok(iv.unwrap()), + _ => Err(py_type_err("getValue() - Unexpected value")), }, - _ => panic!("getValue() - Non literal value encountered"), + _ => Err(py_type_err("getValue() - Non literal value encountered")), } } #[pyo3(name = "getFloat64Value")] - pub fn float_64_value(&mut self) -> f64 { + pub fn float_64_value(&mut self) -> PyResult { match &self.expr { Expr::Literal(scalar_value) => match scalar_value { - ScalarValue::Float64(iv) => iv.unwrap(), - _ => { - panic!("getValue() - Unexpected value") - } + ScalarValue::Float64(iv) => Ok(iv.unwrap()), + _ => Err(py_type_err("getValue() - Unexpected value")), }, - _ => panic!("getValue() - Non literal value encountered"), + _ => Err(py_type_err("getValue() - Non literal value encountered")), } } #[pyo3(name = "getInt8Value")] - pub fn int_8_value(&mut self) -> i8 { + pub fn int_8_value(&mut self) -> PyResult { match &self.expr { Expr::Literal(scalar_value) => match scalar_value { - ScalarValue::Int8(iv) => iv.unwrap(), - _ => { - panic!("getValue() - Unexpected value") - } + ScalarValue::Int8(iv) => Ok(iv.unwrap()), + _ => Err(py_type_err("getValue() - Unexpected value")), }, - _ => panic!("getValue() - Non literal value encountered"), + _ => Err(py_type_err("getValue() - Non literal value encountered")), } } #[pyo3(name = "getInt16Value")] - pub fn int_16_value(&mut self) -> i16 { + pub fn int_16_value(&mut self) -> PyResult { match &self.expr { Expr::Literal(scalar_value) => match scalar_value { - ScalarValue::Int16(iv) => iv.unwrap(), - _ => { - panic!("getValue() - Unexpected value") - } + ScalarValue::Int16(iv) => Ok(iv.unwrap()), + _ => Err(py_type_err("getValue() - Unexpected value")), }, - _ => panic!("getValue() - Non literal value encountered"), + _ => Err(py_type_err("getValue() - Non literal value encountered")), } } #[pyo3(name = "getInt32Value")] - pub fn int_32_value(&mut self) -> i32 { + pub fn int_32_value(&mut self) -> PyResult { match &self.expr { Expr::Literal(scalar_value) => match scalar_value { - ScalarValue::Int32(iv) => iv.unwrap(), - _ => { - panic!("getValue() - Unexpected value") - } + ScalarValue::Int32(iv) => Ok(iv.unwrap()), + _ => Err(py_type_err("getValue() - Unexpected value")), }, - _ => panic!("getValue() - Non literal value encountered"), + _ => Err(py_type_err("getValue() - Non literal value encountered")), } } #[pyo3(name = "getInt64Value")] - pub fn int_64_value(&mut self) -> i64 { + pub fn int_64_value(&mut self) -> PyResult { match &self.expr { Expr::Literal(scalar_value) => match scalar_value { - ScalarValue::Int64(iv) => iv.unwrap(), - _ => { - panic!("getValue() - Unexpected value") - } + ScalarValue::Int64(iv) => Ok(iv.unwrap()), + _ => Err(py_type_err("getValue() - Unexpected value")), }, - _ => panic!("getValue() - Non literal value encountered"), + _ => Err(py_type_err("getValue() - Non literal value encountered")), } } #[pyo3(name = "getUInt8Value")] - pub fn uint_8_value(&mut self) -> u8 { + pub fn uint_8_value(&mut self) -> PyResult { match &self.expr { Expr::Literal(scalar_value) => match scalar_value { - ScalarValue::UInt8(iv) => iv.unwrap(), - _ => { - panic!("getValue() - Unexpected value") - } + ScalarValue::UInt8(iv) => Ok(iv.unwrap()), + _ => Err(py_type_err("getValue() - Unexpected value")), }, - _ => panic!("getValue() - Non literal value encountered"), + _ => Err(py_type_err("getValue() - Non literal value encountered")), } } #[pyo3(name = "getUInt16Value")] - pub fn uint_16_value(&mut self) -> u16 { + pub fn uint_16_value(&mut self) -> PyResult { match &self.expr { Expr::Literal(scalar_value) => match scalar_value { - ScalarValue::UInt16(iv) => iv.unwrap(), - _ => { - panic!("getValue() - Unexpected value") - } + ScalarValue::UInt16(iv) => Ok(iv.unwrap()), + _ => Err(py_type_err("getValue() - Unexpected value")), }, - _ => panic!("getValue() - Non literal value encountered"), + _ => Err(py_type_err("getValue() - Non literal value encountered")), } } #[pyo3(name = "getUInt32Value")] - pub fn uint_32_value(&mut self) -> u32 { + pub fn uint_32_value(&mut self) -> PyResult { match &self.expr { Expr::Literal(scalar_value) => match scalar_value { - ScalarValue::UInt32(iv) => iv.unwrap(), - _ => { - panic!("getValue() - Unexpected value") - } + ScalarValue::UInt32(iv) => Ok(iv.unwrap()), + _ => Err(py_type_err("getValue() - Unexpected value")), }, - _ => panic!("getValue() - Non literal value encountered"), + _ => Err(py_type_err("getValue() - Non literal value encountered")), } } #[pyo3(name = "getUInt64Value")] - pub fn uint_64_value(&mut self) -> u64 { + pub fn uint_64_value(&mut self) -> PyResult { match &self.expr { Expr::Literal(scalar_value) => match scalar_value { - ScalarValue::UInt64(iv) => iv.unwrap(), - _ => { - panic!("getValue() - Unexpected value") - } + ScalarValue::UInt64(iv) => Ok(iv.unwrap()), + _ => Err(py_type_err("getValue() - Unexpected value")), }, - _ => panic!("getValue() - Non literal value encountered"), + _ => Err(py_type_err("getValue() - Non literal value encountered")), } } #[pyo3(name = "getBoolValue")] - pub fn bool_value(&mut self) -> bool { + pub fn bool_value(&mut self) -> PyResult { match &self.expr { Expr::Literal(scalar_value) => match scalar_value { - ScalarValue::Boolean(iv) => iv.unwrap(), - _ => { - panic!("getValue() - Unexpected value") - } + ScalarValue::Boolean(iv) => Ok(iv.unwrap()), + _ => Err(py_type_err("getValue() - Unexpected value")), }, - _ => panic!("getValue() - Non literal value encountered"), + _ => Err(py_type_err("getValue() - Non literal value encountered")), } } #[pyo3(name = "getStringValue")] - pub fn string_value(&mut self) -> String { + pub fn string_value(&mut self) -> PyResult { match &self.expr { Expr::Literal(scalar_value) => match scalar_value { - ScalarValue::Utf8(iv) => iv.clone().unwrap(), - _ => { - panic!("getValue() - Unexpected value") - } + ScalarValue::Utf8(iv) => Ok(iv.clone().unwrap()), + _ => Err(py_type_err("getValue() - Unexpected value")), }, - _ => panic!("getValue() - Non literal value encountered"), + _ => Err(py_type_err("getValue() - Non literal value encountered")), } } @@ -647,7 +640,7 @@ impl PyExpr { | Expr::Exists { negated, .. } | Expr::InList { negated, .. } | Expr::InSubquery { negated, .. } => Ok(negated.clone()), - _ => Err(PyErr::new::(format!( + _ => Err(py_type_err(format!( "unknown Expr type {:?} encountered", &self.expr ))), diff --git a/dask_planner/src/sql.rs b/dask_planner/src/sql.rs index 67da2dbc7..186c387f4 100644 --- a/dask_planner/src/sql.rs +++ b/dask_planner/src/sql.rs @@ -10,7 +10,7 @@ pub mod types; use crate::{ dialect::DaskSqlDialect, - sql::exceptions::{OptimizationException, ParsingException}, + sql::exceptions::{py_optimization_exp, py_parsing_exp, py_runtime_err}, }; use arrow::datatypes::{DataType, Field, Schema}; @@ -171,7 +171,7 @@ impl DaskSQLContext { schema.add_function(function); Ok(true) } - None => Err(PyErr::new::(format!( + None => Err(py_runtime_err(format!( "Schema: {} not found in DaskSQLContext", schema_name ))), @@ -189,7 +189,7 @@ impl DaskSQLContext { schema.add_table(table); Ok(true) } - None => Err(PyErr::new::(format!( + None => Err(py_runtime_err(format!( "Schema: {} not found in DaskSQLContext", schema_name ))), @@ -211,7 +211,7 @@ impl DaskSQLContext { ); Ok(statements) } - Err(e) => Err(PyErr::new::(format!("{}", e))), + Err(e) => Err(py_parsing_exp(e)), } } @@ -227,7 +227,7 @@ impl DaskSQLContext { original_plan: k, current_node: None, }) - .map_err(|e| PyErr::new::(format!("{}", e))) + .map_err(|e| py_parsing_exp(e)) } /// Accepts an existing relational plan, `LogicalPlan`, and optimizes it @@ -249,13 +249,13 @@ impl DaskSQLContext { original_plan: k, current_node: None, }) - .map_err(|e| PyErr::new::(format!("{}", e))) + .map_err(|e| py_optimization_exp(e)) } else { // This LogicalPlan does not support Optimization. Return original Ok(existing_plan) } } - Err(e) => Err(PyErr::new::(format!("{}", e))), + Err(e) => Err(py_optimization_exp(e)), } } } diff --git a/dask_planner/src/sql/exceptions.rs b/dask_planner/src/sql/exceptions.rs index dfd879758..5627bbb3f 100644 --- a/dask_planner/src/sql/exceptions.rs +++ b/dask_planner/src/sql/exceptions.rs @@ -1,4 +1,3 @@ -use datafusion_common::DataFusionError; use pyo3::{create_exception, PyErr}; use std::fmt::Debug; @@ -12,6 +11,14 @@ pub fn py_type_err(e: impl Debug) -> PyErr { PyErr::new::(format!("{:?}", e)) } -pub fn py_runtime_err(e: DataFusionError) -> PyErr { +pub fn py_runtime_err(e: impl Debug) -> PyErr { PyErr::new::(format!("{:?}", e)) } + +pub fn py_parsing_exp(e: impl Debug) -> PyErr { + PyErr::new::(format!("{:?}", e)) +} + +pub fn py_optimization_exp(e: impl Debug) -> PyErr { + PyErr::new::(format!("{:?}", e)) +} diff --git a/dask_planner/src/sql/logical.rs b/dask_planner/src/sql/logical.rs index 7598df864..79fbe5733 100644 --- a/dask_planner/src/sql/logical.rs +++ b/dask_planner/src/sql/logical.rs @@ -138,7 +138,7 @@ impl PyLogicalPlan { pub fn table(&mut self) -> PyResult { match table::table_from_logical_plan(&self.current_node()) { Some(table) => Ok(table), - None => Err(PyErr::new::( + None => Err(py_type_err( "Unable to compute DaskTable from DataFusion LogicalPlan", )), } @@ -163,9 +163,7 @@ impl PyLogicalPlan { pub fn get_current_node_table_name(&mut self) -> PyResult { match self.table() { Ok(dask_table) => Ok(dask_table.name), - Err(_e) => Err(PyErr::new::( - "Unable to determine current node table name", - )), + Err(_e) => Err(py_type_err("Unable to determine current node table name")), } } diff --git a/dask_planner/src/sql/logical/aggregate.rs b/dask_planner/src/sql/logical/aggregate.rs index f14baf4b7..f73e2a48d 100644 --- a/dask_planner/src/sql/logical/aggregate.rs +++ b/dask_planner/src/sql/logical/aggregate.rs @@ -19,7 +19,9 @@ impl PyAggregate { pub fn distinct_columns(&self) -> PyResult> { match &self.distinct { Some(e) => Ok(e.input.schema().field_names()), - None => panic!("distinct_columns invoked for non distinct instance"), + None => Err(py_type_err( + "distinct_columns invoked for non distinct instance", + )), } } @@ -42,10 +44,12 @@ impl PyAggregate { #[pyo3(name = "getAggregationFuncName")] pub fn agg_func_name(&self, expr: PyExpr) -> PyResult { - Ok(match expr.expr { - Expr::AggregateFunction { fun, .. } => fun.to_string(), - _ => panic!("Encountered a non Aggregate type in agg_func_name"), - }) + match expr.expr { + Expr::AggregateFunction { fun, .. } => Ok(fun.to_string()), + _ => Err(py_type_err( + "Encountered a non Aggregate type in agg_func_name", + )), + } } #[pyo3(name = "getArgs")] @@ -55,7 +59,9 @@ impl PyAggregate { Some(e) => py_expr_list(&e.input, &args), None => Ok(vec![]), }, - _ => panic!("Encountered a non Aggregate type in agg_func_name"), + _ => Err(py_type_err( + "Encountered a non Aggregate type in agg_func_name", + )), } } diff --git a/dask_planner/src/sql/logical/join.rs b/dask_planner/src/sql/logical/join.rs index b8d94a83c..c10f11a50 100644 --- a/dask_planner/src/sql/logical/join.rs +++ b/dask_planner/src/sql/logical/join.rs @@ -49,12 +49,20 @@ impl PyJoin { pub fn join_conditions(&mut self) -> PyResult> { let lhs_table_name: String = match &*self.join.left { LogicalPlan::TableScan(scan) => scan.table_name.clone(), - _ => panic!("lhs Expected TableScan but something else was received!"), + _ => { + return Err(py_type_err( + "lhs Expected TableScan but something else was received!", + )) + } }; let rhs_table_name: String = match &*self.join.right { LogicalPlan::TableScan(scan) => scan.table_name.clone(), - _ => panic!("rhs Expected TableScan but something else was received!"), + _ => { + return Err(py_type_err( + "rhs Expected TableScan but something else was received!", + )) + } }; let mut join_conditions: Vec<(column::PyColumn, column::PyColumn)> = Vec::new(); diff --git a/dask_planner/src/sql/types.rs b/dask_planner/src/sql/types.rs index 30fd250ff..934102989 100644 --- a/dask_planner/src/sql/types.rs +++ b/dask_planner/src/sql/types.rs @@ -109,10 +109,7 @@ impl DaskTypeMap { }; DataType::Timestamp(unit, tz) } - _ => { - // panic!("stop here"); - sql_type.to_arrow() - } + _ => sql_type.to_arrow(), }; DaskTypeMap { diff --git a/dask_planner/src/sql/types/rel_data_type.rs b/dask_planner/src/sql/types/rel_data_type.rs index 8681e8fa6..259a710d6 100644 --- a/dask_planner/src/sql/types/rel_data_type.rs +++ b/dask_planner/src/sql/types/rel_data_type.rs @@ -1,3 +1,4 @@ +use crate::sql::exceptions::py_runtime_err; use crate::sql::types::rel_data_type_field::RelDataTypeField; use std::collections::HashMap; @@ -33,25 +34,25 @@ impl RelDataType { /// * `field_name` - A String containing the name of the field to find /// * `case_sensitive` - True if column name matching should be case sensitive and false otherwise #[pyo3(name = "getField")] - pub fn field(&self, field_name: String, case_sensitive: bool) -> RelDataTypeField { + pub fn field(&self, field_name: String, case_sensitive: bool) -> PyResult { assert!(!self.field_list.is_empty()); let field_map: HashMap = self.field_map(); if case_sensitive && !field_map.is_empty() { - field_map.get(&field_name).unwrap().clone() + Ok(field_map.get(&field_name).unwrap().clone()) } else { for field in &self.field_list { if (case_sensitive && field.name().eq(&field_name)) || (!case_sensitive && field.name().eq_ignore_ascii_case(&field_name)) { - return field.clone(); + return Ok(field.clone()); } } // TODO: Throw a proper error here - panic!( + Err(py_runtime_err(format!( "Unable to find RelDataTypeField with name {:?} in the RelDataType field_list", - field_name - ); + field_name, + ))) } }