Skip to content

Commit 53565e2

Browse files
thomasahleclaude
andcommitted
fix(streaming): skip model validators during partial streaming
Model validators (mode="after") can fail during streaming when they reference fields that haven't arrived yet. This commit adds automatic wrapping of model validators to skip them during streaming. Changes: - Pass context={"partial_streaming": True} during streaming validation - Wrap model validators to check context and skip during streaming - Add PartialLiteralMixin for Literal/Enum types (uses partial_mode="on") - Add comprehensive tests for validator behavior during streaming The validators run normally during final validation (without streaming context), ensuring data integrity while allowing smooth streaming. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 2c87556 commit 53565e2

2 files changed

Lines changed: 642 additions & 12 deletions

File tree

instructor/dsl/partial.py

Lines changed: 107 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
)
2828

2929
from jiter import from_json
30-
from pydantic import BaseModel, create_model
30+
from pydantic import BaseModel, create_model, model_validator
3131
from pydantic.fields import FieldInfo
3232

3333
from instructor.mode import Mode
@@ -47,6 +47,20 @@ class MakeFieldsOptional:
4747

4848

4949
class PartialLiteralMixin:
50+
"""Mixin to handle Literal and Enum types during streaming.
51+
52+
When using partial streaming with models that contain Literal or Enum fields,
53+
incomplete strings like "act" (for "active") can cause validation errors.
54+
55+
Adding this mixin to your model switches from partial_mode='trailing-strings'
56+
to partial_mode='on', which drops incomplete strings entirely instead of
57+
keeping them as partial values.
58+
59+
Example:
60+
class MyModel(BaseModel, PartialLiteralMixin):
61+
status: Literal["active", "inactive"]
62+
"""
63+
5064
pass
5165

5266

@@ -58,7 +72,10 @@ def process_potential_object(potential_object, partial_mode, partial_model, **kw
5872
obj = from_json(
5973
(potential_object.strip() or "{}").encode(), partial_mode=partial_mode
6074
)
61-
obj = partial_model.model_validate(obj, strict=None, **kwargs)
75+
# Pass context to skip model validators during streaming
76+
obj = partial_model.model_validate(
77+
obj, strict=None, context={"partial_streaming": True}, **kwargs
78+
)
6279
return obj
6380

6481

@@ -67,6 +84,7 @@ def _process_generic_arg(
6784
make_fields_optional: bool = False,
6885
) -> Any:
6986
arg_origin = get_origin(arg)
87+
7088
if arg_origin is not None:
7189
# Handle any nested generic type (Union, List, Dict, etc.)
7290
nested_args = get_args(arg)
@@ -100,7 +118,7 @@ def _make_field_optional(
100118

101119
annotation = field.annotation
102120

103-
# Handle generics (like List, Dict, etc.)
121+
# Handle generics (like List, Dict, Union, Literal, etc.)
104122
if get_origin(annotation) is not None:
105123
# Get the generic base (like List, Dict) and its arguments (like User in List[User])
106124
generic_base = get_origin(annotation)
@@ -134,7 +152,12 @@ class PartialBase(Generic[T_Model]):
134152
@classmethod
135153
@cache
136154
def get_partial_model(cls) -> type[T_Model]:
137-
"""Return a partial model we can use to validate partial results."""
155+
"""Return a partial model we can use to validate partial results.
156+
157+
This method creates a model with all fields optional and wraps any
158+
model validators to skip validation when fields are incomplete
159+
(None during streaming).
160+
"""
138161
assert issubclass(cls, BaseModel), (
139162
f"{cls.__name__} must be a subclass of BaseModel"
140163
)
@@ -145,7 +168,8 @@ def get_partial_model(cls) -> type[T_Model]:
145168
else f"Partial{cls.__name__}"
146169
)
147170

148-
return create_model(
171+
# Create the base partial model with optional fields
172+
partial_model = create_model(
149173
model_name,
150174
__base__=cls,
151175
__module__=cls.__module__,
@@ -155,6 +179,75 @@ def get_partial_model(cls) -> type[T_Model]:
155179
}, # type: ignore[all]
156180
)
157181

182+
# Check if there are any model validators to wrap
183+
model_validators = cls.__pydantic_decorators__.model_validators
184+
if not model_validators:
185+
return partial_model
186+
187+
# Collect original validator functions
188+
original_validators = {
189+
name: decorator
190+
for name, decorator in model_validators.items()
191+
if decorator.info.mode == "after"
192+
}
193+
194+
if not original_validators:
195+
return partial_model
196+
197+
# Create a subclass that overrides model validators to skip during streaming
198+
def create_streaming_safe_validator(orig_validators):
199+
from pydantic import ValidationInfo
200+
201+
@model_validator(mode="wrap")
202+
@classmethod
203+
def streaming_safe_validator(_cls, values, handler, info: ValidationInfo):
204+
# First, run the default Pydantic validation (field validation, etc.)
205+
model = handler(values)
206+
207+
# Check if we're in partial streaming mode via context
208+
context = info.context or {}
209+
if context.get("partial_streaming"):
210+
# Skip model validators during streaming
211+
return model
212+
213+
# Not streaming - run original validators
214+
for _, decorator in orig_validators.items():
215+
result = decorator.func(model)
216+
if result is not None:
217+
model = result
218+
219+
return model
220+
221+
return streaming_safe_validator
222+
223+
def create_noop_validator():
224+
@model_validator(mode="wrap")
225+
@classmethod
226+
def noop_validator(_cls, values, handler, _info):
227+
return handler(values)
228+
229+
return noop_validator
230+
231+
# Create the wrapper validator that runs all original validators
232+
wrapper_validator = create_streaming_safe_validator(original_validators)
233+
234+
# Create validators dict: first validator is the wrapper, rest are no-ops
235+
# This ensures all parent validators are overridden
236+
validators_dict = {}
237+
validator_names = list(original_validators.keys())
238+
validators_dict[validator_names[0]] = wrapper_validator
239+
for name in validator_names[1:]:
240+
validators_dict[name] = create_noop_validator()
241+
242+
wrapped_model = create_model(
243+
model_name,
244+
__base__=partial_model,
245+
__module__=cls.__module__,
246+
__validators__=validators_dict,
247+
)
248+
249+
return wrapped_model
250+
158251
@classmethod
159252
def from_streaming_response(
160253
cls, completion: Iterable[Any], mode: Mode, **kwargs: Any
@@ -206,7 +299,9 @@ def writer_model_from_chunks(
206299
obj = from_json(
207300
(potential_object.strip() or "{}").encode(), partial_mode=partial_mode
208301
)
209-
obj = partial_model.model_validate(obj, strict=None, **kwargs)
302+
obj = partial_model.model_validate(
303+
obj, strict=None, context={"partial_streaming": True}, **kwargs
304+
)
210305
yield obj
211306

212307
@classmethod
@@ -230,7 +325,9 @@ async def writer_model_from_chunks_async(
230325
obj = from_json(
231326
(potential_object.strip() or "{}").encode(), partial_mode=partial_mode
232327
)
233-
obj = partial_model.model_validate(obj, strict=None, **kwargs)
328+
obj = partial_model.model_validate(
329+
obj, strict=None, context={"partial_streaming": True}, **kwargs
330+
)
234331
yield obj
235332

236333
@classmethod
@@ -274,7 +371,9 @@ async def model_from_chunks_async(
274371
obj = from_json(
275372
(potential_object.strip() or "{}").encode(), partial_mode=partial_mode
276373
)
277-
obj = partial_model.model_validate(obj, strict=None, **kwargs)
374+
obj = partial_model.model_validate(
375+
obj, strict=None, context={"partial_streaming": True}, **kwargs
376+
)
278377
yield obj
279378

280379
@staticmethod

0 commit comments

Comments
 (0)