55import logging
66import shlex
77import sys
8- from typing import Any , Mapping , Optional
8+ from typing import Any , Mapping , Optional , Union , cast
99
1010import openai
1111
1212import evals
1313import evals .api
1414import evals .base
1515import evals .record
16+ from evals .eval import Eval
1617from evals .registry import Registry
1718
1819logger = logging .getLogger (__name__ )
1920
2021
21- def _purple (str ) :
22+ def _purple (str : str ) -> str :
2223 return f"\033 [1;35m{ str } \033 [0m"
2324
2425
@@ -41,7 +42,11 @@ def get_parser() -> argparse.ArgumentParser:
4142 "--log_to_file" , type = str , default = None , help = "Log to a file instead of stdout"
4243 )
4344 parser .add_argument (
44- "--registry_path" , type = str , default = None , action = "append" , help = "Path to the registry"
45+ "--registry_path" ,
46+ type = str ,
47+ default = None ,
48+ action = "append" ,
49+ help = "Path to the registry" ,
4550 )
4651 parser .add_argument ("--debug" , action = argparse .BooleanOptionalAction , default = False )
4752 parser .add_argument ("--local-run" , action = argparse .BooleanOptionalAction , default = True )
@@ -50,7 +55,25 @@ def get_parser() -> argparse.ArgumentParser:
5055 return parser
5156
5257
53- def run (args , registry : Optional [Registry ] = None ):
58+ class OaiEvalArguments (argparse .Namespace ):
59+ completion_fn : str
60+ eval : str
61+ extra_eval_params : str
62+ max_samples : Optional [int ]
63+ cache : bool
64+ visible : Optional [bool ]
65+ seed : int
66+ user : str
67+ record_path : Optional [str ]
68+ log_to_file : Optional [str ]
69+ registry_path : Optional [str ]
70+ debug : bool
71+ local_run : bool
72+ dry_run : bool
73+ dry_run_logging : bool
74+
75+
76+ def run (args : OaiEvalArguments , registry : Optional [Registry ] = None ) -> str :
5477 if args .debug :
5578 logging .getLogger ().setLevel (logging .DEBUG )
5679
@@ -61,7 +84,7 @@ def run(args, registry: Optional[Registry] = None):
6184
6285 registry = registry or Registry ()
6386 if args .registry_path :
64- registry .add_registry_paths (args .registry_path )
87+ registry .add_registry_paths ([ args .registry_path ] )
6588
6689 eval_spec = registry .get_eval (args .eval )
6790 assert (
@@ -83,6 +106,9 @@ def run(args, registry: Optional[Registry] = None):
83106 }
84107
85108 eval_name = eval_spec .key
109+ if eval_name is None :
110+ raise Exception ("you must provide a eval name" )
111+
86112 run_spec = evals .base .RunSpec (
87113 completion_fns = completion_fns ,
88114 eval_name = eval_name ,
@@ -95,26 +121,30 @@ def run(args, registry: Optional[Registry] = None):
95121 record_path = f"/tmp/evallogs/{ run_spec .run_id } _{ args .completion_fn } _{ args .eval } .jsonl"
96122 else :
97123 record_path = args .record_path
124+
125+ recorder : evals .record .RecorderBase
98126 if args .dry_run :
99127 recorder = evals .record .DummyRecorder (run_spec = run_spec , log = args .dry_run_logging )
100128 elif args .local_run :
101129 recorder = evals .record .LocalRecorder (record_path , run_spec = run_spec )
102130 else :
103131 recorder = evals .record .Recorder (record_path , run_spec = run_spec )
104132
105- api_extra_options = {}
133+ api_extra_options : dict [ str , Any ] = {}
106134 if not args .cache :
107135 api_extra_options ["cache_level" ] = 0
108136
109137 run_url = f"{ run_spec .run_id } "
110138 logger .info (_purple (f"Run started: { run_url } " ))
111139
112- def parse_extra_eval_params (param_str : Optional [str ]) -> Mapping [str , Any ]:
140+ def parse_extra_eval_params (
141+ param_str : Optional [str ],
142+ ) -> Mapping [str , Union [str , int , float ]]:
113143 """Parse a string of the form "key1=value1,key2=value2" into a dict."""
114144 if not param_str :
115145 return {}
116146
117- def to_number (x ) :
147+ def to_number (x : str ) -> Union [ int , float , str ] :
118148 try :
119149 return int (x )
120150 except :
@@ -131,7 +161,7 @@ def to_number(x):
131161 extra_eval_params = parse_extra_eval_params (args .extra_eval_params )
132162
133163 eval_class = registry .get_class (eval_spec )
134- eval = eval_class (
164+ eval : Eval = eval_class (
135165 completion_fns = completion_fn_instances ,
136166 seed = args .seed ,
137167 name = eval_name ,
@@ -150,17 +180,19 @@ def to_number(x):
150180 return run_spec .run_id
151181
152182
153- def main ():
183+ def main () -> None :
154184 parser = get_parser ()
155- args = parser .parse_args (sys .argv [1 :])
185+ args = cast ( OaiEvalArguments , parser .parse_args (sys .argv [1 :]) )
156186 logging .basicConfig (
157187 format = "[%(asctime)s] [%(filename)s:%(lineno)d] %(message)s" ,
158188 level = logging .INFO ,
159189 filename = args .log_to_file if args .log_to_file else None ,
160190 )
161191 logging .getLogger ("openai" ).setLevel (logging .WARN )
162- if hasattr (openai .error , "set_display_cause" ):
163- openai .error .set_display_cause ()
192+
193+ # TODO)) why do we need this?
194+ if hasattr (openai .error , "set_display_cause" ): # type: ignore
195+ openai .error .set_display_cause () # type: ignore
164196 run (args )
165197
166198
0 commit comments