diff --git a/dask_planner/src/sql/logical.rs b/dask_planner/src/sql/logical.rs index e53ac50b8..e85d9c985 100644 --- a/dask_planner/src/sql/logical.rs +++ b/dask_planner/src/sql/logical.rs @@ -212,7 +212,7 @@ impl PyLogicalPlan { .iter() .map(|f| RelDataTypeField::from(f, schema.as_ref())) .collect::>>() - .map_err(py_type_err)?; + .map_err(|e| py_type_err(e))?; Ok(RelDataType::new(false, rel_fields)) } } diff --git a/dask_sql/physical/rex/core/call.py b/dask_sql/physical/rex/core/call.py index 69fad753c..b6112b831 100644 --- a/dask_sql/physical/rex/core/call.py +++ b/dask_sql/physical/rex/core/call.py @@ -2,7 +2,7 @@ import logging import operator import re -from functools import reduce +from functools import partial, reduce from typing import TYPE_CHECKING, Any, Callable, Union import dask.array as da @@ -36,6 +36,17 @@ SeriesOrScalar = Union[dd.Series, Any] +def as_timelike(op): + if isinstance(op, np.int64): + return np.timedelta64(op, "D") + elif isinstance(op, str): + return np.datetime64(op) + elif pd.api.types.is_datetime64_dtype(op): + return op + else: + raise ValueError(f"Don't know how to make {type(op)} timelike") + + class Operation: """Helper wrapper around a function, which is used as operator""" @@ -115,10 +126,9 @@ def __init__(self, operation: Callable, unary_operation: Callable = None): def reduce(self, *operands, **kwargs): if len(operands) > 1: - enriched_with_kwargs = lambda kwargs: ( - lambda x, y: self.operation(x, y, **kwargs) - ) - return reduce(enriched_with_kwargs(kwargs), operands) + if any(map(pd.api.types.is_datetime64_dtype, operands)): + operands = tuple(map(as_timelike, operands)) + return reduce(partial(self.operation, **kwargs), operands) else: return self.unary_operation(*operands, **kwargs) @@ -169,11 +179,10 @@ def div(self, lhs, rhs): # We do not need to truncate in this case # So far, I did not spot any other occurrence # of this function. - if isinstance(result, datetime.timedelta): + if isinstance(result, (datetime.timedelta, np.timedelta64)): return result else: # pragma: no cover - result = da.trunc(result) - return result + return da.trunc(result).astype(np.int64) class CaseOperation(Operation): diff --git a/tests/integration/test_filter.py b/tests/integration/test_filter.py index f7c8c5502..d1508e5f8 100644 --- a/tests/integration/test_filter.py +++ b/tests/integration/test_filter.py @@ -74,7 +74,7 @@ def test_string_filter(c, string_table): "datetime_table", pytest.param( "gpu_datetime_table", - marks=(pytest.mark.gpu, pytest.mark.skip(reason="WIP DataFusion")), + marks=(pytest.mark.gpu), ), ], ) @@ -100,7 +100,7 @@ def test_filter_cast_date(c, input_table, request): "datetime_table", pytest.param( "gpu_datetime_table", - marks=(pytest.mark.gpu, pytest.mark.skip(reason="WIP DataFusion")), + marks=(pytest.mark.gpu), ), ], )