@@ -177,8 +177,8 @@ def inner(items):
177177def load_model (
178178 model ,
179179 config : AutoConfig ,
180- model_classes : Optional [ tuple [type , ...]] = None ,
181- task : Optional [ str ] = None ,
180+ model_classes : tuple [type , ...] | None = None ,
181+ task : str | None = None ,
182182 ** model_kwargs ,
183183):
184184 """
@@ -270,7 +270,7 @@ def load_model(
270270 return model
271271
272272
273- def get_default_model_and_revision (targeted_task : dict , task_options : Optional [ Any ] ) -> tuple [str , str ]:
273+ def get_default_model_and_revision (targeted_task : dict , task_options : Any | None ) -> tuple [str , str ]:
274274 """
275275 Select a default model to use for a given task.
276276
@@ -305,9 +305,9 @@ def get_default_model_and_revision(targeted_task: dict, task_options: Optional[A
305305
306306def load_assistant_model (
307307 model : "PreTrainedModel" ,
308- assistant_model : Optional [ Union [str , "PreTrainedModel" ]] ,
309- assistant_tokenizer : Optional [ PreTrainedTokenizer ] ,
310- ) -> tuple [Optional ["PreTrainedModel" ], Optional [ PreTrainedTokenizer ] ]:
308+ assistant_model : Union [str , "PreTrainedModel" ] | None ,
309+ assistant_tokenizer : PreTrainedTokenizer | None ,
310+ ) -> tuple [Optional ["PreTrainedModel" ], PreTrainedTokenizer | None ]:
311311 """
312312 Prepares the assistant model and the assistant tokenizer for a pipeline whose model that can call `generate`.
313313
@@ -404,9 +404,9 @@ class PipelineDataFormat:
404404
405405 def __init__ (
406406 self ,
407- output_path : Optional [ str ] ,
408- input_path : Optional [ str ] ,
409- column : Optional [ str ] ,
407+ output_path : str | None ,
408+ input_path : str | None ,
409+ column : str | None ,
410410 overwrite : bool = False ,
411411 ):
412412 self .output_path = output_path
@@ -430,7 +430,7 @@ def __iter__(self):
430430 raise NotImplementedError ()
431431
432432 @abstractmethod
433- def save (self , data : Union [ dict , list [dict ] ]):
433+ def save (self , data : dict | list [dict ]):
434434 """
435435 Save the provided data object with the representation for the current [`~pipelines.PipelineDataFormat`].
436436
@@ -439,7 +439,7 @@ def save(self, data: Union[dict, list[dict]]):
439439 """
440440 raise NotImplementedError ()
441441
442- def save_binary (self , data : Union [ dict , list [dict ] ]) -> str :
442+ def save_binary (self , data : dict | list [dict ]) -> str :
443443 """
444444 Save the provided data object as a pickle-formatted binary data on the disk.
445445
@@ -460,9 +460,9 @@ def save_binary(self, data: Union[dict, list[dict]]) -> str:
460460 @staticmethod
461461 def from_str (
462462 format : str ,
463- output_path : Optional [ str ] ,
464- input_path : Optional [ str ] ,
465- column : Optional [ str ] ,
463+ output_path : str | None ,
464+ input_path : str | None ,
465+ column : str | None ,
466466 overwrite = False ,
467467 ) -> "PipelineDataFormat" :
468468 """
@@ -507,9 +507,9 @@ class CsvPipelineDataFormat(PipelineDataFormat):
507507
508508 def __init__ (
509509 self ,
510- output_path : Optional [ str ] ,
511- input_path : Optional [ str ] ,
512- column : Optional [ str ] ,
510+ output_path : str | None ,
511+ input_path : str | None ,
512+ column : str | None ,
513513 overwrite = False ,
514514 ):
515515 super ().__init__ (output_path , input_path , column , overwrite = overwrite )
@@ -551,9 +551,9 @@ class JsonPipelineDataFormat(PipelineDataFormat):
551551
552552 def __init__ (
553553 self ,
554- output_path : Optional [ str ] ,
555- input_path : Optional [ str ] ,
556- column : Optional [ str ] ,
554+ output_path : str | None ,
555+ input_path : str | None ,
556+ column : str | None ,
557557 overwrite = False ,
558558 ):
559559 super ().__init__ (output_path , input_path , column , overwrite = overwrite )
@@ -617,7 +617,7 @@ def save(self, data: dict):
617617 """
618618 print (data )
619619
620- def save_binary (self , data : Union [ dict , list [dict ] ]) -> str :
620+ def save_binary (self , data : dict | list [dict ]) -> str :
621621 if self .output_path is None :
622622 raise KeyError (
623623 "When using piped input on pipeline outputting large object requires an output file path. "
@@ -776,13 +776,13 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
776776 def __init__ (
777777 self ,
778778 model : "PreTrainedModel" ,
779- tokenizer : Optional [ PreTrainedTokenizer ] = None ,
779+ tokenizer : PreTrainedTokenizer | None = None ,
780780 feature_extractor : Optional [PreTrainedFeatureExtractor ] = None ,
781- image_processor : Optional [ BaseImageProcessor ] = None ,
782- processor : Optional [ ProcessorMixin ] = None ,
783- modelcard : Optional [ ModelCard ] = None ,
781+ image_processor : BaseImageProcessor | None = None ,
782+ processor : ProcessorMixin | None = None ,
783+ modelcard : ModelCard | None = None ,
784784 task : str = "" ,
785- device : Optional [ Union [int , "torch.device" ]] = None ,
785+ device : Union [int , "torch.device" ] | None = None ,
786786 binary_output : bool = False ,
787787 ** kwargs ,
788788 ):
@@ -939,7 +939,7 @@ def __init__(
939939
940940 def save_pretrained (
941941 self ,
942- save_directory : Union [ str , os .PathLike ] ,
942+ save_directory : str | os .PathLike ,
943943 safe_serialization : bool = True ,
944944 ** kwargs : Any ,
945945 ):
@@ -1085,7 +1085,7 @@ def _ensure_tensor_on_device(self, inputs, device):
10851085 else :
10861086 return inputs
10871087
1088- def check_model_type (self , supported_models : Union [ list [str ], dict ] ):
1088+ def check_model_type (self , supported_models : list [str ] | dict ):
10891089 """
10901090 Check if the model class is in supported by the pipeline.
10911091
@@ -1348,9 +1348,9 @@ def register_pipeline(
13481348 self ,
13491349 task : str ,
13501350 pipeline_class : type ,
1351- pt_model : Optional [ Union [ type , tuple [type ]]] = None ,
1352- default : Optional [ dict ] = None ,
1353- type : Optional [ str ] = None ,
1351+ pt_model : type | tuple [type ] | None = None ,
1352+ default : dict | None = None ,
1353+ type : str | None = None ,
13541354 ) -> None :
13551355 if task in self .supported_tasks :
13561356 logger .warning (f"{ task } is already registered. Overwriting pipeline for task { task } ..." )
0 commit comments