diff --git a/dask_planner/src/expression.rs b/dask_planner/src/expression.rs index 627b5efed..a6f616d4a 100644 --- a/dask_planner/src/expression.rs +++ b/dask_planner/src/expression.rs @@ -419,11 +419,11 @@ impl PyExpr { Operator::StringConcat => "VARCHAR", }, Expr::Literal(scalar_value) => match scalar_value { - ScalarValue::Null => "Null", ScalarValue::Boolean(_value) => "Boolean", ScalarValue::Float32(_value) => "Float32", ScalarValue::Float64(_value) => "Float64", ScalarValue::Decimal128(_value, ..) => "Decimal128", + ScalarValue::Dictionary(..) => "Dictionary", ScalarValue::Int8(_value) => "Int8", ScalarValue::Int16(_value) => "Int16", ScalarValue::Int32(_value) => "Int32", @@ -438,12 +438,16 @@ impl PyExpr { ScalarValue::LargeBinary(_value) => "LargeBinary", ScalarValue::Date32(_value) => "Date32", ScalarValue::Date64(_value) => "Date64", - _ => { - return Err(py_type_err(format!( - "Catch all triggered for Literal in get_type; {:?}", - scalar_value - ))) - } + ScalarValue::Null => "Null", + ScalarValue::TimestampSecond(..) => "TimestampSecond", + ScalarValue::TimestampMillisecond(..) => "TimestampMillisecond", + ScalarValue::TimestampMicrosecond(..) => "TimestampMicrosecond", + ScalarValue::TimestampNanosecond(..) => "TimestampNanosecond", + ScalarValue::IntervalYearMonth(..) => "IntervalYearMonth", + ScalarValue::IntervalDayTime(..) => "IntervalDayTime", + ScalarValue::IntervalMonthDayNano(..) => "IntervalMonthDayNano", + ScalarValue::List(..) => "List", + ScalarValue::Struct(..) => "Struct", }, Expr::ScalarFunction { fun, args: _ } => match fun { BuiltinScalarFunction::Abs => "Abs", @@ -633,6 +637,24 @@ impl PyExpr { } } + #[pyo3(name = "getIntervalDayTimeValue")] + pub fn interval_day_time_value(&mut self) -> (i32, i32) { + match &self.expr { + Expr::Literal(scalar_value) => match scalar_value { + ScalarValue::IntervalDayTime(iv) => { + let interval = iv.unwrap() as u64; + let days = (interval >> 32) as i32; + let ms = interval as i32; + (days, ms) + } + _ => { + panic!("getValue() - Unexpected value") + } + }, + _ => panic!("getValue() - Non literal value encountered"), + } + } + #[pyo3(name = "isNegated")] pub fn is_negated(&self) -> PyResult { match &self.expr { diff --git a/dask_sql/mappings.py b/dask_sql/mappings.py index 083ad921b..6c41e13d0 100644 --- a/dask_sql/mappings.py +++ b/dask_sql/mappings.py @@ -69,17 +69,23 @@ _SQL_TO_PYTHON_FRAMES = { "SqlTypeName.DOUBLE": np.float64, "SqlTypeName.FLOAT": np.float32, - "SqlTypeName.DECIMAL": np.float64, + "SqlTypeName.DECIMAL": np.float64, # We use np.float64 always, even though we might be able to use a smaller type "SqlTypeName.BIGINT": pd.Int64Dtype(), "SqlTypeName.INTEGER": pd.Int32Dtype(), "SqlTypeName.SMALLINT": pd.Int16Dtype(), "SqlTypeName.TINYINT": pd.Int8Dtype(), "SqlTypeName.BOOLEAN": pd.BooleanDtype(), "SqlTypeName.VARCHAR": pd.StringDtype(), + "SqlTypeName.CHAR": pd.StringDtype(), "SqlTypeName.DATE": np.dtype( " Any: return literal_value + elif sql_type == SqlTypeName.INTERVAL_DAY: + return timedelta(days=literal_value[0], milliseconds=literal_value[1]) elif sql_type == SqlTypeName.INTERVAL: # check for finer granular interval types, e.g., INTERVAL MONTH, INTERVAL YEAR try: @@ -200,25 +208,12 @@ def sql_to_python_value(sql_type: "SqlTypeName", literal_value: Any) -> Any: def sql_to_python_type(sql_type: "SqlTypeName") -> type: """Turn an SQL type into a dataframe dtype""" - if sql_type == SqlTypeName.VARCHAR or sql_type == SqlTypeName.CHAR: - return pd.StringDtype() - elif sql_type == SqlTypeName.TIME or sql_type == SqlTypeName.TIMESTAMP: - return np.dtype(" bool: diff --git a/dask_sql/physical/rex/core/literal.py b/dask_sql/physical/rex/core/literal.py index f257fca89..9ac67f24d 100644 --- a/dask_sql/physical/rex/core/literal.py +++ b/dask_sql/physical/rex/core/literal.py @@ -145,8 +145,13 @@ def convert( elif literal_type == "Null": literal_type = SqlTypeName.NULL literal_value = None + elif literal_type == "IntervalDayTime": + literal_type = SqlTypeName.INTERVAL_DAY + literal_value = rex.getIntervalDayTimeValue() else: - raise RuntimeError("Failed to determine DataFusion Type in literal.py") + raise RuntimeError( + f"Failed to map literal type {literal_type} to python type in literal.py" + ) # if isinstance(literal_value, org.apache.calcite.util.Sarg): # return SargPythonImplementation(literal_value, literal_type) diff --git a/tests/integration/test_rex.py b/tests/integration/test_rex.py index a4a9e6a25..1295b13d5 100644 --- a/tests/integration/test_rex.py +++ b/tests/integration/test_rex.py @@ -54,6 +54,20 @@ def test_case(c, df): assert_eq(result_df, expected_df, check_dtype=False) +def test_intervals(c): + df = c.sql( + """SELECT INTERVAL '3' DAY as "IN" + """ + ) + + expected_df = pd.DataFrame( + { + "IN": [pd.to_timedelta("3d")], + } + ) + assert_eq(df, expected_df) + + @pytest.mark.skip(reason="WIP DataFusion") def test_literals(c): df = c.sql(