2727)
2828
2929from jiter import from_json
30- from pydantic import BaseModel , create_model
30+ from pydantic import BaseModel , create_model , model_validator
3131from pydantic .fields import FieldInfo
3232
3333from instructor .mode import Mode
@@ -47,6 +47,20 @@ class MakeFieldsOptional:
4747
4848
4949class PartialLiteralMixin :
50+ """Mixin to handle Literal and Enum types during streaming.
51+
52+ When using partial streaming with models that contain Literal or Enum fields,
53+ incomplete strings like "act" (for "active") can cause validation errors.
54+
55+ Adding this mixin to your model switches from partial_mode='trailing-strings'
56+ to partial_mode='on', which drops incomplete strings entirely instead of
57+ keeping them as partial values.
58+
59+ Example:
60+ class MyModel(BaseModel, PartialLiteralMixin):
61+ status: Literal["active", "inactive"]
62+ """
63+
5064 pass
5165
5266
@@ -58,7 +72,10 @@ def process_potential_object(potential_object, partial_mode, partial_model, **kw
5872 obj = from_json (
5973 (potential_object .strip () or "{}" ).encode (), partial_mode = partial_mode
6074 )
61- obj = partial_model .model_validate (obj , strict = None , ** kwargs )
75+ # Pass context to skip model validators during streaming
76+ obj = partial_model .model_validate (
77+ obj , strict = None , context = {"partial_streaming" : True }, ** kwargs
78+ )
6279 return obj
6380
6481
@@ -67,6 +84,7 @@ def _process_generic_arg(
6784 make_fields_optional : bool = False ,
6885) -> Any :
6986 arg_origin = get_origin (arg )
87+
7088 if arg_origin is not None :
7189 # Handle any nested generic type (Union, List, Dict, etc.)
7290 nested_args = get_args (arg )
@@ -100,7 +118,7 @@ def _make_field_optional(
100118
101119 annotation = field .annotation
102120
103- # Handle generics (like List, Dict, etc.)
121+ # Handle generics (like List, Dict, Union, Literal, etc.)
104122 if get_origin (annotation ) is not None :
105123 # Get the generic base (like List, Dict) and its arguments (like User in List[User])
106124 generic_base = get_origin (annotation )
@@ -134,7 +152,12 @@ class PartialBase(Generic[T_Model]):
134152 @classmethod
135153 @cache
136154 def get_partial_model (cls ) -> type [T_Model ]:
137- """Return a partial model we can use to validate partial results."""
155+ """Return a partial model we can use to validate partial results.
156+
157+ This method creates a model with all fields optional and wraps any
158+ model validators to skip validation when fields are incomplete
159+ (None during streaming).
160+ """
138161 assert issubclass (cls , BaseModel ), (
139162 f"{ cls .__name__ } must be a subclass of BaseModel"
140163 )
@@ -145,7 +168,8 @@ def get_partial_model(cls) -> type[T_Model]:
145168 else f"Partial{ cls .__name__ } "
146169 )
147170
148- return create_model (
171+ # Create the base partial model with optional fields
172+ partial_model = create_model (
149173 model_name ,
150174 __base__ = cls ,
151175 __module__ = cls .__module__ ,
@@ -155,6 +179,75 @@ def get_partial_model(cls) -> type[T_Model]:
155179 }, # type: ignore[all]
156180 )
157181
182+ # Check if there are any model validators to wrap
183+ model_validators = cls .__pydantic_decorators__ .model_validators
184+ if not model_validators :
185+ return partial_model
186+
187+ # Collect original validator functions
188+ original_validators = {
189+ name : decorator
190+ for name , decorator in model_validators .items ()
191+ if decorator .info .mode == "after"
192+ }
193+
194+ if not original_validators :
195+ return partial_model
196+
197+ # Create a subclass that overrides model validators to skip during streaming
198+ def create_streaming_safe_validator (orig_validators ):
199+ from pydantic import ValidationInfo
200+
201+ @model_validator (mode = "wrap" )
202+ @classmethod
203+ def streaming_safe_validator (_cls , values , handler , info : ValidationInfo ):
204+ # First, run the default Pydantic validation (field validation, etc.)
205+ model = handler (values )
206+
207+ # Check if we're in partial streaming mode via context
208+ context = info .context or {}
209+ if context .get ("partial_streaming" ):
210+ # Skip model validators during streaming
211+ return model
212+
213+ # Not streaming - run original validators
214+ for _ , decorator in orig_validators .items ():
215+ result = decorator .func (model )
216+ if result is not None :
217+ model = result
218+
219+ return model
220+
221+ return streaming_safe_validator
222+
223+ def create_noop_validator ():
224+ @model_validator (mode = "wrap" )
225+ @classmethod
226+ def noop_validator (_cls , values , handler , _info ):
227+ return handler (values )
228+
229+ return noop_validator
230+
231+ # Create the wrapper validator that runs all original validators
232+ wrapper_validator = create_streaming_safe_validator (original_validators )
233+
234+ # Create validators dict: first validator is the wrapper, rest are no-ops
235+ # This ensures all parent validators are overridden
236+ validators_dict = {}
237+ validator_names = list (original_validators .keys ())
238+ validators_dict [validator_names [0 ]] = wrapper_validator
239+ for name in validator_names [1 :]:
240+ validators_dict [name ] = create_noop_validator ()
241+
242+ wrapped_model = create_model (
243+ model_name ,
244+ __base__ = partial_model ,
245+ __module__ = cls .__module__ ,
246+ __validators__ = validators_dict ,
247+ )
248+
249+ return wrapped_model
250+
158251 @classmethod
159252 def from_streaming_response (
160253 cls , completion : Iterable [Any ], mode : Mode , ** kwargs : Any
@@ -206,7 +299,9 @@ def writer_model_from_chunks(
206299 obj = from_json (
207300 (potential_object .strip () or "{}" ).encode (), partial_mode = partial_mode
208301 )
209- obj = partial_model .model_validate (obj , strict = None , ** kwargs )
302+ obj = partial_model .model_validate (
303+ obj , strict = None , context = {"partial_streaming" : True }, ** kwargs
304+ )
210305 yield obj
211306
212307 @classmethod
@@ -230,7 +325,9 @@ async def writer_model_from_chunks_async(
230325 obj = from_json (
231326 (potential_object .strip () or "{}" ).encode (), partial_mode = partial_mode
232327 )
233- obj = partial_model .model_validate (obj , strict = None , ** kwargs )
328+ obj = partial_model .model_validate (
329+ obj , strict = None , context = {"partial_streaming" : True }, ** kwargs
330+ )
234331 yield obj
235332
236333 @classmethod
@@ -274,7 +371,9 @@ async def model_from_chunks_async(
274371 obj = from_json (
275372 (potential_object .strip () or "{}" ).encode (), partial_mode = partial_mode
276373 )
277- obj = partial_model .model_validate (obj , strict = None , ** kwargs )
374+ obj = partial_model .model_validate (
375+ obj , strict = None , context = {"partial_streaming" : True }, ** kwargs
376+ )
278377 yield obj
279378
280379 @staticmethod
0 commit comments