Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions dask_sql/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def sql_to_python_type(sql_type: str) -> type:
return pd.StringDtype()
elif sql_type.startswith("INTERVAL"):
return np.dtype("<m8[ns]")

elif sql_type.startswith("TIMESTAMP(") or sql_type.startswith("TIME("):
return np.dtype("<M8[ns]")
elif sql_type.startswith("TIMESTAMP_WITH_LOCAL_TIME_ZONE("):
Expand Down Expand Up @@ -287,15 +288,17 @@ def cast_column_to_type(col: dd.Series, expected_type: str):
logger.debug("...not converting.")
return None

current_float = pd.api.types.is_float_dtype(current_type)
expected_integer = pd.api.types.is_integer_dtype(expected_type)
if current_float and expected_integer:
logger.debug("...truncating...")
# Currently "trunc" can not be applied to NA (the pandas missing value type),
# because NA is a different type. It works with np.NaN though.
# For our use case, that does not matter, as the conversion to integer later
# will convert both NA and np.NaN to NA.
col = da.trunc(col.fillna(value=np.NaN))
if pd.api.types.is_integer_dtype(expected_type):
if pd.api.types.is_float_dtype(current_type):
logger.debug("...truncating...")
# Currently "trunc" can not be applied to NA (the pandas missing value type),
# because NA is a different type. It works with np.NaN though.
# For our use case, that does not matter, as the conversion to integer later
# will convert both NA and np.NaN to NA.
col = da.trunc(col.fillna(value=np.NaN))
elif pd.api.types.is_timedelta64_dtype(current_type):
logger.debug(f"Explicitly casting from {current_type} to np.int64")
return col.astype(np.int64)

logger.debug(f"Need to cast from {current_type} to {expected_type}")
return col.astype(expected_type)
65 changes: 58 additions & 7 deletions dask_sql/physical/rex/core/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from dask.utils import random_state_data

from dask_sql.datacontainer import DataContainer
from dask_sql.java import get_java_class
from dask_sql.mappings import cast_column_to_type, sql_to_python_type
from dask_sql.physical.rex import RexConverter
from dask_sql.physical.rex.base import BaseRexPlugin
Expand Down Expand Up @@ -168,11 +169,10 @@ def div(self, lhs, rhs):
# We do not need to truncate in this case
# So far, I did not spot any other occurrence
# of this function.
if isinstance(result, datetime.timedelta):
return result
else: # pragma: no cover
result = da.trunc(result)
if isinstance(result, (datetime.timedelta, np.timedelta64)):
return result
else:
return da.trunc(result).astype(np.int64)


class CaseOperation(Operation):
Expand Down Expand Up @@ -220,9 +220,6 @@ def __init__(self):
super().__init__(self.cast)

def cast(self, operand, rex=None) -> SeriesOrScalar:
if not is_frame(operand): # pragma: no cover
return operand

output_type = str(rex.getType())
python_type = sql_to_python_type(output_type.upper())

Expand Down Expand Up @@ -715,6 +712,43 @@ def search(self, series: dd.Series, sarg: SargPythonImplementation):
return conditions[0]


class DatetimeSubOperation(Operation):
"""
Datetime subtraction is a special case of the `minus` operation in calcite
which also specifies a sql interval return type for the operation.
"""

needs_rex = True

def __init__(self):
super().__init__(self.datetime_sub)

def datetime_sub(self, *operands, rex=None):
output_type = str(rex.getType())
assert output_type.startswith("INTERVAL")
interval_unit = output_type.split()[1].lower()

subtraction_op = ReduceOperation(
operation=operator.sub, unary_operation=lambda x: -x
)
intermediate_res = subtraction_op(*operands)

# Special case output_type for datetime operations
if interval_unit in {"year", "quarter", "month"}:
# if interval_unit is INTERVAL YEAR, Calcite will covert to months
if not is_frame(intermediate_res):
# Numpy doesn't allow divsion by month time unit
result = intermediate_res.astype("timedelta64[M]")
# numpy -ve timedelta's are off by one vs sql when casted to month
result = result + 1 if result < 0 else result
else:
result = intermediate_res / np.timedelta64(1, "M")
else:
result = intermediate_res.astype("timedelta64[ms]")

return result


class RexCallPlugin(BaseRexPlugin):
"""
RexCall is used for expressions, which calculate something.
Expand Down Expand Up @@ -752,6 +786,7 @@ class RexCallPlugin(BaseRexPlugin):
"/int": IntDivisionOperator(),
# special operations
"cast": CastOperation(),
"reinterpret": CastOperation(),
"case": CaseOperation(),
"like": LikeOperation(),
"similar to": SimilarOperation(),
Expand Down Expand Up @@ -812,6 +847,7 @@ class RexCallPlugin(BaseRexPlugin):
lambda x: x + pd.tseries.offsets.MonthEnd(1),
lambda x: convert_to_datetime(x) + pd.tseries.offsets.MonthEnd(1),
),
"datetime_subtraction": DatetimeSubOperation(),
}

def convert(
Expand All @@ -827,6 +863,8 @@ def convert(

# Now use the operator name in the mapping
schema_name, operator_name = context.fqn(rex.getOperator().getNameAsId())
if special_op := check_special_operator(rex.getOperator()):
operator_name = special_op
operator_name = operator_name.lower()

try:
Expand All @@ -850,3 +888,16 @@ def convert(

return operation(*operands, **kwargs)
# TODO: We have information on the typing here - we should use it


def check_special_operator(operator: "org.apache.calcite.sql.fun"):
"""
Check for special operator classes that have an overloaded name with other
operator type/kinds.

eg: sqlDatetimeSubtractionOperator has the sqltype and kind of the `-` or `minus` operation.
"""
special_op_to_name = {
"org.apache.calcite.sql.fun.SqlDatetimeSubtractionOperator": "datetime_subtraction"
}
return special_op_to_name.get(get_java_class(operator), None)
89 changes: 89 additions & 0 deletions tests/integration/test_rex.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,3 +587,92 @@ def test_date_functions(c):
FROM df
"""
)


@pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)])
def test_timestampdiff(c, gpu):
# single value test
ts_literal1 = "2002-03-07 09:10:05.123"
ts_literal2 = "2001-06-05 10:11:06.234"
query = (
f"SELECT timestampdiff(NANOSECOND, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res0,"
f"timestampdiff(MICROSECOND, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res1,"
f"timestampdiff(SECOND, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res2,"
f"timestampdiff(MINUTE, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res3,"
f"timestampdiff(HOUR, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res4,"
f"timestampdiff(DAY, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res5,"
f"timestampdiff(WEEK, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res6,"
f"timestampdiff(MONTH, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res7,"
f"timestampdiff(QUARTER, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res8,"
f"timestampdiff(YEAR, CAST('{ts_literal1}' AS TIMESTAMP),CAST('{ts_literal2}' AS TIMESTAMP)) as res9"
)
df = c.sql(query)
expected_df = pd.DataFrame(
{
"res0": [-23756339_000_000_000],
"res1": [-23756339_000_000],
"res2": [-23756339],
"res3": [-395938],
"res4": [-6598],
"res5": [-274],
"res6": [-39],
"res7": [-9],
"res8": [-3],
"res9": [0],
}
)
assert_eq(df, expected_df)
# dataframe test

test = pd.DataFrame(
{
"a": [
"2002-06-05 02:01:05.200",
"2002-09-01 00:00:00",
"1970-12-03 00:00:00",
],
"b": [
"2002-06-07 01:00:02.100",
"2003-06-05 00:00:00",
"2038-06-05 00:00:00",
],
}
)

c.create_table("test", test, gpu=gpu)
query = (
"SELECT timestampdiff(NANOSECOND, CAST(a AS TIMESTAMP), CAST(b AS TIMESTAMP)) as nanoseconds,"
"timestampdiff(MICROSECOND, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as microseconds,"
"timestampdiff(SECOND, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as seconds,"
"timestampdiff(MINUTE, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as minutes,"
"timestampdiff(HOUR, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as hours,"
"timestampdiff(DAY, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as days,"
"timestampdiff(WEEK, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as weeks,"
"timestampdiff(MONTH, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as months,"
"timestampdiff(QUARTER, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as quarters,"
"timestampdiff(YEAR, CAST(a AS TIMESTAMP),CAST(b AS TIMESTAMP)) as years"
" FROM test"
)

ddf = c.sql(query)

expected_df = pd.DataFrame(
{
"nanoseconds": [
169136_000_000_000,
23932_800_000_000_000,
2_130_278_400_000_000_000,
],
"microseconds": [169136_000_000, 23932_800_000_000, 2_130_278_400_000_000],
"seconds": [169136, 23932_800, 2_130_278_400],
"minutes": [2818, 398880, 35504640],
"hours": [46, 6648, 591744],
"days": [1, 277, 24656],
"weeks": [0, 39, 3522],
"months": [0, 9, 810],
"quarters": [0, 3, 270],
"years": [0, 0, 67],
}
)

assert_eq(ddf, expected_df, check_dtype=False)