diff --git a/planner/src/main/java/com/dask/sql/application/DaskPlanner.java b/planner/src/main/java/com/dask/sql/application/DaskPlanner.java index 47e7599b8..6240fdd1a 100644 --- a/planner/src/main/java/com/dask/sql/application/DaskPlanner.java +++ b/planner/src/main/java/com/dask/sql/application/DaskPlanner.java @@ -11,6 +11,10 @@ import com.dask.sql.rules.DaskValuesRule; import com.dask.sql.rules.DaskWindowRule; +import org.apache.calcite.config.CalciteConnectionConfig; +import org.apache.calcite.config.CalciteConnectionProperty; +import org.apache.calcite.plan.Context; +import org.apache.calcite.plan.Contexts; import org.apache.calcite.plan.ConventionTraitDef; import org.apache.calcite.plan.volcano.VolcanoPlanner; import org.apache.calcite.rel.rules.CoreRules; @@ -26,6 +30,9 @@ * as the null executor. */ public class DaskPlanner extends VolcanoPlanner { + + private final Context defaultContext; + public DaskPlanner() { // Allow transformation between logical and dask nodes addRule(DaskAggregateRule.INSTANCE); @@ -73,5 +80,13 @@ public DaskPlanner() { // We do not want to execute any SQL setExecutor(null); + + // Use our defined type system and create a default CalciteConfigContext + defaultContext = Contexts.of(CalciteConnectionConfig.DEFAULT.set( + CalciteConnectionProperty.TYPE_SYSTEM, "com.dask.sql.application.DaskSqlDialect#DASKSQL_TYPE_SYSTEM")); + } + + public Context getContext() { + return defaultContext; } } diff --git a/tests/integration/test_filter.py b/tests/integration/test_filter.py index 9168a217c..3b1906910 100644 --- a/tests/integration/test_filter.py +++ b/tests/integration/test_filter.py @@ -68,3 +68,15 @@ def test_string_filter(c, string_table): assert_frame_equal( return_df, string_table.head(1), ) + + +def test_filter_datetime(c): + df = pd.DataFrame({"year": [2015, 2016], "month": [2, 3], "day": [4, 5]}) + + df["dt"] = pd.to_datetime(df) + + c.create_table("datetime_test", df) + actual_df = c.sql("select * from datetime_test where year(dt) < 2016").compute() + expected_df = df[df["year"] < 2016] + + assert_frame_equal(expected_df, actual_df)