Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions examples/bert-loses-patience/test_run_glue_with_pabee.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_setup_file():
return args.f


def clean_test_dir(path="./tests/fixtures/tests_samples/temp_dir"):
def clean_test_dir(path):
shutil.rmtree(path, ignore_errors=True)


Expand All @@ -37,7 +37,6 @@ def test_run_glue(self):
--task_name mrpc
--do_train
--do_eval
--output_dir ./tests/fixtures/tests_samples/temp_dir
--per_gpu_train_batch_size=2
--per_gpu_eval_batch_size=1
--learning_rate=2e-5
Expand All @@ -46,10 +45,13 @@ def test_run_glue(self):
--overwrite_output_dir
--seed=42
--max_seq_length=128
""".split()
"""
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
testargs += "--output_dir " + output_dir
testargs = testargs.split()
with patch.object(sys, "argv", testargs):
result = run_glue_with_pabee.main()
for value in result.values():
self.assertGreaterEqual(value, 0.75)

clean_test_dir()
clean_test_dir(output_dir)
34 changes: 21 additions & 13 deletions examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_setup_file():
return args.f


def clean_test_dir(path="./tests/fixtures/tests_samples/temp_dir"):
def clean_test_dir(path):
shutil.rmtree(path, ignore_errors=True)


Expand All @@ -68,7 +68,6 @@ def test_run_glue(self):
--task_name mrpc
--do_train
--do_eval
--output_dir ./tests/fixtures/tests_samples/temp_dir
--per_device_train_batch_size=2
--per_device_eval_batch_size=1
--learning_rate=1e-4
Expand All @@ -77,13 +76,16 @@ def test_run_glue(self):
--overwrite_output_dir
--seed=42
--max_seq_length=128
""".split()
"""
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
testargs += "--output_dir " + output_dir
testargs = testargs.split()
with patch.object(sys, "argv", testargs):
result = run_glue.main()
del result["eval_loss"]
for value in result.values():
self.assertGreaterEqual(value, 0.75)
clean_test_dir()
clean_test_dir(output_dir)

def test_run_pl_glue(self):
stream_handler = logging.StreamHandler(sys.stdout)
Expand All @@ -96,13 +98,15 @@ def test_run_pl_glue(self):
--task mrpc
--do_train
--do_predict
--output_dir ./tests/fixtures/tests_samples/temp_dir
--train_batch_size=32
--learning_rate=1e-4
--num_train_epochs=1
--seed=42
--max_seq_length=128
""".split()
"""
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
testargs += "--output_dir " + output_dir
testargs = testargs.split()

if torch.cuda.is_available():
testargs += ["--fp16", "--gpus=1"]
Expand All @@ -119,7 +123,7 @@ def test_run_pl_glue(self):
# for k, v in result.items():
# self.assertGreaterEqual(v, 0.75, f"({k})")
#
clean_test_dir()
clean_test_dir(output_dir)

def test_run_language_modeling(self):
stream_handler = logging.StreamHandler(sys.stdout)
Expand All @@ -133,17 +137,19 @@ def test_run_language_modeling(self):
--line_by_line
--train_data_file ./tests/fixtures/sample_text.txt
--eval_data_file ./tests/fixtures/sample_text.txt
--output_dir ./tests/fixtures/tests_samples/temp_dir
--overwrite_output_dir
--do_train
--do_eval
--num_train_epochs=1
--no_cuda
""".split()
"""
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
testargs += "--output_dir " + output_dir
testargs = testargs.split()
with patch.object(sys, "argv", testargs):
result = run_language_modeling.main()
self.assertLess(result["perplexity"], 35)
clean_test_dir()
clean_test_dir(output_dir)

def test_run_squad(self):
stream_handler = logging.StreamHandler(sys.stdout)
Expand All @@ -154,7 +160,6 @@ def test_run_squad(self):
--model_type=distilbert
--model_name_or_path=sshleifer/tiny-distilbert-base-cased-distilled-squad
--data_dir=./tests/fixtures/tests_samples/SQUAD
--output_dir=./tests/fixtures/tests_samples/temp_dir
--max_steps=10
--warmup_steps=2
--do_train
Expand All @@ -165,12 +170,15 @@ def test_run_squad(self):
--per_gpu_eval_batch_size=1
--overwrite_output_dir
--seed=42
""".split()
"""
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
testargs += "--output_dir " + output_dir
testargs = testargs.split()
with patch.object(sys, "argv", testargs):
result = run_squad.main()
self.assertGreaterEqual(result["f1"], 25)
self.assertGreaterEqual(result["exact"], 21)
clean_test_dir()
clean_test_dir(output_dir)

def test_generation(self):
stream_handler = logging.StreamHandler(sys.stdout)
Expand Down