Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dask_sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,6 +898,7 @@ def _register_callable(
row_udf: bool = False,
):
"""Helper function to do the function or aggregation registration"""

schema_name = schema_name or self.schema_name
schema = self.schema[schema_name]

Expand Down
12 changes: 6 additions & 6 deletions dask_sql/datacontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import dask.dataframe as dd
import pandas as pd

from dask_sql.mappings import python_to_sql_type

ColumnType = Union[str, int]

FunctionDescription = namedtuple(
Expand Down Expand Up @@ -198,11 +200,10 @@ def __init__(self, func, row_udf: bool, params, return_type=None):

self.names = [param[0] for param in params]

if return_type is None:
# These UDFs go through apply and without providing
# a return type, dask will attempt to guess it, and
# dask might be wrong.
raise ValueError("Return type must be provided")
# validate UDF metadata
for dt in (*(param[1] for param in params), return_type):
_ = python_to_sql_type(dt)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to push this off to a follow up PR if it would require a lot of time to implement, but since we are computing the SQL types of the params / return Python types here, it could be worthwhile to cache this information now and grab it when we later call python_to_sql_type for UDF preparation:

sql_return_type = python_to_sql_type(function_description.return_type)

sql_param_type = python_to_sql_type(param_type)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I played around with this a bit and ended up getting something working in brandon-b-miller#1 🙂

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @charlesbluca ! will take a look at this later this morning


self.meta = (None, return_type)

def __call__(self, *args, **kwargs):
Expand All @@ -218,7 +219,6 @@ def __call__(self, *args, **kwargs):
df = column_args[0].to_frame(self.names[0])
for name, col in zip(self.names[1:], column_args[1:]):
df[name] = col

result = df.apply(
self.func, axis=1, args=tuple(scalar_args), meta=self.meta
).astype(self.meta[1])
Expand Down
5 changes: 5 additions & 0 deletions dask_sql/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@
def python_to_sql_type(python_type):
"""Mapping between python and SQL types."""

if python_type in (int, float):
python_type = np.dtype(python_type)
elif python_type is str:
python_type = np.dtype("object")

if isinstance(python_type, np.dtype):
python_type = python_type.type

Expand Down
21 changes: 15 additions & 6 deletions tests/integration/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,12 @@ def f(row):

@pytest.mark.parametrize(
"retty",
[None, np.float64, np.float32, np.int64, np.int32, np.int16, np.int8, np.bool_],
[np.float64, np.float32, np.int64, np.int32, np.int16, np.int8, np.bool_],
)
def test_custom_function_row_return_types(c, df, retty):
def f(row):
return row["x"] ** 2

if retty is None:
with pytest.raises(ValueError):
c.register_function(f, "f", [("x", np.float64)], retty, row_udf=True)
return

c.register_function(f, "f", [("x", np.float64)], retty, row_udf=True)

return_df = c.sql("SELECT F(a) AS a FROM df")
Expand Down Expand Up @@ -199,3 +194,17 @@ def f(x):
c.register_aggregation(fagg, "fagg", [("x", np.float64)], np.float64)

c.register_aggregation(fagg, "fagg", [("x", np.float64)], np.float64, replace=True)


@pytest.mark.parametrize("dtype", [np.timedelta64, None, "a string"])
def test_unsupported_dtype(c, dtype):
def f(x):
return x**2

# test that an invalid return type raises
with pytest.raises(NotImplementedError):
c.register_function(f, "f", [("x", np.int64)], dtype)

# test that an invalid param type raises
with pytest.raises(NotImplementedError):
c.register_function(f, "f", [("x", dtype)], np.int64)