Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions guides/tasks/supported_tasks.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,11 @@
| [Argument Reasoning Comprehension](https://arxiv.org/abs/1708.01425) | arct | ✅ | ✅ | arct | [Github](https://github.com/UKPLab/argument-reasoning-comprehension-task) |
| Abductive NLI | abductive_nli | ✅ | ✅ | abductive_nli | |
| SuperGLUE Winogender Diagnostic | superglue_axg | ✅ | ✅ | superglue_axg | SuperGLUE |
| Acceptability Definiteness | acceptability_definiteness | ✅ | | acceptability_definiteness | Function Words |
| Adversarial NLI | `adversarial_nli_{round}` | ✅ | | adversarial_nli | 3 rounds |
| Acceptability Definiteness | acceptability_definiteness | ✅ | ✅ | acceptability_definiteness | Function Words |
| Acceptability Coord | acceptability_coord | ✅ | ✅ | acceptability_coord | Function Words |
| Acceptability EOS | acceptability_eos | ✅ | ✅ | acceptability_eos | Function Words |
| Acceptability WH Words | acceptability_whwords | ✅ | ✅ | acceptability_whwords | Function Words |
| Adversarial NLI | `adversarial_nli_{round}` | ✅ | ✅ | adversarial_nli | 3 rounds |
| ARC ("easy" version) | arc_easy | ✅ | ✅ | arc_easy | [site](https://allenai.org/data/arc) |
| ARC ("challenge" version) | arc_challenge | ✅ | ✅ | arc_challenge | [site](https://allenai.org/data/arc) |
| BoolQ | boolq | ✅ | ✅ | boolq | SuperGLUE |
Expand Down Expand Up @@ -55,7 +58,16 @@
| ReCord | record | ✅ | ✅ | record | SuperGLUE |
| RTE | rte | ✅ | ✅ | rte | GLUE, SuperGLUE |
| SciTail | scitail | ✅ | ✅ | scitail | |
| SentEval: Tense | senteval_tense | ✅ | | senteval_tense | SentEval |
| SentEval: Bigram Shift | senteval_bigram_shift | ✅ | ✅ | senteval_bigram_shift | SentEval |
| SentEval: Coord Inversion | senteval_coordination_inversion | ✅ | ✅ | senteval_coordination_inversion | SentEval |
| SentEval: Obj number | senteval_obj_number | ✅ | ✅ | senteval_obj_number | SentEval |
| SentEval: Odd Man Out | senteval_odd_man_out | ✅ | ✅ | senteval_odd_man_out | SentEval |
| SentEval: Past-Present | senteval_past_present | ✅ | ✅ | senteval_past_present | SentEval |
| SentEval: Sentence Length | senteval_sentence_length | ✅ | ✅ | senteval_sentence_length | SentEval |
| SentEval: Subj Number | senteval_subj_number | ✅ | ✅ | senteval_subj_number | SentEval |
| SentEval: Top Constituents | senteval_top_constituents | ✅ | ✅ | senteval_top_constituents | SentEval |
| SentEval: Tree Depth | senteval_tree_depth | ✅ | ✅ | senteval_tree_depth | SentEval |
| SentEval: Word Content | senteval_word_content | ✅ | ✅ | senteval_word_content | SentEval |
| EP-Rel | semeval | ✅ | | semeval | Edge-Probing |
| SNLI | snli | ✅ | ✅ | snli | |
| SocialIQA | socialiqa | ✅ | ✅ | socialiqa | |
Expand Down
14 changes: 14 additions & 0 deletions jiant/scripts/download_data/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@
"piqa",
"winogrande",
"ropes",
"acceptability_definiteness",
"acceptability_coord",
"acceptability_eos",
"acceptability_whwords",
"senteval_bigram_shift",
"senteval_coordination_inversion",
"senteval_obj_number",
"senteval_odd_man_out",
"senteval_past_present",
"senteval_sentence_length",
"senteval_subj_number",
"senteval_top_constituents",
"senteval_tree_depth",
"senteval_word_content",
}

DIRECT_DOWNLOAD_TASKS = set(
Expand Down
90 changes: 90 additions & 0 deletions jiant/scripts/download_data/dl_datasets/files_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,30 @@ def download_task_data_and_write_config(task_name: str, task_data_path: str, tas
download_ropes_data_and_write_config(
task_name=task_name, task_data_path=task_data_path, task_config_path=task_config_path
)
elif task_name in [
"acceptability_definiteness",
"acceptability_coord",
"acceptability_eos",
"acceptability_whwords",
]:
download_acceptability_judgments_data_and_write_config(
task_name=task_name, task_data_path=task_data_path, task_config_path=task_config_path
)
elif task_name in [
"senteval_bigram_shift",
"senteval_coordination_inversion",
"senteval_obj_number",
"senteval_odd_man_out",
"senteval_past_present",
"senteval_sentence_length",
"senteval_subj_number",
"senteval_top_constituents",
"senteval_tree_depth",
"senteval_word_content",
]:
download_senteval_data_and_write_config(
task_name=task_name, task_data_path=task_data_path, task_config_path=task_config_path
)
else:
raise KeyError(task_name)

Expand Down Expand Up @@ -1012,3 +1036,69 @@ def download_ropes_data_and_write_config(
},
path=task_config_path,
)


def download_acceptability_judgments_data_and_write_config(
task_name: str, task_data_path: str, task_config_path: str
):
dataset_name = {
"acceptability_definiteness": "definiteness",
"acceptability_coord": "coordinating-conjunctions",
"acceptability_whwords": "whwords",
"acceptability_eos": "eos",
}[task_name]
os.makedirs(task_data_path, exist_ok=True)
# data contains all train/val/test examples
# metadata contains the split indicators
# (there are 10 CV folds, we use fold1 by default, see below)
data_path = os.path.join(task_data_path, "data.json")
metadata_path = os.path.join(task_data_path, "metadata.json")
download_utils.download_file(
url="https://raw.githubusercontent.com/decompositional-semantics-initiative/DNC/master/"
f"function_words/ACCEPTABILITY/acceptability-{dataset_name}_data.json",
file_path=data_path,
)
download_utils.download_file(
url="https://raw.githubusercontent.com/decompositional-semantics-initiative/DNC/master/"
f"function_words/ACCEPTABILITY/acceptability-{dataset_name}_metadata.json",
file_path=metadata_path,
)
py_io.write_json(
data={
"task": task_name,
"paths": {"data": data_path, "metadata": metadata_path},
"name": task_name,
"kwargs": {"fold": "fold1"}, # use fold1 (out of 10) by default
},
path=task_config_path,
)


def download_senteval_data_and_write_config(
task_name: str, task_data_path: str, task_config_path: str
):
name_map = {
"senteval_bigram_shift": "bigram_shift",
"senteval_coordination_inversion": "coordination_inversion",
"senteval_obj_number": "obj_number",
"senteval_odd_man_out": "odd_man_out",
"senteval_past_present": "past_present",
"senteval_sentence_length": "sentence_length",
"senteval_subj_number": "subj_number",
"senteval_top_constituents": "top_constituents",
"senteval_tree_depth": "tree_depth",
"senteval_word_content": "word_content",
}
dataset_name = name_map[task_name]
os.makedirs(task_data_path, exist_ok=True)
# data contains all train/val/test examples, first column indicates the split
data_path = os.path.join(task_data_path, "data.tsv")
download_utils.download_file(
url="https://raw.githubusercontent.com/facebookresearch/SentEval/master/data/probing/"
f"{dataset_name}.txt",
file_path=data_path,
)
py_io.write_json(
data={"task": task_name, "paths": {"data": data_path}, "name": task_name},
path=task_config_path,
)
4 changes: 3 additions & 1 deletion jiant/scripts/download_data/runscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ def download_data(task_names, output_base_path):
task_data_base_path = py_io.create_dir(output_base_path, "data")
task_config_base_path = py_io.create_dir(output_base_path, "configs")

assert set(task_names).issubset(SUPPORTED_TASKS)
assert set(task_names).issubset(SUPPORTED_TASKS), "Following tasks are not support: {}".format(
",".join(set(task_names) - SUPPORTED_TASKS)
)

# Download specified tasks and generate configs for specified tasks
for i, task_name in enumerate(task_names):
Expand Down
14 changes: 13 additions & 1 deletion jiant/tasks/evaluate/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,6 +957,9 @@ def get_evaluation_scheme_for_task(task) -> BaseEvaluationScheme:
tasks.AdversarialNliTask,
tasks.AbductiveNliTask,
tasks.AcceptabilityDefinitenessTask,
tasks.AcceptabilityCoordTask,
tasks.AcceptabilityEOSTask,
tasks.AcceptabilityWHwordsTask,
tasks.BoolQTask,
tasks.CopaTask,
tasks.FeverNliTask,
Expand All @@ -966,7 +969,16 @@ def get_evaluation_scheme_for_task(task) -> BaseEvaluationScheme:
tasks.RaceTask,
tasks.RteTask,
tasks.SciTailTask,
tasks.SentevalTenseTask,
tasks.SentEvalBigramShiftTask,
tasks.SentEvalCoordinationInversionTask,
tasks.SentEvalObjNumberTask,
tasks.SentEvalOddManOutTask,
tasks.SentEvalPastPresentTask,
tasks.SentEvalSentenceLengthTask,
tasks.SentEvalSubjNumberTask,
tasks.SentEvalTopConstituentsTask,
tasks.SentEvalTreeDepthTask,
tasks.SentEvalWordContentTask,
tasks.SnliTask,
tasks.SstTask,
tasks.WiCTask,
Expand Down
107 changes: 107 additions & 0 deletions jiant/tasks/lib/acceptability_judgement/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import numpy as np
import torch
from dataclasses import dataclass
from typing import List

from jiant.tasks.core import (
BaseExample,
BaseTokenizedExample,
BaseDataRow,
BatchMixin,
Task,
TaskTypes,
)
from jiant.tasks.lib.templates.shared import single_sentence_featurize, labels_to_bimap
from jiant.utils.python.io import read_json


@dataclass
class Example(BaseExample):
guid: str
text: str
label: str

def tokenize(self, tokenizer):
return TokenizedExample(
guid=self.guid,
text=tokenizer.tokenize(self.text),
label_id=BaseAcceptabilityTask.LABEL_TO_ID[self.label],
)


@dataclass
class TokenizedExample(BaseTokenizedExample):
guid: str
text: List
label_id: int

def featurize(self, tokenizer, feat_spec):
return single_sentence_featurize(
guid=self.guid,
input_tokens=self.text,
label_id=self.label_id,
tokenizer=tokenizer,
feat_spec=feat_spec,
data_row_class=DataRow,
)


@dataclass
class DataRow(BaseDataRow):
guid: str
input_ids: np.ndarray
input_mask: np.ndarray
segment_ids: np.ndarray
label_id: int
tokens: list


@dataclass
class Batch(BatchMixin):
input_ids: torch.LongTensor
input_mask: torch.LongTensor
segment_ids: torch.LongTensor
label_id: torch.LongTensor
tokens: list


class BaseAcceptabilityTask(Task):
Example = Example
TokenizedExample = TokenizedExample
DataRow = DataRow
Batch = Batch

TASK_TYPE = TaskTypes.CLASSIFICATION
LABELS = ["acceptable", "unacceptable"]
LABEL_TO_ID, ID_TO_LABEL = labels_to_bimap(LABELS)
DATA_PHASE_MAP = {"train": "train", "dev": "val", "test": "test"}

def __init__(self, name, path_dict, fold: str):
# Fold should be a string like "fold1"
super().__init__(name=name, path_dict=path_dict)
self.fold = fold

def get_train_examples(self):
return self._create_examples(set_type="train")

def get_val_examples(self):
return self._create_examples(set_type="val")

def get_test_examples(self):
return self._create_examples(set_type="test")

def _create_examples(self, set_type):
data = read_json(self.path_dict["data"])
metadata = read_json(self.path_dict["metadata"])
assert len(data) == len(metadata)
examples = []
for data_row, metadata_row in zip(data, metadata):
row_phase = self.DATA_PHASE_MAP[metadata_row["misc"][self.fold]]
if row_phase != set_type:
continue
examples.append(
Example(
guid=data_row["pair-id"], text=data_row["context"], label=data_row["label"],
)
)
return examples
26 changes: 26 additions & 0 deletions jiant/tasks/lib/acceptability_judgement/coord.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from dataclasses import dataclass
import jiant.tasks.lib.acceptability_judgement.base as base


@dataclass
class Example(base.Example):
pass


@dataclass
class TokenizedExample(base.TokenizedExample):
pass


@dataclass
class DataRow(base.DataRow):
pass


@dataclass
class Batch(base.Batch):
pass


class AcceptabilityCoordTask(base.BaseAcceptabilityTask):
pass
Loading