Skip to content
Merged
3 changes: 1 addition & 2 deletions jiant/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,7 @@ def evaluate(
n_task_examples = 0
task_preds = [] # accumulate DataFrames
assert split in ["train", "val", "test"]
dataset = getattr(task, "%s_data" % split)
generator = iterator(dataset, num_epochs=1, shuffle=False)
generator = iterator(task.get_instance_generator(split), num_epochs=1, shuffle=False)
for batch_idx, batch in enumerate(generator):
with torch.no_grad():
if isinstance(cuda_device, int):
Expand Down
26 changes: 20 additions & 6 deletions jiant/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,22 +416,36 @@ def build_tasks(
target_tasks = []
for task in tasks:
# Replace lists of instances with lazy generators from disk.
task.val_data = _get_instance_generator(task.name, "val", preproc_dir)
task.test_data = _get_instance_generator(task.name, "test", preproc_dir)
task.set_instance_generator(
dataset_name="val",
instance_generator=_get_instance_generator(task.name, "val", preproc_dir),
)
task.set_instance_generator(
dataset_name="test",
instance_generator=_get_instance_generator(task.name, "test", preproc_dir),
)
# When using pretrain_data_fraction, we need modified iterators for use
# only on training datasets at pretraining time.
if task.name in pretrain_task_names:
log.info("\tCreating trimmed pretraining-only version of " + task.name + " train.")
task.train_data = _get_instance_generator(
task.name, "train", preproc_dir, fraction=args.pretrain_data_fraction
task.set_instance_generator(
dataset_name="train",
instance_generator=_get_instance_generator(
task.name, "train", preproc_dir, fraction=args.pretrain_data_fraction
),
phase="pretrain",
)
pretrain_tasks.append(task)
# When using target_train_data_fraction, we need modified iterators
# only for training datasets at do_target_task_training time.
if task.name in target_task_names:
log.info("\tCreating trimmed target-only version of " + task.name + " train.")
task.train_data = _get_instance_generator(
task.name, "train", preproc_dir, fraction=args.target_train_data_fraction
task.set_instance_generator(
dataset_name="train",
instance_generator=_get_instance_generator(
task.name, "train", preproc_dir, fraction=args.target_train_data_fraction
),
phase="target_train",
)
target_tasks.append(task)

Expand Down
39 changes: 38 additions & 1 deletion jiant/tasks/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import logging as log
import os
from typing import Any, Dict, Iterable, List, Sequence, Type
from typing import Any, Dict, Iterable, List, Sequence, Type, Union, Generator

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -35,6 +35,7 @@
tokenize_and_truncate,
load_pair_nli_jsonl,
)
from jiant.utils.serialize import RepeatableIterator
from jiant.utils.tokenizers import get_tokenizer
from jiant.utils.retokenize import get_aligner_fn
from jiant.tasks.registry import register_task # global task registry
Expand Down Expand Up @@ -228,6 +229,7 @@ def __init__(self, name, tokenizer_name):
self.sentences = None
self.example_counts = None
self.contributes_to_aggregate_score = True
self._instance_generators = {}

def load_data(self):
""" Load data from path and create splits. """
Expand Down Expand Up @@ -293,6 +295,41 @@ def handle_preds(self, preds, batch):
"""
return preds

def set_instance_generator(
self, dataset_name: str, instance_generator: Iterable, phase: str = None
):
"""Takes a data instance generator and stores it in a private field of this Task instance

Parameters
----------
dataset_name : string
instance_generator : Iterable
phase : str

"""
self._instance_generators[(dataset_name, phase)] = instance_generator

def get_instance_generator(
self, dataset_name: str, phase: str = None
) -> Union[RepeatableIterator, Generator]:
"""Returns an instance generator for the specified dataset and phase.

Parameters
----------
dataset_name : string
phase : string

Returns
-------
Union[RepeatableIterator, Generator]

"""
if not self._instance_generators:
raise ValueError("set_instance_generator must be called before get_instance_generator")
if dataset_name == "train" and phase is None:
raise ValueError("phase must be specified to get relevant training data")
return self._instance_generators[(dataset_name, phase)]


class ClassificationTask(Task):
""" General classification task """
Expand Down
13 changes: 10 additions & 3 deletions jiant/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,12 @@ def _setup_training(
os.mkdir(os.path.join(self._serialization_dir, task.name))

# Adding task-specific smart iterator to speed up training
instance = [i for i in itertools.islice(task.train_data, 1)][0]
instance = [
i
for i in itertools.islice(
task.get_instance_generator(dataset_name="train", phase=phase), 1
)
][0]
pad_dict = instance.get_padding_lengths()
sorting_keys = []
for field in pad_dict:
Expand All @@ -335,7 +340,9 @@ def _setup_training(
biggest_batch_first=True,
)
task_info["iterator"] = iterator
task_info["tr_generator"] = iterator(task.train_data, num_epochs=None)
task_info["tr_generator"] = iterator(
task.get_instance_generator(dataset_name="train", phase=phase), num_epochs=None
)

n_training_examples = task.n_train_examples
if phase == "pretrain":
Expand Down Expand Up @@ -834,7 +841,7 @@ def _calculate_validation_performance(
else:
max_data_points = task.n_val_examples
val_generator = BasicIterator(batch_size, instances_per_epoch=max_data_points)(
task.val_data, num_epochs=1, shuffle=False
task.get_instance_generator(dataset_name="val"), num_epochs=1, shuffle=False
)
n_val_batches = math.ceil(max_data_points / batch_size)
all_val_metrics["%s_loss" % task.name] = 0.0
Expand Down
72 changes: 58 additions & 14 deletions tests/tasks/test_tasks.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,65 @@
import logging
import unittest

from jiant.tasks import Task
from jiant.tasks.registry import REGISTRY


def test_instantiate_all_tasks():
"""
All tasks should be able to be instantiated without needing to access actual data
class TestTasks(unittest.TestCase):
def test_instantiate_all_tasks(self):
"""
All tasks should be able to be instantiated without needing to access actual data

Test may change if task instantiation signature changes.
"""
logger = logging.getLogger()
logger.setLevel(level=logging.CRITICAL)
for name, (cls, _, kw) in REGISTRY.items():
cls(
"dummy_path",
max_seq_len=1,
name="dummy_name",
tokenizer_name="dummy_tokenizer_name",
**kw,
Test may change if task instantiation signature changes.
"""
logger = logging.getLogger()
logger.setLevel(level=logging.CRITICAL)
for name, (cls, _, kw) in REGISTRY.items():
cls(
"dummy_path",
max_seq_len=1,
name="dummy_name",
tokenizer_name="dummy_tokenizer_name",
**kw,
)

def test_tasks_get_train_instance_generators_without_phase(self):
task = Task(name="dummy_name", tokenizer_name="dummy_tokenizer_name")
train_iterable_instance_generator = [1, 2, 3]
task.set_instance_generator("train", train_iterable_instance_generator, "target_train")
self.assertRaises(ValueError, task.get_instance_generator, "train")

def test_tasks_set_and_get_instance_generators(self):
task = Task(name="dummy_name", tokenizer_name="dummy_tokenizer_name")
val_iterable_instance_generator = [1, 2, 3]
test_iterable_instance_generator = [4, 5, 6]
train_pretrain_iterable_instance_generator = [7, 8]
train_target_train_iterable_instance_generator = [9]
task.set_instance_generator("val", val_iterable_instance_generator)
task.set_instance_generator("test", test_iterable_instance_generator)
task.set_instance_generator("train", train_pretrain_iterable_instance_generator, "pretrain")
task.set_instance_generator(
"train", train_target_train_iterable_instance_generator, "target_train"
)
retreived_val_iterable_instance_generator = task.get_instance_generator("val")
retreived_test_iterable_instance_generator = task.get_instance_generator("test")
retreived_train_pretrain_iterable_instance_generator = task.get_instance_generator(
"train", "pretrain"
)
retreived_train_target_iterable_instance_generator = task.get_instance_generator(
"train", "target_train"
)
self.assertListEqual(
val_iterable_instance_generator, retreived_val_iterable_instance_generator
)
self.assertListEqual(
test_iterable_instance_generator, retreived_test_iterable_instance_generator
)
self.assertListEqual(
train_pretrain_iterable_instance_generator,
retreived_train_pretrain_iterable_instance_generator,
)
self.assertListEqual(
train_target_train_iterable_instance_generator,
retreived_train_target_iterable_instance_generator,
)
Loading