Skip to content

Commit 0cf8f00

Browse files
committed
Add theory of mind eval
1 parent a06a07b commit 0cf8f00

File tree

10 files changed

+1077
-0
lines changed

10 files changed

+1077
-0
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Eval description
2+
This evaluation tests LLMs' performance on theory of mind and social intelligence benchmarks [ToMi](https://github.com/facebookresearch/ToMi) and [SocialIQA](https://allenai.org/data/socialiqa).
3+
4+
The `ToMi` test set contains 5,993 question-answer pairs. These are instances of the [Sally-Anne test](https://en.wikipedia.org/wiki/Sally%E2%80%93Anne_test), which assesses the ability of a person to infer false beliefs in others. The original setting involves two people, Sally and Anne, who are together in a room. Sally places a marble in a box. Then, Anne leaves the room, and while she is away, Sally moves the marble to a basket elsewhere in the room. When Anne returns to the room, where will she search for the marble? If the person responding “has” theory-of-mind they’ll respond that Anne searches for the marble in the box, where she had last seen it. If they do not, they ascribe their own, accurate belief regarding the location to Anne, and say that she looks for it in the basket.
5+
6+
The `SocialIQA` test set contains 2,224 question-answer pairs covering a variety of social scenarios. These are multiple-choice, with 3 options of which only one is correct. The questions cover a person’s wants, needs, motivations, and reactions, as well as the effects of an action (on self or others), and how that action reflects on the person carrying it out (e.g. how others would perceive them after having carried out the action).
7+
8+
Two "light" versions of the datasets are also provided, containing 1/10th of the data points. These are useful for iterating on prompts and developing other scaffolding.
9+
10+
# Token and pricing estimates
11+
On average:
12+
- On the `SocialIQA` dataset, models consume ~250k tokens per run using the simple solver, and ~900k using the CoT solver.
13+
- On the `ToMi` dataset, models consume ~700k tokens per run using the simple solver, and ~2.4m using the CoT solver.
14+
15+
To calculate dollar cost from token counts, please check the latest token pricing [here](https://openai.com/pricing). Note that we count both input and output tokens together, so a lower and upper estimate of the cost of each variant can be predicted.
16+
17+
# Experiments
18+
As a starting point for deeper exploration, we provide scripts for comparing various solvers and eval variants, as well as for plotting the results. To run these:
19+
```
20+
cd scripts/
21+
bash run_experiments.sh
22+
```
23+
24+
# Contribution statement
25+
Eval design was primarily conducted by Andrei Alexandru, under the guidance of (alphabetically by last-name) Steven Adler, James Aung, Rosie Campbell and Jade Leung who provided research input, report revisions, and project management support.
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import json
2+
3+
# %%
4+
filepath = "/evals/registry/data/theory_of_mind/tomi/train.txt"
5+
6+
lines, datapoints = [], []
7+
with open(filepath, "r") as f:
8+
for line in f:
9+
line_index = line.split(" ")[0]
10+
if int(line_index) == 1:
11+
if len(lines) == 0:
12+
lines.append(line)
13+
else:
14+
target = lines[-1].split("\t")[-2]
15+
last_line = lines[-1].split("\t")[0]
16+
lines = [" ".join(line.replace("\n", "").split(" ")[1:]) for line in lines[:-1]]
17+
context = " ".join(lines) + " " + " ".join(last_line.split(" ")[1:])
18+
datapoints.append({"context": context, "target": target})
19+
lines = [line]
20+
else:
21+
lines.append(line)
22+
# %%
23+
def convert_datapoints_to_eval_dataset(datapoints: list) -> list:
24+
system_prompt = "You will read a number of sentences describing a situation involving several people, as well as a question regarding the situation. Your task is to answer the question based on the information in the sentences."
25+
eval_dataset = []
26+
for datapoint in datapoints:
27+
context = datapoint["context"]
28+
target = datapoint["target"]
29+
eval_datapoint = {
30+
"input": [
31+
{"role": "system", "content": system_prompt},
32+
{"role": "user", "content": context},
33+
],
34+
"ideal": target,
35+
}
36+
eval_dataset += [eval_datapoint]
37+
return eval_dataset
38+
39+
40+
# %%
41+
eval_dataset = convert_datapoints_to_eval_dataset(datapoints)
42+
# %%
43+
output_file = "tomi_train.jsonl"
44+
45+
with open(output_file, "w") as out:
46+
for datapoint in eval_dataset:
47+
out.write(json.dumps(datapoint) + "\n")
48+
# %%
49+
filepath = "/evals/registry/data/theory_of_mind/socialiqa/test.jsonl"
50+
system_prompt = "You will read a number of sentences describing a situation, followed by a question regarding the situation. Your task is to answer the question based on the information in the sentences by choosing from one of three answers A, B or C."
51+
52+
dataset = []
53+
with open(filepath, "r") as f:
54+
for line in f:
55+
entry = json.loads(line)
56+
template = f"{entry['context']} {entry['question']} A: {entry['answerA']}; B: {entry['answerB']}; C: {entry['answerC']}."
57+
dataset.append(
58+
{
59+
"input": [
60+
{"role": "system", "content": system_prompt},
61+
{"role": "user", "content": template},
62+
],
63+
"ideal": entry["correct"],
64+
}
65+
)
66+
# %%
67+
output_file = "socialiqa_test.jsonl"
68+
with open(output_file, "w") as out:
69+
for datapoint in dataset:
70+
out.write(json.dumps(datapoint) + "\n")
71+
72+
# %%
73+
74+
filepath = "evals/registry/data/theory_of_mind/socialiqa/test.jsonl"
75+
outpath = "evals/registry/data/theory_of_mind/socialiqa/newtest.jsonl"
76+
77+
dataset = []
78+
with open(filepath, "r") as f, open(outpath, "w") as out:
79+
for line in f:
80+
entry = json.loads(line)
81+
entry["input"] = [entry["input"][1]]
82+
out.write(json.dumps(entry) + "\n")
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
"""Take results from recent experiments and make a bar plot"""
2+
import argparse
3+
from pathlib import Path
4+
from typing import Union
5+
6+
import matplotlib.pyplot as plt
7+
import numpy as np
8+
import pandas as pd
9+
import seaborn as sns
10+
11+
from evals.utils import log_utils
12+
13+
14+
def main():
15+
parser = argparse.ArgumentParser()
16+
parser.add_argument("--log_dir", type=str, required=True)
17+
parser.add_argument("--out_dir", type=str, required=True)
18+
args = parser.parse_args()
19+
20+
log_dir = args.log_dir
21+
out_dir = args.out_dir
22+
df = load_tom_results_from_dir(log_dir)
23+
make_plot(df, out_dir=Path(out_dir))
24+
25+
26+
def load_tom_results_from_dir(log_dir: Union[str, Path]) -> pd.DataFrame:
27+
rows = []
28+
final_results_dict = log_utils.get_final_results_from_dir(log_dir)
29+
30+
for path, final_results in final_results_dict.items():
31+
spec = log_utils.extract_spec(path)
32+
dataset, prompt_type, model = parse_spec(spec)
33+
rows.append(
34+
{
35+
"model": model,
36+
"dataset": dataset,
37+
"prompt_type": prompt_type,
38+
"accuracy": final_results["accuracy"],
39+
"bootstrap_std": final_results["bootstrap_std"],
40+
}
41+
)
42+
return pd.DataFrame(rows)
43+
44+
45+
def parse_spec(spec: dict) -> tuple[str, bool, int]:
46+
"""parse the spec from a MMP run"""
47+
completion_fn = spec["completion_fns"][0]
48+
dataset, prompt_type, model = completion_fn.split("/")
49+
prompt_type = prompt_type.split("_")[0]
50+
51+
return (dataset, prompt_type, model)
52+
53+
54+
def make_plot(df, out_dir):
55+
sns.set_theme(style="whitegrid")
56+
sns.set_palette("dark")
57+
# Define the order of models
58+
model_order = ["gpt-3.5-turbo", "gpt-4-base", "gpt-4"]
59+
datasets = df["dataset"].unique()
60+
61+
for dataset in datasets:
62+
ds = df[df["dataset"] == dataset.lower()]
63+
64+
# Ensure the model column is a categorical type with the specified order
65+
ds["model"] = pd.Categorical(ds["model"], categories=model_order, ordered=True)
66+
ds = ds.sort_values("model") # Sort according to the categorical order
67+
68+
# Unique models
69+
xs = ds["model"].unique()
70+
# Get the accuracy values for both prompt types
71+
simple_acc = ds[ds["prompt_type"] == "simple"]["accuracy"].values
72+
cot_acc = ds[ds["prompt_type"] == "cot"]["accuracy"].values
73+
74+
# Get the corresponding error values from the "bootstrap_std" field
75+
simple_std = ds[ds["prompt_type"] == "simple"]["bootstrap_std"].values
76+
cot_std = ds[ds["prompt_type"] == "cot"]["bootstrap_std"].values
77+
78+
# Define the width of a bar
79+
bar_width = 0.35
80+
# Set the positions of the bars
81+
x_indices = np.arange(len(xs))
82+
x_indices2 = [x + bar_width for x in x_indices]
83+
84+
fig, ax1 = plt.subplots()
85+
fig.suptitle(f"Accuracy on {dataset} dataset")
86+
87+
ax1.set_xlabel("Model")
88+
ax1.set_ylabel("Accuracy")
89+
90+
# Plot the bars for 'simple' and 'cot'
91+
ax1.bar(
92+
x_indices,
93+
simple_acc,
94+
width=bar_width,
95+
color=sns.color_palette("pastel")[0],
96+
yerr=simple_std,
97+
label="simple",
98+
)
99+
ax1.bar(
100+
x_indices2,
101+
cot_acc,
102+
width=bar_width,
103+
color=sns.color_palette("pastel")[1],
104+
yerr=cot_std,
105+
label="chain-of-thought",
106+
)
107+
108+
if dataset == "socialiqa":
109+
# Draw the horizontal line for the human baseline
110+
human_baseline = 0.881
111+
ax1.axhline(y=human_baseline, color="gray", linestyle="--", linewidth=1)
112+
# Add the text label for the human baseline
113+
ax1.text(
114+
0.01, human_baseline, "human baseline", va="center", ha="left", backgroundcolor="w"
115+
)
116+
117+
# Set the x-axis ticks to be in the middle of the two bars
118+
ax1.set_xticks([r + bar_width / 2 for r in range(len(xs))])
119+
ax1.set_xticklabels(xs, rotation=45) # Rotate the x-axis labels if needed
120+
121+
ax1.set_ylim(0, 1)
122+
123+
# Add legend
124+
ax1.legend(loc="upper right", bbox_to_anchor=(1, 1))
125+
126+
# Save the figure
127+
plt.savefig(out_dir / f"accuracy_{dataset.lower()}.png", bbox_inches="tight")
128+
plt.tight_layout() # Adjust the plot to ensure everything fits without overlapping
129+
plt.show()
130+
131+
132+
if __name__ == "__main__":
133+
main()
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
logdir=./logs
2+
outputdir=./outputs
3+
timestamp=$(date +%Y%m%d_%H%M%S)
4+
logpathbase="$logdir/$timestamp/"
5+
6+
echo Running experiments and logging to $logpathbase
7+
8+
DATASETS="tomi socialiqa hitom"
9+
MODELS="gpt-3.5-turbo gpt-4 gpt-4-base"
10+
SOLVER_TYPES="simple_solver cot_solver"
11+
12+
for dataset in $DATASETS
13+
do
14+
for model in $MODELS
15+
do
16+
for solver in $SOLVER_TYPES
17+
do
18+
oaieval $dataset/$solver/$model "theory_of_mind."$dataset --record_path "$logpathbase/$model-$variant.log"
19+
done
20+
done
21+
done
22+
23+
echo Done running experiments, all logs in $logpathbase
24+
25+
echo Producing plots, outputs to $outputdir
26+
python3 make_plots.py --log_dir $logpathbase --out_dir $outputdir
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
tomi/test.jsonl filter=lfs diff=lfs merge=lfs -text
2+
tomi/test_light.jsonl filter=lfs diff=lfs merge=lfs -text
3+
socialiqa/test.jsonl filter=lfs diff=lfs merge=lfs -text
4+
socialiqa/test_light.jsonl filter=lfs diff=lfs merge=lfs -text
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
ToMi:
2+
License: Creative Commons Attribution-NonCommercial 4.0 International (CC-BY-NC 4.0) https://creativecommons.org/licenses/by-nc/4.0/legalcode.en
3+
Source: https://github.com/facebookresearch/ToMi
4+
5+
SocialIQA:
6+
License: Creative Commons Attribution 4.0 International (CC-BY 4.0) https://creativecommons.org/licenses/by/4.0/legalcode.en
7+
Source: https://allenai.org/data/socialiqa

0 commit comments

Comments
 (0)