Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyterrier_rag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from pyterrier_rag import measures
from pyterrier_rag import model
from pyterrier_rag import readers
from pyterrier_rag._frameworks import Iterative
from pyterrier_rag._frameworks import Iterative, Genetic

__all__ = [
'Iterative', 'model', 'readers', 'measures', '_datasets',
'Iterative', 'model', 'readers', 'measures', '_datasets', 'Genetic',
]
100 changes: 98 additions & 2 deletions pyterrier_rag/_frameworks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Optional
from typing import Optional, List, Union, Literal
import random
import itertools
from functools import partial
from collections import Counter
from outlines import prompt

import pandas as pd
Expand Down Expand Up @@ -149,4 +152,97 @@ def transform(self, inp : pd.DataFrame) -> pd.DataFrame:
if "the answer is" in answers.iloc[0].qanswer.lower():
stop = True
inp = answers
return answers
return answers


class Genetic(pt.Transformer):
"""Genetic RAG pipeline (Gen2IR)

.. cite.dblp:: conf/doceng/KulkarniYGFM23
"""

def __init__(self,
fitness: pt.Transformer,
mutators: List[pt.Transformer],
*,
convergence_depth: int = 2,
mutation_depth: int = 2,
mutations_per_generation: int = 8,
response_type: Union[Literal['result_frame'], Literal['answer_frame']] = 'answer_frame',
rng: Optional[int] = None,
):
"""
Args:
fitness: a Transformer that scores the input DataFrame (to determine the best answers)
mutators: a list of Transformers that generate new answers
convergence_depth: the depth at which we consider the results to have converged
mutation_depth: the depth at which sample from for mutations
mutations_per_generation: the number of mutations to generate per generation
response_type: the type of frame to return: either an ``answer_frame`` (which includes only a single qanswer per query)
or a ``result_frame`` which includes all retrieved generated documents.
rng: the random seed
"""
self.fitness = fitness
self.mutators = mutators
self.convergence_depth = convergence_depth
self.mutation_depth = mutation_depth
self.mutations_per_generation = mutations_per_generation
self.response_type = response_type
self.rng = random.Random(rng) or random.Random()

@pta.transform.by_query(add_ranks=False)
def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
pta.validate.result_frame(inp, extra_columns=['query', 'text'])
qid, query = inp['qid'].iloc[0], inp['query'].iloc[0]
scores = Counter()
threshold = float('-inf')
next_res = inp
for generation in itertools.count():
# Evaluate Fitness
next_res = self.fitness(next_res)
for docno, text, score in next_res[['docno', 'text', 'score']].itertuples(index=False):
scores[docno, text] = score
sorted_res = scores.most_common()
if len(next_res[next_res['score'] > threshold]) == 0:
break # converged
threshold = sorted_res[self.convergence_depth][1]

# Mutate
top_frame = pd.DataFrame({
'qid': qid,
'query': query,
'docno': [r[0][0] for r in sorted_res[:self.mutation_depth]],
'text': [r[0][1] for r in sorted_res[:self.mutation_depth]],
'score': [r[1] for r in sorted_res[:self.mutation_depth]],
'rank': list(range(len(sorted_res[:self.mutation_depth]))),
})
next_res = []
for i in range(self.mutations_per_generation):
mutator = self.rng.choice(self.mutators)
answer = mutator(top_frame)
assert len(answer) == 1
next_res.append({
'qid': qid,
'query': query,
'docno': f'g{generation}i{i}',
'text': answer['qanswer'].iloc[0],
})
next_res = pd.DataFrame(next_res)

if self.response_type == 'answer_frame':
return pd.DataFrame({
'qid': [qid],
'query': [query],
'qanswer': [sorted_res[0][0][1]],
})
elif self.response_type == 'result_frame':
return pd.DataFrame({
'qid': qid,
'query': query,
'docno': [r[0][0] for r in sorted_res],
'text': [r[0][1] for r in sorted_res],
'score': [r[1] for r in sorted_res],
'rank': list(range(len(sorted_res))),
})
else:
raise ValueError(f'unknown response_type: {self.response_type!r}')
16 changes: 16 additions & 0 deletions pyterrier_rag/pt_docs/pipelines.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
RAG Pipelines
========================================


Genetic
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The :class:`pyterrier_rag.Genetic` pipeline performs RAG by using a genetic algorithm to iteratively
construct a strong answer.

The pipeline expects a result frame as input, and returns an answer frame.

.. cite.dblp:: conf/doceng/KulkarniYGFM23

.. autoclass:: pyterrier_rag.Genetic
:members:
3 changes: 2 additions & 1 deletion requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ pytest-cov
pytest-subtests
pytest-json-report
ruff
pyterrier-dr # used by Generic
bert_score
evaluate # used by bertscore
evaluate # used by bertscore
28 changes: 28 additions & 0 deletions tests/test_genetic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pyterrier as pt
from pyterrier_rag import Genetic
import pyterrier_dr
import unittest
import pandas as pd


class ConstMutator(pt.Transformer):
def __init__(self, answer):
self.answer = answer

def transform(self, inp):
return pd.DataFrame({'qid': [inp.qid[0]], 'qanswer': [self.answer]})


class TestGenetic(unittest.TestCase):
def test_genetic(self):
electra = pyterrier_dr.ElectraScorer(verbose=False)
dataset = pt.get_dataset('irds:vaswani')
index = pt.Artifact.from_hf('pyterrier/vaswani.terrier')
mutators = [
ConstMutator('Chemical reactions can happen sometimes'),
ConstMutator('Chemical reactions happen when molecules collide with enough energy to break their existing bonds'),
]
pipeline = index.bm25(num_results=5) >> dataset.text_loader() >> Genetic(electra, mutators)
results = pipeline.search('when do chemical reactions happen')
# the qanswer should be the good answer from above (with very high probability)
self.assertEqual(results['qanswer'][0], 'Chemical reactions happen when molecules collide with enough energy to break their existing bonds')