Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
863fcb3
Implement PySort logical_plan
ayushdg May 5, 2022
f565c2a
Add a sort plan accessor to the logical plan
ayushdg May 5, 2022
75933b8
Python: Update the sort plugin
ayushdg May 5, 2022
e882408
Python: Uncomment tests
ayushdg May 5, 2022
5f7e8c6
Merge branch 'datafusion-sql-planner' of github.com:dask-contrib/dask…
ayushdg May 10, 2022
b6c534e
PLanner: Update accessor pattern for concrete logical plan implementa…
ayushdg May 10, 2022
82b4234
Test: Address review comments
ayushdg May 10, 2022
ee99096
add support for expr_to_field for Expr::Sort expressions
andygrove May 10, 2022
5d1a561
Merge branch 'datafusion-sql-planner' of github.com:dask-contrib/dask…
ayushdg May 10, 2022
2932a28
Merge commit 'refs/pull/515/head' of github.com:dask-contrib/dask-sql…
ayushdg May 10, 2022
05c6a85
Planner: Update sort expr utilities and import cleanup
ayushdg May 11, 2022
f9f569e
Python: Re-enable skipped sort tests
ayushdg May 11, 2022
0840870
Merge branch 'datafusion-sql-planner' of github.com:dask-contrib/dask…
ayushdg May 11, 2022
8d81c44
Python: Handle case where orderby column name is an alias
ayushdg May 11, 2022
0a75c32
Apply suggestions from code review
ayushdg May 11, 2022
74d3451
Style: Fix formatting
ayushdg May 11, 2022
8aaf8fe
Merge branch 'datafusion-sql-planner' of github.com:dask-contrib/dask…
ayushdg May 12, 2022
6bbb8d7
Planner: Remove public scope for LogicalPlan import
ayushdg May 12, 2022
36b8460
Python: Add more complex sort tests with alias that error right now
ayushdg May 12, 2022
315fad2
Python: Remove old commented code
ayushdg May 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dask_planner/src/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ impl PyExpr {
Expr::Case { .. } => panic!("Case!!!"),
Expr::Cast { .. } => "Cast",
Expr::TryCast { .. } => panic!("TryCast!!!"),
Expr::Sort { .. } => panic!("Sort!!!"),
Expr::Sort { .. } => "Sort",
Expr::ScalarFunction { .. } => "ScalarFunction",
Expr::AggregateFunction { .. } => "AggregateFunction",
Expr::WindowFunction { .. } => panic!("WindowFunction!!!"),
Expand Down
11 changes: 11 additions & 0 deletions dask_planner/src/sql/logical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ mod explain;
mod filter;
mod join;
pub mod projection;
mod sort;

pub use datafusion::logical_expr::LogicalPlan;

Expand Down Expand Up @@ -100,6 +101,16 @@ impl PyLogicalPlan {
))
}

/// LogicalPlan::Sort as PySort
pub fn sort(&self) -> PyResult<sort::PySort> {
self.current_node
.as_ref()
.map(|plan| plan.clone().into())
.ok_or(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"current_node was None",
))
}

/// Gets the "input" for the current LogicalPlan
pub fn get_inputs(&mut self) -> PyResult<Vec<PyLogicalPlan>> {
let mut py_inputs: Vec<PyLogicalPlan> = Vec::new();
Expand Down
71 changes: 71 additions & 0 deletions dask_planner/src/sql/logical/sort.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
use crate::expression::PyExpr;

use datafusion::logical_expr::{logical_plan::Sort, Expr, LogicalPlan};
use pyo3::prelude::*;

#[pyclass(name = "Sort", module = "dask_planner", subclass)]
#[derive(Clone)]
pub struct PySort {
sort: Sort,
}

impl PySort {
/// Returns if a sort expressions denotes an ascending sort
fn is_ascending(&self, expr: &Expr) -> Result<bool, PyErr> {
match expr {
Expr::Sort { asc, .. } => Ok(asc.clone()),
_ => Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(format!(
"Provided Expr {:?} is not a sort type",
expr
))),
}
}
/// Returns if nulls should be placed first in a sort expression
fn is_nulls_first(&self, expr: &Expr) -> Result<bool, PyErr> {
match &expr {
Expr::Sort { nulls_first, .. } => Ok(nulls_first.clone()),
_ => Err(PyErr::new::<pyo3::exceptions::PyTypeError, _>(format!(
"Provided Expr {:?} is not a sort type",
expr
))),
}
}
}
#[pymethods]
impl PySort {
/// Returns a Vec of the sort expressions
#[pyo3(name = "getCollation")]
pub fn sort_expressions(&self) -> PyResult<Vec<PyExpr>> {
let mut sort_exprs: Vec<PyExpr> = Vec::new();
for expr in &self.sort.expr {
sort_exprs.push(PyExpr::from(expr.clone(), Some(self.sort.input.clone())));
}
Ok(sort_exprs)
}

#[pyo3(name = "getAscending")]
pub fn get_ascending(&self) -> PyResult<Vec<bool>> {
self.sort
.expr
.iter()
.map(|sortexpr| self.is_ascending(sortexpr))
.collect::<Result<Vec<_>, _>>()
}
#[pyo3(name = "getNullsFirst")]
pub fn get_nulls_first(&self) -> PyResult<Vec<bool>> {
self.sort
.expr
.iter()
.map(|sortexpr| self.is_nulls_first(sortexpr))
.collect::<Result<Vec<_>, _>>()
}
}

impl From<LogicalPlan> for PySort {
fn from(logical_plan: LogicalPlan) -> PySort {
match logical_plan {
LogicalPlan::Sort(srt) => PySort { sort: srt },
_ => panic!("something went wrong here"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to change just for reference #510

}
}
}
24 changes: 9 additions & 15 deletions dask_sql/physical/rel/logical/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@

from dask_sql.datacontainer import DataContainer
from dask_sql.physical.rel.base import BaseRelPlugin

# from dask_sql.physical.utils.sort import apply_sort
from dask_sql.physical.utils.sort import apply_sort

if TYPE_CHECKING:
import dask_sql
Expand All @@ -21,21 +20,16 @@ def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContai
(dc,) = self.assert_inputs(rel, 1, context)
df = dc.df
cc = dc.column_container

# TODO: Commented out to pass flake8, will be fixed in sort PR
# sort_collation = rel.getCollation().getFieldCollations()
# sort_columns = [
# cc.get_backend_by_frontend_index(int(x.getFieldIndex()))
# for x in sort_collation
# ]

# ASCENDING = org.apache.calcite.rel.RelFieldCollation.Direction.ASCENDING
# FIRST = org.apache.calcite.rel.RelFieldCollation.NullDirection.FIRST
# sort_ascending = [x.getDirection() == ASCENDING for x in sort_collation]
# sort_null_first = [x.nullDirection == FIRST for x in sort_collation]
sort_expressions = rel.sort().getCollation()
sort_columns = [
cc.get_backend_by_frontend_name(expr.column_name(rel))
for expr in sort_expressions
]
sort_ascending = rel.sort().getAscending()
sort_null_first = rel.sort().getNullsFirst()

df = df.persist()
# df = apply_sort(df, sort_columns, sort_ascending, sort_null_first)
df = apply_sort(df, sort_columns, sort_ascending, sort_null_first)

cc = self.fix_column_to_row_type(cc, rel.getRowType())
# No column type has changed, so no need to cast again
Expand Down
57 changes: 50 additions & 7 deletions tests/integration/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from tests.utils import assert_eq


@pytest.mark.skip(reason="WIP DataFusion")
@pytest.mark.parametrize(
"input_table_1,input_df",
[
Expand Down Expand Up @@ -67,7 +66,6 @@ def test_sort(c, input_table_1, input_df, request):
assert_eq(df_result, df_expected, check_index=False)


@pytest.mark.skip(reason="WIP DataFusion")
@pytest.mark.parametrize(
"input_table_1",
["user_table_1", pytest.param("gpu_user_table_1", marks=pytest.mark.gpu)],
Expand All @@ -90,7 +88,6 @@ def test_sort_by_alias(c, input_table_1, request):
assert_eq(df_result, df_expected, check_index=False)


@pytest.mark.skip(reason="WIP DataFusion")
@pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)])
def test_sort_with_nan(gpu):
c = Context()
Expand Down Expand Up @@ -181,7 +178,6 @@ def test_sort_with_nan(gpu):
)


@pytest.mark.skip(reason="WIP DataFusion")
@pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)])
def test_sort_with_nan_more_columns(gpu):
c = Context()
Expand Down Expand Up @@ -240,7 +236,6 @@ def test_sort_with_nan_more_columns(gpu):
)


@pytest.mark.skip(reason="WIP DataFusion")
@pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)])
def test_sort_with_nan_many_partitions(gpu):
c = Context()
Expand Down Expand Up @@ -281,7 +276,6 @@ def test_sort_with_nan_many_partitions(gpu):
)


@pytest.mark.skip(reason="WIP DataFusion")
@pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)])
def test_sort_strings(c, gpu):
string_table = pd.DataFrame({"a": ["zzhsd", "öfjdf", "baba"]})
Expand All @@ -301,11 +295,60 @@ def test_sort_strings(c, gpu):
assert_eq(df_result, df_expected, check_index=False)


@pytest.mark.skip(reason="WIP DataFusion")
@pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)])
def test_sort_not_allowed(c, gpu):
table_name = "gpu_user_table_1" if gpu else "user_table_1"

# Wrong column
with pytest.raises(Exception):
c.sql(f"SELECT * FROM {table_name} ORDER BY 42")


@pytest.mark.xfail(Reason="Projection step before sort currently failing")
@pytest.mark.parametrize(
"input_table_1",
["user_table_1", pytest.param("gpu_user_table_1", marks=pytest.mark.gpu)],
)
def test_sort_by_old_alias(c, input_table_1, request):
user_table_1 = request.getfixturevalue(input_table_1)

df_result = c.sql(
f"""
SELECT
b AS my_column
FROM {input_table_1}
ORDER BY b, user_id DESC
"""
).rename(columns={"my_column": "b"})
df_expected = user_table_1.sort_values(["b", "user_id"], ascending=[True, False])[
["b"]
]

assert_eq(df_result, df_expected, check_index=False)

df_result = c.sql(
f"""
SELECT
b*-1 AS my_column
FROM {input_table_1}
ORDER BY b, user_id DESC
"""
).rename(columns={"my_column": "b"})
df_expected = user_table_1.sort_values(["b", "user_id"], ascending=[True, False])[
["b"]
]
df_expected["b"] *= -1
assert_eq(df_result, df_expected, check_index=False)

df_result = c.sql(
f"""
SELECT
b*-1 AS my_column
FROM {input_table_1}
ORDER BY my_column, user_id DESC
"""
).rename(columns={"my_column": "b"})
df_expected["b"] *= -1
df_expected = user_table_1.sort_values(["b", "user_id"], ascending=[True, False])[
["b"]
]