diff --git a/nlptest/nlptest.py b/nlptest/nlptest.py index 2d10757b1..e1639c3f6 100644 --- a/nlptest/nlptest.py +++ b/nlptest/nlptest.py @@ -21,6 +21,7 @@ class Harness: Harness class evaluates the performance of a given NLP model. Given test data is used to test the model. A report is generated with test results. """ + SUPPORTED_TASKS = ["ner", "text-classification"] SUPPORTED_HUBS = ["spacy", "huggingface", "johnsnowlabs"] DEFAULTS_DATASET = { ("ner", "dslim/bert-base-NER", "huggingface"): "conll/sample.conll", @@ -36,7 +37,7 @@ class Harness: def __init__( self, model: Union[str, Any], - task: Optional[str] = "ner", + task: str, hub: Optional[str] = None, data: Optional[str] = None, config: Optional[Union[str, dict]] = None @@ -56,6 +57,9 @@ def __init__( """ super().__init__() + + if(task not in self.SUPPORTED_TASKS): + raise ValueError(f"Provided task is not supported. Please choose one of the supported tasks: {self.SUPPORTED_TASKS}") self.task = task if isinstance(model, str) and hub is None: @@ -310,7 +314,7 @@ def save(self, save_dir: str) -> None: pickle.dump(self.data, writer) @classmethod - def load(cls, save_dir: str, model: Union[str, 'ModelFactory'], task: Optional[str] = "ner", + def load(cls, save_dir: str, model: Union[str, 'ModelFactory'], task: str, hub: Optional[str] = None) -> 'Harness': """ Loads a previously saved `Harness` from a given configuration and dataset