11import unittest
2- from unittest import mock
32from typing import List , Optional
3+ from unittest import mock
44
55from transformers import is_tf_available , is_torch_available , pipeline
6- from transformers .tokenization_utils_base import to_py_obj
76from transformers .pipelines import DefaultArgumentHandler , Pipeline
87from transformers .testing_utils import _run_slow_tests , is_pipeline_test , require_tf , require_torch , slow
8+ from transformers .tokenization_utils_base import to_py_obj
99
1010
1111VALID_INPUTS = ["A simple string" , ["list of strings" ]]
1212
1313
14- @is_pipeline_test
14+ # @is_pipeline_test
1515class CustomInputPipelineCommonMixin :
1616 pipeline_task = None
17- pipeline_loading_kwargs = {}
18- small_models = None # Models tested without the @slow decorator
19- large_models = None # Models tested with the @slow decorator
17+ pipeline_loading_kwargs = {} # Additional kwargs to load the pipeline with
18+ pipeline_running_kwargs = {} # Additional kwargs to run the pipeline with
19+ small_models = [] # Models tested without the @slow decorator
20+ large_models = [] # Models tested with the @slow decorator
21+ valid_inputs = VALID_INPUTS # Some inputs which are valid to compare fast and slow tokenizers
2022
2123 def setUp (self ) -> None :
2224 if not is_tf_available () and not is_torch_available ():
@@ -48,78 +50,41 @@ def setUp(self) -> None:
4850 @require_torch
4951 @slow
5052 def test_pt_defaults (self ):
51- pipeline (self .pipeline_task , framework = "pt" )
53+ pipeline (self .pipeline_task , framework = "pt" , ** self . pipeline_loading_kwargs )
5254
5355 @require_tf
5456 @slow
5557 def test_tf_defaults (self ):
56- pipeline (self .pipeline_task , framework = "tf" )
58+ pipeline (self .pipeline_task , framework = "tf" , ** self . pipeline_loading_kwargs )
5759
5860 @require_torch
5961 def test_torch_small (self ):
6062 for model_name in self .small_models :
61- nlp = pipeline (task = self .pipeline_task , model = model_name , tokenizer = model_name , framework = "pt" )
63+ nlp = pipeline (
64+ task = self .pipeline_task ,
65+ model = model_name ,
66+ tokenizer = model_name ,
67+ framework = "pt" ,
68+ ** self .pipeline_loading_kwargs ,
69+ )
6270 self ._test_pipeline (nlp )
6371
6472 @require_tf
6573 def test_tf_small (self ):
6674 for model_name in self .small_models :
67- nlp = pipeline (task = self .pipeline_task , model = model_name , tokenizer = model_name , framework = "tf" )
75+ nlp = pipeline (
76+ task = self .pipeline_task ,
77+ model = model_name ,
78+ tokenizer = model_name ,
79+ framework = "tf" ,
80+ ** self .pipeline_loading_kwargs ,
81+ )
6882 self ._test_pipeline (nlp )
6983
7084 @require_torch
7185 @slow
7286 def test_torch_large (self ):
7387 for model_name in self .large_models :
74- nlp = pipeline (task = self .pipeline_task , model = model_name , tokenizer = model_name , framework = "pt" )
75- self ._test_pipeline (nlp )
76-
77- @require_tf
78- @slow
79- def test_tf_large (self ):
80- for model_name in self .large_models :
81- nlp = pipeline (task = self .pipeline_task , model = model_name , tokenizer = model_name , framework = "tf" )
82- self ._test_pipeline (nlp )
83-
84- def _test_pipeline (self , nlp : Pipeline ):
85- raise NotImplementedError
86-
87-
88- # @is_pipeline_test
89- class MonoInputPipelineCommonMixin :
90- pipeline_task = None
91- pipeline_loading_kwargs = {} # Additional kwargs to load the pipeline with
92- pipeline_running_kwargs = {} # Additional kwargs to run the pipeline with
93- small_models = [] # Models tested without the @slow decorator
94- large_models = [] # Models tested with the @slow decorator
95- mandatory_keys = {} # Keys which should be in the output
96- valid_inputs = VALID_INPUTS # inputs which are valid
97- invalid_inputs = [None ] # inputs which are not allowed
98- expected_multi_result : Optional [List ] = None
99- expected_check_keys : Optional [List [str ]] = None
100-
101- def setUp (self ) -> None :
102- if not is_tf_available () and not is_torch_available ():
103- return # Currently no JAX pipelines
104-
105- for model_name in self .small_models :
106- pipeline (self .pipeline_task , model = model_name , tokenizer = model_name , ** self .pipeline_loading_kwargs )
107- for model_name in self .large_models :
108- pipeline (self .pipeline_task , model = model_name , tokenizer = model_name , ** self .pipeline_loading_kwargs )
109-
110- @require_torch
111- @slow
112- def test_pt_defaults_loads (self ):
113- pipeline (self .pipeline_task , framework = "pt" , ** self .pipeline_loading_kwargs )
114-
115- @require_tf
116- @slow
117- def test_tf_defaults_loads (self ):
118- pipeline (self .pipeline_task , framework = "tf" , ** self .pipeline_loading_kwargs )
119-
120- @require_torch
121- def test_torch_small (self ):
122- for model_name in self .small_models :
12388 nlp = pipeline (
12489 task = self .pipeline_task ,
12590 model = model_name ,
@@ -130,8 +95,9 @@ def test_torch_small(self):
13095 self ._test_pipeline (nlp )
13196
13297 @require_tf
133- def test_tf_small (self ):
134- for model_name in self .small_models :
98+ @slow
99+ def test_tf_large (self ):
100+ for model_name in self .large_models :
135101 nlp = pipeline (
136102 task = self .pipeline_task ,
137103 model = model_name ,
@@ -141,6 +107,9 @@ def test_tf_small(self):
141107 )
142108 self ._test_pipeline (nlp )
143109
110+ def _test_pipeline (self , nlp : Pipeline ):
111+ raise NotImplementedError
112+
144113 @require_torch
145114 def test_compare_slow_fast_torch (self ):
146115 for model_name in self .small_models :
@@ -160,7 +129,7 @@ def test_compare_slow_fast_torch(self):
160129 use_fast = True ,
161130 ** self .pipeline_loading_kwargs ,
162131 )
163- self ._compare_slow_fast_pipelines (nlp_slow , nlp_fast )
132+ self ._compare_slow_fast_pipelines (nlp_slow , nlp_fast , method = "forward" )
164133
165134 @require_tf
166135 def test_compare_slow_fast_tf (self ):
@@ -181,54 +150,51 @@ def test_compare_slow_fast_tf(self):
181150 use_fast = True ,
182151 ** self .pipeline_loading_kwargs ,
183152 )
184- self ._compare_slow_fast_pipelines (nlp_slow , nlp_fast )
185-
186- def _compare_slow_fast_pipelines (self , nlp_slow : Pipeline , nlp_fast : Pipeline ):
187- with mock .patch .object (nlp_slow .model , 'forward' , wraps = nlp_slow .model .forward ) as mock_slow ,\
188- mock .patch .object (nlp_fast .model , 'forward' , wraps = nlp_fast .model .forward ) as mock_fast :
153+ self ._compare_slow_fast_pipelines (nlp_slow , nlp_fast , method = "call" )
154+
155+ def _compare_slow_fast_pipelines (self , nlp_slow : Pipeline , nlp_fast : Pipeline , method : str ):
156+ """We check that the inputs to the models forward passes are identical for
157+ slow and fast tokenizers.
158+ """
159+ with mock .patch .object (
160+ nlp_slow .model , method , wraps = getattr (nlp_slow .model , method )
161+ ) as mock_slow , mock .patch .object (nlp_fast .model , method , wraps = getattr (nlp_fast .model , method )) as mock_fast :
189162 for inputs in self .valid_inputs :
190- outputs_slow = nlp_slow (inputs , ** self .pipeline_running_kwargs )
191- outputs_fast = nlp_fast (inputs , ** self .pipeline_running_kwargs )
163+ if isinstance (inputs , dict ):
164+ inputs .update (self .pipeline_running_kwargs )
165+ _ = nlp_slow (** inputs )
166+ _ = nlp_fast (** inputs )
167+ else :
168+ _ = nlp_slow (inputs , ** self .pipeline_running_kwargs )
169+ _ = nlp_fast (inputs , ** self .pipeline_running_kwargs )
192170
193171 mock_slow .assert_called ()
194172 mock_fast .assert_called ()
195173
196- slow_call_args , slow_call_kwargs = mock_slow .call_args
197- fast_call_args , fast_call_kwargs = mock_fast .call_args
174+ self .assertEqual (len (mock_slow .call_args_list ), len (mock_fast .call_args_list ))
175+ for mock_slow_call_args , mock_fast_call_args in zip (
176+ mock_slow .call_args_list , mock_slow .call_args_list
177+ ):
178+ slow_call_args , slow_call_kwargs = mock_slow_call_args
179+ fast_call_args , fast_call_kwargs = mock_fast_call_args
198180
199- slow_call_args , slow_call_kwargs = to_py_obj (slow_call_args ), to_py_obj (slow_call_kwargs )
200- fast_call_args , fast_call_kwargs = to_py_obj (fast_call_args ), to_py_obj (fast_call_kwargs )
181+ slow_call_args , slow_call_kwargs = to_py_obj (slow_call_args ), to_py_obj (slow_call_kwargs )
182+ fast_call_args , fast_call_kwargs = to_py_obj (fast_call_args ), to_py_obj (fast_call_kwargs )
201183
202- self .assertEqual (slow_call_args , fast_call_args )
203- self .assertDictEqual (slow_call_kwargs , fast_call_kwargs )
184+ self .assertEqual (slow_call_args , fast_call_args )
185+ self .assertDictEqual (slow_call_kwargs , fast_call_kwargs )
204186
205- self .assertEqual (outputs_slow , outputs_fast )
206187
207- @require_torch
208- @slow
209- def test_torch_large (self ):
210- for model_name in self .large_models :
211- nlp = pipeline (
212- task = self .pipeline_task ,
213- model = model_name ,
214- tokenizer = model_name ,
215- framework = "pt" ,
216- ** self .pipeline_loading_kwargs ,
217- )
218- self ._test_pipeline (nlp )
188+ @is_pipeline_test
189+ class MonoInputPipelineCommonMixin (CustomInputPipelineCommonMixin ):
190+ """A version of the CustomInputPipelineCommonMixin
191+ with a predefined `_test_pipeline` method.
192+ """
219193
220- @require_tf
221- @slow
222- def test_tf_large (self ):
223- for model_name in self .large_models :
224- nlp = pipeline (
225- task = self .pipeline_task ,
226- model = model_name ,
227- tokenizer = model_name ,
228- framework = "tf" ,
229- ** self .pipeline_loading_kwargs ,
230- )
231- self ._test_pipeline (nlp )
194+ mandatory_keys = {} # Keys which should be in the output
195+ invalid_inputs = [None ] # inputs which are not allowed
196+ expected_multi_result : Optional [List ] = None
197+ expected_check_keys : Optional [List [str ]] = None
232198
233199 def _test_pipeline (self , nlp : Pipeline ):
234200 self .assertIsNotNone (nlp )
0 commit comments