Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyterrier_dr/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
from pyterrier_dr.prf import AveragePrf, VectorPrf
from pyterrier_dr._ils import ILS, ils
from pyterrier_dr._mmr import MmrScorer
from pyterrier_dr.jina import JinaEmbedder

__all__ = ["FlexIndex", "DocnoFile", "NilIndex", "NumpyIndex", "RankedLists", "FaissFlat", "FaissHnsw", "MemIndex", "TorchIndex",
"BiEncoder", "BiQueryEncoder", "BiDocEncoder", "BiScorer", "HgfBiEncoder", "TasB", "RetroMAE", "SBertBiEncoder", "Ance",
"Query2Query", "GTR", "E5", "TctColBert", "ElectraScorer", "LightningIRMonoScorer", "BGEM3", "BGEM3QueryEncoder", "BGEM3DocEncoder", "CDE", "CDECache",
"SimFn", "infer_device", "AveragePrf", "VectorPrf", "ILS", "ils", "MmrScorer"]
"SimFn", "infer_device", "AveragePrf", "VectorPrf", "ILS", "ils", "MmrScorer", "JinaEmbedder"]
53 changes: 53 additions & 0 deletions pyterrier_dr/jina.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoConfig
from .biencoder import BiEncoder
from tqdm import tqdm

class JinaEmbedder(BiEncoder):
def __init__(self, model_name='jinaai/jina-embeddings-v4', batch_size=32, text_field='text', verbose=False, device=None):
super().__init__(batch_size=batch_size, text_field=text_field, verbose=verbose)
self.model_name = model_name
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = torch.device(device)

self.model = SentenceTransformer(model_name, trust_remote_code=True).to(self.device).eval()
self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)

def encode_queries(self, texts, batch_size=None, prompt="query"):
show_progress = False
if isinstance(texts, tqdm):
texts.disable = True
show_progress = True
texts = list(texts)

if len(texts) == 0:
return np.empty(shape=(0, 0))

return self.model.encode(sentences=texts,
batch_size=batch_size or self.batch_size,
show_progress_bar=show_progress,
task="retrieval",
prompt_name=prompt
)

def encode_docs(self, texts, batch_size=None, prompt="passage"):
show_progress = False
if isinstance(texts, tqdm):
texts.disable = True
show_progress = True
texts = list(texts)

if len(texts) == 0:
return np.empty(shape=(0, 0))

return self.model.encode(sentences=texts,
batch_size=batch_size or self.batch_size,
show_progress_bar=show_progress,
task="retrieval",
prompt_name=prompt
)
def __repr__(self):
return f'JinaEmbedder({repr(self.model_name)})'
6 changes: 6 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,12 @@ def test_bgem3(self):
self._test_bgem3_multi(bgem3.query_multi_encoder(), test_query_multivec_encoder=True)
self._test_bgem3_multi(bgem3.doc_multi_encoder(), test_doc_multivec_encoder=True)

def test_jina_embedder(self):
from pyterrier_dr import JinaEmbedder

self._base_test(JinaEmbedder(), test_doc_encoder=False, test_scorer=False, test_indexer=False, test_retriever=False)
self._base_test(JinaEmbedder(), test_query_encoder=False, test_scorer=False, test_indexer=False, test_retriever=False)

@unittest.skipIf(not LIGHTNING_IR_AVAILIBLE, "lightning_ir is not installed")
def test_lightning_ir_mono_electra(self):
from pyterrier_dr import LightningIRMonoScorer
Expand Down
Loading