Skip to content

Commit 61cbca9

Browse files
committed
Added initial fix for date type mismatch
1 parent 5b8f8a9 commit 61cbca9

4 files changed

Lines changed: 29 additions & 5 deletions

File tree

dask_sql/mappings.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ def sql_to_python_value(sql_type: str, literal_value: Any) -> Any:
163163
dt = datetime.fromtimestamp(
164164
int(literal_value.getTimeInMillis()) / 1000, timezone.utc
165165
)
166-
166+
if sql_type == "DATE":
167+
return dt.date()
167168
return dt
168169

169170
elif sql_type.startswith("DECIMAL("):

dask_sql/physical/rex/core/call.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,14 +196,16 @@ def cast(self, operand, rex=None) -> SeriesOrScalar:
196196
if not is_frame(operand):
197197
return operand
198198

199-
output_type = str(rex.getType())
200-
output_type = sql_to_python_type(output_type.upper())
201-
202-
return_column = cast_column_to_type(operand, output_type)
199+
sql_output_type = str(rex.getType())
200+
python_output_type = sql_to_python_type(sql_output_type.upper())
201+
return_column = cast_column_to_type(operand, python_output_type)
203202

204203
if return_column is None:
205204
return operand
206205
else:
206+
# handle datetime type specially
207+
if sql_output_type == "DATE":
208+
return_column = return_column.dt.date
207209
return return_column
208210

209211

tests/integration/test_filter.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,14 @@ def test_string_filter(c, string_table):
6868
assert_frame_equal(
6969
return_df, string_table.head(1),
7070
)
71+
72+
73+
def test_datetime_filter(c):
74+
df = pd.DataFrame(
75+
{"d_date": ["2001-08-01", "2001-08-02", "2001-08-03"], "val": [1, 2, 3]}
76+
)
77+
c.create_table("datetime_tbl1", df)
78+
query = "SELECT val, d_date FROM datetime_tbl1 WHERE CAST(d_date as date) IN (date '2001-08-01', date '2001-08-03')"
79+
result_df = c.sql(query).compute().reset_index(drop=True)
80+
expected_df = pd.DataFrame({"val": [1, 3], "d_date": ["2001-08-01", "2001-08-03"]})
81+
assert_frame_equal(result_df, expected_df, check_dtype=False)

tests/integration/test_select.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,13 @@ def test_timezones(c, datetime_table):
115115
result_df = result_df.compute()
116116

117117
assert_frame_equal(result_df, datetime_table)
118+
119+
120+
def test_date_casting(c, datetime_table):
121+
# check date casting
122+
query = "SELECT cast(timezone as date) as date1,cast(utc_timezone as date) as date2 FROM datetime_table "
123+
result_df = c.sql(query).compute().astype(str)
124+
expected_df = pd.DataFrame(
125+
{"date1": ["2014-08-01"] * 3, "date2": ["2014-08-01"] * 3}
126+
)
127+
assert_frame_equal(result_df, expected_df, check_dtype=False)

0 commit comments

Comments
 (0)