Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions pyterrier_dr/flex/np_retr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,49 +8,74 @@
import pyterrier_alpha as pta

class NumpyRetriever(pt.Transformer):
def __init__(self,
def __init__(
self,
flex_index: FlexIndex,
*,
num_results: int = 1000,
batch_size: Optional[int] = None,
drop_query_vec: bool = False
mask: Optional[np.ndarray] = None,
drop_query_vec: bool = False,
):
self.flex_index = flex_index
self.num_results = num_results
self.batch_size = batch_size or 4096
self.mask = mask
self.drop_query_vec = drop_query_vec

def fuse_rank_cutoff(self, k):
if k < self.num_results:
return NumpyRetriever(self.flex_index, num_results=k, batch_size=self.batch_size, drop_query_vec=self.drop_query_vec)
return NumpyRetriever(
self.flex_index,
num_results=k,
batch_size=self.batch_size,
mask=self.mask,
drop_query_vec=self.drop_query_vec,
)

def transform(self, inp: pd.DataFrame) -> pd.DataFrame:
pta.validate.query_frame(inp, extra_columns=['query_vec'])

if not len(inp):
result = pta.DataFrameBuilder(['docno', 'docid', 'score', 'rank'])
if self.drop_query_vec:
inp = inp.drop(columns='query_vec')
return result.to_df(inp)

inp = inp.reset_index(drop=True)
query_vecs = np.stack(inp['query_vec'])

docnos, dvecs, config = self.flex_index.payload()

if self.flex_index.sim_fn == SimFn.cos:
query_vecs = query_vecs / np.linalg.norm(query_vecs, axis=1, keepdims=True)
elif self.flex_index.sim_fn == SimFn.dot:
pass # nothing to do
else:
elif self.flex_index.sim_fn != SimFn.dot:
raise ValueError(f'{self.flex_index.sim_fn} not supported')

num_q = query_vecs.shape[0]
ranked_lists = RankedLists(self.num_results, num_q)

batch_it = range(0, dvecs.shape[0], self.batch_size)
if self.flex_index.verbose:
batch_it = pt.tqdm(batch_it, desc='NumpyRetriever scoring', unit='docbatch')

for idx_start in batch_it:
doc_batch = dvecs[idx_start:idx_start+self.batch_size].T

if self.flex_index.sim_fn == SimFn.cos:
doc_batch = doc_batch / np.linalg.norm(doc_batch, axis=0, keepdims=True)

scores = query_vecs @ doc_batch
dids = np.arange(idx_start, idx_start+doc_batch.shape[1], dtype='i4').reshape(1, -1).repeat(num_q, axis=0)

if self.mask is not None:
scores *= self.mask[idx_start:idx_start+doc_batch.shape[1]].reshape(1, -1)

dids = np.arange(
idx_start,
idx_start + doc_batch.shape[1],
dtype='i4'
).reshape(1, -1).repeat(num_q, axis=0)

ranked_lists.update(scores, dids)

result = pta.DataFrameBuilder(['docno', 'docid', 'score', 'rank'])
Expand All @@ -64,6 +89,7 @@ def transform(self, inp: pd.DataFrame) -> pd.DataFrame:

if self.drop_query_vec:
inp = inp.drop(columns='query_vec')

return result.to_df(inp)


Expand Down Expand Up @@ -136,6 +162,7 @@ def _np_vecs(self) -> np.ndarray:
return dvecs
FlexIndex.np_vecs = _np_vecs


def _np_retriever(self, *, num_results: int = 1000, batch_size: Optional[int] = None, drop_query_vec: bool = False):
"""Return a retriever that uses numpy to perform a brute force search over the index.

Expand All @@ -151,6 +178,7 @@ def _np_retriever(self, *, num_results: int = 1000, batch_size: Optional[int] =
:class:`~pyterrier.Transformer`: A retriever that uses numpy to perform a brute force search.
"""
return NumpyRetriever(self, num_results=num_results, batch_size=batch_size, drop_query_vec=drop_query_vec)

FlexIndex.np_retriever = _np_retriever
FlexIndex.retriever = _np_retriever # default retriever

Expand Down
110 changes: 79 additions & 31 deletions pyterrier_dr/flex/torch_retr.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,91 +34,138 @@ def score(self, query_vecs, docids):

# transform inherited from NumpyScorer


class TorchRetriever(pt.Transformer):
def __init__(self,
def __init__(
self,
flex_index: FlexIndex,
torch_vecs: torch.Tensor,
*,
num_results: int = 1000,
qbatch: int = 64,
drop_query_vec: bool = False
index_select: Optional[np.ndarray] = None,
drop_query_vec: bool = False,
):
self.flex_index = flex_index
self.torch_vecs = torch_vecs
self.num_results = num_results or 1000
self.docnos, meta = flex_index.payload(return_dvecs=False)
self.num_results = num_results
self.qbatch = qbatch
self.drop_query_vec = drop_query_vec

self.docnos, _ = flex_index.payload(return_dvecs=False)

self.index_select = None
if index_select is not None:
self.index_select = torch.as_tensor(
index_select,
dtype=torch.long,
device=torch_vecs.device
)

def fuse_rank_cutoff(self, k):
if k < self.num_results:
return TorchRetriever(
self.flex_index,
self.torch_vecs,
num_results=k,
qbatch=self.qbatch,
drop_query_vec=self.drop_query_vec)
self.flex_index,
self.torch_vecs,
num_results=k,
qbatch=self.qbatch,
index_select=self.index_select,
drop_query_vec=self.drop_query_vec,
)

def transform(self, inp):
pta.validate.query_frame(inp, extra_columns=['query_vec'])
inp = inp.reset_index(drop=True)
query_vecs = np.stack(inp['query_vec'])
query_vecs = torch.from_numpy(query_vecs).to(self.torch_vecs)

query_vecs = torch.from_numpy(
np.stack(inp['query_vec'])
).to(self.torch_vecs)

tv = (
self.torch_vecs[self.index_select].T
if self.index_select is not None
else self.torch_vecs.T
)

result = pta.DataFrameBuilder(['docno', 'docid', 'score', 'rank'])

it = range(0, query_vecs.shape[0], self.qbatch)
if self.flex_index.verbose:
it = pt.tqdm(it, desc='TorchRetriever', unit='qbatch')

result = pta.DataFrameBuilder(['docno', 'docid', 'score', 'rank'])
for start_idx in it:
end_idx = start_idx + self.qbatch
batch = query_vecs[start_idx:end_idx]
if self.flex_index.sim_fn == SimFn.dot:
scores = batch @ self.torch_vecs.T
else:
for start in it:
batch = query_vecs[start:start+self.qbatch]

if self.flex_index.sim_fn != SimFn.dot:
raise ValueError(f'{self.flex_index.sim_fn} not supported')

scores = batch @ tv

if scores.shape[1] > self.num_results:
scores, docids = scores.topk(self.num_results, dim=1)
else:
docids = scores.argsort(descending=True, dim=1)
scores = torch.gather(scores, dim=1, index=docids)
for s, d in zip(scores.cpu().numpy(), docids.cpu().numpy()):
scores = torch.gather(scores, 1, docids)

scores = scores.cpu().numpy()
docids = docids.cpu().numpy()

if self.index_select is not None:
docids = self.index_select.cpu().numpy()[docids]

for s, d in zip(scores, docids):
result.extend({
'docno': self.docnos[d],
'docid': d,
'score': s,
'rank': np.arange(s.shape[0]),
'rank': np.arange(len(s)),
})

if self.drop_query_vec:
inp = inp.drop(columns='query_vec')

return result.to_df(inp)


def _torch_vecs(self, *, device: Optional[str] = None, fp16: bool = False) -> torch.Tensor:

def _torch_vecs(
self,
*,
device: Optional[str] = None,
fp16: bool = False
) -> torch.Tensor:
"""Return the indexed vectors as a pytorch tensor.

.. caution::
This method loads the entire index into memory on the provided device. If the index is too large to fit in memory,
consider using a different method that does not fully load the index into memory, like :meth:`np_vecs` or
:meth:`get_corpus_iter`.
This method loads the entire index into memory on the provided device.
If the index is too large to fit in memory, consider using :meth:`np_vecs`
or :meth:`get_corpus_iter`.

Args:
device: The device to use for the tensor. If not provided, the default device is used (cuda if available, otherwise cpu).
fp16: Whether to use half precision (fp16) for the tensor.
device: The device to use for the tensor. If not provided, the default
device is used (cuda if available, otherwise cpu).
fp16: Whether to use half precision (fp16).

Returns:
:class:`torch.Tensor`: The indexed vectors as a torch tensor.
:class:`torch.Tensor`: The indexed vectors.
"""
device = infer_device(device)
key = ('torch_vecs', device, fp16)

if key not in self._cache:
dvecs, meta = self.payload(return_docnos=False)
data = torch.frombuffer(dvecs, dtype=torch.float32).reshape(*dvecs.shape)
# Load numpy-backed vectors (memory-mapped)
dvecs, _ = self.payload(return_docnos=False)

# Important: frombuffer avoids an extra copy
data = torch.frombuffer(
dvecs,
dtype=torch.float32
).reshape(*dvecs.shape)

if fp16:
data = data.half()

self._cache[key] = data.to(device)

return self._cache[key]
FlexIndex.torch_vecs = _torch_vecs

Expand Down Expand Up @@ -172,4 +219,5 @@ def _torch_retriever(self,
:class:`~pyterrier.Transformer`: A transformer that retrieves using pytorch.
"""
return TorchRetriever(self, self.torch_vecs(device=device, fp16=fp16), num_results=num_results, qbatch=qbatch, drop_query_vec=drop_query_vec)

FlexIndex.torch_retriever = _torch_retriever