@@ -3241,3 +3241,47 @@ def get_metrics(self, reset=False):
32413241 """Get metrics specific to the task"""
32423242 acc = self .scorer1 .get_metric (reset )
32433243 return {"accuracy" : acc }
3244+
3245+
3246+ @register_task ("scitail" , rel_path = "SciTailV1.1/tsv_format/" )
3247+ class SciTailTask (PairClassificationTask ):
3248+ """ Task class for SciTail http://data.allenai.org/scitail/ """
3249+
3250+ def __init__ (self , path , max_seq_len , name , ** kw ):
3251+ super ().__init__ (name , n_classes = 2 , ** kw )
3252+ self .path = path
3253+ self .max_seq_len = max_seq_len
3254+
3255+ self .train_data_text = None
3256+ self .val_data_text = None
3257+ self .test_data_text = None
3258+
3259+ def load_data (self ):
3260+ """Process and load Scitail data"""
3261+ targ_map = {"neutral" : 0 , "entails" : 1 }
3262+ self .train_data_text = load_tsv (
3263+ self ._tokenizer_name ,
3264+ os .path .join (self .path , "scitail_1.0_train.tsv" ),
3265+ max_seq_len = self .max_seq_len ,
3266+ label_fn = targ_map .__getitem__ ,
3267+ )
3268+ self .val_data_text = load_tsv (
3269+ self ._tokenizer_name ,
3270+ os .path .join (self .path , "scitail_1.0_dev.tsv" ),
3271+ max_seq_len = self .max_seq_len ,
3272+ label_fn = targ_map .__getitem__ ,
3273+ )
3274+ self .test_data_text = load_tsv (
3275+ self ._tokenizer_name ,
3276+ os .path .join (self .path , "scitail_1.0_test.tsv" ),
3277+ max_seq_len = self .max_seq_len ,
3278+ label_fn = targ_map .__getitem__ ,
3279+ return_indices = True ,
3280+ )
3281+ self .sentences = (
3282+ self .train_data_text [0 ]
3283+ + self .train_data_text [1 ]
3284+ + self .val_data_text [0 ]
3285+ + self .val_data_text [1 ]
3286+ )
3287+ log .info ("\t Finished loading SciTail" )
0 commit comments