diff --git a/dask_sql/mappings.py b/dask_sql/mappings.py index 09828998c..500d7c79d 100644 --- a/dask_sql/mappings.py +++ b/dask_sql/mappings.py @@ -1,5 +1,5 @@ import logging -from datetime import datetime, timedelta, timezone +from datetime import timedelta from typing import Any import dask.array as da @@ -77,7 +77,9 @@ "VARCHAR": pd.StringDtype(), "CHAR": pd.StringDtype(), "STRING": pd.StringDtype(), # Although not in the standard, makes compatibility easier - "DATE": np.dtype(" Any: tz = literal_value.getTimeZone().getID() assert str(tz) == "UTC", "The code can currently only handle UTC timezones" - dt = datetime.fromtimestamp( - int(literal_value.getTimeInMillis()) / 1000, timezone.utc - ) - - return dt + dt = np.datetime64(literal_value.getTimeInMillis(), "ms") + if sql_type == "DATE": + return dt.astype(" SeriesOrScalar: return operand output_type = str(rex.getType()) - output_type = sql_to_python_type(output_type.upper()) + python_type = sql_to_python_type(output_type.upper()) - return_column = cast_column_to_type(operand, output_type) + return_column = cast_column_to_type(operand, python_type) if return_column is None: - return operand - else: - return return_column + return_column = operand + + # TODO: ideally we don't want to directly access the datetimes, + # but Pandas can't truncate timezone datetimes and cuDF can't + # truncate datetimes + if output_type == "DATE": + return return_column.dt.floor("D").astype(python_type) + + return return_column class IsFalseOperation(Operation): diff --git a/tests/integration/fixtures.py b/tests/integration/fixtures.py index aa381de88..ff5f9a23e 100644 --- a/tests/integration/fixtures.py +++ b/tests/integration/fixtures.py @@ -86,11 +86,13 @@ def datetime_table(): return pd.DataFrame( { "timezone": pd.date_range( - start="2014-08-01 09:00", freq="H", periods=3, tz="Europe/Berlin" + start="2014-08-01 09:00", freq="8H", periods=6, tz="Europe/Berlin" + ), + "no_timezone": pd.date_range( + start="2014-08-01 09:00", freq="8H", periods=6 ), - "no_timezone": pd.date_range(start="2014-08-01 09:00", freq="H", periods=3), "utc_timezone": pd.date_range( - start="2014-08-01 09:00", freq="H", periods=3, tz="UTC" + start="2014-08-01 09:00", freq="8H", periods=6, tz="UTC" ), } ) @@ -116,6 +118,11 @@ def gpu_string_table(string_table): return cudf.from_pandas(string_table) if cudf else None +@pytest.fixture() +def gpu_datetime_table(datetime_table): + return cudf.from_pandas(datetime_table) if cudf else None + + @pytest.fixture() def c( df_simple, @@ -131,6 +138,7 @@ def c( gpu_df, gpu_long_table, gpu_string_table, + gpu_datetime_table, ): dfs = { "df_simple": df_simple, @@ -146,6 +154,7 @@ def c( "gpu_df": gpu_df, "gpu_long_table": gpu_long_table, "gpu_string_table": gpu_string_table, + "gpu_datetime_table": gpu_datetime_table, } # Lazy import, otherwise the pytest framework has problems diff --git a/tests/integration/test_filter.py b/tests/integration/test_filter.py index 3b1906910..ad98d4416 100644 --- a/tests/integration/test_filter.py +++ b/tests/integration/test_filter.py @@ -1,4 +1,6 @@ +import dask.dataframe as dd import pandas as pd +import pytest from pandas.testing import assert_frame_equal from dask_sql._compat import INT_NAN_IMPLEMENTED @@ -70,7 +72,47 @@ def test_string_filter(c, string_table): ) -def test_filter_datetime(c): +@pytest.mark.parametrize( + "input_table", + ["datetime_table", pytest.param("gpu_datetime_table", marks=pytest.mark.gpu),], +) +def test_filter_cast_date(c, input_table, request): + datetime_table = request.getfixturevalue(input_table) + return_df = c.sql( + f""" + SELECT * FROM {input_table} WHERE + CAST(timezone AS DATE) > DATE '2014-08-01' + """ + ) + + expected_df = datetime_table[ + datetime_table["timezone"].astype(" pd.Timestamp("2014-08-01") + ] + dd.assert_eq(return_df, expected_df) + + +@pytest.mark.parametrize( + "input_table", + ["datetime_table", pytest.param("gpu_datetime_table", marks=pytest.mark.gpu),], +) +def test_filter_cast_timestamp(c, input_table, request): + datetime_table = request.getfixturevalue(input_table) + return_df = c.sql( + f""" + SELECT * FROM {input_table} WHERE + CAST(timezone AS TIMESTAMP) >= TIMESTAMP '2014-08-01 23:00:00' + """ + ) + + expected_df = datetime_table[ + datetime_table["timezone"].astype("= pd.Timestamp("2014-08-01 23:00:00") + ] + dd.assert_eq(return_df, expected_df) + + +def test_filter_year(c): df = pd.DataFrame({"year": [2015, 2016], "month": [2, 3], "day": [4, 5]}) df["dt"] = pd.to_datetime(df) diff --git a/tests/integration/test_select.py b/tests/integration/test_select.py index 437631cef..f7c20df3d 100644 --- a/tests/integration/test_select.py +++ b/tests/integration/test_select.py @@ -1,3 +1,4 @@ +import dask.dataframe as dd import numpy as np import pandas as pd import pytest @@ -118,6 +119,56 @@ def test_timezones(c, datetime_table): assert_frame_equal(result_df, datetime_table) +@pytest.mark.parametrize( + "input_table", + ["datetime_table", pytest.param("gpu_datetime_table", marks=pytest.mark.gpu),], +) +def test_date_casting(c, input_table, request): + datetime_table = request.getfixturevalue(input_table) + result_df = c.sql( + f""" + SELECT + CAST(timezone AS DATE) AS timezone, + CAST(no_timezone AS DATE) AS no_timezone, + CAST(utc_timezone AS DATE) AS utc_timezone + FROM {input_table} + """ + ) + + expected_df = datetime_table + expected_df["timezone"] = ( + expected_df["timezone"].astype("