From efd2ba7e09e31fdc5b53ef8c1fdf4a2b021b8a56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CAndrew?= Date: Thu, 2 Oct 2025 17:53:30 +0100 Subject: [PATCH 1/4] begin fix --- tests/test_models.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index ce52426..1c8cd3a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -6,6 +6,12 @@ import pyterrier as pt from pyterrier_dr import FlexIndex +IS_FLAG_EMBEDDING_AVAILABLE = False +try: + import FlagEmbedding + IS_FLAG_EMBEDDING_AVAILABLE = True +except ImportError: + pass class TestModels(unittest.TestCase): @@ -103,7 +109,8 @@ def _base_test(self, model, test_query_encoder=True, test_doc_encoder=True, test self.assertTrue('docno' in retr_res.columns) self.assertTrue('score' in retr_res.columns) self.assertTrue('rank' in retr_res.columns) - + + @unittest.skipUnless(IS_FLAG_EMBEDDING_AVAILABLE, "FlagEmbedding package is not available") def _test_bgem3_multi(self, model, test_query_multivec_encoder=False, test_doc_multivec_encoder=False): dataset = pt.get_dataset('irds:vaswani') @@ -175,6 +182,7 @@ def test_query2query(self): from pyterrier_dr import Query2Query self._base_test(Query2Query(), test_doc_encoder=False, test_scorer=False, test_indexer=False, test_retriever=False) + @unittest.skipUnless(IS_FLAG_EMBEDDING_AVAILABLE, "FlagEmbedding package is not available") def test_bgem3(self): from pyterrier_dr import BGEM3 # create BGEM3 instance From 757f7957126e9ae91ab812debb2687161c189ea4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CAndrew?= Date: Thu, 2 Oct 2025 18:20:46 +0100 Subject: [PATCH 2/4] update model --- pyterrier_dr/bge_m3.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/pyterrier_dr/bge_m3.py b/pyterrier_dr/bge_m3.py index 9915d7b..117076c 100644 --- a/pyterrier_dr/bge_m3.py +++ b/pyterrier_dr/bge_m3.py @@ -5,26 +5,32 @@ import pyterrier_alpha as pta from .biencoder import BiEncoder +IS_FLAGEMBEDDING_AVAILIBLE = False +try: + import FlagEmbedding as FE + IS_FLAGEMBEDDING_AVAILIBLE = True +except ImportError: + pass + + class BGEM3(BiEncoder): def __init__(self, model_name='BAAI/bge-m3', batch_size=32, max_length=8192, text_field='text', verbose=False, device=None, use_fp16=False): super().__init__(batch_size=batch_size, text_field=text_field, verbose=verbose) + if not IS_FLAGEMBEDDING_AVAILIBLE: + raise ImportError("BGE-M3 requires the FlagEmbedding package. You can install it using 'pip install pyterrier-dr[bgem3]'") + self.model_name = model_name self.use_fp16 = use_fp16 self.max_length = max_length if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = torch.device(device) - try: - from FlagEmbedding import BGEM3FlagModel - except ImportError: - raise ImportError("BGE-M3 requires the FlagEmbedding package. You can install it using 'pip install pyterrier-dr[bgem3]'") - - self.model = BGEM3FlagModel(self.model_name, use_fp16=self.use_fp16, device=self.device) + self.model = FE.BGEM3FlagModel(self.model_name, use_fp16=self.use_fp16, device=self.device) def __repr__(self): return f'BGEM3({repr(self.model_name)})' - + def encode_queries(self, texts, batch_size=None): return self.model.encode(list(texts), batch_size=batch_size, max_length=self.max_length, return_dense=True, return_sparse=False, return_colbert_vecs=False)['dense_vecs'] @@ -36,15 +42,18 @@ def encode_docs(self, texts, batch_size=None): # Only does dense (single_vec) encoding def query_encoder(self, verbose=None, batch_size=None): return BGEM3QueryEncoder(self, verbose=verbose, batch_size=batch_size) + def doc_encoder(self, verbose=None, batch_size=None): return BGEM3DocEncoder(self, verbose=verbose, batch_size=batch_size) - + # Does all three BGE-M3 encodings: dense, sparse and colbert(multivec) def query_multi_encoder(self, verbose=None, batch_size=None, return_dense=True, return_sparse=True, return_colbert_vecs=True): return BGEM3QueryEncoder(self, verbose=verbose, batch_size=batch_size, return_dense=return_dense, return_sparse=return_sparse, return_colbert_vecs=return_colbert_vecs) + def doc_multi_encoder(self, verbose=None, batch_size=None, return_dense=True, return_sparse=True, return_colbert_vecs=True): return BGEM3DocEncoder(self, verbose=verbose, batch_size=batch_size, return_dense=return_dense, return_sparse=return_sparse, return_colbert_vecs=return_colbert_vecs) + class BGEM3QueryEncoder(pt.Transformer): def __init__(self, bge_factory: BGEM3, verbose=None, batch_size=None, max_length=None, return_dense=True, return_sparse=False, return_colbert_vecs=False): self.bge_factory = bge_factory @@ -55,7 +64,7 @@ def __init__(self, bge_factory: BGEM3, verbose=None, batch_size=None, max_length self.dense = return_dense self.sparse = return_sparse self.multivecs = return_colbert_vecs - + def encode(self, texts): return self.bge_factory.model.encode(list(texts), batch_size=self.batch_size, max_length=self.max_length, return_dense=self.dense, return_sparse=self.sparse, return_colbert_vecs=self.multivecs) @@ -88,10 +97,11 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame: if self.multivecs: inp = inp.assign(query_embs=[bgem3_results['colbert_vecs'][i] for i in inv]) return inp - + def __repr__(self): return f'{repr(self.bge_factory)}.query_encoder()' + class BGEM3DocEncoder(pt.Transformer): def __init__(self, bge_factory: BGEM3, verbose=None, batch_size=None, max_length=None, return_dense=True, return_sparse=False, return_colbert_vecs=False): self.bge_factory = bge_factory From bde726c284218657f89737634fdbc51fb7eba7b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CAndrew?= Date: Thu, 2 Oct 2025 18:40:12 +0100 Subject: [PATCH 3/4] minor update --- tests/test_models.py | 82 ++++++++++++++++++++++---------------------- 1 file changed, 41 insertions(+), 41 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 1c8cd3a..edbde0b 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -110,47 +110,6 @@ def _base_test(self, model, test_query_encoder=True, test_doc_encoder=True, test self.assertTrue('score' in retr_res.columns) self.assertTrue('rank' in retr_res.columns) - @unittest.skipUnless(IS_FLAG_EMBEDDING_AVAILABLE, "FlagEmbedding package is not available") - def _test_bgem3_multi(self, model, test_query_multivec_encoder=False, test_doc_multivec_encoder=False): - dataset = pt.get_dataset('irds:vaswani') - - docs = list(itertools.islice(pt.get_dataset('irds:vaswani').get_corpus_iter(), 200)) - docs_df = pd.DataFrame(docs) - - if test_query_multivec_encoder: - with self.subTest('query_multivec_encoder'): - topics = dataset.get_topics() - enc_topics = model(topics) - self.assertEqual(len(enc_topics), len(topics)) - self.assertTrue('query_toks' in enc_topics.columns) - self.assertTrue('query_embs' in enc_topics.columns) - self.assertTrue(all(c in enc_topics.columns for c in topics.columns)) - self.assertEqual(enc_topics.query_toks.dtype, object) - self.assertTrue(all(isinstance(v, dict) for v in enc_topics.query_toks)) - self.assertEqual(enc_topics.query_embs.dtype, object) - self.assertTrue(all(v.dtype == np.float32 for v in enc_topics.query_embs)) - with self.subTest('query_multivec_encoder empty'): - enc_topics_empty = model(pd.DataFrame(columns=['qid', 'query'])) - self.assertEqual(len(enc_topics_empty), 0) - self.assertTrue('query_toks' in enc_topics_empty.columns) - self.assertTrue('query_embs' in enc_topics_empty.columns) - if test_doc_multivec_encoder: - with self.subTest('doc_multi_encoder'): - enc_docs = model(pd.DataFrame(docs_df)) - self.assertEqual(len(enc_docs), len(docs_df)) - self.assertTrue('toks' in enc_docs.columns) - self.assertTrue('doc_embs' in enc_docs.columns) - self.assertTrue(all(c in enc_docs.columns for c in docs_df.columns)) - self.assertEqual(enc_docs.toks.dtype, object) - self.assertTrue(all(isinstance(v, dict) for v in enc_docs.toks)) - self.assertEqual(enc_docs.doc_embs.dtype, object) - self.assertTrue(all(v.dtype == np.float32 for v in enc_docs.doc_embs)) - with self.subTest('doc_multi_encoder empty'): - enc_docs_empty = model(pd.DataFrame(columns=['docno', 'text'])) - self.assertEqual(len(enc_docs_empty), 0) - self.assertTrue('toks' in enc_docs_empty.columns) - self.assertTrue('doc_embs' in enc_docs_empty.columns) - def test_tct(self): from pyterrier_dr import TctColBert self._base_test(TctColBert()) @@ -194,6 +153,47 @@ def test_bgem3(self): self._test_bgem3_multi(bgem3.query_multi_encoder(), test_query_multivec_encoder=True) self._test_bgem3_multi(bgem3.doc_multi_encoder(), test_doc_multivec_encoder=True) + @unittest.skipUnless(IS_FLAG_EMBEDDING_AVAILABLE, "FlagEmbedding package is not available") + def _test_bgem3_multi(self, model, test_query_multivec_encoder=False, test_doc_multivec_encoder=False): + dataset = pt.get_dataset('irds:vaswani') + + docs = list(itertools.islice(pt.get_dataset('irds:vaswani').get_corpus_iter(), 200)) + docs_df = pd.DataFrame(docs) + + if test_query_multivec_encoder: + with self.subTest('query_multivec_encoder'): + topics = dataset.get_topics() + enc_topics = model(topics) + self.assertEqual(len(enc_topics), len(topics)) + self.assertTrue('query_toks' in enc_topics.columns) + self.assertTrue('query_embs' in enc_topics.columns) + self.assertTrue(all(c in enc_topics.columns for c in topics.columns)) + self.assertEqual(enc_topics.query_toks.dtype, object) + self.assertTrue(all(isinstance(v, dict) for v in enc_topics.query_toks)) + self.assertEqual(enc_topics.query_embs.dtype, object) + self.assertTrue(all(v.dtype == np.float32 for v in enc_topics.query_embs)) + with self.subTest('query_multivec_encoder empty'): + enc_topics_empty = model(pd.DataFrame(columns=['qid', 'query'])) + self.assertEqual(len(enc_topics_empty), 0) + self.assertTrue('query_toks' in enc_topics_empty.columns) + self.assertTrue('query_embs' in enc_topics_empty.columns) + if test_doc_multivec_encoder: + with self.subTest('doc_multi_encoder'): + enc_docs = model(pd.DataFrame(docs_df)) + self.assertEqual(len(enc_docs), len(docs_df)) + self.assertTrue('toks' in enc_docs.columns) + self.assertTrue('doc_embs' in enc_docs.columns) + self.assertTrue(all(c in enc_docs.columns for c in docs_df.columns)) + self.assertEqual(enc_docs.toks.dtype, object) + self.assertTrue(all(isinstance(v, dict) for v in enc_docs.toks)) + self.assertEqual(enc_docs.doc_embs.dtype, object) + self.assertTrue(all(v.dtype == np.float32 for v in enc_docs.doc_embs)) + with self.subTest('doc_multi_encoder empty'): + enc_docs_empty = model(pd.DataFrame(columns=['docno', 'text'])) + self.assertEqual(len(enc_docs_empty), 0) + self.assertTrue('toks' in enc_docs_empty.columns) + self.assertTrue('doc_embs' in enc_docs_empty.columns) + if __name__ == '__main__': unittest.main() From 01f64fac283c9cc0ca73ca22869c2f5c5e496a7f Mon Sep 17 00:00:00 2001 From: Craig Macdonald Date: Sat, 4 Oct 2025 19:00:53 +0100 Subject: [PATCH 4/4] add df to GHA --- .github/workflows/test.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 57f5af0..d58cfc8 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -39,6 +39,7 @@ jobs: - name: Install Dependencies run: | + df -h pip install --upgrade -r requirements.txt -r requirements-dev.txt pip install -e .[bgem3]