|
| 1 | +import json |
| 2 | +import logging |
| 3 | +import random |
| 4 | +from pathlib import Path |
| 5 | +from typing import Any, Optional, Union |
| 6 | + |
| 7 | +import numpy as np |
| 8 | + |
| 9 | +import evals |
| 10 | +import evals.metrics |
| 11 | +from evals.api import CompletionFn |
| 12 | +from evals.elsuite.self_prompting.task_description import sample_in_token, task_description_template |
| 13 | +from evals.eval import SolverEval |
| 14 | +from evals.registry import registry |
| 15 | +from evals.solvers.solver import Solver |
| 16 | +from evals.task_state import TaskState |
| 17 | +from evals.utils.log_utils import extract_final_results, extract_spec |
| 18 | + |
| 19 | +logger = logging.getLogger(__name__) |
| 20 | + |
| 21 | + |
| 22 | +class SelfPrompting(SolverEval): |
| 23 | + def __init__( |
| 24 | + self, |
| 25 | + completion_fns: list[CompletionFn], |
| 26 | + samples_jsonl: str, |
| 27 | + tasker_models: list[str], |
| 28 | + n_tasks: int = 50, |
| 29 | + n_samples_per_task: int = 10, |
| 30 | + n_preview_samples: int = 5, |
| 31 | + baseline_logpath: Optional[str] = None, |
| 32 | + *args, |
| 33 | + **kwargs, |
| 34 | + ): |
| 35 | + super().__init__(completion_fns, *args, **kwargs) |
| 36 | + # CI doesn't have access to model APIs, so replace tasker_models with dummy models |
| 37 | + # if we're running in CI (i.e. if the first completion_fn is a DummyCompletionFn) |
| 38 | + if isinstance(completion_fns[0], evals.api.DummyCompletionFn): |
| 39 | + tasker_models = ["dummy" for _ in tasker_models] |
| 40 | + |
| 41 | + self.samples_jsonl = samples_jsonl |
| 42 | + self.tasker_models = tasker_models |
| 43 | + self.n_tasks = n_tasks |
| 44 | + self.n_samples_per_task = n_samples_per_task |
| 45 | + self.n_preview_samples = n_preview_samples |
| 46 | + self.baseline_logpath = ( |
| 47 | + self._prefix_registry_path(baseline_logpath) if baseline_logpath else None |
| 48 | + ) |
| 49 | + assert len(self.tasker_models) > 0, "Must provide at least one tasker model" |
| 50 | + assert self.n_tasks > 0, "Must provide at least one task" |
| 51 | + assert self.n_samples_per_task > 0, "Must provide at least one sample per task" |
| 52 | + |
| 53 | + np.random.seed(self.seed) |
| 54 | + |
| 55 | + self.tasker_completion_fns = {} |
| 56 | + for tasker_model in self.tasker_models: |
| 57 | + self.tasker_completion_fns[tasker_model] = registry.make_completion_fn(tasker_model) |
| 58 | + |
| 59 | + def eval_sample(self, solver: Solver, sample: Any, rng: random.Random): |
| 60 | + if sample["stage"] == "prompting": |
| 61 | + return self._run_prompting(solver, sample) |
| 62 | + elif sample["stage"] == "tasking": |
| 63 | + return self._run_tasking(sample) |
| 64 | + else: |
| 65 | + raise ValueError(f"Invalid stage {sample['stage']}") |
| 66 | + |
| 67 | + def _run_prompting(self, solver: Solver, sample: Any, *_): |
| 68 | + # Prompt the prompter_model to generate a prompt for the tasker_model |
| 69 | + task_description = task_description_template.format( |
| 70 | + instruction=sample["task"]["instruction"], |
| 71 | + samples=json.dumps(sample["task"]["train_samples"], indent=2), |
| 72 | + tasker_model=sample["tasker_model"], |
| 73 | + ) |
| 74 | + task_state = TaskState( |
| 75 | + task_description=task_description, |
| 76 | + current_state={ |
| 77 | + "instruction": sample["task"]["instruction"], |
| 78 | + "samples": sample["task"]["train_samples"], |
| 79 | + "tasker_model": sample["tasker_model"], |
| 80 | + }, |
| 81 | + ) |
| 82 | + solver_result = solver(task_state) |
| 83 | + model_instruction = solver_result.output |
| 84 | + |
| 85 | + prompt_rule_violation = sample_in_token not in model_instruction |
| 86 | + |
| 87 | + output = { |
| 88 | + **sample, |
| 89 | + "task_description": task_description, |
| 90 | + "current_state": task_state.current_state, |
| 91 | + "prompting_solver_metadata": solver_result.to_json(), |
| 92 | + "model_instruction": model_instruction, |
| 93 | + "prompt_rule_violation": prompt_rule_violation, |
| 94 | + } |
| 95 | + return output |
| 96 | + |
| 97 | + def _run_tasking(self, sample: Any, *_): |
| 98 | + tasker_completion_fn = self.tasker_completion_fns[sample["tasker_model"]] |
| 99 | + |
| 100 | + if sample_in_token in sample["model_instruction"]: |
| 101 | + # Fill in the sample input |
| 102 | + full_prompt = sample["model_instruction"].replace(sample_in_token, sample["input"]) |
| 103 | + else: |
| 104 | + # Append the sample input |
| 105 | + full_prompt = f"{sample['model_instruction']}\n{sample['input']}" |
| 106 | + tasker_output = tasker_completion_fn(full_prompt).get_completions()[0] |
| 107 | + |
| 108 | + exact = 1 if tasker_output == sample["output"] else 0 |
| 109 | + fuzzy = 1 if tasker_output in sample["output"] or sample["output"] in tasker_output else 0 |
| 110 | + |
| 111 | + output = { |
| 112 | + **sample, |
| 113 | + "full_prompt": full_prompt, |
| 114 | + "tasker_output": tasker_output, |
| 115 | + "exact": exact, |
| 116 | + "fuzzy": fuzzy, |
| 117 | + } |
| 118 | + evals.record.record_metrics(**output) |
| 119 | + return output |
| 120 | + |
| 121 | + def _calculate_improvement_wrt_baseline( |
| 122 | + self, current_res: dict[str, float] |
| 123 | + ) -> dict[str, float]: |
| 124 | + if self.baseline_logpath is None: |
| 125 | + logger.warn("SKIPPING IMPROVEMENT METRICS. (No baseline logpath provided.)") |
| 126 | + return {} |
| 127 | + |
| 128 | + # Check that baseline was run on the same tasker models, tasks, and samples |
| 129 | + baseline_spec = extract_spec(Path(self.baseline_logpath)) |
| 130 | + try: |
| 131 | + spec_args = baseline_spec["run_config"]["eval_spec"]["args"] |
| 132 | + except KeyError: |
| 133 | + logger.warn("SKIPPING IMPROVEMENT METRICS. (Failed to validate baseline spec.)") |
| 134 | + return {} |
| 135 | + if set(spec_args["tasker_models"]) != set(self.tasker_models): |
| 136 | + logger.warn( |
| 137 | + f"SKIPPING IMPROVEMENT METRICS. (Baseline tasker_models {spec_args['tasker_models']} do not match {self.tasker_models}.)" |
| 138 | + ) |
| 139 | + return {} |
| 140 | + if ( |
| 141 | + spec_args["n_tasks"] != self.n_tasks |
| 142 | + ): # TODO: Ideally we would check that the tasks are the same |
| 143 | + logger.warn( |
| 144 | + f"SKIPPING IMPROVEMENT METRICS. (Baseline n_tasks {spec_args['n_tasks']} does not match {self.n_tasks}.)" |
| 145 | + ) |
| 146 | + return {} |
| 147 | + if spec_args["n_samples_per_task"] != self.n_samples_per_task: |
| 148 | + logger.warn( |
| 149 | + f"SKIPPING IMPROVEMENT METRICS. (Baseline n_samples_per_task {spec_args['n_samples_per_task']} does not match {self.n_samples_per_task}.)" |
| 150 | + ) |
| 151 | + return {} |
| 152 | + |
| 153 | + baseline_res = extract_final_results(Path(self.baseline_logpath)) |
| 154 | + |
| 155 | + def normalized_improvement(current, baseline): |
| 156 | + """ |
| 157 | + Returns a score between -1 and 1, where |
| 158 | + -1 means the current score maximally regresses from the baseline (i.e. the current score is 0) |
| 159 | + 0 means the current score is the same as the baseline |
| 160 | + +1 means the current score achieves max improvement over the baseline |
| 161 | + """ |
| 162 | + if current < baseline: |
| 163 | + return (current - baseline) / baseline |
| 164 | + else: |
| 165 | + return (current - baseline) / (1 - baseline) |
| 166 | + |
| 167 | + improvement_scores = { |
| 168 | + "accuracy_improvement_wrt_oriprompt": normalized_improvement( |
| 169 | + current_res["accuracy"], baseline_res["accuracy"] |
| 170 | + ), |
| 171 | + "accuracy_fuzzy_improvement_wrt_oriprompt": normalized_improvement( |
| 172 | + current_res["accuracy_fuzzy"], baseline_res["accuracy_fuzzy"] |
| 173 | + ), |
| 174 | + "baseline_accuracy": baseline_res["accuracy"], |
| 175 | + "baseline_accuracy_fuzzy": baseline_res["accuracy_fuzzy"], |
| 176 | + } |
| 177 | + logger.info(f"Improvement scores: {improvement_scores}") |
| 178 | + return improvement_scores |
| 179 | + |
| 180 | + def run(self, recorder: evals.record.Recorder) -> dict[str, Union[float, int]]: |
| 181 | + samples = self.get_samples() |
| 182 | + |
| 183 | + # Shuffle and limit samples |
| 184 | + np.random.shuffle(samples) |
| 185 | + samples_by_task = samples[: self.n_tasks] |
| 186 | + assert len(samples_by_task) == self.n_tasks |
| 187 | + for task in samples_by_task: |
| 188 | + np.random.shuffle(task["test_samples"]) |
| 189 | + np.random.shuffle(task["train_samples"]) |
| 190 | + task["test_samples"] = task["test_samples"][: self.n_samples_per_task] |
| 191 | + task["train_samples"] = task["train_samples"][: self.n_preview_samples] |
| 192 | + assert len(task["test_samples"]) == self.n_samples_per_task |
| 193 | + assert len(task["train_samples"]) == self.n_preview_samples |
| 194 | + |
| 195 | + # Run prompting |
| 196 | + prompting_samples = [] |
| 197 | + for task in samples_by_task: |
| 198 | + for tasker_model in self.tasker_models: |
| 199 | + prompting_samples.append( |
| 200 | + { |
| 201 | + "stage": "prompting", |
| 202 | + "tasker_model": tasker_model, |
| 203 | + "task": task, |
| 204 | + } |
| 205 | + ) |
| 206 | + assert len(prompting_samples) == len(self.tasker_models) * self.n_tasks |
| 207 | + prompting_results = self.eval_all_samples(recorder, prompting_samples) |
| 208 | + |
| 209 | + # Run tasking |
| 210 | + tasking_samples = [] # Store in flattened list for parallel eval |
| 211 | + for prompt_res in prompting_results: |
| 212 | + prompt_res["stage"] = "tasking" # Update stage |
| 213 | + for sample in prompt_res["task"]["test_samples"]: |
| 214 | + tasking_samples.append( |
| 215 | + { |
| 216 | + **prompt_res, |
| 217 | + "input": sample["input"], |
| 218 | + "output": sample["output"], |
| 219 | + } |
| 220 | + ) |
| 221 | + assert len(tasking_samples) == len(prompting_results) * self.n_samples_per_task |
| 222 | + self.eval_all_samples(recorder, tasking_samples) |
| 223 | + |
| 224 | + # The score of a Prompter is the average score of all Tasker models it writes prompts for |
| 225 | + metrics = recorder.get_metrics() |
| 226 | + |
| 227 | + # Primary metrics |
| 228 | + result = { |
| 229 | + "accuracy": np.mean([metric["exact"] for metric in metrics]), |
| 230 | + "accuracy_fuzzy": np.mean([metric["fuzzy"] for metric in metrics]), |
| 231 | + } |
| 232 | + # Relative improvement against baseline |
| 233 | + improvement_scores = self._calculate_improvement_wrt_baseline(result) |
| 234 | + if improvement_scores: |
| 235 | + result.update(improvement_scores) |
| 236 | + |
| 237 | + # Peripheral metrics |
| 238 | + result.update( |
| 239 | + { |
| 240 | + "prompt_rule_violation_rate": np.mean( |
| 241 | + [int(metric["prompt_rule_violation"]) for metric in metrics] |
| 242 | + ), |
| 243 | + "n_samples": len(metrics), |
| 244 | + } |
| 245 | + ) |
| 246 | + |
| 247 | + # Breakdown by tasker model |
| 248 | + def compute_mean_tasker(key, tasker_model): |
| 249 | + return np.mean( |
| 250 | + [metric[key] for metric in metrics if metric["tasker_model"] == tasker_model] |
| 251 | + ) |
| 252 | + |
| 253 | + for tasker in self.tasker_models: |
| 254 | + result.update( |
| 255 | + { |
| 256 | + f"accuracy_{tasker}": compute_mean_tasker("exact", tasker), |
| 257 | + f"accuracy_fuzzy_{tasker}": compute_mean_tasker("fuzzy", tasker), |
| 258 | + } |
| 259 | + ) |
| 260 | + |
| 261 | + return result |
0 commit comments