Skip to content

Conversation

@davidheineman
Copy link
Member

@davidheineman davidheineman commented May 17, 2025

Add an option to specify MCQA tasks using 1 forward pass instead of 4 forward passes (only works with single-token continuations, will throw failure otherwise)

To test:

torchrun --nproc-per-node=1 src/scripts/train/OLMo2-ladder.py train 190M 5xC ai2/jupiter-cirrascale-2 \
    --trainer.callbacks.downstream_evaluator.eval_on_startup=true \
    --trainer.callbacks.downstream_evaluator.cancel_after_first_eval=true \
    --launch.num_gpus=1 \
    --ladder.save_folder=/tmp/debug \
    --trainer.save_folder=/tmp/debug \
    --launch.allow_dirty=True \
    --trainer.callbacks.downstream_evaluator.tasks=[arc_challenge_test_mc_5shot_fast,arc_challenge_test_mc_5shot]

Outputs:

2025-05-16 20:26:47.681 triton-cs-aus-454.reviz.ai2.in:0        olmo_core.train.callbacks.evaluator_callback:151    INFO     Finished downstream evals in 45.5 seconds. Metrics:
    arc_challenge_test_mc_5shot (accuracy)=0.2654
    arc_challenge_test_mc_5shot (accuracy v2)=0.2654
    arc_challenge_test_mc_5shot (CE loss)=11.76
    arc_challenge_test_mc_5shot (CE loss v2)=5.880
    arc_challenge_test_mc_5shot (BPB)=16.96
    arc_challenge_test_mc_5shot (BPB v2)=8.482
    arc_challenge_test_mc_5shot (soft loss)=0.2594
    arc_challenge_test_mc_5shot (soft loss v2)=0.2594
    arc_challenge_test_mc_5shot (log soft loss)=-2.13E+00
    arc_challenge_test_mc_5shot (log soft loss v2)=-2.13E+00
2025-05-16 20:27:00.868 triton-cs-aus-454.reviz.ai2.in:0        olmo_core.train.callbacks.evaluator_callback:151    INFO     Finished downstream evals in 13.2 seconds. Metrics:
    arc_challenge_test_mc_5shot_fast (accuracy)=0.2654
    arc_challenge_test_mc_5shot_fast (accuracy v2)=0.2654
    arc_challenge_test_mc_5shot_fast (CE loss)=11.76
    arc_challenge_test_mc_5shot_fast (CE loss v2)=5.880
    arc_challenge_test_mc_5shot_fast (BPB)=16.96
    arc_challenge_test_mc_5shot_fast (BPB v2)=8.482
    arc_challenge_test_mc_5shot_fast (soft loss)=0.2594
    arc_challenge_test_mc_5shot_fast (soft loss v2)=0.2594
    arc_challenge_test_mc_5shot_fast (log soft loss)=-2.13E+00
    arc_challenge_test_mc_5shot_fast (log soft loss v2)=-2.13E+00
2025-05-16 20:27:00.868 triton-cs-aus-454.reviz.ai2.in:0        olmo_core.train.callbacks.evaluator_callback:171    INFO     Evaluation speed:
    arc_challenge_test_mc_5shot_fast (accuracy) (+variants): 13.2 sec (23 batches)
    arc_challenge_test_mc_5shot (accuracy) (+variants): 45.5 sec (91 batches)

To test on the canonical task suite:

# Run on all tasks (slow)
torchrun --nproc-per-node=1 src/scripts/train/OLMo2-ladder.py train 1B 5xC ai2/jupiter-cirrascale-2 \
    --trainer.callbacks.downstream_evaluator.eval_on_startup=true \
    --trainer.callbacks.downstream_evaluator.cancel_after_first_eval=true \
    --launch.num_gpus=1 \
    --ladder.save_folder=/tmp/debug \
    --trainer.save_folder=/tmp/debug \
    --launch.allow_dirty=True \
    --trainer.callbacks.downstream_evaluator.tasks=[arc_challenge_test_mc_5shot,arc_easy_test_mc_5shot,hellaswag_rc_5shot,csqa_val_mc_5shot,piqa_val_mc_5shot,socialiqa_val_mc_5shot,winogrande_val_rc_5shot,mmlu_stem_val_mc_5shot,mmlu_humanities_val_mc_5shot,mmlu_social_sciences_val_mc_5shot,mmlu_other_val_mc_5shot,mmlu_stem_test_mc_5shot,mmlu_humanities_test_mc_5shot,mmlu_social_sciences_test_mc_5shot,mmlu_other_test_mc_5shot,gsm8k_gold_bpb_5shot,minerva_math_algebra_gold_bpb_0shot,minerva_math_counting_and_probability_gold_bpb_0shot,minerva_math_geometry_gold_bpb_0shot,minerva_math_intermediate_algebra_gold_bpb_0shot,minerva_math_number_theory_gold_bpb_0shot,minerva_math_prealgebra_gold_bpb_0shot,minerva_math_precalculus_gold_bpb_0shot,codex_humaneval_gold_bpb_0shot,codex_mbpp_gold_bpb_0shot,copycolors_10way]

# Run on all tasks (fast)
torchrun --nproc-per-node=1 src/scripts/train/OLMo2-ladder.py train 1B 5xC ai2/jupiter-cirrascale-2 \
    --trainer.callbacks.downstream_evaluator.eval_on_startup=true \
    --trainer.callbacks.downstream_evaluator.cancel_after_first_eval=true \
    --launch.num_gpus=1 \
    --ladder.save_folder=/tmp/debug \
    --trainer.save_folder=/tmp/debug \
    --launch.allow_dirty=True \
    --trainer.callbacks.downstream_evaluator.tasks=[arc_challenge_test_mc_5shot_fast,arc_easy_test_mc_5shot_fast,hellaswag_rc_5shot,csqa_val_mc_5shot_fast,piqa_val_mc_5shot_fast,socialiqa_val_mc_5shot_fast,winogrande_val_rc_5shot,mmlu_stem_val_mc_5shot_fast,mmlu_humanities_val_mc_5shot_fast,mmlu_social_sciences_val_mc_5shot_fast,mmlu_other_val_mc_5shot_fast,mmlu_stem_test_mc_5shot_fast,mmlu_humanities_test_mc_5shot_fast,mmlu_social_sciences_test_mc_5shot_fast,mmlu_other_test_mc_5shot_fast,gsm8k_gold_bpb_5shot,minerva_math_algebra_gold_bpb_0shot,minerva_math_counting_and_probability_gold_bpb_0shot,minerva_math_geometry_gold_bpb_0shot,minerva_math_intermediate_algebra_gold_bpb_0shot,minerva_math_number_theory_gold_bpb_0shot,minerva_math_prealgebra_gold_bpb_0shot,minerva_math_precalculus_gold_bpb_0shot,codex_humaneval_gold_bpb_0shot,codex_mbpp_gold_bpb_0shot,copycolors_10way_fast]

Output on original setup:

2025-05-16 21:22:15.614 triton-cs-aus-454.reviz.ai2.in:0        olmo_core.train.callbacks.evaluator_callback:171        INFO    Evaluation speed:
    copycolors_10way (accuracy) (+variants): 4.8 sec (10 batches)
    codex_humaneval_gold_bpb_0shot (BPB) (+variants): 5.3 sec (4 batches)
    mmlu_other_val_mc_5shot (length-normalized accuracy) (+variants): 15.2 sec (71 batches)
    codex_mbpp_gold_bpb_0shot (BPB) (+variants): 15.4 sec (30 batches)
    mmlu_social_sciences_val_mc_5shot (length-normalized accuracy) (+variants): 15.5 sec (75 batches)
    mmlu_stem_val_mc_5shot (length-normalized accuracy) (+variants): 15.6 sec (50 batches)
    winogrande_val_rc_5shot (length-normalized accuracy) (+variants): 16.6 sec (25 batches)
    piqa_val_mc_5shot (accuracy) (+variants): 34.2 sec (109 batches)
    minerva_math_counting_and_probability_gold_bpb_0shot (BPB) (+variants): 38.9 sec (37 batches)
    socialiqa_val_mc_5shot (accuracy) (+variants): 41.6 sec (85 batches)
    csqa_val_mc_5shot (accuracy) (+variants): 42.8 sec (89 batches)
    arc_challenge_test_mc_5shot (accuracy) (+variants): 43.6 sec (91 batches)
    mmlu_humanities_val_mc_5shot (length-normalized accuracy) (+variants): 44.3 sec (160 batches)
    minerva_math_number_theory_gold_bpb_0shot (BPB) (+variants): 48.9 sec (39 batches)
    minerva_math_prealgebra_gold_bpb_0shot (BPB) (+variants): 61.9 sec (55 batches)
    arc_easy_test_mc_5shot (accuracy) (+variants): 68.7 sec (183 batches)
    gsm8k_gold_bpb_5shot (BPB) (+variants): 69.0 sec (58 batches)
    minerva_math_geometry_gold_bpb_0shot (BPB) (+variants): 69.2 sec (37 batches)
    hellaswag_rc_5shot (length-normalized accuracy) (+variants): 84.3 sec (77 batches)
    minerva_math_precalculus_gold_bpb_0shot (BPB) (+variants): 94.2 sec (42 batches)
    minerva_math_algebra_gold_bpb_0shot (BPB) (+variants): 96.9 sec (92 batches)
    mmlu_other_test_mc_5shot (length-normalized accuracy) (+variants): 139.8 sec (721 batches)
    mmlu_social_sciences_test_mc_5shot (length-normalized accuracy) (+variants): 140.1 sec (684 batches)
    minerva_math_intermediate_algebra_gold_bpb_0shot (BPB) (+variants): 140.2 sec (70 batches)
    mmlu_stem_test_mc_5shot (length-normalized accuracy) (+variants): 142.4 sec (525 batches)
    mmlu_humanities_test_mc_5shot (length-normalized accuracy) (+variants): 397.5 sec (1448 batches)
    Total evaluation time: 1887.0 seconds (4867 batches)

Output on new setup:

2025-05-17 02:34:16.010 triton-cs-aus-454.reviz.ai2.in:0        olmo_core.train.callbacks.evaluator_callback:171        INFO       Evaluation speed:
    copycolors_10way_fast (accuracy) (+variants): 0.8 sec (1 batches)
    mmlu_stem_val_mc_5shot_fast (length-normalized accuracy) (+variants): 4.5 sec (13 batches)
    mmlu_other_val_mc_5shot_fast (length-normalized accuracy) (+variants): 4.6 sec (18 batches)
    mmlu_social_sciences_val_mc_5shot_fast (length-normalized accuracy) (+variants): 4.8 sec (19 batches)
    codex_humaneval_gold_bpb_0shot (BPB) (+variants): 4.8 sec (4 batches)
    csqa_val_mc_5shot_fast (accuracy) (+variants): 10.3 sec (18 batches)
    mmlu_humanities_val_mc_5shot_fast (length-normalized accuracy) (+variants): 12.3 sec (40 batches)
    codex_mbpp_gold_bpb_0shot (BPB) (+variants): 14.4 sec (30 batches)
    arc_challenge_test_mc_5shot_fast (accuracy) (+variants): 15.0 sec (23 batches)
    socialiqa_val_mc_5shot_fast (accuracy) (+variants): 15.2 sec (29 batches)
    winogrande_val_rc_5shot (length-normalized accuracy) (+variants): 15.7 sec (25 batches)
    piqa_val_mc_5shot_fast (accuracy) (+variants): 19.0 sec (55 batches)
    arc_easy_test_mc_5shot_fast (accuracy) (+variants): 20.5 sec (46 batches)
    minerva_math_counting_and_probability_gold_bpb_0shot (BPB) (+variants): 36.0 sec (37 batches)
    mmlu_stem_test_mc_5shot_fast (length-normalized accuracy) (+variants): 40.4 sec (132 batches)
    mmlu_social_sciences_test_mc_5shot_fast (length-normalized accuracy) (+variants): 40.4 sec (171 batches)
    mmlu_other_test_mc_5shot_fast (length-normalized accuracy) (+variants): 40.5 sec (181 batches)
    minerva_math_number_theory_gold_bpb_0shot (BPB) (+variants): 45.4 sec (39 batches)
    minerva_math_prealgebra_gold_bpb_0shot (BPB) (+variants): 58.1 sec (55 batches)
    minerva_math_geometry_gold_bpb_0shot (BPB) (+variants): 62.8 sec (37 batches)
    gsm8k_gold_bpb_5shot (BPB) (+variants): 65.0 sec (58 batches)
    hellaswag_rc_5shot (length-normalized accuracy) (+variants): 79.9 sec (77 batches)
    minerva_math_precalculus_gold_bpb_0shot (BPB) (+variants): 87.0 sec (42 batches)
    minerva_math_algebra_gold_bpb_0shot (BPB) (+variants): 88.7 sec (92 batches)
    mmlu_humanities_test_mc_5shot_fast (length-normalized accuracy) (+variants): 108.1 sec (362 batches)
    minerva_math_intermediate_algebra_gold_bpb_0shot (BPB) (+variants): 127.5 sec (70 batches)
    Total evaluation time: 1021.8 seconds (1674 batches)

Here is new new and old config based on this

CURRENT_CONFIG = [
    # OLMES Core 9(-ish) RC
    "arc_challenge_test_rc_5shot",
    "arc_easy_test_rc_5shot",
    "hellaswag_rc_5shot",  # 1K subset of HellaSwag
    "winogrande_val_rc_5shot",  # Helpful after 750M-5xC scale
    "csqa_val_rc_5shot",
    "piqa_val_rc_5shot",
    "socialiqa_val_rc_5shot",
    # MMLU RC
    "mmlu_stem_val_rc_5shot",
    "mmlu_humanities_val_rc_5shot",
    "mmlu_social_sciences_val_rc_5shot",
    "mmlu_other_val_rc_5shot",
    "mmlu_stem_test_rc_5shot",
    "mmlu_humanities_test_rc_5shot",
    "mmlu_social_sciences_test_rc_5shot",
    "mmlu_other_test_rc_5shot",
    # OLMES Core 9(-ish) MC
    "arc_challenge_test_mc_5shot",
    "arc_easy_test_mc_5shot",
    "hellaswag_rc_5shot",  # 1K subset of HellaSwag
    "csqa_val_mc_5shot",
    "piqa_val_mc_5shot",
    "socialiqa_val_mc_5shot",
    "winogrande_val_rc_5shot",
    # MMLU MC BPB
    "mmlu_stem_val_mc_5shot",
    "mmlu_humanities_val_mc_5shot",
    "mmlu_social_sciences_val_mc_5shot",
    "mmlu_other_val_mc_5shot",
    "mmlu_stem_test_mc_5shot",
    "mmlu_humanities_test_mc_5shot",
    "mmlu_social_sciences_test_mc_5shot",
    "mmlu_other_test_mc_5shot",
    # Gen tasks BPB
    "gsm8k_gold_bpb_5shot",
    "minerva_math_algebra_gold_bpb_0shot",
    "minerva_math_counting_and_probability_gold_bpb_0shot",
    "minerva_math_geometry_gold_bpb_0shot",
    "minerva_math_intermediate_algebra_gold_bpb_0shot",
    "minerva_math_number_theory_gold_bpb_0shot",
    "minerva_math_prealgebra_gold_bpb_0shot",
    "minerva_math_precalculus_gold_bpb_0shot",
    "codex_humaneval_gold_bpb_0shot",
    "codex_mbpp_gold_bpb_0shot",
    # Sanity check for MCQA ability
    "copycolors_10way",
]

PROPOSED_CONFIG = [
    # OLMES Core 9(-ish) BPB
    "arc_challenge_test_bpb_5shot",
    "arc_easy_test_bpb_5shot",
    "hellaswag_bpb_5shot",  # 1K subset of HellaSwag
    # MMLU BPB
    "mmlu_stem_test_bpb_5shot",
    "mmlu_humanities_test_bpb_5shot",
    "mmlu_social_sciences_test_bpb_5shot",
    "mmlu_other_test_bpb_5shot",
    # OLMES Core 9(-ish) MC
    "arc_challenge_test_mc_5shot_fast",
    "arc_easy_test_mc_5shot_fast",
    # MMLU MC
    "mmlu_stem_test_mc_5shot_fast",
    "mmlu_humanities_test_mc_5shot_fast",
    "mmlu_social_sciences_test_mc_5shot_fast",
    "mmlu_other_test_mc_5shot_fast",
    # Gen tasks BPB
    "gsm8k_gold_bpb_5shot",
    "minerva_math_algebra_gold_bpb_0shot",
    "minerva_math_counting_and_probability_gold_bpb_0shot",
    "minerva_math_geometry_gold_bpb_0shot",
    "minerva_math_intermediate_algebra_gold_bpb_0shot",
    "minerva_math_number_theory_gold_bpb_0shot",
    "minerva_math_prealgebra_gold_bpb_0shot",
    "minerva_math_precalculus_gold_bpb_0shot",
    "codex_humaneval_gold_bpb_0shot",
    "codex_mbpp_gold_bpb_0shot",
    # Sanity check for MCQA ability
    "copycolors_10way_fast",
    # Basic Skills
    "basic_skills_arithmetic_rc_5shot",
    "basic_skills_coding_rc_5shot",
    "basic_skills_common_knowledge_rc_5shot",
    "basic_skills_logical_reasoning_rc_5shot",
    "basic_skills_pattern_rc_5shot",
    "basic_skills_string_operations_rc_5shot",
]

Here is the speed for that set:

2025-05-17 11:25:23.908 triton-cs-aus-454.reviz.ai2.in:0        olmo_core.train.callbacks.evaluator_callback:171        INFO    Evaluation speed:
    copycolors_10way_fast (accuracy) (+variants): 0.9 sec (1 batches)
    codex_humaneval_gold_bpb_0shot (BPB) (+variants): 5.1 sec (4 batches)
    arc_challenge_test_mc_5shot_fast (accuracy) (+variants): 12.1 sec (23 batches)
    arc_challenge_test_bpb_5shot (BPB) (+variants): 12.5 sec (17 batches)
    codex_mbpp_gold_bpb_0shot (BPB) (+variants): 15.1 sec (30 batches)
    arc_easy_test_bpb_5shot (BPB) (+variants): 15.8 sec (35 batches)
    arc_easy_test_mc_5shot_fast (accuracy) (+variants): 20.9 sec (46 batches)
    hellaswag_bpb_5shot (BPB) (+variants): 21.1 sec (20 batches)
    mmlu_social_sciences_test_bpb_5shot (BPB) (+variants): 31.9 sec (76 batches)
    mmlu_other_test_bpb_5shot (BPB) (+variants): 33.9 sec (141 batches)
    mmlu_stem_test_bpb_5shot (BPB) (+variants): 34.0 sec (105 batches)
    minerva_math_counting_and_probability_gold_bpb_0shot (BPB) (+variants): 39.6 sec (37 batches)
    mmlu_stem_test_mc_5shot_fast (length-normalized accuracy) (+variants): 40.9 sec (132 batches)
    mmlu_social_sciences_test_mc_5shot_fast (length-normalized accuracy) (+variants): 41.0 sec (171 batches)
    mmlu_other_test_mc_5shot_fast (length-normalized accuracy) (+variants): 41.2 sec (181 batches)
    minerva_math_number_theory_gold_bpb_0shot (BPB) (+variants): 48.8 sec (39 batches)
    minerva_math_prealgebra_gold_bpb_0shot (BPB) (+variants): 63.4 sec (55 batches)
    minerva_math_geometry_gold_bpb_0shot (BPB) (+variants): 68.5 sec (37 batches)
    gsm8k_gold_bpb_5shot (BPB) (+variants): 68.5 sec (58 batches)
    minerva_math_precalculus_gold_bpb_0shot (BPB) (+variants): 96.7 sec (42 batches)
    minerva_math_algebra_gold_bpb_0shot (BPB) (+variants): 97.0 sec (92 batches)
    mmlu_humanities_test_bpb_5shot (BPB) (+variants): 102.5 sec (362 batches)
    mmlu_humanities_test_mc_5shot_fast (length-normalized accuracy) (+variants): 109.9 sec (362 batches)
    minerva_math_intermediate_algebra_gold_bpb_0shot (BPB) (+variants): 137.0 sec (70 batches)
    Total evaluation time: 1158.4 seconds (2136 batches)

@davidheineman davidheineman self-assigned this May 17, 2025
@davidheineman davidheineman merged commit b635bb9 into main May 18, 2025
8 checks passed
@davidheineman davidheineman deleted the fast-mc branch May 18, 2025 00:50
epwalsh pushed a commit to allenai/OLMo-core that referenced this pull request May 19, 2025
Bump in-loop evals to v0.8.1. This will add "fast" MCQA, which performs
MC tasks in 1 forward pass instead of 4 forward passes. (We extract the
A/B/C/D logits from a single pass).

allenai/OLMo-in-loop-evals#8

This will make the MC tasks 4x faster, and produces the same numbers.

Also, added Java, Rust and C++ translated MBPP BPB.
TianhuaTao pushed a commit to allenai/OLMo-core that referenced this pull request May 28, 2025
Bump in-loop evals to v0.8.1. This will add "fast" MCQA, which performs
MC tasks in 1 forward pass instead of 4 forward passes. (We extract the
A/B/C/D logits from a single pass).

allenai/OLMo-in-loop-evals#8

This will make the MC tasks 4x faster, and produces the same numbers.

Also, added Java, Rust and C++ translated MBPP BPB.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants