diff --git a/dask_sql/mappings.py b/dask_sql/mappings.py index 5df6a42bd..3d98c8580 100644 --- a/dask_sql/mappings.py +++ b/dask_sql/mappings.py @@ -163,7 +163,8 @@ def sql_to_python_value(sql_type: str, literal_value: Any) -> Any: dt = datetime.fromtimestamp( int(literal_value.getTimeInMillis()) / 1000, timezone.utc ) - + if sql_type == "DATE": + return dt.date() return dt elif sql_type.startswith("DECIMAL("): diff --git a/dask_sql/physical/rex/core/call.py b/dask_sql/physical/rex/core/call.py index 4b042ede0..2dbf7caef 100644 --- a/dask_sql/physical/rex/core/call.py +++ b/dask_sql/physical/rex/core/call.py @@ -196,14 +196,16 @@ def cast(self, operand, rex=None) -> SeriesOrScalar: if not is_frame(operand): return operand - output_type = str(rex.getType()) - output_type = sql_to_python_type(output_type.upper()) - - return_column = cast_column_to_type(operand, output_type) + sql_output_type = str(rex.getType()) + python_output_type = sql_to_python_type(sql_output_type.upper()) + return_column = cast_column_to_type(operand, python_output_type) if return_column is None: return operand else: + # handle datetime type specially + if sql_output_type == "DATE": + return_column = return_column.dt.date return return_column diff --git a/tests/integration/test_filter.py b/tests/integration/test_filter.py index b668d1c51..0c599065c 100644 --- a/tests/integration/test_filter.py +++ b/tests/integration/test_filter.py @@ -68,3 +68,14 @@ def test_string_filter(c, string_table): assert_frame_equal( return_df, string_table.head(1), ) + + +def test_datetime_filter(c): + df = pd.DataFrame( + {"d_date": ["2001-08-01", "2001-08-02", "2001-08-03"], "val": [1, 2, 3]} + ) + c.create_table("datetime_tbl1", df) + query = "SELECT val, d_date FROM datetime_tbl1 WHERE CAST(d_date as date) IN (date '2001-08-01', date '2001-08-03')" + result_df = c.sql(query).compute().reset_index(drop=True) + expected_df = pd.DataFrame({"val": [1, 3], "d_date": ["2001-08-01", "2001-08-03"]}) + assert_frame_equal(result_df, expected_df, check_dtype=False) diff --git a/tests/integration/test_select.py b/tests/integration/test_select.py index 4ca56a31e..397b7fb77 100644 --- a/tests/integration/test_select.py +++ b/tests/integration/test_select.py @@ -115,3 +115,13 @@ def test_timezones(c, datetime_table): result_df = result_df.compute() assert_frame_equal(result_df, datetime_table) + + +def test_date_casting(c, datetime_table): + # check date casting + query = "SELECT cast(timezone as date) as date1,cast(utc_timezone as date) as date2 FROM datetime_table " + result_df = c.sql(query).compute().astype(str) + expected_df = pd.DataFrame( + {"date1": ["2014-08-01"] * 3, "date2": ["2014-08-01"] * 3} + ) + assert_frame_equal(result_df, expected_df, check_dtype=False)