Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
67a042b
clean code
MXueguang Sep 17, 2020
a42911a
temp
MXueguang Sep 19, 2020
f5215dd
fix_bug
MXueguang Sep 19, 2020
4222fb1
search Question.text
MXueguang Sep 19, 2020
09a4e2c
fix bug in bert reader
MXueguang Sep 19, 2020
1ea494a
test chinese
MXueguang Sep 20, 2020
0c7ef5e
create utils_new
MXueguang Sep 20, 2020
f6b11fb
add experiment_squad.py
MXueguang Sep 22, 2020
d71ade8
fix bug of experiment_squad
MXueguang Sep 22, 2020
2e36092
add experiment_cmrc
MXueguang Sep 22, 2020
a521329
change repo structure
MXueguang Sep 22, 2020
d819109
update for cmrc
MXueguang Sep 22, 2020
25ec43e
remove args.py
MXueguang Sep 22, 2020
4e72b57
replace hardcode args by argparse
MXueguang Sep 22, 2020
45bd735
remove some scripts that won't be used in the future
MXueguang Sep 22, 2020
fde68cf
improve repo structure
MXueguang Sep 22, 2020
b5451d5
add experiments documents
MXueguang Sep 22, 2020
73b285f
Merge pull request #2 from MXueguang/temp
MXueguang Sep 22, 2020
49b7301
fix typo in doc
MXueguang Sep 22, 2020
fdf5c23
Merge pull request #3 from MXueguang/temp
MXueguang Sep 22, 2020
6b2e3e7
bug fix
MXueguang Sep 24, 2020
07e8e3e
update setup.py
MXueguang Sep 24, 2020
570eea8
specify package versions
MXueguang Sep 24, 2020
e57bce6
update README
MXueguang Sep 24, 2020
48db464
make bert args configurable
MXueguang Sep 24, 2020
97dbf35
add document for chinese qa
MXueguang Sep 24, 2020
17d0fb2
fix requirements.txt
MXueguang Sep 25, 2020
3c9c53f
Take advantage of pyserini's new prebuilt index features (#10)
qguo96 Oct 5, 2020
70c878e
fix experiment documents and download punkt corpus (#11)
qguo96 Oct 5, 2020
cb0be10
runable when update to huggingface 4.5. Possible issue: inference sup…
amyxie361 Feb 4, 2022
835ab58
runable squad inference+eval. Changes: fix transformers version, upda…
amyxie361 Mar 3, 2022
26f169a
Merge branch 'development' of github.com:rsvp-ai/bertserini into deve…
amyxie361 Mar 3, 2022
618ad87
revise version
amyxie361 Mar 5, 2022
d9d15ab
runable chinese
amyxie361 Mar 11, 2022
f7c79e9
update url (#22)
akkefa Mar 18, 2022
1e57e96
Merge branch 'development' of github.com:castorini/bertserini into de…
amyxie361 Mar 19, 2022
bb7cffb
add string representations for bass classes (#24)
amyxie361 Mar 19, 2022
9a627aa
Update apis for transformers 4.17 and also update pyserini api (#23)
amyxie361 Mar 19, 2022
d84e462
bug fixed, (#25)
akkefa Mar 20, 2022
5ef0645
Add dpr retriever (#26)
amyxie361 Mar 21, 2022
58f33eb
Update requirements.txt (#27)
akkefa Mar 21, 2022
247f455
merge upstream
amyxie361 Mar 21, 2022
6cf264d
add dpr reader (#28)
amyxie361 Mar 21, 2022
1fab31a
Merge branch 'development' of github.com:castorini/bertserini into de…
amyxie361 Mar 21, 2022
50758ec
clean up code
amyxie361 Mar 22, 2022
f11f638
fix typo and remove un-used file
amyxie361 Mar 22, 2022
71ae035
clean up (#29)
amyxie361 Mar 22, 2022
50f96d9
Merge branch 'castorini:development' into development
amyxie361 Mar 22, 2022
93e018d
merge development to master
amyxie361 Mar 22, 2022
0b86c86
fix minor bug
amyxie361 Mar 22, 2022
ff6d802
address comments'
amyxie361 Mar 30, 2022
24a3c5f
Merge branch 'master' of github.com:castorini/bertserini
amyxie361 Aug 18, 2022
eac213b
update to fast-tokenize
amyxie361 Aug 23, 2022
6cd9518
runable fast-tokenizer, need to fix accuracy issues
amyxie361 Aug 24, 2022
0ba273f
add utils
amyxie361 Aug 24, 2022
676ef62
refactor utils and fix minor bug
amyxie361 Aug 24, 2022
5c962a3
fix minor bug
amyxie361 Aug 24, 2022
4ef5e5e
clean up code and tqdm
amyxie361 Aug 24, 2022
3ffd1e0
add T5 reader and corresponding test
AileenLin Sep 19, 2022
c3d7273
Merge pull request #1 from AileenLin/add_t5_aileen
amyxie361 Sep 20, 2022
f16ab3a
Merge pull request #2 from amyxie361/amyxie361/fix-speed
amyxie361 Oct 13, 2022
28e756b
Merge branch 'castorini:master' into master
amyxie361 Oct 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Below is a example for English Question-Answering. We also provide an example fo
```python
from bertserini.reader.base import Question, Context
from bertserini.reader.bert_reader import BERT
from bertserini.utils.utils_new import get_best_answer
from bertserini.utils.utils import get_best_answer

model_name = "rsvp-ai/bertserini-bert-base-squad"
tokenizer_name = "rsvp-ai/bertserini-bert-base-squad"
Expand Down
53 changes: 48 additions & 5 deletions bertserini/experiments/eval/evaluate_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
import argparse
import json

from rouge_metric import PyRouge
rouge = PyRouge(rouge_n=(2,), rouge_su=True, skip_gap=4)
#from rouge_score import rouge_scorer
#rouge1_scorer = rouge_scorer.RougeScorer(['rouge1'], use_stemmer=True)
#rougel_scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)

from bertserini.utils.utils import normalize_answer, init_logger

logger = init_logger("evluation")
Expand Down Expand Up @@ -67,6 +73,32 @@ def overlap_score(prediction, ground_truth):
def exact_match_score(prediction, ground_truth):
return normalize_answer(prediction) == normalize_answer(ground_truth)

def rouge2_r_score(prediction, ground_truth):
if len(prediction) == 0:
return 0
return rouge.evaluate([ground_truth], [[prediction]])["rouge-2"]["r"]
#return rouge1_scorer.score(prediction, ground_truth)

def rouge2_f_score(prediction, ground_truth):
if len(prediction) == 0:
return 0
return rouge.evaluate([ground_truth], [[prediction]])["rouge-2"]["f"]

def rougesu4_r_score(prediction, ground_truth):
if len(prediction) == 0:
return 0
return rouge.evaluate([ground_truth], [[prediction]])["rouge-su4"]["r"]

def rougesu4_f_score(prediction, ground_truth):
if len(prediction) == 0:
return 0
return rouge.evaluate([ground_truth], [[prediction]])["rouge-su4"]["f"]

#def rougel_score(prediction, ground_truth):
# print(rougel_scorer.score(prediction, ground_truth))
# input()
# return rougel_scorer.score(prediction, ground_truth)


def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
scores_for_ground_truths = []
Expand All @@ -92,7 +124,7 @@ def metric_max_recall(metric_fn, prediction, ground_truths):


def evaluate(dataset, predictions):
sentence_cover = precision = cover = sentence_recall = recall = f1 = exact_match = total = overlap = 0
sentence_cover = precision = cover = sentence_recall = recall = f1 = exact_match = total = overlap = rouge2_r = rouge2_f = rougesu4_r = rougesu4_f = 0
for article in dataset:
for paragraph in article['paragraphs']:
for qa in paragraph['qas']:
Expand All @@ -104,6 +136,11 @@ def evaluate(dataset, predictions):
ground_truths = list(map(lambda x: x['text'], qa['answers']))
prediction = [predictions[qa['id']]]
#prediction_sentence = predictions[qa['id']]['sentences']
rouge2_r += metric_max_recall(rouge2_r_score, prediction, ground_truths)
rouge2_f += metric_max_recall(rouge2_f_score, prediction, ground_truths)
rougesu4_r += metric_max_recall(rougesu4_r_score, prediction, ground_truths)
rougesu4_f += metric_max_recall(rougesu4_f_score, prediction, ground_truths)
#rougel += metric_max_recall(rougel_score, prediction, ground_truths)
cover += metric_max_recall(cover_score, prediction, ground_truths)
exact_match += metric_max_recall(
exact_match_score, prediction, ground_truths)
Expand All @@ -124,21 +161,27 @@ def evaluate(dataset, predictions):
overlap = 100.0 * overlap / total
cover = 100.0 * cover / total
precision = 100.0 * precision / total
rouge2_r = 100.0 * rouge2_r / total
rouge2_f = 100.0 * rouge2_f / total
rougesu4_r = 100.0 * rougesu4_r / total
rougesu4_f = 100.0 * rougesu4_f / total
#rougel = 100.0 * rougel / total
#sentence_recall = 100.0 * sentence_recall / total
#sentence_cover = 100.0 * sentence_cover / total

return {'exact_match': exact_match, 'f1': f1, "recall": recall,
#"sentence_recall": sentence_recall, "sentence_cover": sentence_cover,
"precision": precision, "cover": cover, "overlap": overlap}
"precision": precision, "cover": cover, "overlap": overlap,
"rouge2_r": rouge2_r, "rouge2_f":rouge2_f, "rougesu4_r": rougesu4_r, "rougesu4_f": rougesu4_f}


def squad_v1_eval(dataset_filename, prediction_filename):
expected_version = '1.1'
with open(dataset_filename) as dataset_file:
dataset_json = json.load(dataset_file)
if dataset_json['version'] != expected_version:
logger.error('Evaluation expects v-{}, but got dataset with v-{}'.format(
expected_version, dataset_json['version']))
#if dataset_json['version'] != expected_version:
# logger.error('Evaluation expects v-{}, but got dataset with v-{}'.format(
# expected_version, dataset_json['version']))
dataset = dataset_json['data']
with open(prediction_filename) as prediction_file:
predictions = json.load(prediction_file)
Expand Down
6 changes: 5 additions & 1 deletion bertserini/experiments/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
from tqdm import tqdm
from bertserini.reader.bert_reader import BERT
from bertserini.retriever.pyserini_retriever import retriever, build_searcher
from bertserini.utils.utils_new import extract_squad_questions
from bertserini.utils.utils import extract_squad_questions
from bertserini.experiments.args import *
import time

if __name__ == "__main__":
questions = extract_squad_questions(args.dataset_path, do_strip_accents=args.strip_accents)
Expand All @@ -13,8 +14,11 @@

all_answer = []
for question in tqdm(questions):
# print("before retriever:", time.time())
contexts = retriever(question, searcher, args.topk)
# print("before reader:", time.time())
final_answers = bert_reader.predict(question, contexts)
# print("after reader:", time.time())
final_answers_lst = []
for ans in final_answers:
final_answers_lst.append(
Expand Down
Loading