-
Notifications
You must be signed in to change notification settings - Fork 12
CDE utilization fix #44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
272d12f
9f1636a
183ed46
1f192ba
236e01a
e341c1d
f34d6ba
3d4c263
8eb5393
fe7b0f2
834fc67
87c6e49
44c04fd
3c01686
1e61e9c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,7 +10,6 @@ | |
| from tqdm import tqdm | ||
| import pyterrier_alpha as pta | ||
|
|
||
|
|
||
| class CDE(BiEncoder): | ||
| def __init__(self, model_name='jxm/cde-small-v1', cache: Optional['CDECache'] = None, batch_size=32, text_field='text', verbose=False, device=None): | ||
| super().__init__(batch_size=batch_size, text_field=text_field, verbose=verbose) | ||
|
|
@@ -22,35 +21,39 @@ def __init__(self, model_name='jxm/cde-small-v1', cache: Optional['CDECache'] = | |
| self.model = SentenceTransformer(model_name, trust_remote_code=True).to(self.device).eval() | ||
| self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) | ||
| self.cache = cache or CDECache(cde=self) | ||
| self.verbose = verbose | ||
|
|
||
| def encode_context(self, texts: List[str], batch_size=None): | ||
| show_progress = False | ||
| if isinstance(texts, tqdm): | ||
| texts.disable = True | ||
| show_progress = True | ||
| elif self.verbose: | ||
| show_progress = True | ||
| texts = list(texts) | ||
| if len(texts) == 0: | ||
| return np.empty(shape=(0, 0)) | ||
| return self.model.encode( | ||
| texts, | ||
| prompt_name="document", | ||
| ["search_document: " + t for t in texts], | ||
| batch_size=batch_size or self.batch_size, | ||
| show_progress=show_progress, | ||
| show_progress_bar=show_progress, | ||
| ) | ||
|
|
||
| def encode_queries(self, texts: List[str], batch_size=None): | ||
| show_progress = False | ||
| if isinstance(texts, tqdm): | ||
| texts.disable = True | ||
| show_progress = True | ||
| elif self.verbose: | ||
| show_progress = True | ||
| texts = list(texts) | ||
| if len(texts) == 0: | ||
| return np.empty(shape=(0, 0)) | ||
| result = self.model.encode( | ||
| texts, | ||
| prompt_name='query', | ||
| ["search_query: " + t for t in texts], | ||
| dataset_embeddings=self.cache.context(), | ||
| show_progress=show_progress, | ||
| show_progress_bar=show_progress, | ||
| batch_size=batch_size or self.batch_size, | ||
| ) | ||
| # sentence transformers doesn't norm? | ||
| result = result / np.linalg.norm(result, ord=2, axis=1, keepdims=True) | ||
|
|
@@ -61,14 +64,19 @@ def encode_docs(self, texts: List[str], batch_size=None): | |
| if isinstance(texts, tqdm): | ||
| texts.disable = True | ||
| show_progress = True | ||
| elif self.verbose: | ||
| show_progress = True | ||
| texts = list(texts) | ||
| if len(texts) == 0: | ||
| return np.empty(shape=(0, 0)) | ||
|
|
||
| # print("texts", texts, type(texts)) | ||
| # print("self.cache.context()", self.cache.context(), type(self.cache.context())) | ||
| result = self.model.encode( | ||
| texts, | ||
| prompt_name='document', | ||
| ["search_documents: " + t for t in texts], | ||
| dataset_embeddings=self.cache.context(), | ||
| show_progress=show_progress, | ||
| show_progress_bar=show_progress, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why show_progress -> show_progress_bar |
||
| batch_size=batch_size or self.batch_size, | ||
| ) | ||
| # sentence transformers doesn't norm? | ||
| result = result / np.linalg.norm(result, ord=2, axis=1, keepdims=True) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| from typing import Optional | ||
| from typing import Optional, Sequence | ||
| import numpy as np | ||
| import torch | ||
| import pyterrier_alpha as pta | ||
|
|
@@ -34,91 +34,138 @@ def score(self, query_vecs, docids): | |
|
|
||
| # transform inherited from NumpyScorer | ||
|
|
||
|
|
||
| class TorchRetriever(pt.Transformer): | ||
| def __init__(self, | ||
| def __init__( | ||
| self, | ||
| flex_index: FlexIndex, | ||
| torch_vecs: torch.Tensor, | ||
| *, | ||
| num_results: int = 1000, | ||
| qbatch: int = 64, | ||
| drop_query_vec: bool = False | ||
| index_select: Optional[np.ndarray] = None, | ||
| drop_query_vec: bool = False, | ||
| ): | ||
| self.flex_index = flex_index | ||
| self.torch_vecs = torch_vecs | ||
| self.num_results = num_results or 1000 | ||
| self.docnos, meta = flex_index.payload(return_dvecs=False) | ||
| self.num_results = num_results | ||
| self.qbatch = qbatch | ||
| self.drop_query_vec = drop_query_vec | ||
|
|
||
| self.docnos, _ = flex_index.payload(return_dvecs=False) | ||
|
|
||
| self.index_select = None | ||
| if index_select is not None: | ||
| self.index_select = torch.as_tensor( | ||
| index_select, | ||
| dtype=torch.long, | ||
| device=torch_vecs.device | ||
| ) | ||
|
|
||
| def fuse_rank_cutoff(self, k): | ||
| if k < self.num_results: | ||
| return TorchRetriever( | ||
| self.flex_index, | ||
| self.torch_vecs, | ||
| num_results=k, | ||
| qbatch=self.qbatch, | ||
| drop_query_vec=self.drop_query_vec) | ||
| self.flex_index, | ||
| self.torch_vecs, | ||
| num_results=k, | ||
| qbatch=self.qbatch, | ||
| index_select=self.index_select, | ||
| drop_query_vec=self.drop_query_vec, | ||
| ) | ||
|
|
||
| def transform(self, inp): | ||
| pta.validate.query_frame(inp, extra_columns=['query_vec']) | ||
| inp = inp.reset_index(drop=True) | ||
| query_vecs = np.stack(inp['query_vec']) | ||
| query_vecs = torch.from_numpy(query_vecs).to(self.torch_vecs) | ||
|
|
||
| query_vecs = torch.from_numpy( | ||
| np.stack(inp['query_vec']) | ||
| ).to(self.torch_vecs) | ||
|
|
||
| tv = ( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think tv needs an explanatory comment |
||
| self.torch_vecs[self.index_select].T | ||
| if self.index_select is not None | ||
| else self.torch_vecs.T | ||
| ) | ||
|
|
||
| result = pta.DataFrameBuilder(['docno', 'docid', 'score', 'rank']) | ||
|
|
||
| it = range(0, query_vecs.shape[0], self.qbatch) | ||
| if self.flex_index.verbose: | ||
| it = pt.tqdm(it, desc='TorchRetriever', unit='qbatch') | ||
|
|
||
| result = pta.DataFrameBuilder(['docno', 'docid', 'score', 'rank']) | ||
| for start_idx in it: | ||
| end_idx = start_idx + self.qbatch | ||
| batch = query_vecs[start_idx:end_idx] | ||
| if self.flex_index.sim_fn == SimFn.dot: | ||
| scores = batch @ self.torch_vecs.T | ||
| else: | ||
| for start in it: | ||
| batch = query_vecs[start:start+self.qbatch] | ||
|
|
||
| if self.flex_index.sim_fn != SimFn.dot: | ||
| raise ValueError(f'{self.flex_index.sim_fn} not supported') | ||
|
|
||
| scores = batch @ tv | ||
|
|
||
| if scores.shape[1] > self.num_results: | ||
| scores, docids = scores.topk(self.num_results, dim=1) | ||
| else: | ||
| docids = scores.argsort(descending=True, dim=1) | ||
| scores = torch.gather(scores, dim=1, index=docids) | ||
| for s, d in zip(scores.cpu().numpy(), docids.cpu().numpy()): | ||
| scores = torch.gather(scores, 1, docids) | ||
|
|
||
| scores = scores.cpu().numpy() | ||
| docids = docids.cpu().numpy() | ||
|
|
||
| if self.index_select is not None: | ||
| docids = self.index_select.cpu().numpy()[docids] | ||
|
|
||
| for s, d in zip(scores, docids): | ||
| result.extend({ | ||
| 'docno': self.docnos[d], | ||
| 'docid': d, | ||
| 'score': s, | ||
| 'rank': np.arange(s.shape[0]), | ||
| 'rank': np.arange(len(s)), | ||
| }) | ||
|
|
||
| if self.drop_query_vec: | ||
| inp = inp.drop(columns='query_vec') | ||
|
|
||
| return result.to_df(inp) | ||
|
|
||
|
|
||
| def _torch_vecs(self, *, device: Optional[str] = None, fp16: bool = False) -> torch.Tensor: | ||
|
|
||
| def _torch_vecs( | ||
| self, | ||
| *, | ||
| device: Optional[str] = None, | ||
| fp16: bool = False | ||
| ) -> torch.Tensor: | ||
| """Return the indexed vectors as a pytorch tensor. | ||
|
|
||
| .. caution:: | ||
| This method loads the entire index into memory on the provided device. If the index is too large to fit in memory, | ||
| consider using a different method that does not fully load the index into memory, like :meth:`np_vecs` or | ||
| :meth:`get_corpus_iter`. | ||
| This method loads the entire index into memory on the provided device. | ||
| If the index is too large to fit in memory, consider using :meth:`np_vecs` | ||
| or :meth:`get_corpus_iter`. | ||
|
|
||
| Args: | ||
| device: The device to use for the tensor. If not provided, the default device is used (cuda if available, otherwise cpu). | ||
| fp16: Whether to use half precision (fp16) for the tensor. | ||
| device: The device to use for the tensor. If not provided, the default | ||
| device is used (cuda if available, otherwise cpu). | ||
| fp16: Whether to use half precision (fp16). | ||
|
|
||
| Returns: | ||
| :class:`torch.Tensor`: The indexed vectors as a torch tensor. | ||
| :class:`torch.Tensor`: The indexed vectors. | ||
| """ | ||
| device = infer_device(device) | ||
| key = ('torch_vecs', device, fp16) | ||
|
|
||
| if key not in self._cache: | ||
| dvecs, meta = self.payload(return_docnos=False) | ||
| data = torch.frombuffer(dvecs, dtype=torch.float32).reshape(*dvecs.shape) | ||
| # Load numpy-backed vectors (memory-mapped) | ||
| dvecs, _ = self.payload(return_docnos=False) | ||
|
|
||
| # Important: frombuffer avoids an extra copy | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isnt this just formatting? whats the change here? |
||
| data = torch.frombuffer( | ||
| dvecs, | ||
| dtype=torch.float32 | ||
| ).reshape(*dvecs.shape) | ||
|
|
||
| if fp16: | ||
| data = data.half() | ||
|
|
||
| self._cache[key] = data.to(device) | ||
|
|
||
| return self._cache[key] | ||
| FlexIndex.torch_vecs = _torch_vecs | ||
|
|
||
|
|
@@ -152,6 +199,7 @@ def _torch_retriever(self, | |
| fp16: bool = False, | ||
| qbatch: int = 64, | ||
| drop_query_vec: bool = False | ||
| index_select: Optional[np.ndarray] = None, | ||
| ): | ||
| """Return a retriever that uses pytorch to perform brute-force retrieval results using the indexed vectors. | ||
|
|
||
|
|
@@ -167,9 +215,13 @@ def _torch_retriever(self, | |
| fp16: Whether to use half precision (fp16) for scoring. | ||
| qbatch: The number of queries to score in each batch. | ||
| drop_query_vec: Whether to drop the query vector from the output. | ||
| index_select: Optional list or array of document ids to restrict retrieval to. | ||
| If provided, retrieval is performed only over this subset of the index, | ||
| which is internally converted to a torch tensor on the target device. | ||
|
|
||
| Returns: | ||
| :class:`~pyterrier.Transformer`: A transformer that retrieves using pytorch. | ||
| """ | ||
| return TorchRetriever(self, self.torch_vecs(device=device, fp16=fp16), num_results=num_results, qbatch=qbatch, drop_query_vec=drop_query_vec) | ||
| return TorchRetriever(self, self.torch_vecs(device=device, fp16=fp16), num_results=num_results, qbatch=qbatch, drop_query_vec=drop_query_vec, index_select=index_select) | ||
|
|
||
| FlexIndex.torch_retriever = _torch_retriever | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dead code