diff --git a/dask_sql/datacontainer.py b/dask_sql/datacontainer.py index c69dee79c..9c7178578 100644 --- a/dask_sql/datacontainer.py +++ b/dask_sql/datacontainer.py @@ -205,10 +205,19 @@ def __init__(self, func, row_udf: bool, return_type=None): def __call__(self, *args, **kwargs): if self.row_udf: - df = args[0].to_frame() - for operand in args[1:]: - df[operand.name] = operand - result = df.apply(self.func, axis=1, meta=self.meta).astype(self.meta[1]) + column_args = [] + scalar_args = [] + for operand in args: + if isinstance(operand, dd.Series): + column_args.append(operand) + else: + scalar_args.append(operand) + df = column_args[0].to_frame() + for col in column_args[1:]: + df[col.name] = col + result = df.apply( + self.func, axis=1, args=tuple(scalar_args), meta=self.meta + ).astype(self.meta[1]) else: result = self.func(*args, **kwargs) return result diff --git a/docs/pages/custom.rst b/docs/pages/custom.rst index 2eefe6813..e38658d7b 100644 --- a/docs/pages/custom.rst +++ b/docs/pages/custom.rst @@ -48,6 +48,17 @@ These functions may be registered as above and flagged as row UDFs using the `ro ** Note: Row UDFs use `apply` which may have unpredictable performance characteristics, depending on the function and dataframe library ** +UDFs written in this way can also be extended to accept scalar arguments along with the incoming row: + +.. code-block:: python + + def f(row, k): + return row['a'] + k + + c.register_function(f, "f", [("a", np.int64), ("k", np.int64)], np.int64, row_udf=True) + c.sql("SELECT f(a, 42) FROM data") + + Aggregation Functions --------------------- diff --git a/tests/integration/test_function.py b/tests/integration/test_function.py index 2012e86ec..7aa27d709 100644 --- a/tests/integration/test_function.py +++ b/tests/integration/test_function.py @@ -1,3 +1,5 @@ +import operator + import dask.dataframe as dd import numpy as np import pytest @@ -63,6 +65,64 @@ def f(row): assert_frame_equal(return_df.reset_index(drop=True), expectation) +# Test row UDFs with one arg +@pytest.mark.parametrize("k", [1, 1.5, True]) +@pytest.mark.parametrize( + "op", [operator.add, operator.sub, operator.mul, operator.truediv] +) +@pytest.mark.parametrize("retty", [np.int64, np.float64, np.bool_]) +def test_custom_function_row_args(c, df, k, op, retty): + const_type = np.dtype(type(k)).type + + def f(row, k): + return op(row["a"], k) + + c.register_function( + f, "f", [("a", np.int64), ("k", const_type)], retty, row_udf=True + ) + + statement = f"SELECT F(a, {k}) as a from df" + + return_df = c.sql(statement) + return_df = return_df.compute() + expectation = op(df[["a"]], k).astype(retty) + assert_frame_equal(return_df.reset_index(drop=True), expectation) + + +# Test row UDFs with two args +@pytest.mark.parametrize("k2", [1, 1.5, True]) +@pytest.mark.parametrize("k1", [1, 1.5, True]) +@pytest.mark.parametrize( + "op", [operator.add, operator.sub, operator.mul, operator.truediv] +) +@pytest.mark.parametrize("retty", [np.int64, np.float64, np.bool_]) +def test_custom_function_row_two_args(c, df, k1, k2, op, retty): + const_type_k1 = np.dtype(type(k1)).type + const_type_k2 = np.dtype(type(k2)).type + + def f(row, k1, k2): + x = op(row["a"], k1) + y = op(x, k2) + + return y + + c.register_function( + f, + "f", + [("a", np.int64), ("k1", const_type_k1), ("k2", const_type_k2)], + retty, + row_udf=True, + ) + + statement = f"SELECT F(a, {k1}, {k2}) as a from df" + + return_df = c.sql(statement) + return_df = return_df.compute() + + expectation = op(op(df[["a"]], k1), k2).astype(retty) + assert_frame_equal(return_df.reset_index(drop=True), expectation) + + def test_multiple_definitions(c, df_simple): def f(x): return x ** 2