|
| 1 | +"""Take results from recent experiments and make a bar plot""" |
| 2 | +import argparse |
| 3 | +from pathlib import Path |
| 4 | +from typing import Union |
| 5 | + |
| 6 | +import matplotlib.pyplot as plt |
| 7 | +import numpy as np |
| 8 | +import pandas as pd |
| 9 | +import seaborn as sns |
| 10 | + |
| 11 | +from evals.utils import log_utils |
| 12 | + |
| 13 | + |
| 14 | +def main(): |
| 15 | + parser = argparse.ArgumentParser() |
| 16 | + parser.add_argument("--log_dir", type=str, required=True) |
| 17 | + parser.add_argument("--out_dir", type=str, required=True) |
| 18 | + args = parser.parse_args() |
| 19 | + |
| 20 | + log_dir = args.log_dir |
| 21 | + out_dir = args.out_dir |
| 22 | + df = load_tom_results_from_dir(log_dir) |
| 23 | + make_plot(df, out_dir=Path(out_dir)) |
| 24 | + |
| 25 | + |
| 26 | +def load_tom_results_from_dir(log_dir: Union[str, Path]) -> pd.DataFrame: |
| 27 | + rows = [] |
| 28 | + final_results_dict = log_utils.get_final_results_from_dir(log_dir) |
| 29 | + |
| 30 | + for path, final_results in final_results_dict.items(): |
| 31 | + spec = log_utils.extract_spec(path) |
| 32 | + dataset, prompt_type, model = parse_spec(spec) |
| 33 | + rows.append( |
| 34 | + { |
| 35 | + "model": model, |
| 36 | + "dataset": dataset, |
| 37 | + "prompt_type": prompt_type, |
| 38 | + "accuracy": final_results["accuracy"], |
| 39 | + "bootstrap_std": final_results["bootstrap_std"], |
| 40 | + } |
| 41 | + ) |
| 42 | + return pd.DataFrame(rows) |
| 43 | + |
| 44 | + |
| 45 | +def parse_spec(spec: dict) -> tuple[str, bool, int]: |
| 46 | + """parse the spec from a MMP run""" |
| 47 | + completion_fn = spec["completion_fns"][0] |
| 48 | + dataset, prompt_type, model = completion_fn.split("/") |
| 49 | + prompt_type = prompt_type.split("_")[0] |
| 50 | + |
| 51 | + return (dataset, prompt_type, model) |
| 52 | + |
| 53 | + |
| 54 | +def make_plot(df, out_dir): |
| 55 | + sns.set_theme(style="whitegrid") |
| 56 | + sns.set_palette("dark") |
| 57 | + # Define the order of models |
| 58 | + model_order = ["gpt-3.5-turbo", "gpt-4-base", "gpt-4"] |
| 59 | + datasets = df["dataset"].unique() |
| 60 | + |
| 61 | + for dataset in datasets: |
| 62 | + ds = df[df["dataset"] == dataset.lower()] |
| 63 | + |
| 64 | + # Ensure the model column is a categorical type with the specified order |
| 65 | + ds["model"] = pd.Categorical(ds["model"], categories=model_order, ordered=True) |
| 66 | + ds = ds.sort_values("model") # Sort according to the categorical order |
| 67 | + |
| 68 | + # Unique models |
| 69 | + xs = ds["model"].unique() |
| 70 | + # Get the accuracy values for both prompt types |
| 71 | + simple_acc = ds[ds["prompt_type"] == "simple"]["accuracy"].values |
| 72 | + cot_acc = ds[ds["prompt_type"] == "cot"]["accuracy"].values |
| 73 | + |
| 74 | + # Get the corresponding error values from the "bootstrap_std" field |
| 75 | + simple_std = ds[ds["prompt_type"] == "simple"]["bootstrap_std"].values |
| 76 | + cot_std = ds[ds["prompt_type"] == "cot"]["bootstrap_std"].values |
| 77 | + |
| 78 | + # Define the width of a bar |
| 79 | + bar_width = 0.35 |
| 80 | + # Set the positions of the bars |
| 81 | + x_indices = np.arange(len(xs)) |
| 82 | + x_indices2 = [x + bar_width for x in x_indices] |
| 83 | + |
| 84 | + fig, ax1 = plt.subplots() |
| 85 | + fig.suptitle(f"Accuracy on {dataset} dataset") |
| 86 | + |
| 87 | + ax1.set_xlabel("Model") |
| 88 | + ax1.set_ylabel("Accuracy") |
| 89 | + |
| 90 | + # Plot the bars for 'simple' and 'cot' |
| 91 | + ax1.bar( |
| 92 | + x_indices, |
| 93 | + simple_acc, |
| 94 | + width=bar_width, |
| 95 | + color=sns.color_palette("pastel")[0], |
| 96 | + yerr=simple_std, |
| 97 | + label="simple", |
| 98 | + ) |
| 99 | + ax1.bar( |
| 100 | + x_indices2, |
| 101 | + cot_acc, |
| 102 | + width=bar_width, |
| 103 | + color=sns.color_palette("pastel")[1], |
| 104 | + yerr=cot_std, |
| 105 | + label="chain-of-thought", |
| 106 | + ) |
| 107 | + |
| 108 | + if dataset == "socialiqa": |
| 109 | + # Draw the horizontal line for the human baseline |
| 110 | + human_baseline = 0.881 |
| 111 | + ax1.axhline(y=human_baseline, color="gray", linestyle="--", linewidth=1) |
| 112 | + # Add the text label for the human baseline |
| 113 | + ax1.text( |
| 114 | + 0.01, human_baseline, "human baseline", va="center", ha="left", backgroundcolor="w" |
| 115 | + ) |
| 116 | + |
| 117 | + # Set the x-axis ticks to be in the middle of the two bars |
| 118 | + ax1.set_xticks([r + bar_width / 2 for r in range(len(xs))]) |
| 119 | + ax1.set_xticklabels(xs, rotation=45) # Rotate the x-axis labels if needed |
| 120 | + |
| 121 | + ax1.set_ylim(0, 1) |
| 122 | + |
| 123 | + # Add legend |
| 124 | + ax1.legend(loc="upper right", bbox_to_anchor=(1, 1)) |
| 125 | + |
| 126 | + # Save the figure |
| 127 | + plt.savefig(out_dir / f"accuracy_{dataset.lower()}.png", bbox_inches="tight") |
| 128 | + plt.tight_layout() # Adjust the plot to ensure everything fits without overlapping |
| 129 | + plt.show() |
| 130 | + |
| 131 | + |
| 132 | +if __name__ == "__main__": |
| 133 | + main() |
0 commit comments