Skip to content

Commit 35a015e

Browse files
authored
Add typing to cli.oaieval (openai#1036)
- Enable type checking for `cli.oaieval` - Correct the type in cli.oaieval]
1 parent cce2fc0 commit 35a015e

File tree

3 files changed

+50
-14
lines changed

3 files changed

+50
-14
lines changed

evals/cli/oaieval.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,21 @@
55
import logging
66
import shlex
77
import sys
8-
from typing import Any, Mapping, Optional
8+
from typing import Any, Mapping, Optional, Union, cast
99

1010
import openai
1111

1212
import evals
1313
import evals.api
1414
import evals.base
1515
import evals.record
16+
from evals.eval import Eval
1617
from evals.registry import Registry
1718

1819
logger = logging.getLogger(__name__)
1920

2021

21-
def _purple(str):
22+
def _purple(str: str) -> str:
2223
return f"\033[1;35m{str}\033[0m"
2324

2425

@@ -41,7 +42,11 @@ def get_parser() -> argparse.ArgumentParser:
4142
"--log_to_file", type=str, default=None, help="Log to a file instead of stdout"
4243
)
4344
parser.add_argument(
44-
"--registry_path", type=str, default=None, action="append", help="Path to the registry"
45+
"--registry_path",
46+
type=str,
47+
default=None,
48+
action="append",
49+
help="Path to the registry",
4550
)
4651
parser.add_argument("--debug", action=argparse.BooleanOptionalAction, default=False)
4752
parser.add_argument("--local-run", action=argparse.BooleanOptionalAction, default=True)
@@ -50,7 +55,25 @@ def get_parser() -> argparse.ArgumentParser:
5055
return parser
5156

5257

53-
def run(args, registry: Optional[Registry] = None):
58+
class OaiEvalArguments(argparse.Namespace):
59+
completion_fn: str
60+
eval: str
61+
extra_eval_params: str
62+
max_samples: Optional[int]
63+
cache: bool
64+
visible: Optional[bool]
65+
seed: int
66+
user: str
67+
record_path: Optional[str]
68+
log_to_file: Optional[str]
69+
registry_path: Optional[str]
70+
debug: bool
71+
local_run: bool
72+
dry_run: bool
73+
dry_run_logging: bool
74+
75+
76+
def run(args: OaiEvalArguments, registry: Optional[Registry] = None) -> str:
5477
if args.debug:
5578
logging.getLogger().setLevel(logging.DEBUG)
5679

@@ -61,7 +84,7 @@ def run(args, registry: Optional[Registry] = None):
6184

6285
registry = registry or Registry()
6386
if args.registry_path:
64-
registry.add_registry_paths(args.registry_path)
87+
registry.add_registry_paths([args.registry_path])
6588

6689
eval_spec = registry.get_eval(args.eval)
6790
assert (
@@ -83,6 +106,9 @@ def run(args, registry: Optional[Registry] = None):
83106
}
84107

85108
eval_name = eval_spec.key
109+
if eval_name is None:
110+
raise Exception("you must provide a eval name")
111+
86112
run_spec = evals.base.RunSpec(
87113
completion_fns=completion_fns,
88114
eval_name=eval_name,
@@ -95,26 +121,30 @@ def run(args, registry: Optional[Registry] = None):
95121
record_path = f"/tmp/evallogs/{run_spec.run_id}_{args.completion_fn}_{args.eval}.jsonl"
96122
else:
97123
record_path = args.record_path
124+
125+
recorder: evals.record.RecorderBase
98126
if args.dry_run:
99127
recorder = evals.record.DummyRecorder(run_spec=run_spec, log=args.dry_run_logging)
100128
elif args.local_run:
101129
recorder = evals.record.LocalRecorder(record_path, run_spec=run_spec)
102130
else:
103131
recorder = evals.record.Recorder(record_path, run_spec=run_spec)
104132

105-
api_extra_options = {}
133+
api_extra_options: dict[str, Any] = {}
106134
if not args.cache:
107135
api_extra_options["cache_level"] = 0
108136

109137
run_url = f"{run_spec.run_id}"
110138
logger.info(_purple(f"Run started: {run_url}"))
111139

112-
def parse_extra_eval_params(param_str: Optional[str]) -> Mapping[str, Any]:
140+
def parse_extra_eval_params(
141+
param_str: Optional[str],
142+
) -> Mapping[str, Union[str, int, float]]:
113143
"""Parse a string of the form "key1=value1,key2=value2" into a dict."""
114144
if not param_str:
115145
return {}
116146

117-
def to_number(x):
147+
def to_number(x: str) -> Union[int, float, str]:
118148
try:
119149
return int(x)
120150
except:
@@ -131,7 +161,7 @@ def to_number(x):
131161
extra_eval_params = parse_extra_eval_params(args.extra_eval_params)
132162

133163
eval_class = registry.get_class(eval_spec)
134-
eval = eval_class(
164+
eval: Eval = eval_class(
135165
completion_fns=completion_fn_instances,
136166
seed=args.seed,
137167
name=eval_name,
@@ -150,17 +180,19 @@ def to_number(x):
150180
return run_spec.run_id
151181

152182

153-
def main():
183+
def main() -> None:
154184
parser = get_parser()
155-
args = parser.parse_args(sys.argv[1:])
185+
args = cast(OaiEvalArguments, parser.parse_args(sys.argv[1:]))
156186
logging.basicConfig(
157187
format="[%(asctime)s] [%(filename)s:%(lineno)d] %(message)s",
158188
level=logging.INFO,
159189
filename=args.log_to_file if args.log_to_file else None,
160190
)
161191
logging.getLogger("openai").setLevel(logging.WARN)
162-
if hasattr(openai.error, "set_display_cause"):
163-
openai.error.set_display_cause()
192+
193+
# TODO)) why do we need this?
194+
if hasattr(openai.error, "set_display_cause"): # type: ignore
195+
openai.error.set_display_cause() # type: ignore
164196
run(args)
165197

166198

evals/utils/misc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def t(duration: float) -> str:
1717
return f"{duration//60}min{int(duration%60)}s"
1818

1919

20-
def make_object(object_ref: Any, *args: Any, **kwargs: Any) -> Any:
20+
def make_object(object_ref: str, *args: Any, **kwargs: Any) -> Any:
2121
modname, qualname_separator, qualname = object_ref.partition(":")
2222
obj = importlib.import_module(modname)
2323
if qualname_separator:

mypy.ini

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ disallow_untyped_defs=True
3030
ignore_errors=False
3131
disallow_untyped_defs=True
3232

33+
[mypy-evals.cli.oaieval]
34+
ignore_errors=False
35+
disallow_untyped_defs=True
36+
3337
[mypy-scripts.*]
3438
ignore_errors=False
3539
disallow_untyped_defs=True

0 commit comments

Comments
 (0)