22import os
33import pickle
44from collections import defaultdict
5- from typing import Optional , Union
5+ from typing import Optional , Union , Any
66
77import pandas as pd
88import yaml
@@ -35,7 +35,7 @@ class Harness:
3535
3636 def __init__ (
3737 self ,
38- model : Union [str ],
38+ model : Union [str , Any ],
3939 task : Optional [str ] = "ner" ,
4040 hub : Optional [str ] = None ,
4141 data : Optional [str ] = None ,
@@ -58,7 +58,14 @@ def __init__(
5858 super ().__init__ ()
5959 self .task = task
6060
61- if data is None and (task , model , hub ) in self .DEFAULTS_DATASET .keys ():
61+ if isinstance (model , str ) and hub is None :
62+ raise ValueError (f"When passing a string argument to the 'model' parameter, you must provide an argument "
63+ f"for the 'hub' parameter as well." )
64+
65+ if hub is not None and hub not in self .SUPPORTED_HUBS :
66+ raise ValueError (f"Provided hub is not supported. Please choose one of the supported hubs: { self .SUPPORTED_HUBS } " )
67+
68+ if data is None and (task , model , hub ) in self .DEFAULTS_DATASET :
6269 data_path = os .path .join ("data" , self .DEFAULTS_DATASET [(task , model , hub )])
6370 data = resource_filename ("nlptest" , data_path )
6471 self .data = DataFactory (data , task = self .task ).load ()
@@ -77,17 +84,14 @@ def __init__(
7784 self .data = DataFactory (data , task = self .task ).load () if data is not None else None
7885
7986 if isinstance (model , str ):
80- if hub is None :
81- raise OSError (f"You need to pass the 'hub' parameter when passing a string as 'model'." )
82-
8387 self .model = ModelFactory .load_model (path = model , task = task , hub = hub )
8488 else :
8589 self .model = ModelFactory (task = task , model = model )
8690
8791 if config is not None :
8892 self ._config = self .configure (config )
8993 else :
90- logging .info (f "No configuration file was provided, loading default config." )
94+ logging .info ("No configuration file was provided, loading default config." )
9195 self ._config = self .configure (resource_filename ("nlptest" , "data/config.yml" ))
9296
9397 self ._testcases = None
0 commit comments