Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,17 @@ Python version:
from jiant.proj.simple import runscript as run
import jiant.scripts.download_data.runscript as downloader

EXP_DIR = "/path/to/exp"

# Download the Data
downloader.download_data(["mrpc"], "/content/data")
downloader.download_data(["mrpc"], f"{EXP_DIR}/tasks")

# Set up the arguments for the Simple API
args = run.RunConfiguration(
run_name="simple",
exp_dir="/path/to/exp",
data_dir="/path/to/exp/tasks",
model_type="roberta-base",
exp_dir=EXP_DIR,
data_dir=f"{EXP_DIR}/tasks",
hf_pretrained_model_name_or_path="roberta-base",
tasks="mrpc",
train_batch_size=16,
num_train_epochs=3
Expand All @@ -91,15 +93,17 @@ run.run_simple(args)

Bash version:
```bash
EXP_DIR=/path/to/exp

python jiant/scripts/download_data/runscript.py \
download \
--tasks mrpc \
--output_path /path/to/exp/tasks
--output_path ${EXP_DIR}/tasks
python jiant/proj/simple/runscript.py \
run \
--run_name simple \
--exp_dir /path/to/exp \
--data_dir /path/to/exp/tasks \
--exp_dir ${EXP_DIR}/ \
--data_dir ${EXP_DIR}/tasks \
--model_type roberta-base \
--tasks mrpc \
--train_batch_size 16 \
Expand Down
7 changes: 3 additions & 4 deletions jiant/proj/simple/runscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def run_simple(args: RunConfiguration, with_continue: bool = False):
else:
model_load_mode = "from_transformers"
model_weights_path = os.path.join(
model_cache_path, hf_config.model_type, "model", f"{hf_config.model_type}.p"
model_cache_path, hf_config.model_type, "model", "model.p"
)
run_output_dir = os.path.join(args.exp_dir, "runs", args.run_name)

Expand All @@ -213,10 +213,9 @@ def run_simple(args: RunConfiguration, with_continue: bool = False):
output_dir=run_output_dir,
# === Model parameters === #
hf_pretrained_model_name_or_path=args.hf_pretrained_model_name_or_path,
model_type=hf_config.model_type,
model_path=model_weights_path,
model_config_path=os.path.join(
model_cache_path, hf_config.model_type, "model", f"{hf_config.model_type}.json"
model_cache_path, hf_config.model_type, "model", "config.json",
),
model_load_mode=model_load_mode,
# === Running Setup === #
Expand Down Expand Up @@ -258,7 +257,7 @@ def main():
args = RunConfiguration.default_run_cli(cl_args=cl_args)
if mode == "run":
run_simple(args, with_continue=False)
if mode == "run_with_continue":
elif mode == "run_with_continue":
run_simple(args, with_continue=True)
else:
raise zconf.ModeLookupError(mode)
Expand Down
1 change: 1 addition & 0 deletions jiant/scripts/download_data/runscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def download_data_cli(args):


def download_data(task_names, output_base_path):
output_base_path = os.path.abspath(output_base_path)
task_data_base_path = py_io.create_dir(output_base_path, "data")
task_config_base_path = py_io.create_dir(output_base_path, "configs")

Expand Down
2 changes: 1 addition & 1 deletion jiant/tasks/lib/acceptability_judgement/coord.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
import jiant.tasks.lib.acceptability_judgement.base as base
from . import base as base


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion jiant/tasks/lib/acceptability_judgement/definiteness.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
import jiant.tasks.lib.acceptability_judgement.base as base
from . import base as base


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion jiant/tasks/lib/acceptability_judgement/eos.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
import jiant.tasks.lib.acceptability_judgement.base as base
from . import base as base


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion jiant/tasks/lib/acceptability_judgement/whwords.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
import jiant.tasks.lib.acceptability_judgement.base as base
from . import base as base


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion jiant/tasks/lib/senteval/bigram_shift.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
import jiant.tasks.lib.senteval.base as base
from . import base as base
from jiant.tasks.lib.templates.shared import labels_to_bimap


Expand Down
2 changes: 1 addition & 1 deletion jiant/tasks/lib/senteval/coordination_inversion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
import jiant.tasks.lib.senteval.base as base
from . import base as base
from jiant.tasks.lib.templates.shared import labels_to_bimap


Expand Down
2 changes: 1 addition & 1 deletion jiant/tasks/lib/senteval/obj_number.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
import jiant.tasks.lib.senteval.base as base
from . import base as base
from jiant.tasks.lib.templates.shared import labels_to_bimap


Expand Down
2 changes: 1 addition & 1 deletion jiant/tasks/lib/senteval/odd_man_out.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
import jiant.tasks.lib.senteval.base as base
from . import base as base
from jiant.tasks.lib.templates.shared import labels_to_bimap


Expand Down
2 changes: 1 addition & 1 deletion jiant/tasks/lib/senteval/past_present.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
import jiant.tasks.lib.senteval.base as base
from . import base as base
from jiant.tasks.lib.templates.shared import labels_to_bimap


Expand Down
2 changes: 1 addition & 1 deletion jiant/tasks/lib/senteval/sentence_length.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
import jiant.tasks.lib.senteval.base as base
from . import base as base
from jiant.tasks.lib.templates.shared import labels_to_bimap


Expand Down
2 changes: 1 addition & 1 deletion jiant/tasks/lib/senteval/subj_number.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
import jiant.tasks.lib.senteval.base as base
from . import base as base
from jiant.tasks.lib.templates.shared import labels_to_bimap


Expand Down
2 changes: 1 addition & 1 deletion jiant/tasks/lib/senteval/top_constituents.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
import jiant.tasks.lib.senteval.base as base
from . import base as base
from jiant.tasks.lib.templates.shared import labels_to_bimap


Expand Down
2 changes: 1 addition & 1 deletion jiant/tasks/lib/senteval/tree_depth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
import jiant.tasks.lib.senteval.base as base
from . import base as base
from jiant.tasks.lib.templates.shared import labels_to_bimap


Expand Down
2 changes: 1 addition & 1 deletion jiant/tasks/lib/senteval/word_content.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
import jiant.tasks.lib.senteval.base as base
from . import base as base
from jiant.tasks.lib.templates.shared import labels_to_bimap


Expand Down