Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
863fcb3
Implement PySort logical_plan
ayushdg May 5, 2022
f565c2a
Add a sort plan accessor to the logical plan
ayushdg May 5, 2022
75933b8
Python: Update the sort plugin
ayushdg May 5, 2022
e882408
Python: Uncomment tests
ayushdg May 5, 2022
5f7e8c6
Merge branch 'datafusion-sql-planner' of github.com:dask-contrib/dask…
ayushdg May 10, 2022
b6c534e
PLanner: Update accessor pattern for concrete logical plan implementa…
ayushdg May 10, 2022
82b4234
Test: Address review comments
ayushdg May 10, 2022
ee99096
add support for expr_to_field for Expr::Sort expressions
andygrove May 10, 2022
5d1a561
Merge branch 'datafusion-sql-planner' of github.com:dask-contrib/dask…
ayushdg May 10, 2022
2932a28
Merge commit 'refs/pull/515/head' of github.com:dask-contrib/dask-sql…
ayushdg May 10, 2022
05c6a85
Planner: Update sort expr utilities and import cleanup
ayushdg May 11, 2022
f9f569e
Python: Re-enable skipped sort tests
ayushdg May 11, 2022
0840870
Merge branch 'datafusion-sql-planner' of github.com:dask-contrib/dask…
ayushdg May 11, 2022
8d81c44
Python: Handle case where orderby column name is an alias
ayushdg May 11, 2022
0a75c32
Apply suggestions from code review
ayushdg May 11, 2022
74d3451
Style: Fix formatting
ayushdg May 11, 2022
8aaf8fe
Merge branch 'datafusion-sql-planner' of github.com:dask-contrib/dask…
ayushdg May 12, 2022
6bbb8d7
Planner: Remove public scope for LogicalPlan import
ayushdg May 12, 2022
36b8460
Python: Add more complex sort tests with alias that error right now
ayushdg May 12, 2022
315fad2
Python: Remove old commented code
ayushdg May 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dask_planner/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ impl PyExpr {
Expr::Case { .. } => panic!("Case!!!"),
Expr::Cast { .. } => "Cast",
Expr::TryCast { .. } => panic!("TryCast!!!"),
Expr::Sort { .. } => panic!("Sort!!!"),
Expr::Sort { .. } => "Sort",
Expr::ScalarFunction { .. } => "ScalarFunction",
Expr::AggregateFunction { .. } => "AggregateFunction",
Expr::WindowFunction { .. } => panic!("WindowFunction!!!"),
Expand Down
51 changes: 39 additions & 12 deletions dask_planner/src/sql/logical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod aggregate;
mod filter;
mod join;
pub mod projection;
mod sort;

pub use datafusion_expr::LogicalPlan;

Expand Down Expand Up @@ -49,28 +50,54 @@ impl PyLogicalPlan {

#[pymethods]
impl PyLogicalPlan {
/// LogicalPlan::Projection as PyProjection
pub fn projection(&self) -> PyResult<projection::PyProjection> {
let proj: projection::PyProjection = self.current_node.clone().unwrap().into();
Ok(proj)
/// LogicalPlan::Aggregate as PyAggregate
pub fn aggregate(&self) -> PyResult<aggregate::PyAggregate> {
self.current_node
.as_ref()
.map(|plan| plan.clone().into())
.ok_or(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"current_node was None",
))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this approach is fine for this PR but I would like to follow up with a helper method for these conversions.

}

/// LogicalPlan::Filter as PyFilter
pub fn filter(&self) -> PyResult<filter::PyFilter> {
let filter: filter::PyFilter = self.current_node.clone().unwrap().into();
Ok(filter)
self.current_node
.as_ref()
.map(|plan| plan.clone().into())
.ok_or(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"current_node was None",
))
}

/// LogicalPlan::Join as PyJoin
pub fn join(&self) -> PyResult<join::PyJoin> {
let join: join::PyJoin = self.current_node.clone().unwrap().into();
Ok(join)
self.current_node
.as_ref()
.map(|plan| plan.clone().into())
.ok_or(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"current_node was None",
))
}

/// LogicalPlan::Aggregate as PyAggregate
pub fn aggregate(&self) -> PyResult<aggregate::PyAggregate> {
let agg: aggregate::PyAggregate = self.current_node.clone().unwrap().into();
Ok(agg)
/// LogicalPlan::Projection as PyProjection
pub fn projection(&self) -> PyResult<projection::PyProjection> {
self.current_node
.as_ref()
.map(|plan| plan.clone().into())
.ok_or(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"current_node was None",
))
}

/// LogicalPlan::Sort as PySort
pub fn sort(&self) -> PyResult<sort::PySort> {
self.current_node
.as_ref()
.map(|plan| plan.clone().into())
.ok_or(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"current_node was None",
))
}

/// Gets the "input" for the current LogicalPlan
Expand Down
78 changes: 78 additions & 0 deletions dask_planner/src/sql/logical/sort.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
use crate::expression::PyExpr;

use datafusion_expr::logical_plan::Sort;
pub use datafusion_expr::{logical_plan::LogicalPlan, Expr};

use crate::sql::exceptions::py_type_err;
use pyo3::prelude::*;

#[pyclass(name = "Sort", module = "dask_planner", subclass)]
#[derive(Clone)]
pub struct PySort {
sort: Sort,
}

impl PySort {
/// Returns if a sort expressions denotes an ascending sort
fn is_ascending(&self, expr: Expr) -> bool {
match expr {
Expr::Sort {
expr: _,
asc,
nulls_first: _,
} => asc,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was trying to figure out a way to use some of the utilities in expressions.rs instead of explicitly matching here but iiic we need to create a PyExpr, get the expr type and then match to this if the result is a Sort expr type which may not be preferred.

_ => panic!("Provided expression is not a sort epxression"),
}
}
/// Returns if nulls should be placed first in a sort expression
fn is_nulls_first(&self, expr: Expr) -> bool {
match &expr {
Expr::Sort {
expr: _,
asc: _,
nulls_first,
} => nulls_first.clone(),
_ => panic!("Provided expression is not a sort epxression"),
}
}
}
#[pymethods]
impl PySort {
/// Returns a Vec of the sort expressions
#[pyo3(name = "getCollation")]
pub fn sort_expressions(&self) -> PyResult<Vec<PyExpr>> {
let mut sort_exprs: Vec<PyExpr> = Vec::new();
for expr in &self.sort.expr {
sort_exprs.push(PyExpr::from(expr.clone(), Some(self.sort.input.clone())));
}
Ok(sort_exprs)
}

#[pyo3(name = "getAscending")]
pub fn get_ascending(&self) -> PyResult<Vec<bool>> {
let mut is_ascending: Vec<bool> = Vec::new();
for sortexpr in &self.sort.expr {
is_ascending.push(self.is_ascending(sortexpr.clone()))
}
Ok(is_ascending)
}
#[pyo3(name = "getNullsFirst")]
pub fn get_nulls_first(&self) -> PyResult<Vec<bool>> {
let nulls_first: Vec<bool> = self
.sort
.expr
.iter()
.map(|sortexpr| self.is_nulls_first(sortexpr.clone()))
.collect::<Vec<bool>>();
Ok(nulls_first)
}
}

impl From<LogicalPlan> for PySort {
fn from(logical_plan: LogicalPlan) -> PySort {
match logical_plan {
LogicalPlan::Sort(srt) => PySort { sort: srt },
_ => panic!("something went wrong here"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to change just for reference #510

}
}
}
10 changes: 6 additions & 4 deletions dask_sql/physical/rel/logical/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

from dask_sql.datacontainer import DataContainer
from dask_sql.physical.rel.base import BaseRelPlugin

# from dask_sql.physical.utils.sort import apply_sort
from dask_sql.physical.utils.sort import apply_sort

if TYPE_CHECKING:
import dask_sql
Expand All @@ -21,7 +20,10 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai
(dc,) = self.assert_inputs(rel, 1, context)
df = dc.df
cc = dc.column_container

sort_expressions = rel.sort().getCollation()
sort_columns = [expr.column_name(rel) for expr in sort_expressions]
sort_ascending = rel.sort().getAscending()
sort_null_first = rel.sort().getNullsFirst()
# TODO: Commented out to pass flake8, will be fixed in sort PR
# sort_collation = rel.getCollation().getFieldCollations()
# sort_columns = [
Expand All @@ -35,7 +37,7 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai
# sort_null_first = [x.nullDirection == FIRST for x in sort_collation]

df = df.persist()
# df = apply_sort(df, sort_columns, sort_ascending, sort_null_first)
df = apply_sort(df, sort_columns, sort_ascending, sort_null_first)

cc = self.fix_column_to_row_type(cc, rel.getRowType())
# No column type has changed, so no need to cast again
Expand Down
12 changes: 6 additions & 6 deletions tests/integration/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tests.utils import assert_eq


@pytest.mark.skip(reason="WIP DataFusion")
# @pytest.mark.skip(reason="WIP DataFusion")
@pytest.mark.parametrize(
"input_table_1,input_df",
[
Expand Down Expand Up @@ -90,7 +90,7 @@ def test_sort_by_alias(c, input_table_1, request):
assert_eq(df_result, df_expected, check_index=False)


@pytest.mark.skip(reason="WIP DataFusion")
# @pytest.mark.skip(reason="WIP DataFusion")
@pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)])
def test_sort_with_nan(gpu):
c = Context()
Expand Down Expand Up @@ -181,7 +181,7 @@ def test_sort_with_nan(gpu):
)


@pytest.mark.skip(reason="WIP DataFusion")
# @pytest.mark.skip(reason="WIP DataFusion")
@pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)])
def test_sort_with_nan_more_columns(gpu):
c = Context()
Expand Down Expand Up @@ -240,7 +240,7 @@ def test_sort_with_nan_more_columns(gpu):
)


@pytest.mark.skip(reason="WIP DataFusion")
# @pytest.mark.skip(reason="WIP DataFusion")
@pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)])
def test_sort_with_nan_many_partitions(gpu):
c = Context()
Expand Down Expand Up @@ -281,7 +281,7 @@ def test_sort_with_nan_many_partitions(gpu):
)


@pytest.mark.skip(reason="WIP DataFusion")
# @pytest.mark.skip(reason="WIP DataFusion")
@pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)])
def test_sort_strings(c, gpu):
string_table = pd.DataFrame({"a": ["zzhsd", "öfjdf", "baba"]})
Expand All @@ -301,7 +301,7 @@ def test_sort_strings(c, gpu):
assert_eq(df_result, df_expected, check_index=False)


@pytest.mark.skip(reason="WIP DataFusion")
# @pytest.mark.skip(reason="WIP DataFusion")
@pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)])
def test_sort_not_allowed(c, gpu):
table_name = "gpu_user_table_1" if gpu else "user_table_1"
Expand Down