1818import json
1919import logging
2020import os
21+ import shutil
22+ import subprocess
2123import sys
22- from unittest . mock import patch
24+ import tempfile
2325
2426import torch
2527
28+ from accelerate .utils import write_basic_config
2629from transformers .testing_utils import TestCasePlus , get_gpu_count , slow , torch_device
2730from transformers .utils import is_apex_available
2831
2932
30- SRC_DIRS = [
31- os .path .join (os .path .dirname (__file__ ), dirname )
32- for dirname in [
33- "text-generation" ,
34- "text-classification" ,
35- "token-classification" ,
36- "language-modeling" ,
37- "multiple-choice" ,
38- "question-answering" ,
39- "summarization" ,
40- "translation" ,
41- "image-classification" ,
42- "speech-recognition" ,
43- "audio-classification" ,
44- "speech-pretraining" ,
45- "image-pretraining" ,
46- "semantic-segmentation" ,
47- ]
48- ]
49- sys .path .extend (SRC_DIRS )
50-
51-
52- if SRC_DIRS is not None :
53- import run_clm_no_trainer
54- import run_glue_no_trainer
55- import run_image_classification_no_trainer
56- import run_mlm_no_trainer
57- import run_ner_no_trainer
58- import run_qa_no_trainer as run_squad_no_trainer
59- import run_semantic_segmentation_no_trainer
60- import run_summarization_no_trainer
61- import run_swag_no_trainer
62- import run_translation_no_trainer
63-
6433logging .basicConfig (level = logging .DEBUG )
6534
6635logger = logging .getLogger ()
@@ -94,10 +63,22 @@ def is_cuda_and_apex_available():
9463
9564
9665class ExamplesTestsNoTrainer (TestCasePlus ):
66+ @classmethod
67+ def setUpClass (cls ):
68+ # Write Accelerate config, will pick up on CPU, GPU, and multi-GPU
69+ cls .tmpdir = tempfile .mkdtemp ()
70+ cls .configPath = os .path .join (cls .tmpdir , "default_config.yml" )
71+ write_basic_config (save_location = cls .configPath )
72+ cls ._launch_args = ["accelerate" , "launch" , "--config_file" , cls .configPath ]
73+
74+ @classmethod
75+ def tearDownClass (cls ):
76+ shutil .rmtree (cls .tmpdir )
77+
9778 def test_run_glue_no_trainer (self ):
9879 tmp_dir = self .get_auto_remove_tmp_dir ()
9980 testargs = f"""
100- run_glue_no_trainer.py
81+ { self . examples_dir } /pytorch/text-classification/ run_glue_no_trainer.py
10182 --model_name_or_path distilbert-base-uncased
10283 --output_dir { tmp_dir }
10384 --train_file ./tests/fixtures/tests_samples/MRPC/train.csv
@@ -113,17 +94,16 @@ def test_run_glue_no_trainer(self):
11394 if is_cuda_and_apex_available ():
11495 testargs .append ("--fp16" )
11596
116- with patch .object (sys , "argv" , testargs ):
117- run_glue_no_trainer .main ()
118- result = get_results (tmp_dir )
119- self .assertGreaterEqual (result ["eval_accuracy" ], 0.75 )
120- self .assertTrue (os .path .exists (os .path .join (tmp_dir , "epoch_0" )))
121- self .assertTrue (os .path .exists (os .path .join (tmp_dir , "glue_no_trainer" )))
97+ _ = subprocess .run (self ._launch_args + testargs , stdout = subprocess .PIPE )
98+ result = get_results (tmp_dir )
99+ self .assertGreaterEqual (result ["eval_accuracy" ], 0.75 )
100+ self .assertTrue (os .path .exists (os .path .join (tmp_dir , "epoch_0" )))
101+ self .assertTrue (os .path .exists (os .path .join (tmp_dir , "glue_no_trainer" )))
122102
123103 def test_run_clm_no_trainer (self ):
124104 tmp_dir = self .get_auto_remove_tmp_dir ()
125105 testargs = f"""
126- run_clm_no_trainer.py
106+ { self . examples_dir } /pytorch/language-modeling/ run_clm_no_trainer.py
127107 --model_name_or_path distilgpt2
128108 --train_file ./tests/fixtures/sample_text.txt
129109 --validation_file ./tests/fixtures/sample_text.txt
@@ -140,17 +120,16 @@ def test_run_clm_no_trainer(self):
140120 # Skipping because there are not enough batches to train the model + would need a drop_last to work.
141121 return
142122
143- with patch .object (sys , "argv" , testargs ):
144- run_clm_no_trainer .main ()
145- result = get_results (tmp_dir )
146- self .assertLess (result ["perplexity" ], 100 )
147- self .assertTrue (os .path .exists (os .path .join (tmp_dir , "epoch_0" )))
148- self .assertTrue (os .path .exists (os .path .join (tmp_dir , "clm_no_trainer" )))
123+ _ = subprocess .run (self ._launch_args + testargs , stdout = subprocess .PIPE )
124+ result = get_results (tmp_dir )
125+ self .assertLess (result ["perplexity" ], 100 )
126+ self .assertTrue (os .path .exists (os .path .join (tmp_dir , "epoch_0" )))
127+ self .assertTrue (os .path .exists (os .path .join (tmp_dir , "clm_no_trainer" )))
149128
150129 def test_run_mlm_no_trainer (self ):
151130 tmp_dir = self .get_auto_remove_tmp_dir ()
152131 testargs = f"""
153- run_mlm_no_trainer.py
132+ { self . examples_dir } /pytorch/language-modeling/ run_mlm_no_trainer.py
154133 --model_name_or_path distilroberta-base
155134 --train_file ./tests/fixtures/sample_text.txt
156135 --validation_file ./tests/fixtures/sample_text.txt
@@ -160,20 +139,19 @@ def test_run_mlm_no_trainer(self):
160139 --with_tracking
161140 """ .split ()
162141
163- with patch .object (sys , "argv" , testargs ):
164- run_mlm_no_trainer .main ()
165- result = get_results (tmp_dir )
166- self .assertLess (result ["perplexity" ], 42 )
167- self .assertTrue (os .path .exists (os .path .join (tmp_dir , "epoch_0" )))
168- self .assertTrue (os .path .exists (os .path .join (tmp_dir , "mlm_no_trainer" )))
142+ _ = subprocess .run (self ._launch_args + testargs , stdout = subprocess .PIPE )
143+ result = get_results (tmp_dir )
144+ self .assertLess (result ["perplexity" ], 42 )
145+ self .assertTrue (os .path .exists (os .path .join (tmp_dir , "epoch_0" )))
146+ self .assertTrue (os .path .exists (os .path .join (tmp_dir , "mlm_no_trainer" )))
169147
170148 def test_run_ner_no_trainer (self ):
171149 # with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
172150 epochs = 7 if get_gpu_count () > 1 else 2
173151
174152 tmp_dir = self .get_auto_remove_tmp_dir ()
175153 testargs = f"""
176- run_ner_no_trainer.py
154+ { self . examples_dir } /pytorch/token-classification/ run_ner_no_trainer.py
177155 --model_name_or_path bert-base-uncased
178156 --train_file tests/fixtures/tests_samples/conll/sample.json
179157 --validation_file tests/fixtures/tests_samples/conll/sample.json
@@ -187,18 +165,17 @@ def test_run_ner_no_trainer(self):
187165 --with_tracking
188166 """ .split ()
189167
190- with patch .object (sys , "argv" , testargs ):
191- run_ner_no_trainer .main ()
192- result = get_results (tmp_dir )
193- self .assertGreaterEqual (result ["eval_accuracy" ], 0.75 )
194- self .assertLess (result ["train_loss" ], 0.5 )
195- self .assertTrue (os .path .exists (os .path .join (tmp_dir , "epoch_0" )))
196- self .assertTrue (os .path .exists (os .path .join (tmp_dir , "ner_no_trainer" )))
168+ _ = subprocess .run (self ._launch_args + testargs , stdout = subprocess .PIPE )
169+ result = get_results (tmp_dir )
170+ self .assertGreaterEqual (result ["eval_accuracy" ], 0.75 )
171+ self .assertLess (result ["train_loss" ], 0.5 )
172+ self .assertTrue (os .path .exists (os .path .join (tmp_dir , "epoch_0" )))
173+ self .assertTrue (os .path .exists (os .path .join (tmp_dir , "ner_no_trainer" )))
197174
198175 def test_run_squad_no_trainer (self ):
199176 tmp_dir = self .get_auto_remove_tmp_dir ()
200177 testargs = f"""
201- run_qa_no_trainer.py
178+ { self . examples_dir } /pytorch/question-answering/ run_qa_no_trainer.py
202179 --model_name_or_path bert-base-uncased
203180 --version_2_with_negative
204181 --train_file tests/fixtures/tests_samples/SQUAD/sample.json
@@ -213,19 +190,18 @@ def test_run_squad_no_trainer(self):
213190 --with_tracking
214191 """ .split ()
215192
216- with patch .object (sys , "argv" , testargs ):
217- run_squad_no_trainer .main ()
218- result = get_results (tmp_dir )
219- # Because we use --version_2_with_negative the testing script uses SQuAD v2 metrics.
220- self .assertGreaterEqual (result ["eval_f1" ], 30 )
221- self .assertGreaterEqual (result ["eval_exact" ], 30 )
222- self .assertTrue (os .path .exists (os .path .join (tmp_dir , "epoch_0" )))
223- self .assertTrue (os .path .exists (os .path .join (tmp_dir , "qa_no_trainer" )))
193+ _ = subprocess .run (self ._launch_args + testargs , stdout = subprocess .PIPE )
194+ result = get_results (tmp_dir )
195+ # Because we use --version_2_with_negative the testing script uses SQuAD v2 metrics.
196+ self .assertGreaterEqual (result ["eval_f1" ], 30 )
197+ self .assertGreaterEqual (result ["eval_exact" ], 30 )
198+ self .assertTrue (os .path .exists (os .path .join (tmp_dir , "epoch_0" )))
199+ self .assertTrue (os .path .exists (os .path .join (tmp_dir , "qa_no_trainer" )))
224200
225201 def test_run_swag_no_trainer (self ):
226202 tmp_dir = self .get_auto_remove_tmp_dir ()
227203 testargs = f"""
228- run_swag_no_trainer.py
204+ { self . examples_dir } /pytorch/multiple-choice/ run_swag_no_trainer.py
229205 --model_name_or_path bert-base-uncased
230206 --train_file tests/fixtures/tests_samples/swag/sample.json
231207 --validation_file tests/fixtures/tests_samples/swag/sample.json
@@ -238,17 +214,16 @@ def test_run_swag_no_trainer(self):
238214 --with_tracking
239215 """ .split ()
240216
241- with patch .object (sys , "argv" , testargs ):
242- run_swag_no_trainer .main ()
243- result = get_results (tmp_dir )
244- self .assertGreaterEqual (result ["eval_accuracy" ], 0.8 )
245- self .assertTrue (os .path .exists (os .path .join (tmp_dir , "swag_no_trainer" )))
217+ _ = subprocess .run (self ._launch_args + testargs , stdout = subprocess .PIPE )
218+ result = get_results (tmp_dir )
219+ self .assertGreaterEqual (result ["eval_accuracy" ], 0.8 )
220+ self .assertTrue (os .path .exists (os .path .join (tmp_dir , "swag_no_trainer" )))
246221
247222 @slow
248223 def test_run_summarization_no_trainer (self ):
249224 tmp_dir = self .get_auto_remove_tmp_dir ()
250225 testargs = f"""
251- run_summarization_no_trainer.py
226+ { self . examples_dir } /pytorch/summarization/ run_summarization_no_trainer.py
252227 --model_name_or_path t5-small
253228 --train_file tests/fixtures/tests_samples/xsum/sample.json
254229 --validation_file tests/fixtures/tests_samples/xsum/sample.json
@@ -262,21 +237,20 @@ def test_run_summarization_no_trainer(self):
262237 --with_tracking
263238 """ .split ()
264239
265- with patch .object (sys , "argv" , testargs ):
266- run_summarization_no_trainer .main ()
267- result = get_results (tmp_dir )
268- self .assertGreaterEqual (result ["eval_rouge1" ], 10 )
269- self .assertGreaterEqual (result ["eval_rouge2" ], 2 )
270- self .assertGreaterEqual (result ["eval_rougeL" ], 7 )
271- self .assertGreaterEqual (result ["eval_rougeLsum" ], 7 )
272- self .assertTrue (os .path .exists (os .path .join (tmp_dir , "epoch_0" )))
273- self .assertTrue (os .path .exists (os .path .join (tmp_dir , "summarization_no_trainer" )))
240+ _ = subprocess .run (self ._launch_args + testargs , stdout = subprocess .PIPE )
241+ result = get_results (tmp_dir )
242+ self .assertGreaterEqual (result ["eval_rouge1" ], 10 )
243+ self .assertGreaterEqual (result ["eval_rouge2" ], 2 )
244+ self .assertGreaterEqual (result ["eval_rougeL" ], 7 )
245+ self .assertGreaterEqual (result ["eval_rougeLsum" ], 7 )
246+ self .assertTrue (os .path .exists (os .path .join (tmp_dir , "epoch_0" )))
247+ self .assertTrue (os .path .exists (os .path .join (tmp_dir , "summarization_no_trainer" )))
274248
275249 @slow
276250 def test_run_translation_no_trainer (self ):
277251 tmp_dir = self .get_auto_remove_tmp_dir ()
278252 testargs = f"""
279- run_translation_no_trainer.py
253+ { self . examples_dir } /pytorch/translation/ run_translation_no_trainer.py
280254 --model_name_or_path sshleifer/student_marian_en_ro_6_1
281255 --source_lang en
282256 --target_lang ro
@@ -294,12 +268,11 @@ def test_run_translation_no_trainer(self):
294268 --with_tracking
295269 """ .split ()
296270
297- with patch .object (sys , "argv" , testargs ):
298- run_translation_no_trainer .main ()
299- result = get_results (tmp_dir )
300- self .assertGreaterEqual (result ["eval_bleu" ], 30 )
301- self .assertTrue (os .path .exists (os .path .join (tmp_dir , "epoch_0" )))
302- self .assertTrue (os .path .exists (os .path .join (tmp_dir , "translation_no_trainer" )))
271+ _ = subprocess .run (self ._launch_args + testargs , stdout = subprocess .PIPE )
272+ result = get_results (tmp_dir )
273+ self .assertGreaterEqual (result ["eval_bleu" ], 30 )
274+ self .assertTrue (os .path .exists (os .path .join (tmp_dir , "epoch_0" )))
275+ self .assertTrue (os .path .exists (os .path .join (tmp_dir , "translation_no_trainer" )))
303276
304277 @slow
305278 def test_run_semantic_segmentation_no_trainer (self ):
@@ -308,7 +281,7 @@ def test_run_semantic_segmentation_no_trainer(self):
308281
309282 tmp_dir = self .get_auto_remove_tmp_dir ()
310283 testargs = f"""
311- run_semantic_segmentation_no_trainer.py
284+ { self . examples_dir } /pytorch/semantic-segmentation/ run_semantic_segmentation_no_trainer.py
312285 --dataset_name huggingface/semantic-segmentation-test-sample
313286 --output_dir { tmp_dir }
314287 --max_train_steps=10
@@ -319,15 +292,14 @@ def test_run_semantic_segmentation_no_trainer(self):
319292 --checkpointing_steps epoch
320293 """ .split ()
321294
322- with patch .object (sys , "argv" , testargs ):
323- run_semantic_segmentation_no_trainer .main ()
324- result = get_results (tmp_dir )
325- self .assertGreaterEqual (result ["eval_overall_accuracy" ], 0.10 )
295+ _ = subprocess .run (self ._launch_args + testargs , stdout = subprocess .PIPE )
296+ result = get_results (tmp_dir )
297+ self .assertGreaterEqual (result ["eval_overall_accuracy" ], 0.10 )
326298
327299 def test_run_image_classification_no_trainer (self ):
328300 tmp_dir = self .get_auto_remove_tmp_dir ()
329301 testargs = f"""
330- run_image_classification_no_trainer.py
302+ { self . examples_dir } /pytorch/image-classification/ run_image_classification_no_trainer.py
331303 --dataset_name huggingface/image-classification-test-sample
332304 --output_dir { tmp_dir }
333305 --num_warmup_steps=8
@@ -339,9 +311,8 @@ def test_run_image_classification_no_trainer(self):
339311 --seed 42
340312 """ .split ()
341313
342- with patch .object (sys , "argv" , testargs ):
343- run_image_classification_no_trainer .main ()
344- result = get_results (tmp_dir )
345- self .assertGreaterEqual (result ["eval_accuracy" ], 0.50 )
346- self .assertTrue (os .path .exists (os .path .join (tmp_dir , "epoch_0" )))
347- self .assertTrue (os .path .exists (os .path .join (tmp_dir , "image_classification_no_trainer" )))
314+ _ = subprocess .run (self ._launch_args + testargs , stdout = subprocess .PIPE )
315+ result = get_results (tmp_dir )
316+ self .assertGreaterEqual (result ["eval_accuracy" ], 0.50 )
317+ self .assertTrue (os .path .exists (os .path .join (tmp_dir , "epoch_0" )))
318+ self .assertTrue (os .path .exists (os .path .join (tmp_dir , "image_classification_no_trainer" )))
0 commit comments