Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions dask_sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,7 +748,7 @@ def _prepare_schemas(self):
logger.debug("No custom functions defined.")
for function_description in schema.function_lists:
name = function_description.name
sql_return_type = python_to_sql_type(function_description.return_type)
sql_return_type = function_description.return_type
if function_description.aggregation:
logger.debug(f"Adding function '{name}' to schema as aggregation.")
dask_function = DaskAggregateFunction(name, sql_return_type)
Expand All @@ -771,10 +771,7 @@ def _prepare_schemas(self):
@staticmethod
def _add_parameters_from_description(function_description, dask_function):
for parameter in function_description.parameters:
param_name, param_type = parameter
sql_param_type = python_to_sql_type(param_type)

dask_function.addParameter(param_name, sql_param_type, False)
dask_function.addParameter(*parameter, False)

return dask_function

Expand Down Expand Up @@ -898,9 +895,16 @@ def _register_callable(
row_udf: bool = False,
):
"""Helper function to do the function or aggregation registration"""

schema_name = schema_name or self.schema_name
schema = self.schema[schema_name]

# validate and cache UDF metadata
sql_parameters = [
(name, python_to_sql_type(param_type)) for name, param_type in parameters
]
sql_return_type = python_to_sql_type(return_type)

if not aggregation:
f = UDF(f, row_udf, parameters, return_type)
lower_name = name.lower()
Expand All @@ -920,9 +924,13 @@ def _register_callable(
)

schema.function_lists.append(
FunctionDescription(name.upper(), parameters, return_type, aggregation)
FunctionDescription(
name.upper(), sql_parameters, sql_return_type, aggregation
)
)
schema.function_lists.append(
FunctionDescription(name.lower(), parameters, return_type, aggregation)
FunctionDescription(
name.lower(), sql_parameters, sql_return_type, aggregation
)
)
schema.functions[lower_name] = f
6 changes: 0 additions & 6 deletions dask_sql/datacontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,11 +198,6 @@ def __init__(self, func, row_udf: bool, params, return_type=None):

self.names = [param[0] for param in params]

if return_type is None:
# These UDFs go through apply and without providing
# a return type, dask will attempt to guess it, and
# dask might be wrong.
raise ValueError("Return type must be provided")
self.meta = (None, return_type)

def __call__(self, *args, **kwargs):
Expand All @@ -218,7 +213,6 @@ def __call__(self, *args, **kwargs):
df = column_args[0].to_frame(self.names[0])
for name, col in zip(self.names[1:], column_args[1:]):
df[name] = col

result = df.apply(
self.func, axis=1, args=tuple(scalar_args), meta=self.meta
).astype(self.meta[1])
Expand Down
5 changes: 5 additions & 0 deletions dask_sql/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@
def python_to_sql_type(python_type):
"""Mapping between python and SQL types."""

if python_type in (int, float):
python_type = np.dtype(python_type)
elif python_type is str:
python_type = np.dtype("object")

if isinstance(python_type, np.dtype):
python_type = python_type.type

Expand Down
21 changes: 15 additions & 6 deletions tests/integration/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,12 @@ def f(row):

@pytest.mark.parametrize(
"retty",
[None, np.float64, np.float32, np.int64, np.int32, np.int16, np.int8, np.bool_],
[np.float64, np.float32, np.int64, np.int32, np.int16, np.int8, np.bool_],
)
def test_custom_function_row_return_types(c, df, retty):
def f(row):
return row["x"] ** 2

if retty is None:
with pytest.raises(ValueError):
c.register_function(f, "f", [("x", np.float64)], retty, row_udf=True)
return

c.register_function(f, "f", [("x", np.float64)], retty, row_udf=True)

return_df = c.sql("SELECT F(a) AS a FROM df")
Expand Down Expand Up @@ -199,3 +194,17 @@ def f(x):
c.register_aggregation(fagg, "fagg", [("x", np.float64)], np.float64)

c.register_aggregation(fagg, "fagg", [("x", np.float64)], np.float64, replace=True)


@pytest.mark.parametrize("dtype", [np.timedelta64, None, "a string"])
def test_unsupported_dtype(c, dtype):
def f(x):
return x**2

# test that an invalid return type raises
with pytest.raises(NotImplementedError):
c.register_function(f, "f", [("x", np.int64)], dtype)

# test that an invalid param type raises
with pytest.raises(NotImplementedError):
c.register_function(f, "f", [("x", dtype)], np.int64)
62 changes: 38 additions & 24 deletions tests/unit/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from dask_sql import Context
from dask_sql.datacontainer import Statistics
from dask_sql.mappings import python_to_sql_type
from tests.utils import assert_eq

try:
Expand Down Expand Up @@ -198,6 +199,11 @@ def g(gpu=gpu):
g(gpu=gpu)


int_sql_type = python_to_sql_type(int)
float_sql_type = python_to_sql_type(float)
str_sql_type = python_to_sql_type(str)


def test_function_adding():
c = Context()

Expand All @@ -211,12 +217,12 @@ def test_function_adding():
assert c.schema[c.schema_name].functions["f"].func == f
assert len(c.schema[c.schema_name].function_lists) == 2
assert c.schema[c.schema_name].function_lists[0].name == "F"
assert c.schema[c.schema_name].function_lists[0].parameters == [("x", int)]
assert c.schema[c.schema_name].function_lists[0].return_type == float
assert c.schema[c.schema_name].function_lists[0].parameters == [("x", int_sql_type)]
assert c.schema[c.schema_name].function_lists[0].return_type == float_sql_type
assert not c.schema[c.schema_name].function_lists[0].aggregation
assert c.schema[c.schema_name].function_lists[1].name == "f"
assert c.schema[c.schema_name].function_lists[1].parameters == [("x", int)]
assert c.schema[c.schema_name].function_lists[1].return_type == float
assert c.schema[c.schema_name].function_lists[1].parameters == [("x", int_sql_type)]
assert c.schema[c.schema_name].function_lists[1].return_type == float_sql_type
assert not c.schema[c.schema_name].function_lists[1].aggregation

# Without replacement
Expand All @@ -226,12 +232,16 @@ def test_function_adding():
assert c.schema[c.schema_name].functions["f"].func == f
assert len(c.schema[c.schema_name].function_lists) == 4
assert c.schema[c.schema_name].function_lists[2].name == "F"
assert c.schema[c.schema_name].function_lists[2].parameters == [("x", float)]
assert c.schema[c.schema_name].function_lists[2].return_type == int
assert c.schema[c.schema_name].function_lists[2].parameters == [
("x", float_sql_type)
]
assert c.schema[c.schema_name].function_lists[2].return_type == int_sql_type
assert not c.schema[c.schema_name].function_lists[2].aggregation
assert c.schema[c.schema_name].function_lists[3].name == "f"
assert c.schema[c.schema_name].function_lists[3].parameters == [("x", float)]
assert c.schema[c.schema_name].function_lists[3].return_type == int
assert c.schema[c.schema_name].function_lists[3].parameters == [
("x", float_sql_type)
]
assert c.schema[c.schema_name].function_lists[3].return_type == int_sql_type
assert not c.schema[c.schema_name].function_lists[3].aggregation

# With replacement
Expand All @@ -242,12 +252,12 @@ def test_function_adding():
assert c.schema[c.schema_name].functions["f"].func == f
assert len(c.schema[c.schema_name].function_lists) == 2
assert c.schema[c.schema_name].function_lists[0].name == "F"
assert c.schema[c.schema_name].function_lists[0].parameters == [("x", str)]
assert c.schema[c.schema_name].function_lists[0].return_type == str
assert c.schema[c.schema_name].function_lists[0].parameters == [("x", str_sql_type)]
assert c.schema[c.schema_name].function_lists[0].return_type == str_sql_type
assert not c.schema[c.schema_name].function_lists[0].aggregation
assert c.schema[c.schema_name].function_lists[1].name == "f"
assert c.schema[c.schema_name].function_lists[1].parameters == [("x", str)]
assert c.schema[c.schema_name].function_lists[1].return_type == str
assert c.schema[c.schema_name].function_lists[1].parameters == [("x", str_sql_type)]
assert c.schema[c.schema_name].function_lists[1].return_type == str_sql_type
assert not c.schema[c.schema_name].function_lists[1].aggregation


Expand All @@ -264,12 +274,12 @@ def test_aggregation_adding():
assert c.schema[c.schema_name].functions["f"] == f
assert len(c.schema[c.schema_name].function_lists) == 2
assert c.schema[c.schema_name].function_lists[0].name == "F"
assert c.schema[c.schema_name].function_lists[0].parameters == [("x", int)]
assert c.schema[c.schema_name].function_lists[0].return_type == float
assert c.schema[c.schema_name].function_lists[0].parameters == [("x", int_sql_type)]
assert c.schema[c.schema_name].function_lists[0].return_type == float_sql_type
assert c.schema[c.schema_name].function_lists[0].aggregation
assert c.schema[c.schema_name].function_lists[1].name == "f"
assert c.schema[c.schema_name].function_lists[1].parameters == [("x", int)]
assert c.schema[c.schema_name].function_lists[1].return_type == float
assert c.schema[c.schema_name].function_lists[1].parameters == [("x", int_sql_type)]
assert c.schema[c.schema_name].function_lists[1].return_type == float_sql_type
assert c.schema[c.schema_name].function_lists[1].aggregation

# Without replacement
Expand All @@ -279,12 +289,16 @@ def test_aggregation_adding():
assert c.schema[c.schema_name].functions["f"] == f
assert len(c.schema[c.schema_name].function_lists) == 4
assert c.schema[c.schema_name].function_lists[2].name == "F"
assert c.schema[c.schema_name].function_lists[2].parameters == [("x", float)]
assert c.schema[c.schema_name].function_lists[2].return_type == int
assert c.schema[c.schema_name].function_lists[2].parameters == [
("x", float_sql_type)
]
assert c.schema[c.schema_name].function_lists[2].return_type == int_sql_type
assert c.schema[c.schema_name].function_lists[2].aggregation
assert c.schema[c.schema_name].function_lists[3].name == "f"
assert c.schema[c.schema_name].function_lists[3].parameters == [("x", float)]
assert c.schema[c.schema_name].function_lists[3].return_type == int
assert c.schema[c.schema_name].function_lists[3].parameters == [
("x", float_sql_type)
]
assert c.schema[c.schema_name].function_lists[3].return_type == int_sql_type
assert c.schema[c.schema_name].function_lists[3].aggregation

# With replacement
Expand All @@ -295,12 +309,12 @@ def test_aggregation_adding():
assert c.schema[c.schema_name].functions["f"] == f
assert len(c.schema[c.schema_name].function_lists) == 2
assert c.schema[c.schema_name].function_lists[0].name == "F"
assert c.schema[c.schema_name].function_lists[0].parameters == [("x", str)]
assert c.schema[c.schema_name].function_lists[0].return_type == str
assert c.schema[c.schema_name].function_lists[0].parameters == [("x", str_sql_type)]
assert c.schema[c.schema_name].function_lists[0].return_type == str_sql_type
assert c.schema[c.schema_name].function_lists[0].aggregation
assert c.schema[c.schema_name].function_lists[1].name == "f"
assert c.schema[c.schema_name].function_lists[1].parameters == [("x", str)]
assert c.schema[c.schema_name].function_lists[1].return_type == str
assert c.schema[c.schema_name].function_lists[1].parameters == [("x", str_sql_type)]
assert c.schema[c.schema_name].function_lists[1].return_type == str_sql_type
assert c.schema[c.schema_name].function_lists[1].aggregation


Expand Down