diff --git a/pyterrier_dr/cde.py b/pyterrier_dr/cde.py index f421300..e9e7e94 100644 --- a/pyterrier_dr/cde.py +++ b/pyterrier_dr/cde.py @@ -10,7 +10,6 @@ from tqdm import tqdm import pyterrier_alpha as pta - class CDE(BiEncoder): def __init__(self, model_name='jxm/cde-small-v1', cache: Optional['CDECache'] = None, batch_size=32, text_field='text', verbose=False, device=None): super().__init__(batch_size=batch_size, text_field=text_field, verbose=verbose) @@ -22,20 +21,22 @@ def __init__(self, model_name='jxm/cde-small-v1', cache: Optional['CDECache'] = self.model = SentenceTransformer(model_name, trust_remote_code=True).to(self.device).eval() self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) self.cache = cache or CDECache(cde=self) + self.verbose = verbose def encode_context(self, texts: List[str], batch_size=None): show_progress = False if isinstance(texts, tqdm): texts.disable = True show_progress = True + elif self.verbose: + show_progress = True texts = list(texts) if len(texts) == 0: return np.empty(shape=(0, 0)) return self.model.encode( - texts, - prompt_name="document", + ["search_document: " + t for t in texts], batch_size=batch_size or self.batch_size, - show_progress=show_progress, + show_progress_bar=show_progress, ) def encode_queries(self, texts: List[str], batch_size=None): @@ -43,14 +44,16 @@ def encode_queries(self, texts: List[str], batch_size=None): if isinstance(texts, tqdm): texts.disable = True show_progress = True + elif self.verbose: + show_progress = True texts = list(texts) if len(texts) == 0: return np.empty(shape=(0, 0)) result = self.model.encode( - texts, - prompt_name='query', + ["search_query: " + t for t in texts], dataset_embeddings=self.cache.context(), - show_progress=show_progress, + show_progress_bar=show_progress, + batch_size=batch_size or self.batch_size, ) # sentence transformers doesn't norm? result = result / np.linalg.norm(result, ord=2, axis=1, keepdims=True) @@ -61,14 +64,19 @@ def encode_docs(self, texts: List[str], batch_size=None): if isinstance(texts, tqdm): texts.disable = True show_progress = True + elif self.verbose: + show_progress = True texts = list(texts) if len(texts) == 0: return np.empty(shape=(0, 0)) + + # print("texts", texts, type(texts)) + # print("self.cache.context()", self.cache.context(), type(self.cache.context())) result = self.model.encode( - texts, - prompt_name='document', + ["search_documents: " + t for t in texts], dataset_embeddings=self.cache.context(), - show_progress=show_progress, + show_progress_bar=show_progress, + batch_size=batch_size or self.batch_size, ) # sentence transformers doesn't norm? result = result / np.linalg.norm(result, ord=2, axis=1, keepdims=True) diff --git a/pyterrier_dr/flex/np_retr.py b/pyterrier_dr/flex/np_retr.py index c6b5942..576552c 100644 --- a/pyterrier_dr/flex/np_retr.py +++ b/pyterrier_dr/flex/np_retr.py @@ -8,49 +8,74 @@ import pyterrier_alpha as pta class NumpyRetriever(pt.Transformer): - def __init__(self, + def __init__( + self, flex_index: FlexIndex, *, num_results: int = 1000, batch_size: Optional[int] = None, - drop_query_vec: bool = False + mask: Optional[np.ndarray] = None, + drop_query_vec: bool = False, ): self.flex_index = flex_index self.num_results = num_results self.batch_size = batch_size or 4096 + self.mask = mask self.drop_query_vec = drop_query_vec def fuse_rank_cutoff(self, k): if k < self.num_results: - return NumpyRetriever(self.flex_index, num_results=k, batch_size=self.batch_size, drop_query_vec=self.drop_query_vec) + return NumpyRetriever( + self.flex_index, + num_results=k, + batch_size=self.batch_size, + mask=self.mask, + drop_query_vec=self.drop_query_vec, + ) def transform(self, inp: pd.DataFrame) -> pd.DataFrame: pta.validate.query_frame(inp, extra_columns=['query_vec']) + if not len(inp): result = pta.DataFrameBuilder(['docno', 'docid', 'score', 'rank']) if self.drop_query_vec: inp = inp.drop(columns='query_vec') return result.to_df(inp) + inp = inp.reset_index(drop=True) query_vecs = np.stack(inp['query_vec']) + docnos, dvecs, config = self.flex_index.payload() + if self.flex_index.sim_fn == SimFn.cos: query_vecs = query_vecs / np.linalg.norm(query_vecs, axis=1, keepdims=True) - elif self.flex_index.sim_fn == SimFn.dot: - pass # nothing to do - else: + elif self.flex_index.sim_fn != SimFn.dot: raise ValueError(f'{self.flex_index.sim_fn} not supported') + num_q = query_vecs.shape[0] ranked_lists = RankedLists(self.num_results, num_q) + batch_it = range(0, dvecs.shape[0], self.batch_size) if self.flex_index.verbose: batch_it = pt.tqdm(batch_it, desc='NumpyRetriever scoring', unit='docbatch') + for idx_start in batch_it: doc_batch = dvecs[idx_start:idx_start+self.batch_size].T + if self.flex_index.sim_fn == SimFn.cos: doc_batch = doc_batch / np.linalg.norm(doc_batch, axis=0, keepdims=True) + scores = query_vecs @ doc_batch - dids = np.arange(idx_start, idx_start+doc_batch.shape[1], dtype='i4').reshape(1, -1).repeat(num_q, axis=0) + + if self.mask is not None: + scores *= self.mask[idx_start:idx_start+doc_batch.shape[1]].reshape(1, -1) + + dids = np.arange( + idx_start, + idx_start + doc_batch.shape[1], + dtype='i4' + ).reshape(1, -1).repeat(num_q, axis=0) + ranked_lists.update(scores, dids) result = pta.DataFrameBuilder(['docno', 'docid', 'score', 'rank']) @@ -64,6 +89,7 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame: if self.drop_query_vec: inp = inp.drop(columns='query_vec') + return result.to_df(inp) @@ -136,7 +162,8 @@ def _np_vecs(self) -> np.ndarray: return dvecs FlexIndex.np_vecs = _np_vecs -def _np_retriever(self, *, num_results: int = 1000, batch_size: Optional[int] = None, drop_query_vec: bool = False): + +def _np_retriever(self, *, num_results: int = 1000, batch_size: Optional[int] = None, drop_query_vec: bool = False, mask: Optional[np.ndarray] = None) -> pt.Transformer: """Return a retriever that uses numpy to perform a brute force search over the index. The returned transformer expects a DataFrame with columns ``qid`` and ``query_vec``. It outpus @@ -150,7 +177,8 @@ def _np_retriever(self, *, num_results: int = 1000, batch_size: Optional[int] = Returns: :class:`~pyterrier.Transformer`: A retriever that uses numpy to perform a brute force search. """ - return NumpyRetriever(self, num_results=num_results, batch_size=batch_size, drop_query_vec=drop_query_vec) + return NumpyRetriever(self, num_results=num_results, batch_size=batch_size, drop_query_vec=drop_query_vec, mask=mask) + FlexIndex.np_retriever = _np_retriever FlexIndex.retriever = _np_retriever # default retriever @@ -194,6 +222,11 @@ def _np_scorer(self, *, num_results: Optional[int] = None) -> pt.Transformer: Args: num_results: The number of results to return per query. If not provided, all resuls from the original fram are returned. + mask: Optional sequence of per-document weights. + If provided, scores for each document are multiplied by the corresponding + mask value. This can be used to filter or downweight documents during + retrieval. + Returns: :class:`~pyterrier.Transformer`: A transformer that scores query vectors with numpy. diff --git a/pyterrier_dr/flex/torch_retr.py b/pyterrier_dr/flex/torch_retr.py index 674e7fb..2d50ce5 100644 --- a/pyterrier_dr/flex/torch_retr.py +++ b/pyterrier_dr/flex/torch_retr.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Sequence import numpy as np import torch import pyterrier_alpha as pta @@ -34,91 +34,138 @@ def score(self, query_vecs, docids): # transform inherited from NumpyScorer - class TorchRetriever(pt.Transformer): - def __init__(self, + def __init__( + self, flex_index: FlexIndex, torch_vecs: torch.Tensor, *, num_results: int = 1000, qbatch: int = 64, - drop_query_vec: bool = False + index_select: Optional[np.ndarray] = None, + drop_query_vec: bool = False, ): self.flex_index = flex_index self.torch_vecs = torch_vecs - self.num_results = num_results or 1000 - self.docnos, meta = flex_index.payload(return_dvecs=False) + self.num_results = num_results self.qbatch = qbatch self.drop_query_vec = drop_query_vec + self.docnos, _ = flex_index.payload(return_dvecs=False) + + self.index_select = None + if index_select is not None: + self.index_select = torch.as_tensor( + index_select, + dtype=torch.long, + device=torch_vecs.device + ) + def fuse_rank_cutoff(self, k): if k < self.num_results: return TorchRetriever( - self.flex_index, - self.torch_vecs, - num_results=k, - qbatch=self.qbatch, - drop_query_vec=self.drop_query_vec) + self.flex_index, + self.torch_vecs, + num_results=k, + qbatch=self.qbatch, + index_select=self.index_select, + drop_query_vec=self.drop_query_vec, + ) def transform(self, inp): pta.validate.query_frame(inp, extra_columns=['query_vec']) inp = inp.reset_index(drop=True) - query_vecs = np.stack(inp['query_vec']) - query_vecs = torch.from_numpy(query_vecs).to(self.torch_vecs) + + query_vecs = torch.from_numpy( + np.stack(inp['query_vec']) + ).to(self.torch_vecs) + + tv = ( + self.torch_vecs[self.index_select].T + if self.index_select is not None + else self.torch_vecs.T + ) + + result = pta.DataFrameBuilder(['docno', 'docid', 'score', 'rank']) it = range(0, query_vecs.shape[0], self.qbatch) if self.flex_index.verbose: it = pt.tqdm(it, desc='TorchRetriever', unit='qbatch') - result = pta.DataFrameBuilder(['docno', 'docid', 'score', 'rank']) - for start_idx in it: - end_idx = start_idx + self.qbatch - batch = query_vecs[start_idx:end_idx] - if self.flex_index.sim_fn == SimFn.dot: - scores = batch @ self.torch_vecs.T - else: + for start in it: + batch = query_vecs[start:start+self.qbatch] + + if self.flex_index.sim_fn != SimFn.dot: raise ValueError(f'{self.flex_index.sim_fn} not supported') + + scores = batch @ tv + if scores.shape[1] > self.num_results: scores, docids = scores.topk(self.num_results, dim=1) else: docids = scores.argsort(descending=True, dim=1) - scores = torch.gather(scores, dim=1, index=docids) - for s, d in zip(scores.cpu().numpy(), docids.cpu().numpy()): + scores = torch.gather(scores, 1, docids) + + scores = scores.cpu().numpy() + docids = docids.cpu().numpy() + + if self.index_select is not None: + docids = self.index_select.cpu().numpy()[docids] + + for s, d in zip(scores, docids): result.extend({ 'docno': self.docnos[d], 'docid': d, 'score': s, - 'rank': np.arange(s.shape[0]), + 'rank': np.arange(len(s)), }) if self.drop_query_vec: inp = inp.drop(columns='query_vec') + return result.to_df(inp) -def _torch_vecs(self, *, device: Optional[str] = None, fp16: bool = False) -> torch.Tensor: + +def _torch_vecs( + self, + *, + device: Optional[str] = None, + fp16: bool = False +) -> torch.Tensor: """Return the indexed vectors as a pytorch tensor. .. caution:: - This method loads the entire index into memory on the provided device. If the index is too large to fit in memory, - consider using a different method that does not fully load the index into memory, like :meth:`np_vecs` or - :meth:`get_corpus_iter`. + This method loads the entire index into memory on the provided device. + If the index is too large to fit in memory, consider using :meth:`np_vecs` + or :meth:`get_corpus_iter`. Args: - device: The device to use for the tensor. If not provided, the default device is used (cuda if available, otherwise cpu). - fp16: Whether to use half precision (fp16) for the tensor. + device: The device to use for the tensor. If not provided, the default + device is used (cuda if available, otherwise cpu). + fp16: Whether to use half precision (fp16). Returns: - :class:`torch.Tensor`: The indexed vectors as a torch tensor. + :class:`torch.Tensor`: The indexed vectors. """ device = infer_device(device) key = ('torch_vecs', device, fp16) + if key not in self._cache: - dvecs, meta = self.payload(return_docnos=False) - data = torch.frombuffer(dvecs, dtype=torch.float32).reshape(*dvecs.shape) + # Load numpy-backed vectors (memory-mapped) + dvecs, _ = self.payload(return_docnos=False) + + # Important: frombuffer avoids an extra copy + data = torch.frombuffer( + dvecs, + dtype=torch.float32 + ).reshape(*dvecs.shape) + if fp16: data = data.half() + self._cache[key] = data.to(device) + return self._cache[key] FlexIndex.torch_vecs = _torch_vecs @@ -152,6 +199,7 @@ def _torch_retriever(self, fp16: bool = False, qbatch: int = 64, drop_query_vec: bool = False + index_select: Optional[np.ndarray] = None, ): """Return a retriever that uses pytorch to perform brute-force retrieval results using the indexed vectors. @@ -167,9 +215,13 @@ def _torch_retriever(self, fp16: Whether to use half precision (fp16) for scoring. qbatch: The number of queries to score in each batch. drop_query_vec: Whether to drop the query vector from the output. + index_select: Optional list or array of document ids to restrict retrieval to. + If provided, retrieval is performed only over this subset of the index, + which is internally converted to a torch tensor on the target device. Returns: :class:`~pyterrier.Transformer`: A transformer that retrieves using pytorch. """ - return TorchRetriever(self, self.torch_vecs(device=device, fp16=fp16), num_results=num_results, qbatch=qbatch, drop_query_vec=drop_query_vec) + return TorchRetriever(self, self.torch_vecs(device=device, fp16=fp16), num_results=num_results, qbatch=qbatch, drop_query_vec=drop_query_vec, index_select=index_select) + FlexIndex.torch_retriever = _torch_retriever