Skip to content

Commit 4c72d98

Browse files
authored
Merge pull request #317 from JohnSnowLabs/fix/hub-check
fix invalid hub
2 parents de77d4d + 5c58759 commit 4c72d98

2 files changed

Lines changed: 12 additions & 8 deletions

File tree

nlptest/nlptest.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import pickle
44
from collections import defaultdict
5-
from typing import Optional, Union
5+
from typing import Optional, Union, Any
66

77
import pandas as pd
88
import yaml
@@ -35,7 +35,7 @@ class Harness:
3535

3636
def __init__(
3737
self,
38-
model: Union[str],
38+
model: Union[str, Any],
3939
task: Optional[str] = "ner",
4040
hub: Optional[str] = None,
4141
data: Optional[str] = None,
@@ -58,7 +58,14 @@ def __init__(
5858
super().__init__()
5959
self.task = task
6060

61-
if data is None and (task, model, hub) in self.DEFAULTS_DATASET.keys():
61+
if isinstance(model, str) and hub is None:
62+
raise ValueError(f"When passing a string argument to the 'model' parameter, you must provide an argument "
63+
f"for the 'hub' parameter as well.")
64+
65+
if hub is not None and hub not in self.SUPPORTED_HUBS:
66+
raise ValueError(f"Provided hub is not supported. Please choose one of the supported hubs: {self.SUPPORTED_HUBS}")
67+
68+
if data is None and (task, model, hub) in self.DEFAULTS_DATASET:
6269
data_path = os.path.join("data", self.DEFAULTS_DATASET[(task, model, hub)])
6370
data = resource_filename("nlptest", data_path)
6471
self.data = DataFactory(data, task=self.task).load()
@@ -77,17 +84,14 @@ def __init__(
7784
self.data = DataFactory(data, task=self.task).load() if data is not None else None
7885

7986
if isinstance(model, str):
80-
if hub is None:
81-
raise OSError(f"You need to pass the 'hub' parameter when passing a string as 'model'.")
82-
8387
self.model = ModelFactory.load_model(path=model, task=task, hub=hub)
8488
else:
8589
self.model = ModelFactory(task=task, model=model)
8690

8791
if config is not None:
8892
self._config = self.configure(config)
8993
else:
90-
logging.info(f"No configuration file was provided, loading default config.")
94+
logging.info("No configuration file was provided, loading default config.")
9195
self._config = self.configure(resource_filename("nlptest", "data/config.yml"))
9296

9397
self._testcases = None

tests/test_harness.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_Harness(self):
2828

2929
def test_missing_parameter(self):
3030
""""""
31-
with self.assertRaises(OSError) as _:
31+
with self.assertRaises(ValueError) as _:
3232
Harness(task='ner', model='dslim/bert-base-NER',
3333
data=self.data_path, config=self.config_path)
3434

0 commit comments

Comments
 (0)