Skip to content

Commit 708f91a

Browse files
committed
test: verify litellm exception cloudpickle roundtripping
1 parent 90a8122 commit 708f91a

1 file changed

Lines changed: 293 additions & 0 deletions

File tree

tests/test_docket_serialization.py

Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
"""
2+
Test cloudpickle serialization of docket task arguments and exceptions.
3+
4+
Docket uses cloudpickle to serialize:
5+
1. Task arguments (when scheduling: cloudpickle.dumps(args))
6+
2. Task results (on success: cloudpickle.dumps(result))
7+
3. Task exceptions (on failure: cloudpickle.dumps(exception))
8+
9+
litellm exception classes have mandatory __init__ args (message, model,
10+
llm_provider) that cloudpickle doesn't preserve during deserialization.
11+
This means cloudpickle.loads() calls ExceptionClass.__init__() without
12+
the required positional args, raising TypeError. This prevents the docket
13+
worker from storing or reporting task errors.
14+
15+
The fix is in agent_memory_server.litellm_pickle_compat, which patches
16+
litellm exceptions with __reduce__ methods for safe pickle round-tripping.
17+
"""
18+
19+
import threading
20+
21+
import cloudpickle
22+
import httpx
23+
import litellm
24+
import pytest
25+
26+
from agent_memory_server.models import MemoryMessage, MemoryRecord
27+
28+
29+
# Factory to create litellm exceptions with correct constructor args
30+
def _make_litellm_exc(exc_class, **overrides):
31+
"""Create a litellm exception instance with the right constructor args."""
32+
mock_response = httpx.Response(
33+
status_code=500,
34+
request=httpx.Request(method="POST", url="https://test.example.com"),
35+
)
36+
kwargs = {"message": "test error", "model": "test-model", "llm_provider": "test"}
37+
kwargs.update(overrides)
38+
39+
# Classes that require response as a non-optional positional arg
40+
if exc_class in (
41+
litellm.exceptions.PermissionDeniedError,
42+
litellm.exceptions.UnprocessableEntityError,
43+
):
44+
kwargs.setdefault("response", mock_response)
45+
46+
# APIError requires status_code as first positional arg
47+
if exc_class is litellm.exceptions.APIError:
48+
kwargs.setdefault("status_code", 500)
49+
50+
return exc_class(**kwargs)
51+
52+
53+
class TestTaskArgumentSerialization:
54+
"""Verify that task arguments can be serialized by docket."""
55+
56+
def test_memory_record_serializes(self):
57+
"""MemoryRecord is the primary task argument for extract_memory_structure."""
58+
record = MemoryRecord(
59+
id="test-123",
60+
text="User prefers dark mode",
61+
user_id="alice",
62+
namespace="test",
63+
)
64+
data = cloudpickle.dumps(record)
65+
restored = cloudpickle.loads(data)
66+
assert restored.id == "test-123"
67+
assert restored.text == "User prefers dark mode"
68+
69+
def test_memory_message_serializes(self):
70+
"""MemoryMessage has a ClassVar threading.Lock for deprecation warnings."""
71+
msg = MemoryMessage(role="user", content="Hello")
72+
data = cloudpickle.dumps(msg)
73+
restored = cloudpickle.loads(data)
74+
assert restored.role == "user"
75+
assert restored.content == "Hello"
76+
77+
def test_list_of_memory_records_serializes(self):
78+
"""index_long_term_memories takes a list of MemoryRecord."""
79+
records = [MemoryRecord(id=f"test-{i}", text=f"Memory {i}") for i in range(5)]
80+
data = cloudpickle.dumps(records)
81+
restored = cloudpickle.loads(data)
82+
assert len(restored) == 5
83+
84+
85+
class TestExceptionSerializationBaseline:
86+
"""Verify baseline serialization behavior (non-litellm exceptions)."""
87+
88+
def test_plain_exception_serializes(self):
89+
exc = ValueError("something went wrong")
90+
data = cloudpickle.dumps(exc)
91+
restored = cloudpickle.loads(data)
92+
assert str(restored) == "something went wrong"
93+
94+
def test_threading_lock_does_not_serialize(self):
95+
lock = threading.Lock()
96+
with pytest.raises(TypeError, match="cannot pickle"):
97+
cloudpickle.dumps(lock)
98+
99+
def test_httpx_client_does_not_serialize(self):
100+
client = httpx.Client(timeout=10.0)
101+
with pytest.raises(TypeError, match="cannot pickle"):
102+
cloudpickle.dumps(client)
103+
client.close()
104+
105+
def test_httpx_connect_error_serializes(self):
106+
exc = httpx.ConnectError("connection failed")
107+
data = cloudpickle.dumps(exc)
108+
restored = cloudpickle.loads(data)
109+
assert isinstance(restored, httpx.ConnectError)
110+
111+
def test_httpx_timeout_error_serializes(self):
112+
exc = httpx.ReadTimeout("Connection timed out")
113+
data = cloudpickle.dumps(exc)
114+
restored = cloudpickle.loads(data)
115+
assert isinstance(restored, httpx.ReadTimeout)
116+
117+
def test_exception_from_memory_message_validation_serializes(self):
118+
try:
119+
MemoryMessage(role=123, content=456) # type: ignore
120+
except Exception as e:
121+
try:
122+
data = cloudpickle.dumps(e)
123+
cloudpickle.loads(data)
124+
except TypeError as pickle_err:
125+
pytest.fail(
126+
f"MemoryMessage validation exception cannot be pickled: {pickle_err}"
127+
)
128+
129+
def test_exception_with_traceback_from_locked_class_serializes(self):
130+
class ServiceWithLock:
131+
_lock = threading.Lock()
132+
133+
def do_work(self):
134+
with self._lock:
135+
raise RuntimeError("LLM call failed")
136+
137+
svc = ServiceWithLock()
138+
try:
139+
svc.do_work()
140+
except RuntimeError as e:
141+
try:
142+
data = cloudpickle.dumps(e)
143+
cloudpickle.loads(data)
144+
except TypeError as pickle_err:
145+
pytest.fail(
146+
f"Exception with lock in traceback cannot be pickled: {pickle_err}"
147+
)
148+
149+
def test_chained_exception_with_httpx_context_serializes(self):
150+
connect_err = httpx.ConnectError("connection failed")
151+
152+
with pytest.raises(RuntimeError) as excinfo:
153+
try:
154+
raise connect_err
155+
except Exception as err:
156+
raise RuntimeError("LLM extraction failed") from err
157+
158+
e = excinfo.value
159+
assert e.__cause__ is not None
160+
assert isinstance(e.__cause__, httpx.ConnectError)
161+
try:
162+
data = cloudpickle.dumps(e)
163+
cloudpickle.loads(data)
164+
except TypeError as pickle_err:
165+
pytest.fail(
166+
f"Chained exception with httpx context cannot be pickled: {pickle_err}"
167+
)
168+
169+
170+
class TestLiteLLMExceptionBugProof:
171+
"""
172+
Prove the underlying bug: litellm exception __init__ requires positional
173+
args that cloudpickle doesn't preserve.
174+
175+
We demonstrate this by calling __init__ without the required args,
176+
which is what cloudpickle does internally during deserialization.
177+
"""
178+
179+
@pytest.mark.parametrize(
180+
"exc_class",
181+
[
182+
litellm.exceptions.APIConnectionError,
183+
litellm.exceptions.RateLimitError,
184+
litellm.exceptions.Timeout,
185+
litellm.exceptions.ServiceUnavailableError,
186+
litellm.exceptions.BadRequestError,
187+
litellm.exceptions.AuthenticationError,
188+
litellm.exceptions.NotFoundError,
189+
litellm.exceptions.ContentPolicyViolationError,
190+
],
191+
ids=lambda c: c.__name__,
192+
)
193+
def test_litellm_init_requires_positional_args(self, exc_class):
194+
"""
195+
litellm exceptions cannot be constructed without message, model,
196+
and llm_provider. This is why cloudpickle deserialization fails:
197+
it calls __init__() with no args.
198+
199+
docket/worker.py line ~1001 calls cloudpickle.dumps(e) on every
200+
failed task. Without a __reduce__ patch, the deserialized exception
201+
would fail to reconstruct.
202+
"""
203+
with pytest.raises(TypeError, match="missing.*required"):
204+
exc_class()
205+
206+
207+
class TestLiteLLMExceptionPatched:
208+
"""
209+
Verify the fix: with litellm_pickle_compat imported, all litellm
210+
exceptions roundtrip through cloudpickle successfully.
211+
212+
The patch adds __reduce__ methods that bypass __init__ on deserialization,
213+
using Exception.__new__() and restoring __dict__ directly.
214+
"""
215+
216+
@classmethod
217+
def setup_class(cls):
218+
"""Ensure the pickle compat patch is applied."""
219+
import agent_memory_server.litellm_pickle_compat # noqa: F401
220+
221+
@pytest.mark.parametrize(
222+
"exc_class",
223+
[
224+
litellm.exceptions.APIConnectionError,
225+
litellm.exceptions.RateLimitError,
226+
litellm.exceptions.Timeout,
227+
litellm.exceptions.ServiceUnavailableError,
228+
litellm.exceptions.BadRequestError,
229+
litellm.exceptions.AuthenticationError,
230+
litellm.exceptions.NotFoundError,
231+
litellm.exceptions.ContentPolicyViolationError,
232+
litellm.exceptions.InternalServerError,
233+
litellm.exceptions.BadGatewayError,
234+
litellm.exceptions.PermissionDeniedError,
235+
litellm.exceptions.UnprocessableEntityError,
236+
litellm.exceptions.APIError,
237+
litellm.exceptions.APIResponseValidationError,
238+
litellm.exceptions.ContextWindowExceededError,
239+
],
240+
ids=lambda c: c.__name__,
241+
)
242+
def test_patched_litellm_exception_roundtrips(self, exc_class):
243+
"""All litellm exceptions roundtrip through cloudpickle after patching."""
244+
exc = _make_litellm_exc(exc_class)
245+
data = cloudpickle.dumps(exc)
246+
restored = cloudpickle.loads(data)
247+
assert isinstance(restored, Exception)
248+
assert "test error" in restored.message
249+
assert restored.model == "test-model"
250+
assert restored.llm_provider == "test"
251+
252+
def test_patched_exception_preserves_status_code(self):
253+
"""Status codes survive the roundtrip."""
254+
exc = _make_litellm_exc(litellm.exceptions.RateLimitError)
255+
data = cloudpickle.dumps(exc)
256+
restored = cloudpickle.loads(data)
257+
assert restored.status_code == 429
258+
259+
def test_patched_exception_preserves_str_representation(self):
260+
"""str() on the restored exception contains the error message."""
261+
exc = _make_litellm_exc(
262+
litellm.exceptions.APIConnectionError,
263+
message="connection refused",
264+
)
265+
data = cloudpickle.dumps(exc)
266+
restored = cloudpickle.loads(data)
267+
assert "connection refused" in str(restored)
268+
269+
def test_patched_exception_preserves_chaining(self):
270+
"""Exception chaining (__cause__) survives the roundtrip."""
271+
cause = ValueError("upstream failure")
272+
exc = _make_litellm_exc(
273+
litellm.exceptions.APIConnectionError,
274+
message="connection failed",
275+
)
276+
exc.__cause__ = cause
277+
data = cloudpickle.dumps(exc)
278+
restored = cloudpickle.loads(data)
279+
assert restored.__cause__ is not None
280+
assert isinstance(restored.__cause__, ValueError)
281+
assert str(restored.__cause__) == "upstream failure"
282+
283+
def test_patch_is_idempotent(self):
284+
"""Calling patch() multiple times is safe."""
285+
import agent_memory_server.litellm_pickle_compat as compat
286+
287+
compat.patch()
288+
compat.patch()
289+
290+
exc = _make_litellm_exc(litellm.exceptions.Timeout, message="timed out")
291+
data = cloudpickle.dumps(exc)
292+
restored = cloudpickle.loads(data)
293+
assert "timed out" in restored.message

0 commit comments

Comments
 (0)