Skip to content

Commit c181273

Browse files
pyeressleepinyourhat
authored andcommitted
add adversarial_nli tasks (#966)
1 parent 8a059b8 commit c181273

1 file changed

Lines changed: 95 additions & 0 deletions

File tree

jiant/tasks/tasks.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1097,6 +1097,101 @@ def load_data(self):
10971097
log.info("\tFinished loading SNLI data.")
10981098

10991099

1100+
@register_task("adversarial_nli_a1", rel_path="AdversarialNLI/", datasets=["R1"])
1101+
@register_task("adversarial_nli_a2", rel_path="AdversarialNLI/", datasets=["R2"])
1102+
@register_task("adversarial_nli_a3", rel_path="AdversarialNLI/", datasets=["R3"])
1103+
@register_task("adversarial_nli", rel_path="AdversarialNLI/", datasets=["R1", "R2", "R3"])
1104+
class AdversarialNLITask(PairClassificationTask):
1105+
"""Task class for use with Adversarial Natural Language Inference dataset.
1106+
1107+
Configures a 3-class PairClassificationTask using Adversarial NLI data.
1108+
Requires original ANLI dataset file structure under the relative path.
1109+
Data: https://dl.fbaipublicfiles.com/anli/anli_v0.1.zip
1110+
Paper: https://arxiv.org/abs/1910.14599
1111+
1112+
Attributes:
1113+
path (str): AdversarialNLI path relative to JIANT_DATA_DIR
1114+
max_seq_len (int): max tokens allowed in a sequence
1115+
train_data_text (list[list[str], list[str], list[int]]):
1116+
list of lists of context, hypothesis, and target training data
1117+
val_data_text (list[list[str], list[str], list[int]]):
1118+
list of lists of context, hypothesis, and target val data
1119+
test_data_text (list[list[str], list[str], list[int]]):
1120+
list of lists of context, hypothesis, and target test data
1121+
datasets (list[str]): list of sub-datasets used in task (e.g., R1)
1122+
sentences (list): list of all (tokenized) context and hypothesis
1123+
texts from train and val data.
1124+
"""
1125+
1126+
def __init__(self, path, max_seq_len, name, datasets, **kw):
1127+
"""Initialize an AdversarialNLITask task.
1128+
1129+
Args:
1130+
path (str): AdversarialNLI path relative to the data dir
1131+
max_seq_len (int): max tokens allowed in a sequence
1132+
name (str): task name, specified in @register_task
1133+
datasets (list[str]): list of ANLI sub-datasets used in task
1134+
"""
1135+
super(AdversarialNLITask, self).__init__(name, n_classes=3, **kw)
1136+
self.path = path
1137+
self.max_seq_len = max_seq_len
1138+
self.train_data_text = None
1139+
self.val_data_text = None
1140+
self.test_data_text = None
1141+
self.datasets = datasets
1142+
1143+
def _read_data(self, path: str) -> pd.core.frame.DataFrame:
1144+
"""Read json, tokenize text, encode labels as int, return dataframe."""
1145+
df = pd.read_json(path_or_buf=path, encoding="UTF-8", lines=True)
1146+
# for ANLI datasets n=neutral, e=entailment, c=contradiction
1147+
df["target"] = df["label"].map({"n": 0, "e": 1, "c": 2})
1148+
tokenizer = get_tokenizer(self._tokenizer_name)
1149+
df["context"] = df["context"].apply(tokenizer.tokenize)
1150+
df["hypothesis"] = df["hypothesis"].apply(tokenizer.tokenize)
1151+
return df[["context", "hypothesis", "target"]]
1152+
1153+
def load_data(self):
1154+
"""Read, preprocess and load data into an AdversarialNLITask.
1155+
1156+
Assumes original dataset file structure under `self.rel_path`.
1157+
Loads only the datasets (e.g., "R1") specified in the `datasets` attr.
1158+
Populates task train_, val_, test_data_text and `sentence` attr.
1159+
"""
1160+
train_dfs, val_dfs, test_dfs = [], [], []
1161+
for dataset in self.datasets:
1162+
train_dfs.append(self._read_data(os.path.join(self.path, dataset, "train.jsonl")))
1163+
val_dfs.append(self._read_data(os.path.join(self.path, dataset, "dev.jsonl")))
1164+
test_dfs.append(self._read_data(os.path.join(self.path, dataset, "test.jsonl")))
1165+
train_df = pd.concat(train_dfs, axis=0, ignore_index=True)
1166+
val_df = pd.concat(val_dfs, axis=0, ignore_index=True)
1167+
test_df = pd.concat(test_dfs, axis=0, ignore_index=True)
1168+
1169+
self.train_data_text = [
1170+
train_df["context"].tolist(),
1171+
train_df["hypothesis"].tolist(),
1172+
train_df["target"].tolist(),
1173+
]
1174+
self.val_data_text = [
1175+
val_df["context"].tolist(),
1176+
val_df["hypothesis"].tolist(),
1177+
val_df["target"].tolist(),
1178+
]
1179+
self.test_data_text = [
1180+
test_df["context"].tolist(),
1181+
test_df["hypothesis"].tolist(),
1182+
test_df["target"].tolist(),
1183+
]
1184+
1185+
self.sentences = (
1186+
train_df["context"].tolist()
1187+
+ train_df["hypothesis"].tolist()
1188+
+ val_df["context"].tolist()
1189+
+ val_df["hypothesis"].tolist()
1190+
)
1191+
1192+
log.info("\tFinished loading ANLI data: " + self.name)
1193+
1194+
11001195
@register_task("mnli", rel_path="MNLI/")
11011196
# second copy for different params
11021197
@register_task("mnli-alt", rel_path="MNLI/")

0 commit comments

Comments
 (0)