diff --git a/lm_eval/tasks/__init__.py b/lm_eval/tasks/__init__.py index b169e24d26..9d6c87d56a 100644 --- a/lm_eval/tasks/__init__.py +++ b/lm_eval/tasks/__init__.py @@ -32,6 +32,7 @@ from . import tydiqa from . import wino_bias from . import wmt +from . import xnli from . import xquad @@ -111,6 +112,22 @@ # XQuAD "xquad_en": xquad.XQuADEnglish, "xquad_ar": xquad.XQuADArabic, + # XNLI + "xnli_en": xnli.XNLIEn, + "xnli_fr": xnli.XNLIFr, + "xnli_es": xnli.XNLIEs, + "xnli_de": xnli.XNLIDe, + "xnli_el": xnli.XNLIEl, + "xnli_bg": xnli.XNLIBg, + "xnli_ru": xnli.XNLIRu, + "xnli_tr": xnli.XNLITr, + "xnli_ar": xnli.XNLIAr, + "xnli_vi": xnli.XNLIVi, + "xnli_th": xnli.XNLITh, + "xnli_zh": xnli.XNLIZh, + "xnli_hi": xnli.XNLIHi, + "xnli_sw": xnli.XNLISw, + "xnli_ur": xnli.XNLIUr, # PIAF "piaf": piaf.PIAF, # Flores 101 (MT) diff --git a/lm_eval/tasks/xnli.py b/lm_eval/tasks/xnli.py new file mode 100644 index 0000000000..928816168c --- /dev/null +++ b/lm_eval/tasks/xnli.py @@ -0,0 +1,122 @@ +""" +XNLI is an evaluation corpus for language transfer and cross-lingual sentence classification in 15 languages. +https://arxiv.org/abs/1809.05053 + +Homepage: None, Repo: https://github.com/facebookresearch/XNLI +""" +import typing + +from lm_eval.api.task import PromptSourceTask + + +_CITATION = """ +@inproceedings{conneau2018xnli, + title={XNLI: Evaluating Cross-lingual Sentence Representations}, + author={Conneau, Alexis and Rinott, Ruty and Lample, Guillaume and Williams, Adina and Bowman, Samuel and Schwenk, Holger and Stoyanov, Veselin}, + booktitle={Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing}, + pages={2475--2485}, + year={2018} +} +}""" + + +class XNLI(PromptSourceTask): + VERSION = 1 + DATASET_PATH = "xnli" + DATASET_NAME = None + + def has_training_docs(self): + return True + + def has_validation_docs(self): + return True + + def has_test_docs(self): + return True + + def training_docs(self): + if self.has_training_docs(): + return self.dataset["train"] + + def validation_docs(self): + if self.has_validation_docs(): + return self.dataset["validation"] + + +class XNLIEn(XNLI): + DATASET_NAME = "en" + +class XNLIFr(XNLI): + DATASET_NAME = "fr" + +class XNLIEs(XNLI): + DATASET_NAME = "es" + +class XNLIDe(XNLI): + DATASET_NAME = "de" + +class XNLIEl(XNLI): + DATASET_NAME = "el" + +class XNLIBg(XNLI): + DATASET_NAME = "bg" + +class XNLIRu(XNLI): + DATASET_NAME = "ru" + +class XNLITr(XNLI): + DATASET_NAME = "tr" + +class XNLIAr(XNLI): + DATASET_NAME = "ar" + +class XNLIVi(XNLI): + DATASET_NAME = "vi" + +class XNLITh(XNLI): + DATASET_NAME = "th" + +class XNLIZh(XNLI): + DATASET_NAME = "zh" + +class XNLIHi(XNLI): + DATASET_NAME = "hi" + +class XNLISw(XNLI): + DATASET_NAME = "sw" + +class XNLIUr(XNLI): + DATASET_NAME = "ur" + + +XNLI_TASKS = [ + XNLIEn, + XNLIFr, + XNLIEs, + XNLIDe, + XNLIEl, + XNLIBg, + XNLIRu, + XNLITr, + XNLIAr, + XNLIVi, + XNLITh, + XNLIZh, + XNLIHi, + XNLISw, + XNLIUr +] + + +def construct_tasks() -> typing.Dict[str, XNLI]: + """ + Returns a dictionary of tasks keyed by task name, for example: + "GEM/wiki_lingua_ar" + will dispatch to the GEM WikiLingua Arabic class. + """ + tasks = {} + for task_class in XNLI_TASKS: + benchmark = task_class.DATASET_PATH + lang = task_class.DATASET_NAME + tasks[f"{benchmark}_{lang}"] = task_class + return tasks