Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ jobs:

- name: Install Dependencies
run: |
df -h
pip install --upgrade -r requirements.txt -r requirements-dev.txt
pip install -e .[bgem3]

Expand Down
30 changes: 20 additions & 10 deletions pyterrier_dr/bge_m3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,32 @@
import pyterrier_alpha as pta
from .biencoder import BiEncoder

IS_FLAGEMBEDDING_AVAILIBLE = False
try:
import FlagEmbedding as FE
IS_FLAGEMBEDDING_AVAILIBLE = True
except ImportError:
pass


class BGEM3(BiEncoder):
def __init__(self, model_name='BAAI/bge-m3', batch_size=32, max_length=8192, text_field='text', verbose=False, device=None, use_fp16=False):
super().__init__(batch_size=batch_size, text_field=text_field, verbose=verbose)
if not IS_FLAGEMBEDDING_AVAILIBLE:
raise ImportError("BGE-M3 requires the FlagEmbedding package. You can install it using 'pip install pyterrier-dr[bgem3]'")

self.model_name = model_name
self.use_fp16 = use_fp16
self.max_length = max_length
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = torch.device(device)
try:
from FlagEmbedding import BGEM3FlagModel
except ImportError:
raise ImportError("BGE-M3 requires the FlagEmbedding package. You can install it using 'pip install pyterrier-dr[bgem3]'")

self.model = BGEM3FlagModel(self.model_name, use_fp16=self.use_fp16, device=self.device)

self.model = FE.BGEM3FlagModel(self.model_name, use_fp16=self.use_fp16, device=self.device)

def __repr__(self):
return f'BGEM3({repr(self.model_name)})'

def encode_queries(self, texts, batch_size=None):
return self.model.encode(list(texts), batch_size=batch_size, max_length=self.max_length,
return_dense=True, return_sparse=False, return_colbert_vecs=False)['dense_vecs']
Expand All @@ -36,15 +42,18 @@ def encode_docs(self, texts, batch_size=None):
# Only does dense (single_vec) encoding
def query_encoder(self, verbose=None, batch_size=None):
return BGEM3QueryEncoder(self, verbose=verbose, batch_size=batch_size)

def doc_encoder(self, verbose=None, batch_size=None):
return BGEM3DocEncoder(self, verbose=verbose, batch_size=batch_size)

# Does all three BGE-M3 encodings: dense, sparse and colbert(multivec)
def query_multi_encoder(self, verbose=None, batch_size=None, return_dense=True, return_sparse=True, return_colbert_vecs=True):
return BGEM3QueryEncoder(self, verbose=verbose, batch_size=batch_size, return_dense=return_dense, return_sparse=return_sparse, return_colbert_vecs=return_colbert_vecs)

def doc_multi_encoder(self, verbose=None, batch_size=None, return_dense=True, return_sparse=True, return_colbert_vecs=True):
return BGEM3DocEncoder(self, verbose=verbose, batch_size=batch_size, return_dense=return_dense, return_sparse=return_sparse, return_colbert_vecs=return_colbert_vecs)


class BGEM3QueryEncoder(pt.Transformer):
def __init__(self, bge_factory: BGEM3, verbose=None, batch_size=None, max_length=None, return_dense=True, return_sparse=False, return_colbert_vecs=False):
self.bge_factory = bge_factory
Expand All @@ -55,7 +64,7 @@ def __init__(self, bge_factory: BGEM3, verbose=None, batch_size=None, max_length
self.dense = return_dense
self.sparse = return_sparse
self.multivecs = return_colbert_vecs

def encode(self, texts):
return self.bge_factory.model.encode(list(texts), batch_size=self.batch_size, max_length=self.max_length,
return_dense=self.dense, return_sparse=self.sparse, return_colbert_vecs=self.multivecs)
Expand Down Expand Up @@ -88,10 +97,11 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
if self.multivecs:
inp = inp.assign(query_embs=[bgem3_results['colbert_vecs'][i] for i in inv])
return inp

def __repr__(self):
return f'{repr(self.bge_factory)}.query_encoder()'


class BGEM3DocEncoder(pt.Transformer):
def __init__(self, bge_factory: BGEM3, verbose=None, batch_size=None, max_length=None, return_dense=True, return_sparse=False, return_colbert_vecs=False):
self.bge_factory = bge_factory
Expand Down
88 changes: 48 additions & 40 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
import pyterrier as pt
from pyterrier_dr import FlexIndex

IS_FLAG_EMBEDDING_AVAILABLE = False
try:
import FlagEmbedding
IS_FLAG_EMBEDDING_AVAILABLE = True
except ImportError:
pass

class TestModels(unittest.TestCase):

Expand Down Expand Up @@ -103,46 +109,6 @@ def _base_test(self, model, test_query_encoder=True, test_doc_encoder=True, test
self.assertTrue('docno' in retr_res.columns)
self.assertTrue('score' in retr_res.columns)
self.assertTrue('rank' in retr_res.columns)

def _test_bgem3_multi(self, model, test_query_multivec_encoder=False, test_doc_multivec_encoder=False):
dataset = pt.get_dataset('irds:vaswani')

docs = list(itertools.islice(pt.get_dataset('irds:vaswani').get_corpus_iter(), 200))
docs_df = pd.DataFrame(docs)

if test_query_multivec_encoder:
with self.subTest('query_multivec_encoder'):
topics = dataset.get_topics()
enc_topics = model(topics)
self.assertEqual(len(enc_topics), len(topics))
self.assertTrue('query_toks' in enc_topics.columns)
self.assertTrue('query_embs' in enc_topics.columns)
self.assertTrue(all(c in enc_topics.columns for c in topics.columns))
self.assertEqual(enc_topics.query_toks.dtype, object)
self.assertTrue(all(isinstance(v, dict) for v in enc_topics.query_toks))
self.assertEqual(enc_topics.query_embs.dtype, object)
self.assertTrue(all(v.dtype == np.float32 for v in enc_topics.query_embs))
with self.subTest('query_multivec_encoder empty'):
enc_topics_empty = model(pd.DataFrame(columns=['qid', 'query']))
self.assertEqual(len(enc_topics_empty), 0)
self.assertTrue('query_toks' in enc_topics_empty.columns)
self.assertTrue('query_embs' in enc_topics_empty.columns)
if test_doc_multivec_encoder:
with self.subTest('doc_multi_encoder'):
enc_docs = model(pd.DataFrame(docs_df))
self.assertEqual(len(enc_docs), len(docs_df))
self.assertTrue('toks' in enc_docs.columns)
self.assertTrue('doc_embs' in enc_docs.columns)
self.assertTrue(all(c in enc_docs.columns for c in docs_df.columns))
self.assertEqual(enc_docs.toks.dtype, object)
self.assertTrue(all(isinstance(v, dict) for v in enc_docs.toks))
self.assertEqual(enc_docs.doc_embs.dtype, object)
self.assertTrue(all(v.dtype == np.float32 for v in enc_docs.doc_embs))
with self.subTest('doc_multi_encoder empty'):
enc_docs_empty = model(pd.DataFrame(columns=['docno', 'text']))
self.assertEqual(len(enc_docs_empty), 0)
self.assertTrue('toks' in enc_docs_empty.columns)
self.assertTrue('doc_embs' in enc_docs_empty.columns)

def test_tct(self):
from pyterrier_dr import TctColBert
Expand Down Expand Up @@ -175,6 +141,7 @@ def test_query2query(self):
from pyterrier_dr import Query2Query
self._base_test(Query2Query(), test_doc_encoder=False, test_scorer=False, test_indexer=False, test_retriever=False)

@unittest.skipUnless(IS_FLAG_EMBEDDING_AVAILABLE, "FlagEmbedding package is not available")
def test_bgem3(self):
from pyterrier_dr import BGEM3
# create BGEM3 instance
Expand All @@ -186,6 +153,47 @@ def test_bgem3(self):
self._test_bgem3_multi(bgem3.query_multi_encoder(), test_query_multivec_encoder=True)
self._test_bgem3_multi(bgem3.doc_multi_encoder(), test_doc_multivec_encoder=True)

@unittest.skipUnless(IS_FLAG_EMBEDDING_AVAILABLE, "FlagEmbedding package is not available")
def _test_bgem3_multi(self, model, test_query_multivec_encoder=False, test_doc_multivec_encoder=False):
dataset = pt.get_dataset('irds:vaswani')

docs = list(itertools.islice(pt.get_dataset('irds:vaswani').get_corpus_iter(), 200))
docs_df = pd.DataFrame(docs)

if test_query_multivec_encoder:
with self.subTest('query_multivec_encoder'):
topics = dataset.get_topics()
enc_topics = model(topics)
self.assertEqual(len(enc_topics), len(topics))
self.assertTrue('query_toks' in enc_topics.columns)
self.assertTrue('query_embs' in enc_topics.columns)
self.assertTrue(all(c in enc_topics.columns for c in topics.columns))
self.assertEqual(enc_topics.query_toks.dtype, object)
self.assertTrue(all(isinstance(v, dict) for v in enc_topics.query_toks))
self.assertEqual(enc_topics.query_embs.dtype, object)
self.assertTrue(all(v.dtype == np.float32 for v in enc_topics.query_embs))
with self.subTest('query_multivec_encoder empty'):
enc_topics_empty = model(pd.DataFrame(columns=['qid', 'query']))
self.assertEqual(len(enc_topics_empty), 0)
self.assertTrue('query_toks' in enc_topics_empty.columns)
self.assertTrue('query_embs' in enc_topics_empty.columns)
if test_doc_multivec_encoder:
with self.subTest('doc_multi_encoder'):
enc_docs = model(pd.DataFrame(docs_df))
self.assertEqual(len(enc_docs), len(docs_df))
self.assertTrue('toks' in enc_docs.columns)
self.assertTrue('doc_embs' in enc_docs.columns)
self.assertTrue(all(c in enc_docs.columns for c in docs_df.columns))
self.assertEqual(enc_docs.toks.dtype, object)
self.assertTrue(all(isinstance(v, dict) for v in enc_docs.toks))
self.assertEqual(enc_docs.doc_embs.dtype, object)
self.assertTrue(all(v.dtype == np.float32 for v in enc_docs.doc_embs))
with self.subTest('doc_multi_encoder empty'):
enc_docs_empty = model(pd.DataFrame(columns=['docno', 'text']))
self.assertEqual(len(enc_docs_empty), 0)
self.assertTrue('toks' in enc_docs_empty.columns)
self.assertTrue('doc_embs' in enc_docs_empty.columns)


if __name__ == '__main__':
unittest.main()
Loading