Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion lm_eval/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from . import tydiqa
from . import wino_bias
from . import wmt
from . import xcopa
from . import xquad


Expand Down Expand Up @@ -79,7 +80,7 @@
"hans": hans.HANS,
# CNN Daily Mail
"cnn_dailymail": cnn_dailymail.CnnDailyMail,
# GEM/xum
# GEM/xsum
"gem_xsum": gem_xsum.GEMXSUM,
"gem_xsum_challenge_sample": gem_xsum.GEMXSUMChallgeSample,
"gem_xsum_challenge_test_backtranslation": gem_xsum.GEMXSUMChallgeTestBacktranslation,
Expand Down Expand Up @@ -198,6 +199,8 @@
# TyDi QA
"tydiqa_primary": tydiqa.TyDiQAPrimaryClassification,
"tydiqa_secondary": tydiqa.TyDiQAGoldPGeneration,
# XCOPA
**xcopa.construct_tasks(),
#######################################################
# TODO: Not Yet Available in `promptsource/eval-hackathon`
########################################################
Expand Down
85 changes: 85 additions & 0 deletions lm_eval/tasks/xcopa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
Homepage:
"""
import typing

from lm_eval.api.task import PromptSourceTask


_CITATION = """
TODO: add
"""

class XCopaBase(PromptSourceTask):
VERSION = 0
DATASET_PATH = "xcopa"
DATASET_NAME = None

def has_training_docs(self):
return False

def has_validation_docs(self):
return True

def has_test_docs(self):
return True

def validation_docs(self):
return self.dataset["validation"]

def test_docs(self):
return self.dataset["test"]

def invalid_doc_for_prompt(self, doc) -> bool:
# HACK: Some copa templates have conditionals that ignore documents
# when the condition is not met, like `{if doc['question'] != \"cause\"}`.
# This means the prompt will never produce an input and target.
try:
result = self.prompt_template.apply(doc)
if result == ['']:
return True
else:
return False
except Exception:
return True

class XCopaId(XCopaBase):
DATASET_NAME = "id"

class XCopaIt(XCopaBase):
DATASET_NAME = "it"

class XCopaSw(XCopaBase):
DATASET_NAME = "sw"

class XCopaTa(XCopaBase):
DATASET_NAME = "ta"

class XCopaVi(XCopaBase):
DATASET_NAME = "vi"

class XCopaZh(XCopaBase):
DATASET_NAME = "zh"

XCOPA_TASKS = [
XCopaId,
XCopaIt,
XCopaSw,
XCopaTa,
XCopaVi,
XCopaZh,
]

def construct_tasks() -> typing.Dict[str, XCopaBase]:
"""
Returns a dictionary of tasks keyed by task name, for example:
"xcopa/id": XCopaId
will dispatch to the GEM WikiLingua Arabic class.
"""
tasks = {}
for task_class in XCOPA_TASKS:
benchmark = task_class.DATASET_PATH
lang = task_class.DATASET_NAME
tasks[f"{benchmark}_{lang}"] = task_class
return tasks
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
"tqdm-multiprocess==0.0.11",
"accelerate@git+https://github.com/huggingface/accelerate@main",
"transformers@git+https://github.com/huggingface/transformers@main",
"promptsource@git+https://github.com/bigscience-workshop/promptsource@eval-hackathon",
#"promptsource@git+https://github.com/bigscience-workshop/promptsource@eval-hackathon",
# install promptsource manually to ensure it's up-to-date with the correct branch
]
dependency_links = []

Expand Down