From 8919fc0326b20a807e0d22d34e3d92372fb1bc93 Mon Sep 17 00:00:00 2001 From: brandon-b-miller Date: Wed, 23 Feb 2022 14:53:17 -0800 Subject: [PATCH 1/6] retain original input names and transform to them later --- dask_sql/context.py | 3 ++- dask_sql/datacontainer.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/dask_sql/context.py b/dask_sql/context.py index af4a12b4b..3c22bdc03 100644 --- a/dask_sql/context.py +++ b/dask_sql/context.py @@ -965,7 +965,8 @@ def _register_callable( if not aggregation: f = UDF(f, row_udf, return_type) - + nm = [i[0] for i in parameters] + f._names = nm lower_name = name.lower() if lower_name in schema.functions: if replace: diff --git a/dask_sql/datacontainer.py b/dask_sql/datacontainer.py index f9605625f..5af4bcf19 100644 --- a/dask_sql/datacontainer.py +++ b/dask_sql/datacontainer.py @@ -215,6 +215,7 @@ def __call__(self, *args, **kwargs): df = column_args[0].to_frame() for col in column_args[1:]: df[col.name] = col + df.columns = self._names result = df.apply( self.func, axis=1, args=tuple(scalar_args), meta=self.meta ).astype(self.meta[1]) From 7787d111e596761c72cdaf7dee5a6e9f6fe48ea6 Mon Sep 17 00:00:00 2001 From: brandon-b-miller Date: Wed, 23 Feb 2022 15:15:04 -0800 Subject: [PATCH 2/6] adjust impl --- dask_sql/context.py | 4 +--- dask_sql/datacontainer.py | 11 +++++++---- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/dask_sql/context.py b/dask_sql/context.py index 3c22bdc03..3237e0603 100644 --- a/dask_sql/context.py +++ b/dask_sql/context.py @@ -964,9 +964,7 @@ def _register_callable( schema = self.schema[schema_name] if not aggregation: - f = UDF(f, row_udf, return_type) - nm = [i[0] for i in parameters] - f._names = nm + f = UDF(f, row_udf, parameters, return_type) lower_name = name.lower() if lower_name in schema.functions: if replace: diff --git a/dask_sql/datacontainer.py b/dask_sql/datacontainer.py index 5af4bcf19..d00789de3 100644 --- a/dask_sql/datacontainer.py +++ b/dask_sql/datacontainer.py @@ -183,7 +183,7 @@ def assign(self) -> dd.DataFrame: class UDF: - def __init__(self, func, row_udf: bool, return_type=None): + def __init__(self, func, row_udf: bool, params, return_type=None): """ Helper class that handles different types of UDFs and manages how they should be mapped to dask operations. Two versions of @@ -196,6 +196,8 @@ def __init__(self, func, row_udf: bool, return_type=None): self.row_udf = row_udf self.func = func + self.names = [param[0] for param in params] + if return_type is None: # These UDFs go through apply and without providing # a return type, dask will attempt to guess it, and @@ -212,10 +214,11 @@ def __call__(self, *args, **kwargs): column_args.append(operand) else: scalar_args.append(operand) + df = column_args[0].to_frame() - for col in column_args[1:]: - df[col.name] = col - df.columns = self._names + for name, col in zip(self.names, column_args): + df[name] = col + result = df.apply( self.func, axis=1, args=tuple(scalar_args), meta=self.meta ).astype(self.meta[1]) From 6aa08cdf052ca436a555ed186372dba0d08861ff Mon Sep 17 00:00:00 2001 From: brandon-b-miller Date: Mon, 28 Feb 2022 13:29:36 -0800 Subject: [PATCH 3/6] tests and updates --- dask_sql/datacontainer.py | 4 ++-- tests/integration/fixtures.py | 15 +++++++++++++++ tests/integration/test_function.py | 24 +++++++++++++++++++++++- 3 files changed, 40 insertions(+), 3 deletions(-) diff --git a/dask_sql/datacontainer.py b/dask_sql/datacontainer.py index d00789de3..bdb7f7270 100644 --- a/dask_sql/datacontainer.py +++ b/dask_sql/datacontainer.py @@ -214,11 +214,11 @@ def __call__(self, *args, **kwargs): column_args.append(operand) else: scalar_args.append(operand) - + df = column_args[0].to_frame() for name, col in zip(self.names, column_args): df[name] = col - + result = df.apply( self.func, axis=1, args=tuple(scalar_args), meta=self.meta ).astype(self.meta[1]) diff --git a/tests/integration/fixtures.py b/tests/integration/fixtures.py index ff5f9a23e..e7c6d0ece 100644 --- a/tests/integration/fixtures.py +++ b/tests/integration/fixtures.py @@ -37,6 +37,19 @@ def df_simple(): return pd.DataFrame({"a": [1, 2, 3], "b": [1.1, 2.2, 3.3]}) +@pytest.fixture() +def df_wide(): + return pd.DataFrame( + { + "a": [0, 1, 2], + "b": [3, 4, 5], + "c": [6, 7, 8], + "d": [9, 10, 11], + "e": [12, 13, 14], + } + ) + + @pytest.fixture() def df(): np.random.seed(42) @@ -126,6 +139,7 @@ def gpu_datetime_table(datetime_table): @pytest.fixture() def c( df_simple, + df_wide, df, user_table_1, user_table_2, @@ -142,6 +156,7 @@ def c( ): dfs = { "df_simple": df_simple, + "df_wide": df_wide, "df": df, "user_table_1": user_table_1, "user_table_2": user_table_2, diff --git a/tests/integration/test_function.py b/tests/integration/test_function.py index 7aa27d709..eda4f0244 100644 --- a/tests/integration/test_function.py +++ b/tests/integration/test_function.py @@ -1,9 +1,10 @@ +import itertools import operator import dask.dataframe as dd import numpy as np import pytest -from pandas.testing import assert_frame_equal +from pandas.testing import assert_frame_equal, assert_series_equal def test_custom_function(c, df): @@ -40,6 +41,27 @@ def f(row): assert_frame_equal(return_df.reset_index(drop=True), df[["a"]] ** 2) +@pytest.mark.parametrize("colnames", list(itertools.combinations(["a", "b", "c"], 2))) +def test_custom_function_any_colnames(colnames, df_wide, c): + # a third column is needed + + def f(row): + return row["x"] + row["y"] + + colname_x, colname_y = colnames + c.register_function( + f, "f", [("x", np.int64), ("y", np.int64)], np.int64, row_udf=True + ) + + return_df = c.sql(f"SELECT F({colname_x},{colname_y}) FROM df_wide") + + return_df = return_df.compute() + expect = df_wide[colname_x] + df_wide[colname_y] + got = return_df[return_df.columns[0]] + + assert_series_equal(expect, got, check_names=False) + + @pytest.mark.parametrize( "retty", [None, np.float64, np.float32, np.int64, np.int32, np.int16, np.int8, np.bool_], From cf7b183e0c207e74dd0e81031f66d65d3d87586c Mon Sep 17 00:00:00 2001 From: brandon-b-miller Date: Mon, 28 Feb 2022 13:53:58 -0800 Subject: [PATCH 4/6] fix tests --- tests/integration/test_show.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/integration/test_show.py b/tests/integration/test_show.py index 6e2af33d1..a04129489 100644 --- a/tests/integration/test_show.py +++ b/tests/integration/test_show.py @@ -35,6 +35,7 @@ def test_tables(c): "Table": [ "df", "df_simple", + "df_wide", "user_table_1", "user_table_2", "long_table", @@ -47,6 +48,7 @@ def test_tables(c): else [ "df", "df_simple", + "df_wide", "user_table_1", "user_table_2", "long_table", From f2ce42fb4a28566ae924eaeee4fdbbc330324942 Mon Sep 17 00:00:00 2001 From: brandon-b-miller Date: Tue, 1 Mar 2022 10:19:55 -0800 Subject: [PATCH 5/6] Address reviews, fix tests --- dask_sql/datacontainer.py | 4 ++-- tests/integration/test_function.py | 11 +++++------ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/dask_sql/datacontainer.py b/dask_sql/datacontainer.py index bdb7f7270..10956ec1d 100644 --- a/dask_sql/datacontainer.py +++ b/dask_sql/datacontainer.py @@ -215,8 +215,8 @@ def __call__(self, *args, **kwargs): else: scalar_args.append(operand) - df = column_args[0].to_frame() - for name, col in zip(self.names, column_args): + df = column_args[0].to_frame(self.names[0]) + for name, col in zip(self.names[1:], column_args[1:]): df[name] = col result = df.apply( diff --git a/tests/integration/test_function.py b/tests/integration/test_function.py index eda4f0244..a116e2459 100644 --- a/tests/integration/test_function.py +++ b/tests/integration/test_function.py @@ -4,7 +4,7 @@ import dask.dataframe as dd import numpy as np import pytest -from pandas.testing import assert_frame_equal, assert_series_equal +from pandas.testing import assert_frame_equal def test_custom_function(c, df): @@ -26,7 +26,7 @@ def f(x): def test_custom_function_row(c, df): def f(row): - return row["a"] ** 2 + return row["x"] ** 2 c.register_function(f, "f", [("x", np.float64)], np.float64, row_udf=True) @@ -55,11 +55,10 @@ def f(row): return_df = c.sql(f"SELECT F({colname_x},{colname_y}) FROM df_wide") - return_df = return_df.compute() expect = df_wide[colname_x] + df_wide[colname_y] - got = return_df[return_df.columns[0]] + got = return_df.iloc[:, 0] - assert_series_equal(expect, got, check_names=False) + dd.assert_eq(expect, got, check_names=False) @pytest.mark.parametrize( @@ -68,7 +67,7 @@ def f(row): ) def test_custom_function_row_return_types(c, df, retty): def f(row): - return row["a"] ** 2 + return row["x"] ** 2 if retty is None: with pytest.raises(ValueError): From ee0c52022243899a7f570c6c18ee3db0d1e71bea Mon Sep 17 00:00:00 2001 From: brandon-b-miller Date: Tue, 1 Mar 2022 12:34:47 -0800 Subject: [PATCH 6/6] update docs --- dask_sql/context.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/dask_sql/context.py b/dask_sql/context.py index 3237e0603..0b6b3c5c8 100644 --- a/dask_sql/context.py +++ b/dask_sql/context.py @@ -326,7 +326,9 @@ def f(x): f (:obj:`Callable`): The function to register name (:obj:`str`): Under which name should the new function be addressable in SQL parameters (:obj:`List[Tuple[str, type]]`): A list ot tuples of parameter name and parameter type. - Use `numpy dtypes `_ if possible. + Use `numpy dtypes `_ if possible. This + function is sensitive to the order of specified parameters when `row_udf=True`, and it is assumed + that column arguments are specified in order, followed by scalar arguments. return_type (:obj:`type`): The return type of the function replace (:obj:`bool`): If `True`, do not raise an error if a function with the same name is already present; instead, replace the original function. Default is `False`.