Skip to content

Commit 7630560

Browse files
authored
Merge pull request #1126 from gpt-engineer-org/bench_config
Bench config
2 parents 31da734 + 3542d17 commit 7630560

File tree

12 files changed

+246
-51
lines changed

12 files changed

+246
-51
lines changed

gpt_engineer/benchmark/__init__.py

Whitespace-only changes.

gpt_engineer/benchmark/__main__.py

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
The standard boilerplate for invoking the main function when the script is executed.
2121
"""
2222
import importlib
23+
import os.path
2324

2425
from typing import Annotated, Optional
2526

@@ -29,6 +30,7 @@
2930
from langchain.globals import set_llm_cache
3031

3132
from gpt_engineer.applications.cli.main import load_env_if_needed
33+
from gpt_engineer.benchmark.bench_config import BenchConfig
3234
from gpt_engineer.benchmark.benchmarks.load import get_benchmark
3335
from gpt_engineer.benchmark.run import print_results, run
3436

@@ -69,12 +71,9 @@ def main(
6971
help="python file that contains a function called 'default_config_agent'"
7072
),
7173
],
72-
benchmarks: Annotated[
73-
str, typer.Argument(help="benchmark name(s) separated by ','")
74-
],
75-
task_name: Annotated[
74+
bench_config: Annotated[
7675
Optional[str], typer.Argument(help="optional task name in benchmark")
77-
] = None,
76+
] = os.path.join(os.path.dirname(__file__), "default_bench_config.toml"),
7877
verbose: Annotated[
7978
bool, typer.Option(help="print results for each task", show_default=False)
8079
] = False,
@@ -88,8 +87,8 @@ def main(
8887
The file path to the Python module that contains a function called 'default_config_agent'.
8988
benchmarks : str
9089
A comma-separated string of benchmark names to run.
91-
task_name : Optional[str], default=None
92-
An optional task name to run within the benchmark.
90+
bench_config : Optional[str], default=default_bench_config.toml
91+
Configuration file for choosing which benchmark problems to run. See default config for more details.
9392
verbose : bool, default=False
9493
A flag to indicate whether to print results for each task.
9594
@@ -99,13 +98,27 @@ def main(
9998
"""
10099
set_llm_cache(SQLiteCache(database_path=".langchain.db"))
101100
load_env_if_needed()
101+
config = BenchConfig.from_toml(bench_config)
102+
print("using config file: " + bench_config)
103+
benchmarks = list()
104+
for specific_config_name in vars(config):
105+
specific_config = getattr(config, specific_config_name)
106+
if hasattr(specific_config, "active"):
107+
if specific_config.active:
108+
benchmarks.append(specific_config_name)
102109

103-
benchmarks = benchmarks.split(",")
104110
for benchmark_name in benchmarks:
105-
benchmark = get_benchmark(benchmark_name)
111+
benchmark = get_benchmark(benchmark_name, config)
112+
if len(benchmark.tasks) == 0:
113+
print(
114+
benchmark_name
115+
+ " was skipped, since no tasks are specified. Increase the number of tasks in the config file at: "
116+
+ bench_config
117+
)
118+
continue
106119
agent = get_agent(path_to_agent)
107120

108-
results = run(agent, benchmark, task_name, verbose=verbose)
121+
results = run(agent, benchmark, verbose=verbose)
109122
print(
110123
f"\n--- Results for agent {path_to_agent}, benchmark: {benchmark_name} ---"
111124
)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from dataclasses import dataclass, field
2+
from pathlib import Path
3+
4+
from gpt_engineer.core.project_config import read_config
5+
6+
7+
@dataclass
8+
class AppsConfig:
9+
active: bool | None = True
10+
test_start_index: int | None = 0
11+
test_end_index: int | None = 1
12+
train_start_index: int | None = 0
13+
train_end_index: int | None = 0
14+
15+
16+
@dataclass
17+
class MbppConfig:
18+
active: bool | None = True
19+
test_len: int | None = 1
20+
train_len: int | None = 0
21+
22+
23+
@dataclass
24+
class GptmeConfig:
25+
active: bool | None = True
26+
27+
28+
@dataclass
29+
class GptengConfig:
30+
active: bool | None = True
31+
32+
33+
@dataclass
34+
class BenchConfig:
35+
"""Configuration for the GPT Engineer CLI and gptengineer.app via `gpt-engineer.toml`."""
36+
37+
apps: AppsConfig = field(default_factory=AppsConfig)
38+
mbpp: MbppConfig = field(default_factory=MbppConfig)
39+
gptme: GptmeConfig = field(default_factory=GptmeConfig)
40+
gpteng: GptengConfig = field(default_factory=GptengConfig)
41+
42+
@classmethod
43+
def from_toml(cls, config_file: Path | str):
44+
if isinstance(config_file, str):
45+
config_file = Path(config_file)
46+
config_dict = read_config(config_file)
47+
return cls.from_dict(config_dict)
48+
49+
@classmethod
50+
def from_dict(cls, config_dict: dict):
51+
return cls(
52+
apps=AppsConfig(**config_dict.get("apps", {})),
53+
mbpp=MbppConfig(**config_dict.get("mbpp", {})),
54+
gptme=GptmeConfig(**config_dict.get("gptme", {})),
55+
gpteng=GptengConfig(**config_dict.get("gpteng", {})),
56+
)

gpt_engineer/benchmark/benchmarks/apps/load.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
1818

19+
from gpt_engineer.benchmark.bench_config import AppsConfig
1920
from gpt_engineer.benchmark.benchmarks.apps.problem import Problem
20-
from gpt_engineer.benchmark.benchmarks.apps.problems import PROBLEM_IDS
2121
from gpt_engineer.benchmark.types import Assertable, Benchmark, Task
2222
from gpt_engineer.core.default.disk_execution_env import DiskExecutionEnv
2323
from gpt_engineer.core.files_dict import FilesDict
@@ -57,12 +57,12 @@ def _get_dataset() -> Union[Dataset, DatasetDict]:
5757
print("Dataset not found locally, downloading...")
5858

5959
dataset = load_dataset("codeparrot/apps", trust_remote_code=True)
60-
dataset.save_to_disk(DATASET_PATH)
60+
dataset.save_to_disk(str(DATASET_PATH))
6161

6262
return dataset
6363

6464

65-
def load_apps():
65+
def load_apps(config: AppsConfig) -> Benchmark:
6666
"""
6767
Loads the APPS benchmark, which consists of a series coding problems.
6868
@@ -73,17 +73,19 @@ def load_apps():
7373
"""
7474
dataset = _get_dataset()
7575
tasks = []
76-
77-
problems = [
78-
Problem(
79-
id=problem["problem_id"],
80-
question=problem["question"],
81-
input_output=problem["input_output"],
82-
starter_code=problem["starter_code"],
83-
)
84-
for problem in dataset["test"]
85-
if problem["problem_id"] in PROBLEM_IDS
86-
]
76+
problems = list()
77+
for dataset_type in ["test", "train"]:
78+
problems += [
79+
Problem(
80+
id=problem["problem_id"],
81+
question=problem["question"],
82+
input_output=problem["input_output"],
83+
starter_code=problem["starter_code"],
84+
)
85+
for index, problem in enumerate(dataset[dataset_type])
86+
if (index < config.__getattribute__(dataset_type + "_end_index"))
87+
and (index >= config.__getattribute__(dataset_type + "_start_index"))
88+
]
8789

8890
for problem in problems:
8991
prompt = Prompt(
@@ -110,6 +112,6 @@ def load_apps():
110112
)
111113

112114
return Benchmark(
113-
name="APPS",
115+
name="apps",
114116
tasks=tasks,
115117
)

gpt_engineer/benchmark/benchmarks/gpteng/load.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@
1919

2020
from pathlib import Path
2121

22+
from gpt_engineer.benchmark.bench_config import GptengConfig
2223
from gpt_engineer.benchmark.benchmarks.gpteng.eval_tools import (
2324
check_evaluation_component,
2425
)
2526
from gpt_engineer.benchmark.types import Assertable, Benchmark, Task
2627
from gpt_engineer.core.chat_to_files import chat_to_files_dict
28+
from gpt_engineer.core.prompt import Prompt
2729

2830
evaluations = [
2931
{
@@ -192,7 +194,7 @@ def eval_to_task(case):
192194
return Task(
193195
name=case["name"],
194196
initial_code=chat_to_files_dict(Path(case["code_blob"]).read_text()),
195-
prompt=prompt,
197+
prompt=Prompt(prompt),
196198
command=None,
197199
assertions={
198200
f"{e['type']}_{i}": expect_to_assertion(e)
@@ -201,7 +203,7 @@ def eval_to_task(case):
201203
)
202204

203205

204-
def load_gpteng():
206+
def load_gpteng(config: GptengConfig) -> Benchmark:
205207
"""
206208
Loads the GPT-Eng benchmark, which consists of a series of tasks for evaluation.
207209

gpt_engineer/benchmark/benchmarks/gptme/load.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@
1010
load_gptme : function
1111
Loads the GPT-Me benchmark, which consists of a series of tasks for evaluation.
1212
"""
13+
from gpt_engineer.benchmark.bench_config import GptmeConfig
1314
from gpt_engineer.benchmark.types import Benchmark, Task
1415
from gpt_engineer.core.files_dict import FilesDict
1516
from gpt_engineer.core.prompt import Prompt
1617

1718

18-
def load_gptme():
19+
def load_gptme(config: GptmeConfig) -> Benchmark:
1920
"""
2021
Loads the GPT-Me benchmark, which consists of a series of tasks for evaluation.
2122

gpt_engineer/benchmark/benchmarks/load.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
get_benchmark : function
1010
Retrieves a Benchmark object by name. Raises ValueError if the benchmark is unknown.
1111
"""
12+
from gpt_engineer.benchmark.bench_config import BenchConfig
1213
from gpt_engineer.benchmark.benchmarks.apps.load import load_apps
1314
from gpt_engineer.benchmark.benchmarks.gpteng.load import load_gpteng
1415
from gpt_engineer.benchmark.benchmarks.gptme.load import load_gptme
@@ -23,14 +24,16 @@
2324
}
2425

2526

26-
def get_benchmark(name: str) -> Benchmark:
27+
def get_benchmark(name: str, config: BenchConfig) -> Benchmark:
2728
"""
2829
Retrieves a Benchmark object by name. Raises ValueError if the benchmark is unknown.
2930
3031
Parameters
3132
----------
3233
name : str
3334
The name of the benchmark to retrieve.
35+
config : BenchConfig
36+
Configuration object for the benchmarks.
3437
3538
Returns
3639
-------
@@ -44,4 +47,4 @@ def get_benchmark(name: str) -> Benchmark:
4447
"""
4548
if name not in BENCHMARKS:
4649
raise ValueError(f"Unknown benchmark {name}.")
47-
return BENCHMARKS[name]()
50+
return BENCHMARKS[name](config.__getattribute__(name))

gpt_engineer/benchmark/benchmarks/mbpp/load.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
1818

19+
from gpt_engineer.benchmark.bench_config import MbppConfig
1920
from gpt_engineer.benchmark.benchmarks.mbpp.problem import Problem
20-
from gpt_engineer.benchmark.benchmarks.mbpp.problems import PROBLEM_IDS
2121
from gpt_engineer.benchmark.types import Assertable, Benchmark, Task
2222
from gpt_engineer.core.default.disk_execution_env import DiskExecutionEnv
2323
from gpt_engineer.core.files_dict import FilesDict
@@ -57,12 +57,12 @@ def _get_dataset() -> Union[Dataset, DatasetDict]:
5757
print("Dataset not found locally, downloading...")
5858

5959
dataset = load_dataset("mbpp", "sanitized", trust_remote_code=True)
60-
dataset.save_to_disk(DATASET_PATH)
60+
dataset.save_to_disk(str(DATASET_PATH))
6161

6262
return dataset
6363

6464

65-
def load_mbpp():
65+
def load_mbpp(config: MbppConfig) -> Benchmark:
6666
"""
6767
Loads the MBPP benchmark, which consists of a series coding problems.
6868
@@ -73,19 +73,20 @@ def load_mbpp():
7373
"""
7474
dataset = _get_dataset()
7575
tasks = []
76-
77-
problems = [
78-
Problem(
79-
source_file=problem["source_file"],
80-
task_id=problem["task_id"],
81-
prompt=problem["prompt"],
82-
code=problem["code"],
83-
test_imports=problem["test_imports"],
84-
test_list=problem["test_list"],
85-
)
86-
for problem in dataset["test"]
87-
if problem["task_id"] in PROBLEM_IDS
88-
]
76+
problems = []
77+
for dataset_type in ["test", "train"]:
78+
problems += [
79+
Problem(
80+
source_file=problem["source_file"],
81+
task_id=problem["task_id"],
82+
prompt=problem["prompt"],
83+
code=problem["code"],
84+
test_imports=problem["test_imports"],
85+
test_list=problem["test_list"],
86+
)
87+
for index, problem in enumerate(dataset[dataset_type])
88+
if index < config.__getattribute__(dataset_type + "_len")
89+
]
8990

9091
for problem in problems:
9192
prompt = Prompt(
@@ -109,6 +110,6 @@ def load_mbpp():
109110
)
110111

111112
return Benchmark(
112-
name="MBPP",
113+
name="mbpp",
113114
tasks=tasks,
114115
)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# For apps, the maximal range is 0:5000 for both train and test
2+
[apps]
3+
active = true
4+
test_start_index = 0
5+
test_end_index = 2
6+
train_start_index = 0
7+
train_end_index = 2
8+
9+
# For mbpp, the maximal range is 0:47
10+
[mbpp]
11+
active = true
12+
test_len = 2
13+
train_len = 2
14+
15+
[gpteng]
16+
active = true
17+
18+
[gptme]
19+
active = true

gpt_engineer/benchmark/run.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
"""
1515
import time
1616

17-
from typing import List, Optional
17+
from typing import List
1818

1919
from gpt_engineer.benchmark.types import Assertable, Benchmark, TaskResult
2020
from gpt_engineer.core.base_agent import BaseAgent
@@ -24,7 +24,6 @@
2424
def run(
2525
agent: BaseAgent,
2626
benchmark: Benchmark,
27-
task_name: Optional[str] = None,
2827
verbose=False,
2928
) -> List[TaskResult]:
3029
"""
@@ -36,8 +35,6 @@ def run(
3635
The agent to use for running the benchmark tasks.
3736
benchmark : Benchmark
3837
The benchmark containing the tasks to run.
39-
task_name : Optional[str], default=None
40-
An optional name of a specific task to run within the benchmark.
4138
verbose : bool, default=False
4239
A flag to indicate whether to print verbose output during the benchmark.
4340

0 commit comments

Comments
 (0)