Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions dask_sql/physical/rex/core/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,25 @@ def __init__(self, operation: Callable, unary_operation: Callable = None):

def reduce(self, *operands, **kwargs):
if len(operands) > 1:
# Doing math against dates requires a Timedelta
# Find all the operands that are of a datetime64 type
date_operands = [
idx for idx in {0, 1} if pd.api.types.is_datetime64_dtype(operands[idx])
]

# If there are datetime64 operands we need to make sure that the other operands
# in the list are Timedelta for the operation to work.
if date_operands:
# Operands is a Set, since we are altering it must convert to a List
operands = list(operands)

# Knowing there are datetime types in the operands check for incompatable other types
# If found, convert them to Timedelta
for idx, operand in enumerate(operands):
# Default to `Day`/`D` since that is what PostgreSQL does
if isinstance(operand, (np.int64, str)):
operands[idx] = np.timedelta64(operand, "D")

enriched_with_kwargs = lambda kwargs: (
lambda x, y: self.operation(x, y, **kwargs)
)
Expand Down Expand Up @@ -169,11 +188,10 @@ def div(self, lhs, rhs):
# We do not need to truncate in this case
# So far, I did not spot any other occurrence
# of this function.
if isinstance(result, datetime.timedelta):
if isinstance(result, (datetime.timedelta, np.timedelta64)):
return result
else: # pragma: no cover
result = da.trunc(result)
return result
return da.trunc(result).astype(np.int64)


class CaseOperation(Operation):
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_string_filter(c, string_table):
"datetime_table",
pytest.param(
"gpu_datetime_table",
marks=(pytest.mark.gpu, pytest.mark.skip(reason="WIP DataFusion")),
marks=(pytest.mark.gpu),
),
],
)
Expand All @@ -100,7 +100,7 @@ def test_filter_cast_date(c, input_table, request):
"datetime_table",
pytest.param(
"gpu_datetime_table",
marks=(pytest.mark.gpu, pytest.mark.skip(reason="WIP DataFusion")),
marks=(pytest.mark.gpu),
),
],
)
Expand Down