Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dask_planner/src/sql/logical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ impl PyLogicalPlan {
.iter()
.map(|f| RelDataTypeField::from(f, schema.as_ref()))
.collect::<Result<Vec<_>>>()
.map_err(py_type_err)?;
.map_err(|e| py_type_err(e))?;
Ok(RelDataType::new(false, rel_fields))
}
}
Expand Down
25 changes: 17 additions & 8 deletions dask_sql/physical/rex/core/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
import operator
import re
from functools import reduce
from functools import partial, reduce
from typing import TYPE_CHECKING, Any, Callable, Union

import dask.array as da
Expand Down Expand Up @@ -36,6 +36,17 @@
SeriesOrScalar = Union[dd.Series, Any]


def as_timelike(op):
if isinstance(op, np.int64):
return np.timedelta64(op, "D")
elif isinstance(op, str):
return np.datetime64(op)
elif pd.api.types.is_datetime64_dtype(op):
return op
else:
raise ValueError(f"Don't know how to make {type(op)} timelike")


class Operation:
"""Helper wrapper around a function, which is used as operator"""

Expand Down Expand Up @@ -115,10 +126,9 @@ def __init__(self, operation: Callable, unary_operation: Callable = None):

def reduce(self, *operands, **kwargs):
if len(operands) > 1:
enriched_with_kwargs = lambda kwargs: (
lambda x, y: self.operation(x, y, **kwargs)
)
return reduce(enriched_with_kwargs(kwargs), operands)
if any(map(pd.api.types.is_datetime64_dtype, operands)):
operands = tuple(map(as_timelike, operands))
return reduce(partial(self.operation, **kwargs), operands)
else:
return self.unary_operation(*operands, **kwargs)

Expand Down Expand Up @@ -169,11 +179,10 @@ def div(self, lhs, rhs):
# We do not need to truncate in this case
# So far, I did not spot any other occurrence
# of this function.
if isinstance(result, datetime.timedelta):
if isinstance(result, (datetime.timedelta, np.timedelta64)):
return result
else: # pragma: no cover
result = da.trunc(result)
return result
return da.trunc(result).astype(np.int64)


class CaseOperation(Operation):
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/test_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_string_filter(c, string_table):
"datetime_table",
pytest.param(
"gpu_datetime_table",
marks=(pytest.mark.gpu, pytest.mark.skip(reason="WIP DataFusion")),
marks=(pytest.mark.gpu),
),
],
)
Expand All @@ -100,7 +100,7 @@ def test_filter_cast_date(c, input_table, request):
"datetime_table",
pytest.param(
"gpu_datetime_table",
marks=(pytest.mark.gpu, pytest.mark.skip(reason="WIP DataFusion")),
marks=(pytest.mark.gpu),
),
],
)
Expand Down