Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,35 @@ We have evaluated our system on `SQuAD 1.1` and `CMRC2018` development set.
Please see following documents for details:
- [SQuAD experiments](docs/experiments-squad.md)
- [CMRC experiments](docs/experiments-cmrc.md)

## DPR supporting

We enabled DPR retriever with pyserini indexed corpus.
The corpus is created from the command:
```
python -m pyserini.encode \
input --corpus <original_corpus_dir> \
--delimiter "DoNotApplyDelimiterPlease" \
--shard-id 0 \
--shard-num 1 \
output --embeddings dpr-ctx_encoder-multiset-base.corpus \
--to-faiss \
encoder --encoder facebook/dpr-ctx_encoder-multiset-base \
--batch-size 16 \
--device cuda:0 \
--fp16 # if inference with autocast()
```

When enable dpr option in e2e inference, please set the following arguments:

```
--retriever dpr \
--encoder <path to dpr query encoder> \
--index_path <pyserini indexed dpr dir> \
--sparse_index <bm25 indexed corpus dir> \ # the dense index doesn't store the raw text, we need to get the original text from the sparse index
--device cuda:0
```

## Citation

Please cite [the NAACL 2019 paper]((https://www.aclweb.org/anthology/N19-4013/)):
Expand Down
55 changes: 49 additions & 6 deletions bertserini/experiments/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,64 @@

parser = argparse.ArgumentParser()

parser.add_argument(
"--device",
default="cpu",
type=str,
help="Device to run query encoder, cpu or [cuda:0, cuda:1, ...]",
)
parser.add_argument(
"--dataset_path",
default=None,
type=str,
required=True,
help="Path to the [dev, test] dataset",
)

parser.add_argument(
"--retriever",
default="bm25",
type=str,
help="define the indexer type",
)
parser.add_argument(
"--k1",
default=0.9,
type=float,
help="k1, parameter for bm25 retriever",
)
parser.add_argument(
"--b",
default=0.4,
type=float,
help="b, parameter for bm25 retriever",
)
parser.add_argument(
"--encoder",
default="facebook/dpr-question_encoder-multiset-base",
type=str,
help="dpr encoder path or name",
)
parser.add_argument(
"--query_tokenizer_name",
default=None,
type=str,
help="tokenizer for dpr encoder",
)
parser.add_argument(
"--index_path",
default=None,
type=str,
required=True,
help="Path to the indexes of contexts",
)
parser.add_argument(
"--sparse_index",
default=None,
type=str,
help="Path to the indexes of sarse tokenizer, required when using dense index, in order to retrieve the raw document",
)
parser.add_argument(
"--model_name_or_path",
default=None,
type=str,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models",
)
parser.add_argument(
Expand All @@ -34,12 +72,11 @@
"--output",
default=None,
type=str,
required=True,
help="The output file where the runs results will be written to",
)
parser.add_argument(
"--output_nbest_file",
default="./tmp.nbest",
default=None,
type=str,
help="The output file for store nbest results temporarily",
)
Expand All @@ -49,6 +86,12 @@
type=str,
help="The language of task",
)
parser.add_argument(
"--eval_batch_size",
default=32,
type=int,
help="batch size for evaluation",
)
parser.add_argument(
"--topk",
default=10,
Expand Down
9 changes: 5 additions & 4 deletions bertserini/experiments/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,17 @@ def get_score_with_results(eval_data, predictions, mu, dataset):
return eval_result, answers


def get_best_mu_with_scores(eval_data, predictions, mu_range, dataset, output_path):
def get_best_mu_with_scores(eval_data, predictions, mu_range, dataset, output_path, standard="f1"):
# standard = "f1" or "exact_match"
score_test = {}
best_mu = 0
best_em = 0
best_score = 0
for mu in mu_range:
eval_result, answers = get_score_with_results(eval_data, predictions, mu, dataset)
score_test[mu] = eval_result
if eval_result["exact_match"] > best_em:
if eval_result[standard] > best_score:
best_mu = mu
best_em = eval_result['exact_match']
best_score = eval_result[standard]
json.dump(answers, open(output_path + "/prediction.json", 'w'))

json.dump(score_test, open(output_path + "/score.json", 'w'))
Expand Down
5 changes: 3 additions & 2 deletions bertserini/experiments/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@

if __name__ == "__main__":
questions = extract_squad_questions(args.dataset_path)
bert_reader = BERT(args.model_name_or_path, args.tokenizer_name)
searcher = build_searcher(args.index_path, language=args.language)
#bert_reader = BERT(args.model_name_or_path, args.tokenizer_name)
bert_reader = BERT(args)
searcher = build_searcher(args)

all_answer = []
for question in tqdm(questions):
Expand Down
2 changes: 2 additions & 0 deletions bertserini/reader/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@ class Context:

def __init__(self,
text: str,
title: Optional[str] = "",
language: str = "en",
metadata: Mapping[str, Any] = None,
score: Optional[float] = 0):
self.text = text
self.title = title
self.language = language
if metadata is None:
metadata = dict()
Expand Down
24 changes: 15 additions & 9 deletions bertserini/reader/bert_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,19 @@ def craft_squad_examples(question: Question, contexts: List[Context]) -> List[Sq
title="",
is_impossible=False,
answers=[],
language=ctx.language
)
)
return examples


class BERT(Reader):
def __init__(self, model_name: str, tokenizer_name: str = None, output_nbest_file=None):
if tokenizer_name is None:
tokenizer_name = model_name
def __init__(self, args):
self.model_args = args
if self.model_args.tokenizer_name is None:
self.model_args.tokenizer_name = self.model_args.model_name_or_path
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = AutoModelForQuestionAnswering.from_pretrained(model_name).to(self.device).eval()
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, do_lower_case=True, use_fast=False)
self.model = AutoModelForQuestionAnswering.from_pretrained(self.model_args.model_name_or_path).to(self.device).eval()
self.tokenizer = AutoTokenizer.from_pretrained(self.model_args.tokenizer_name, do_lower_case=True, use_fast=False)
self.args = {
"max_seq_length": 384,
"doc_stride": 128,
Expand All @@ -49,7 +49,7 @@ def __init__(self, model_name: str, tokenizer_name: str = None, output_nbest_fil
"max_answer_length": 30,
"do_lower_case": True,
"output_prediction_file": False,
"output_nbest_file": output_nbest_file,
"output_nbest_file": self.model_args.output_nbest_file,
"output_null_log_odds_file": None,
"verbose_logging": False,
"version_2_with_negative": True,
Expand Down Expand Up @@ -77,7 +77,7 @@ def predict(self, question: Question, contexts: List[Context]) -> List[Answer]:

# Note that DistributedSampler samples randomly
eval_sampler = SequentialSampler(dataset)
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=32)
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=self.model_args.eval_batch_size)

all_results = []

Expand All @@ -98,8 +98,14 @@ def predict(self, question: Question, contexts: List[Context]) -> List[Answer]:
unique_id = int(eval_feature.unique_id)

output = [outputs[oname][i] for oname in outputs]
start_logits = outputs.start_logits[i]
end_logits = outputs.end_logits[i]
try:
start_logits = start_logits.item()
end_logits = end_logits.item()
except:
pass

start_logits, end_logits = output
result = SquadResult(unique_id, start_logits, end_logits)

all_results.append(result)
Expand Down
66 changes: 45 additions & 21 deletions bertserini/retriever/pyserini_retriever.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,58 @@
from typing import List
import json

from pyserini.search import FaissSearcher, DprQueryEncoder
from pyserini.search.lucene import LuceneSearcher, JLuceneSearcherResult
from bertserini.utils.utils import init_logger
from bertserini.reader.base import Context

logger = init_logger("retriever")


def build_searcher(index_path, k1=0.9, b=0.4, language="en"):
searcher = LuceneSearcher(index_path)
searcher.set_bm25(k1, b)
searcher.object.setLanguage(language)
def build_searcher(args):
if args.retriever == "bm25":
searcher = LuceneSearcher(args.index_path)
searcher.set_bm25(args.k1, args.b)
searcher.object.setLanguage(args.language)
elif args.retriever == "dpr":
query_encoder = DprQueryEncoder(
encoder_dir=args.encoder,
tokenizer_name=args.query_tokenizer_name,
device=args.device)
searcher = FaissSearcher(args.index_path, query_encoder)
ssearcher = LuceneSearcher(args.sparse_index)
searcher.ssearcher = ssearcher
else:
raise Exception("Non-Defined Retriever:", args.retriever)
return searcher

def build_searcher_from_prebuilt_index(index_name, k1=0.9, b=0.4, language="en"):
searcher = LuceneSearcher.from_prebuilt_index(index_name)
searcher.set_bm25(k1, b)
searcher.object.setLanguage(language)
def build_searcher_from_prebuilt_index(args):
if args.retriever == "bm25":
searcher = LuceneSearcher.from_prebuilt_index(args.index_path)
searcher.set_bm25(args.k1, args.b)
searcher.object.setLanguage(args.language)
else:
raise Exception("Not implemented regriever from prebuilt index:", args.retirever)
return searcher

def retriever(question, searcher, para_num=20):
language = question.language
try:
if language == "zh":
hits = searcher.search(question.text.encode("utf-8"), k=para_num)
else:
hits = searcher.search(question.text, k=para_num)
except ValueError as e:
logger.error("Search failure: {}, {}".format(question.text, e))
return []
if type(searcher) == FaissSearcher:
results = searcher.search(question.text, para_num)
hits = []
for r in results:
hit = searcher.doc(r.docid).get("raw")
hits.append((hit, r.score))
else:
try:
if language == "zh":
hits = searcher.search(question.text.encode("utf-8"), k=para_num)
else:
hits = searcher.search(question.text, k=para_num)
except ValueError as e:
logger.error("Search failure: {}, {}".format(question.text, e))
return []
hits = [(h.raw, h.score) for h in hits]
return hits_to_contexts(hits, language)


Expand All @@ -53,14 +76,15 @@ def hits_to_contexts(hits: List[JLuceneSearcherResult], language="en", field='ra
"""
contexts = []
for i in range(0, len(hits)):
t = hits[i].raw if field == 'raw' else hits[i].contents
hit, score = hits[i]
try: # the previous chinese index stores the contents as "raw", while the english index stores the json string.
t = json.loads(t)["contents"]
t = json.loads(hit)["contents"]
except:
pass
t = hit
for s in blacklist:
if s in t:
continue
metadata = {'raw': hits[i].raw, 'docid': hits[i].docid}
contexts.append(Context(t, language, metadata, hits[i].score))
#metadata = {'raw': hits.raw, 'docid': hits.docid}
metadata = {}
contexts.append(Context(t, language, metadata, score))
return contexts
Loading