Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions pyterrier_dr/cde.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from tqdm import tqdm
import pyterrier_alpha as pta


class CDE(BiEncoder):
def __init__(self, model_name='jxm/cde-small-v1', cache: Optional['CDECache'] = None, batch_size=32, text_field='text', verbose=False, device=None):
super().__init__(batch_size=batch_size, text_field=text_field, verbose=verbose)
Expand All @@ -22,35 +21,39 @@ def __init__(self, model_name='jxm/cde-small-v1', cache: Optional['CDECache'] =
self.model = SentenceTransformer(model_name, trust_remote_code=True).to(self.device).eval()
self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
self.cache = cache or CDECache(cde=self)
self.verbose = verbose

def encode_context(self, texts: List[str], batch_size=None):
show_progress = False
if isinstance(texts, tqdm):
texts.disable = True
show_progress = True
elif self.verbose:
show_progress = True
texts = list(texts)
if len(texts) == 0:
return np.empty(shape=(0, 0))
return self.model.encode(
texts,
prompt_name="document",
["search_document: " + t for t in texts],
batch_size=batch_size or self.batch_size,
show_progress=show_progress,
show_progress_bar=show_progress,
)

def encode_queries(self, texts: List[str], batch_size=None):
show_progress = False
if isinstance(texts, tqdm):
texts.disable = True
show_progress = True
elif self.verbose:
show_progress = True
texts = list(texts)
if len(texts) == 0:
return np.empty(shape=(0, 0))
result = self.model.encode(
texts,
prompt_name='query',
["search_query: " + t for t in texts],
dataset_embeddings=self.cache.context(),
show_progress=show_progress,
show_progress_bar=show_progress,
batch_size=batch_size or self.batch_size,
)
# sentence transformers doesn't norm?
result = result / np.linalg.norm(result, ord=2, axis=1, keepdims=True)
Expand All @@ -61,14 +64,19 @@ def encode_docs(self, texts: List[str], batch_size=None):
if isinstance(texts, tqdm):
texts.disable = True
show_progress = True
elif self.verbose:
show_progress = True
texts = list(texts)
if len(texts) == 0:
return np.empty(shape=(0, 0))

# print("texts", texts, type(texts))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dead code

# print("self.cache.context()", self.cache.context(), type(self.cache.context()))
result = self.model.encode(
texts,
prompt_name='document',
["search_documents: " + t for t in texts],
dataset_embeddings=self.cache.context(),
show_progress=show_progress,
show_progress_bar=show_progress,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why show_progress -> show_progress_bar

batch_size=batch_size or self.batch_size,
)
# sentence transformers doesn't norm?
result = result / np.linalg.norm(result, ord=2, axis=1, keepdims=True)
Expand Down
51 changes: 42 additions & 9 deletions pyterrier_dr/flex/np_retr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,49 +8,74 @@
import pyterrier_alpha as pta

class NumpyRetriever(pt.Transformer):
def __init__(self,
def __init__(
self,
flex_index: FlexIndex,
*,
num_results: int = 1000,
batch_size: Optional[int] = None,
drop_query_vec: bool = False
mask: Optional[np.ndarray] = None,
drop_query_vec: bool = False,
):
self.flex_index = flex_index
self.num_results = num_results
self.batch_size = batch_size or 4096
self.mask = mask
self.drop_query_vec = drop_query_vec

def fuse_rank_cutoff(self, k):
if k < self.num_results:
return NumpyRetriever(self.flex_index, num_results=k, batch_size=self.batch_size, drop_query_vec=self.drop_query_vec)
return NumpyRetriever(
self.flex_index,
num_results=k,
batch_size=self.batch_size,
mask=self.mask,
drop_query_vec=self.drop_query_vec,
)

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
pta.validate.query_frame(inp, extra_columns=['query_vec'])

if not len(inp):
result = pta.DataFrameBuilder(['docno', 'docid', 'score', 'rank'])
if self.drop_query_vec:
inp = inp.drop(columns='query_vec')
return result.to_df(inp)

inp = inp.reset_index(drop=True)
query_vecs = np.stack(inp['query_vec'])

docnos, dvecs, config = self.flex_index.payload()

if self.flex_index.sim_fn == SimFn.cos:
query_vecs = query_vecs / np.linalg.norm(query_vecs, axis=1, keepdims=True)
elif self.flex_index.sim_fn == SimFn.dot:
pass # nothing to do
else:
elif self.flex_index.sim_fn != SimFn.dot:
raise ValueError(f'{self.flex_index.sim_fn} not supported')

num_q = query_vecs.shape[0]
ranked_lists = RankedLists(self.num_results, num_q)

batch_it = range(0, dvecs.shape[0], self.batch_size)
if self.flex_index.verbose:
batch_it = pt.tqdm(batch_it, desc='NumpyRetriever scoring', unit='docbatch')

for idx_start in batch_it:
doc_batch = dvecs[idx_start:idx_start+self.batch_size].T

if self.flex_index.sim_fn == SimFn.cos:
doc_batch = doc_batch / np.linalg.norm(doc_batch, axis=0, keepdims=True)

scores = query_vecs @ doc_batch
dids = np.arange(idx_start, idx_start+doc_batch.shape[1], dtype='i4').reshape(1, -1).repeat(num_q, axis=0)

if self.mask is not None:
scores *= self.mask[idx_start:idx_start+doc_batch.shape[1]].reshape(1, -1)

dids = np.arange(
idx_start,
idx_start + doc_batch.shape[1],
dtype='i4'
).reshape(1, -1).repeat(num_q, axis=0)

ranked_lists.update(scores, dids)

result = pta.DataFrameBuilder(['docno', 'docid', 'score', 'rank'])
Expand All @@ -64,6 +89,7 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:

if self.drop_query_vec:
inp = inp.drop(columns='query_vec')

return result.to_df(inp)


Expand Down Expand Up @@ -136,7 +162,8 @@ def _np_vecs(self) -> np.ndarray:
return dvecs
FlexIndex.np_vecs = _np_vecs

def _np_retriever(self, *, num_results: int = 1000, batch_size: Optional[int] = None, drop_query_vec: bool = False):

def _np_retriever(self, *, num_results: int = 1000, batch_size: Optional[int] = None, drop_query_vec: bool = False, mask: Optional[np.ndarray] = None) -> pt.Transformer:
"""Return a retriever that uses numpy to perform a brute force search over the index.

The returned transformer expects a DataFrame with columns ``qid`` and ``query_vec``. It outpus
Expand All @@ -150,7 +177,8 @@ def _np_retriever(self, *, num_results: int = 1000, batch_size: Optional[int] =
Returns:
:class:`~pyterrier.Transformer`: A retriever that uses numpy to perform a brute force search.
"""
return NumpyRetriever(self, num_results=num_results, batch_size=batch_size, drop_query_vec=drop_query_vec)
return NumpyRetriever(self, num_results=num_results, batch_size=batch_size, drop_query_vec=drop_query_vec, mask=mask)

FlexIndex.np_retriever = _np_retriever
FlexIndex.retriever = _np_retriever # default retriever

Expand Down Expand Up @@ -194,6 +222,11 @@ def _np_scorer(self, *, num_results: Optional[int] = None) -> pt.Transformer:

Args:
num_results: The number of results to return per query. If not provided, all resuls from the original fram are returned.
mask: Optional sequence of per-document weights.
If provided, scores for each document are multiplied by the corresponding
mask value. This can be used to filter or downweight documents during
retrieval.


Returns:
:class:`~pyterrier.Transformer`: A transformer that scores query vectors with numpy.
Expand Down
118 changes: 85 additions & 33 deletions pyterrier_dr/flex/torch_retr.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Sequence
import numpy as np
import torch
import pyterrier_alpha as pta
Expand Down Expand Up @@ -34,91 +34,138 @@ def score(self, query_vecs, docids):

# transform inherited from NumpyScorer


class TorchRetriever(pt.Transformer):
def __init__(self,
def __init__(
self,
flex_index: FlexIndex,
torch_vecs: torch.Tensor,
*,
num_results: int = 1000,
qbatch: int = 64,
drop_query_vec: bool = False
index_select: Optional[np.ndarray] = None,
drop_query_vec: bool = False,
):
self.flex_index = flex_index
self.torch_vecs = torch_vecs
self.num_results = num_results or 1000
self.docnos, meta = flex_index.payload(return_dvecs=False)
self.num_results = num_results
self.qbatch = qbatch
self.drop_query_vec = drop_query_vec

self.docnos, _ = flex_index.payload(return_dvecs=False)

self.index_select = None
if index_select is not None:
self.index_select = torch.as_tensor(
index_select,
dtype=torch.long,
device=torch_vecs.device
)

def fuse_rank_cutoff(self, k):
if k < self.num_results:
return TorchRetriever(
self.flex_index,
self.torch_vecs,
num_results=k,
qbatch=self.qbatch,
drop_query_vec=self.drop_query_vec)
self.flex_index,
self.torch_vecs,
num_results=k,
qbatch=self.qbatch,
index_select=self.index_select,
drop_query_vec=self.drop_query_vec,
)

def transform(self, inp):
pta.validate.query_frame(inp, extra_columns=['query_vec'])
inp = inp.reset_index(drop=True)
query_vecs = np.stack(inp['query_vec'])
query_vecs = torch.from_numpy(query_vecs).to(self.torch_vecs)

query_vecs = torch.from_numpy(
np.stack(inp['query_vec'])
).to(self.torch_vecs)

tv = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think tv needs an explanatory comment

self.torch_vecs[self.index_select].T
if self.index_select is not None
else self.torch_vecs.T
)

result = pta.DataFrameBuilder(['docno', 'docid', 'score', 'rank'])

it = range(0, query_vecs.shape[0], self.qbatch)
if self.flex_index.verbose:
it = pt.tqdm(it, desc='TorchRetriever', unit='qbatch')

result = pta.DataFrameBuilder(['docno', 'docid', 'score', 'rank'])
for start_idx in it:
end_idx = start_idx + self.qbatch
batch = query_vecs[start_idx:end_idx]
if self.flex_index.sim_fn == SimFn.dot:
scores = batch @ self.torch_vecs.T
else:
for start in it:
batch = query_vecs[start:start+self.qbatch]

if self.flex_index.sim_fn != SimFn.dot:
raise ValueError(f'{self.flex_index.sim_fn} not supported')

scores = batch @ tv

if scores.shape[1] > self.num_results:
scores, docids = scores.topk(self.num_results, dim=1)
else:
docids = scores.argsort(descending=True, dim=1)
scores = torch.gather(scores, dim=1, index=docids)
for s, d in zip(scores.cpu().numpy(), docids.cpu().numpy()):
scores = torch.gather(scores, 1, docids)

scores = scores.cpu().numpy()
docids = docids.cpu().numpy()

if self.index_select is not None:
docids = self.index_select.cpu().numpy()[docids]

for s, d in zip(scores, docids):
result.extend({
'docno': self.docnos[d],
'docid': d,
'score': s,
'rank': np.arange(s.shape[0]),
'rank': np.arange(len(s)),
})

if self.drop_query_vec:
inp = inp.drop(columns='query_vec')

return result.to_df(inp)


def _torch_vecs(self, *, device: Optional[str] = None, fp16: bool = False) -> torch.Tensor:

def _torch_vecs(
self,
*,
device: Optional[str] = None,
fp16: bool = False
) -> torch.Tensor:
"""Return the indexed vectors as a pytorch tensor.

.. caution::
This method loads the entire index into memory on the provided device. If the index is too large to fit in memory,
consider using a different method that does not fully load the index into memory, like :meth:`np_vecs` or
:meth:`get_corpus_iter`.
This method loads the entire index into memory on the provided device.
If the index is too large to fit in memory, consider using :meth:`np_vecs`
or :meth:`get_corpus_iter`.

Args:
device: The device to use for the tensor. If not provided, the default device is used (cuda if available, otherwise cpu).
fp16: Whether to use half precision (fp16) for the tensor.
device: The device to use for the tensor. If not provided, the default
device is used (cuda if available, otherwise cpu).
fp16: Whether to use half precision (fp16).

Returns:
:class:`torch.Tensor`: The indexed vectors as a torch tensor.
:class:`torch.Tensor`: The indexed vectors.
"""
device = infer_device(device)
key = ('torch_vecs', device, fp16)

if key not in self._cache:
dvecs, meta = self.payload(return_docnos=False)
data = torch.frombuffer(dvecs, dtype=torch.float32).reshape(*dvecs.shape)
# Load numpy-backed vectors (memory-mapped)
dvecs, _ = self.payload(return_docnos=False)

# Important: frombuffer avoids an extra copy
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isnt this just formatting? whats the change here?

data = torch.frombuffer(
dvecs,
dtype=torch.float32
).reshape(*dvecs.shape)

if fp16:
data = data.half()

self._cache[key] = data.to(device)

return self._cache[key]
FlexIndex.torch_vecs = _torch_vecs

Expand Down Expand Up @@ -152,6 +199,7 @@ def _torch_retriever(self,
fp16: bool = False,
qbatch: int = 64,
drop_query_vec: bool = False
index_select: Optional[np.ndarray] = None,
):
"""Return a retriever that uses pytorch to perform brute-force retrieval results using the indexed vectors.

Expand All @@ -167,9 +215,13 @@ def _torch_retriever(self,
fp16: Whether to use half precision (fp16) for scoring.
qbatch: The number of queries to score in each batch.
drop_query_vec: Whether to drop the query vector from the output.
index_select: Optional list or array of document ids to restrict retrieval to.
If provided, retrieval is performed only over this subset of the index,
which is internally converted to a torch tensor on the target device.

Returns:
:class:`~pyterrier.Transformer`: A transformer that retrieves using pytorch.
"""
return TorchRetriever(self, self.torch_vecs(device=device, fp16=fp16), num_results=num_results, qbatch=qbatch, drop_query_vec=drop_query_vec)
return TorchRetriever(self, self.torch_vecs(device=device, fp16=fp16), num_results=num_results, qbatch=qbatch, drop_query_vec=drop_query_vec, index_select=index_select)

FlexIndex.torch_retriever = _torch_retriever