@@ -1097,6 +1097,101 @@ def load_data(self):
10971097 log .info ("\t Finished loading SNLI data." )
10981098
10991099
1100+ @register_task ("adversarial_nli_a1" , rel_path = "AdversarialNLI/" , datasets = ["R1" ])
1101+ @register_task ("adversarial_nli_a2" , rel_path = "AdversarialNLI/" , datasets = ["R2" ])
1102+ @register_task ("adversarial_nli_a3" , rel_path = "AdversarialNLI/" , datasets = ["R3" ])
1103+ @register_task ("adversarial_nli" , rel_path = "AdversarialNLI/" , datasets = ["R1" , "R2" , "R3" ])
1104+ class AdversarialNLITask (PairClassificationTask ):
1105+ """Task class for use with Adversarial Natural Language Inference dataset.
1106+
1107+ Configures a 3-class PairClassificationTask using Adversarial NLI data.
1108+ Requires original ANLI dataset file structure under the relative path.
1109+ Data: https://dl.fbaipublicfiles.com/anli/anli_v0.1.zip
1110+ Paper: https://arxiv.org/abs/1910.14599
1111+
1112+ Attributes:
1113+ path (str): AdversarialNLI path relative to JIANT_DATA_DIR
1114+ max_seq_len (int): max tokens allowed in a sequence
1115+ train_data_text (list[list[str], list[str], list[int]]):
1116+ list of lists of context, hypothesis, and target training data
1117+ val_data_text (list[list[str], list[str], list[int]]):
1118+ list of lists of context, hypothesis, and target val data
1119+ test_data_text (list[list[str], list[str], list[int]]):
1120+ list of lists of context, hypothesis, and target test data
1121+ datasets (list[str]): list of sub-datasets used in task (e.g., R1)
1122+ sentences (list): list of all (tokenized) context and hypothesis
1123+ texts from train and val data.
1124+ """
1125+
1126+ def __init__ (self , path , max_seq_len , name , datasets , ** kw ):
1127+ """Initialize an AdversarialNLITask task.
1128+
1129+ Args:
1130+ path (str): AdversarialNLI path relative to the data dir
1131+ max_seq_len (int): max tokens allowed in a sequence
1132+ name (str): task name, specified in @register_task
1133+ datasets (list[str]): list of ANLI sub-datasets used in task
1134+ """
1135+ super (AdversarialNLITask , self ).__init__ (name , n_classes = 3 , ** kw )
1136+ self .path = path
1137+ self .max_seq_len = max_seq_len
1138+ self .train_data_text = None
1139+ self .val_data_text = None
1140+ self .test_data_text = None
1141+ self .datasets = datasets
1142+
1143+ def _read_data (self , path : str ) -> pd .core .frame .DataFrame :
1144+ """Read json, tokenize text, encode labels as int, return dataframe."""
1145+ df = pd .read_json (path_or_buf = path , encoding = "UTF-8" , lines = True )
1146+ # for ANLI datasets n=neutral, e=entailment, c=contradiction
1147+ df ["target" ] = df ["label" ].map ({"n" : 0 , "e" : 1 , "c" : 2 })
1148+ tokenizer = get_tokenizer (self ._tokenizer_name )
1149+ df ["context" ] = df ["context" ].apply (tokenizer .tokenize )
1150+ df ["hypothesis" ] = df ["hypothesis" ].apply (tokenizer .tokenize )
1151+ return df [["context" , "hypothesis" , "target" ]]
1152+
1153+ def load_data (self ):
1154+ """Read, preprocess and load data into an AdversarialNLITask.
1155+
1156+ Assumes original dataset file structure under `self.rel_path`.
1157+ Loads only the datasets (e.g., "R1") specified in the `datasets` attr.
1158+ Populates task train_, val_, test_data_text and `sentence` attr.
1159+ """
1160+ train_dfs , val_dfs , test_dfs = [], [], []
1161+ for dataset in self .datasets :
1162+ train_dfs .append (self ._read_data (os .path .join (self .path , dataset , "train.jsonl" )))
1163+ val_dfs .append (self ._read_data (os .path .join (self .path , dataset , "dev.jsonl" )))
1164+ test_dfs .append (self ._read_data (os .path .join (self .path , dataset , "test.jsonl" )))
1165+ train_df = pd .concat (train_dfs , axis = 0 , ignore_index = True )
1166+ val_df = pd .concat (val_dfs , axis = 0 , ignore_index = True )
1167+ test_df = pd .concat (test_dfs , axis = 0 , ignore_index = True )
1168+
1169+ self .train_data_text = [
1170+ train_df ["context" ].tolist (),
1171+ train_df ["hypothesis" ].tolist (),
1172+ train_df ["target" ].tolist (),
1173+ ]
1174+ self .val_data_text = [
1175+ val_df ["context" ].tolist (),
1176+ val_df ["hypothesis" ].tolist (),
1177+ val_df ["target" ].tolist (),
1178+ ]
1179+ self .test_data_text = [
1180+ test_df ["context" ].tolist (),
1181+ test_df ["hypothesis" ].tolist (),
1182+ test_df ["target" ].tolist (),
1183+ ]
1184+
1185+ self .sentences = (
1186+ train_df ["context" ].tolist ()
1187+ + train_df ["hypothesis" ].tolist ()
1188+ + val_df ["context" ].tolist ()
1189+ + val_df ["hypothesis" ].tolist ()
1190+ )
1191+
1192+ log .info ("\t Finished loading ANLI data: " + self .name )
1193+
1194+
11001195@register_task ("mnli" , rel_path = "MNLI/" )
11011196# second copy for different params
11021197@register_task ("mnli-alt" , rel_path = "MNLI/" )
0 commit comments