Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions haystack/utils/callable_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
# SPDX-License-Identifier: Apache-2.0

import inspect
import sys
from typing import Callable, Optional

from haystack import DeserializationError
from haystack.utils.type_serialization import thread_safe_import


def serialize_callable(callable_handle: Callable) -> str:
Expand Down Expand Up @@ -37,9 +37,10 @@ def deserialize_callable(callable_handle: str) -> Optional[Callable]:
parts = callable_handle.split(".")
module_name = ".".join(parts[:-1])
function_name = parts[-1]
module = sys.modules.get(module_name, None)
if not module:
raise DeserializationError(f"Could not locate the module of the callable: {module_name}")
try:
module = thread_safe_import(module_name)
except Exception as e:
raise DeserializationError(f"Could not locate the module of the callable: {module_name}") from e
deserialized_callable = getattr(module, function_name, None)
if not deserialized_callable:
raise DeserializationError(f"Could not locate the callable: {function_name}")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
enhancements:
- |
Improved deserialization of callables by using `importlib` instead of `sys.modules`.
This change allows importing local functions and classes that are not in `sys.modules`
when deserializing callables.
9 changes: 9 additions & 0 deletions test/utils/test_callable_serialization.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import pytest
import requests

from haystack import DeserializationError
from haystack.components.generators.utils import print_streaming_chunk
from haystack.utils import serialize_callable, deserialize_callable

Expand Down Expand Up @@ -36,3 +38,10 @@ def test_callable_deserialization_non_local():
result = serialize_callable(requests.api.get)
fn = deserialize_callable(result)
assert fn is requests.api.get


def test_callable_deserialization_error():
with pytest.raises(DeserializationError):
deserialize_callable("this.is.not.a.valid.module")
with pytest.raises(DeserializationError):
deserialize_callable("sys.foobar")