-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_llm_component.py
More file actions
141 lines (112 loc) · 6.5 KB
/
run_llm_component.py
File metadata and controls
141 lines (112 loc) · 6.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import pandas as pd
import json
import os
import argparse
from utils import * # custom util functions
# Parse command-line arguments
parser = argparse.ArgumentParser(description="Run fail-to-pass test generation and evaluation for Pynguin.")
parser.add_argument("--label", help=f"label for this run")
parser.add_argument("--model", help=f"LLM model to use. One of 'gpt-4o' (default), 'llama-3.3-70b-versatile', 'deepseek-r1-distill-llama-70b'", default='gpt-4o')
parser.add_argument("--T", help=f"LLM temperature (default: 0.0). Not applicable to DeepSeek.", default=0.0, type=float)
parser.add_argument("--debug", default=0, help="if true, run only a subset of instances (to be used for local development only)")
cli_args = parser.parse_args()
label = cli_args.label
model = cli_args.model
T = cli_args.T
debgug = bool(cli_args.debug)
assert model in ['gpt-4o', 'llama-3.3-70b-versatile', 'deepseek-r1-distill-llama-70b'], "model must be one of 'gpt-4o', 'llama-3.3-70b-versatile', 'deepseek-r1-distill-llama-70b'"
repo_base = 'repos/'
sbst_results_folder = os.path.join('logs_pynguin', label)
dataset_fname_raw = "tddbench_verified.pickle"
dataset_fname_processed = "tddbench_verified_processed.pickle"
if os.path.isfile(dataset_fname_processed):
# Read the processed and cached version of the dataset
d = pd.read_pickle(dataset_fname_processed)
else:
print("Processing dataset for the first time. The results will be cached after this.")
d = pd.read_pickle(dataset_fname_raw)
# Slice code
d['golden_code_contents_sliced_long'] = d.apply(lambda row: slice_golden_file(
row['golden_code_contents'],
row['patch'],
row['problem_statement'],
return_file="pre",
append_line_numbers=True
), axis=1)
# Retrieve most coupled test file
d[['predicted_test_file_name', 'predicted_test_file_content', 'predicted_test_file_content_sliced']] = d.apply(lambda row: pd.Series(get_contents_of_test_file_to_inject(row, repo_base)), axis=1)
# Cache the processed file now
d.to_pickle('tddbench_verified_processed.pickle')
if debgug:
d = d.head(3) # for local testing only
dataset_size = len(d)
# Load the SBST-generated tests
d['sbst_test_file_content'] = d.apply(lambda row: pd.Series(get_sbst_tests(row, sbst_results_folder)), axis=1)
# Construct LLM prompt with the following ablation inputs. This is the default setup C6 in our paper (Table II)
include_issue_description = True
include_golden_code = True # whether to include the code files that changed in the patch
sliced = True # whether to slice the above code files
include_issue_comments = False
include_pr_desc = False
include_predicted_test_file = False # whether to include parts of the test file where the generated test will be injected into (as predicted by our coedit distance)
include_sbst_test_file = True # whether to include tests generated by Pynguin for the changed module
d['prompt'] = d.apply(lambda row: build_prompt(row,
include_issue_description = include_issue_description,
include_golden_code = include_golden_code,
sliced = sliced,
include_issue_comments = include_issue_comments,
include_pr_desc = include_pr_desc,
include_predicted_test_file = include_predicted_test_file,
include_sbst_test_file = include_sbst_test_file,
), axis=1)
# output file required by the SWT-Bench evaluation harness
json_file_path_final_preds = f"preds_{label}.jsonl"
# string to keep track of the configurations
config_str = f"{model}_T{T}_size{dataset_size}_issueDesc{include_issue_description}_goldenCode{include_golden_code}{sliced}_issueComments{include_issue_comments}_prDesc{include_pr_desc}_testFile{include_predicted_test_file}_sbst{include_sbst_test_file}"
# Get all the trivial instances
with open('trivial_ids.txt', 'r') as f:
trivial_ids = f.readlines()
trivial_ids = [iid.strip() for iid in trivial_ids]
# loop over the dataset and feed the prompt to the LLM for each instance
for i in range(len(d)):
instance_id = d['instance_id'].iloc[i]
if instance_id in trivial_ids: # skip trivial instances
continue
_, repo_folder, _ = parse_instanceID_string(instance_id)
repo_dir = repo_base + repo_folder
base_commit = d['base_commit'].iloc[i]
prompt = d['prompt'].iloc[i]
test_filename = d['predicted_test_file_name'].iloc[i]
test_file_content = d['predicted_test_file_content'].iloc[i]
if test_filename == "":
print("No suitable file found for %s, skipping" % instance_id)
# Query model while guarding against context overflow or invalid API keys
try:
response = query_model(prompt, model=model, T=T)
except Exception as e:
print("The following exception occured for i=%d: %s Skipping" % (i, e))
continue
# Remove the surrounding styling while guarding against invalid format
try:
new_test = response.split('```python')[1].replace('```', '')
except:
print("Invalid output format for %d" % instance_id)
# Append the generated test in the designated existing test file
try:
new_test_file_content = append_function(test_file_content, new_test, insert_in_class="NOCLASS")
except Exception as e:
print("Skipping %s due to the following error in test injection: %s" % (instance_id, e))
continue
test_filename = test_filename.replace(repo_dir+'/', '') # make path relative
# The SWT-Bench evaluation harness needs to provide the test as a diff to the test file (which is a bit ugly but understandable since it comes from SWE-Bench)
model_patch = unified_diff(test_file_content, new_test_file_content, fromfile=test_filename, tofile=test_filename)
print("Generated response for i=%d (%s)" % (i, instance_id))
data = {
"model_name_or_path": model,
"instance_id": instance_id,
"model_patch": model_patch,
"config_string": config_str,
}
jsonl_line = json.dumps(data)
with open(json_file_path_final_preds, "a") as f:
f.write(jsonl_line + "\n") # Add newline character after each JSON object