diff --git a/dask_sql/context.py b/dask_sql/context.py index b99865e45..2b3565f58 100644 --- a/dask_sql/context.py +++ b/dask_sql/context.py @@ -815,7 +815,7 @@ def _register_callable( schema = self.schema[schema_name] if not aggregation: - f = UDF(f, row_udf) + f = UDF(f, row_udf, return_type) lower_name = name.lower() if lower_name in schema.functions: diff --git a/dask_sql/datacontainer.py b/dask_sql/datacontainer.py index cadc44f74..c2eaacb0c 100644 --- a/dask_sql/datacontainer.py +++ b/dask_sql/datacontainer.py @@ -182,7 +182,7 @@ def assign(self) -> dd.DataFrame: class UDF: - def __init__(self, func, row_udf: bool): + def __init__(self, func, row_udf: bool, return_type=None): """ Helper class that handles different types of UDFs and manages how they should be mapped to dask operations. Two versions of @@ -194,13 +194,14 @@ def __init__(self, func, row_udf: bool): """ self.row_udf = row_udf self.func = func + self.meta = (None, return_type) def __call__(self, *args, **kwargs): if self.row_udf: df = args[0].to_frame() for operand in args[1:]: df[operand.name] = operand - result = df.apply(self.func, axis=1) + result = df.apply(self.func, axis=1, meta=self.meta) else: result = self.func(*args, **kwargs) return result