Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
069a273
drop unused is_mlp
farhadrgh Apr 1, 2025
44874b3
update the in/out shapes
farhadrgh Apr 2, 2025
b62aab3
bum NeMo
farhadrgh Apr 9, 2025
071243a
change checkpoint name pattern (#786)
farhadrgh Apr 1, 2025
46cfde2
Revert commit 67a869b3c601096475b82c1da1d2339d860dbce1 (TE_VERSION=v1…
dorotat-nv Apr 1, 2025
8a1232a
[cye/subpack-gpu-testing] Add GPU runner to testing job. (#776)
cspades Apr 2, 2025
17f97b0
Dockerfile improvements for ARM (#777)
trvachov Apr 2, 2025
739d178
Remove llama-index from container to fix CVEs (#800)
trvachov Apr 3, 2025
345c9aa
update
farhadrgh May 5, 2025
b56ce3e
Add local clone script (#787)
nvdreidenbach Apr 4, 2025
08c1833
Fix ARM docker build (#801)
trvachov Apr 4, 2025
391c04c
[cye/ml-subpackage-ci] Onboard bionemo-llm and bionemo-noodles to the…
cspades Apr 9, 2025
af6446b
Update README.md link (#812)
nvdreidenbach Apr 9, 2025
f8246b4
Have dependabot update our docker base image (#813)
pstjohn Apr 10, 2025
019b9fc
Add .codecov.yml status checks (#618)
pstjohn Apr 10, 2025
36fe406
Add AMPLIFY model documentation, minor type fixes (#788)
pstjohn Apr 10, 2025
21fd6c3
Remove import guard in bionemo-llm (#804)
pstjohn Apr 10, 2025
c96de97
Bump rust from 1.82.0 to 1.86.0 (#819)
dependabot[bot] Apr 10, 2025
6c47b7a
Bump crossbeam-channel from 0.5.13 to 0.5.15 in /sub-packages/bionemo…
dependabot[bot] Apr 10, 2025
eb569d5
Pbinder/geneformer partial conv (#802)
polinabinder1 Apr 10, 2025
14825c7
[cye/rapids-sc-install] Add rapids_singlecell import to BioNeMo FW co…
cspades Apr 10, 2025
3723d9b
Biopharma mailing list docs addition. (#822)
trvachov Apr 11, 2025
13904b6
unify the implementation of early training termination across BioNeMo…
dorotat-nv Apr 11, 2025
0bd33a8
Fix bitsandbytes issue on ARM (#824)
trvachov Apr 11, 2025
7cd0c4d
Fixes for AMPLIFY QA scripts (#825)
pstjohn Apr 14, 2025
159886c
remove call to context_parallel loss
farhadrgh Apr 14, 2025
b6de3b5
updated configs for benchmarks (#833)
dorotat-nv Apr 15, 2025
4262343
temporary dependency fixes for upstream nemo changes
pstjohn Apr 16, 2025
3cd4c4b
pin ngcsdk
pstjohn Apr 17, 2025
dddd377
Remove temporary pins in docs build (#828)
pstjohn Apr 17, 2025
5a30c50
Adding baseline metrics for benchmarking ESM2 model (#831)
ShevaNguyen Apr 17, 2025
3ad546a
Updates docs for geneformer training, inference, and cellxclassificat…
jomitchellnv Apr 21, 2025
a32fe83
Add pre commit to verify test status (#841)
pstjohn Apr 22, 2025
91b2feb
fix geneformer image paths (#839)
jomitchellnv Apr 22, 2025
2c27849
fix geneformer image links (#844)
jomitchellnv Apr 24, 2025
26684c5
Pbinder/auto resume (#766)
polinabinder1 Apr 25, 2025
7b4c9b9
Pbinder/esm2 document (#846)
polinabinder1 Apr 25, 2025
dda686c
h11 CRIT vuln fix (#847)
trvachov Apr 25, 2025
ca6e41f
Docs fix (#826)
trvachov Apr 25, 2025
917c9c4
Fix masked token loss refactor from NeMo bump. (#855)
cspades May 5, 2025
1550e85
resolve conflicts
farhadrgh May 5, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/Megatron-LM
Submodule Megatron-LM updated 969 files
2 changes: 1 addition & 1 deletion 3rdparty/NeMo
Submodule NeMo updated from b68596 to 42d2b5
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.

import torch
from megatron.core.models.common.embeddings.rope_utils import apply_rotary_pos_emb
from megatron.core.models.common.embeddings.rope_utils import _apply_rotary_pos_emb_bshd
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
from transformers import AutoConfig

Expand Down Expand Up @@ -47,8 +47,20 @@ def test_rope_embeddings():
seq_len_interpolation_factor=nemo_config.seq_len_interpolation_factor,
)
rotary_pos_emb = rotary_pos_layer(q.shape[1])
q_post_nemo = apply_rotary_pos_emb(q.transpose(0, 1).cuda(), rotary_pos_emb.cuda(), config=nemo_config).cpu()
k_post_nemo = apply_rotary_pos_emb(k.transpose(0, 1).cuda(), rotary_pos_emb.cuda(), config=nemo_config).cpu()
# Note: Use the backend implementation of the RoPE to avoid
# getting or instantiating a CP process group.
q_post_nemo = _apply_rotary_pos_emb_bshd(
q.transpose(0, 1).cuda(),
rotary_pos_emb.cuda(),
rotary_interleaved=nemo_config.rotary_interleaved,
multi_latent_attention=nemo_config.multi_latent_attention,
).cpu()
k_post_nemo = _apply_rotary_pos_emb_bshd(
k.transpose(0, 1).cuda(),
rotary_pos_emb.cuda(),
rotary_interleaved=nemo_config.rotary_interleaved,
multi_latent_attention=nemo_config.multi_latent_attention,
).cpu()

torch.testing.assert_close(q_post, q_post_nemo.transpose(0, 1))
torch.testing.assert_close(k_post, k_post_nemo.transpose(0, 1))
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def test_main_runs(tmp_path, dummy_protein_dataset, dummy_parquet_train_val_inpu
event_files = list(log_dir.rglob("events.out.tfevents*"))
assert event_files, f"No TensorBoard event files found under {log_dir}"
assert "val_ppl" in trainer.logged_metrics # validation logging on by default
assert "tflops_per_sec_per_gpu" in trainer.logged_metrics # ensuring that tflops logger can be added
assert "TFLOPS_per_GPU" in trainer.logged_metrics # ensuring that tflops logger can be added
assert "train_step_timing in s" in trainer.logged_metrics


Expand Down
4 changes: 3 additions & 1 deletion sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ def predict_step(self, batch, batch_idx: Optional[int] = None) -> Tensor:
return forward_out
# Reminder: the model's predictions for input i land at output i+1. To get everything to align, we prepend the
# EOS token to the input sequences and take the outputs for all but the first token.
forward_out_tp_gathered = _gather_along_last_dim(forward_out)
forward_out_tp_gathered = _gather_along_last_dim(
forward_out, group=parallel_state.get_tensor_model_parallel_group()
)
# else:
# forward_out_tp_gathered = _collect_into_dim(forward_out, dim=-1)
forward_out_gathered = _gather_along_cp_dim(forward_out_tp_gathered)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def test_train_evo2_stops(tmp_path):
)

assert "reduced_train_loss" in trainer.logged_metrics # validation logging on by default
assert "tflops_per_sec_per_gpu" in trainer.logged_metrics # ensuring that tflops logger can be added
assert "TFLOPS_per_GPU" in trainer.logged_metrics # ensuring that tflops logger can be added
assert "train_step_timing in s" in trainer.logged_metrics


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,14 @@ def test_gpu_forward(self, operator: ParallelHyenaOperator):
g = operator.num_groups
dg = operator.group_dim

x1 = torch.ones((batch_size, seq_len, g, dg), device=device)
x2 = torch.ones((batch_size, seq_len, g, dg), device=device)
v = torch.ones((batch_size, seq_len, g, dg), device=device)
x1 = torch.ones((batch_size, (g * dg), seq_len), device=device)
x2 = torch.ones((batch_size, (g * dg), seq_len), device=device)
v = torch.ones((batch_size, (g * dg), seq_len), device=device)

output = operator(x1, x2, v)
assert output.shape[0] == batch_size
assert output.shape[1] == seq_len
assert output.shape[2] == operator.hidden_size
assert output.shape[1] == operator.hidden_size
assert output.shape[2] == seq_len


class TestParallelShortHyenaOperator:
Expand All @@ -89,7 +89,6 @@ def operator(self, transformer_config: TransformerConfig, hyena_config: HyenaCon
init_method="small_init",
short_conv_class=ParallelCausalDepthwiseConv1d,
use_fast_causal_conv=False,
is_mlp=False,
local_init=False,
use_conv_bias=False,
)
Expand All @@ -109,14 +108,14 @@ def test_gpu_forward(self, operator: ParallelShortHyenaOperator):
g = operator.num_groups
dg = operator.group_dim

x1 = torch.ones((batch_size, seq_len, g, dg), device=device)
x2 = torch.ones((batch_size, seq_len, g, dg), device=device)
v = torch.ones((batch_size, seq_len, g, dg), device=device)
x1 = torch.ones((batch_size, (g * dg), seq_len), device=device)
x2 = torch.ones((batch_size, (g * dg), seq_len), device=device)
v = torch.ones((batch_size, (g * dg), seq_len), device=device)

output = operator(x1, x2, v)
assert output.shape[0] == batch_size
assert output.shape[1] == seq_len
assert output.shape[2] == operator.hidden_size
assert output.shape[1] == operator.hidden_size
assert output.shape[2] == seq_len


class TestParallelShortHyenaOperatorWithConvBias:
Expand All @@ -130,7 +129,6 @@ def operator(self, transformer_config: TransformerConfig, hyena_config: HyenaCon
init_method="small_init",
short_conv_class=ParallelCausalDepthwiseConv1d,
use_fast_causal_conv=False,
is_mlp=False,
local_init=False,
use_conv_bias=True,
)
Expand All @@ -150,14 +148,14 @@ def test_gpu_forward(self, operator: ParallelShortHyenaOperator):
g = operator.num_groups
dg = operator.group_dim

x1 = torch.ones((batch_size, seq_len, g, dg), device=device)
x2 = torch.ones((batch_size, seq_len, g, dg), device=device)
v = torch.ones((batch_size, seq_len, g, dg), device=device)
x1 = torch.ones((batch_size, (g * dg), seq_len), device=device)
x2 = torch.ones((batch_size, (g * dg), seq_len), device=device)
v = torch.ones((batch_size, (g * dg), seq_len), device=device)

output = operator(x1, x2, v)
assert output.shape[0] == batch_size
assert output.shape[1] == seq_len
assert output.shape[2] == operator.hidden_size
assert output.shape[1] == operator.hidden_size
assert output.shape[2] == seq_len


class TestParallelCausalDepthwiseConv1d:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@
from nemo.collections.llm.peft.lora import LoRA, LoRALinear
from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ParallelLinearAdapter
from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group
from nemo.lightning.megatron_parallel import (
masked_token_loss,
masked_token_loss_context_parallel,
)
from nemo.lightning.megatron_parallel import masked_token_loss
from torch import Tensor, nn

from bionemo.llm.model.biobert.model import BioBertConfig, BioBertOutput, MegatronBioBertModel
Expand Down Expand Up @@ -102,17 +99,7 @@ def forward(
# TODO(@jstjohn) also handle different output keys, like the sequence loss.

cp_size = parallel_state.get_context_parallel_world_size()
if cp_size == 1:
# reduce the loss across the micro batch
loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"])
else:
# reduce the loss across the micro batch.
# TODO(@jomitchell): Figure out who defines "num_valid_tokens_in_ub" in the batch and document/understand this.
# This has something to do with context parallel, and there is probably a megatron or nemo function that adds this and
# other necessary keys to the batch. Thanks!
loss_for_microbatch = masked_token_loss_context_parallel(
unreduced_token_loss, batch["loss_mask"], batch["num_valid_tokens_in_ub"]
)
loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"], cp_size)

# If we do not drop the last partial batch of validation, we need to do fancy reduction handling to support
# reducing the loss across the data parallel group.
Expand Down
2 changes: 1 addition & 1 deletion sub-packages/bionemo-llm/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dependencies = [
# external
'lightning>=2.2.1',
'megatron-core',
'nemo_toolkit[nlp]>=2.2.1',
'nemo_toolkit[nlp,eval]>=2.2.1',
'nemo-run',
'hatchling',
]
Expand Down
40 changes: 21 additions & 19 deletions sub-packages/bionemo-llm/src/bionemo/llm/model/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from nemo.lightning.megatron_parallel import (
MegatronLossReduction,
masked_token_loss,
masked_token_loss_context_parallel,
)
from torch import Tensor

Expand Down Expand Up @@ -179,24 +178,17 @@ def forward(

# TODO(@jstjohn) also handle different output keys, like the sequence loss.

# compute loss
# Compute loss over "valid" tokens in the microbatch, i.e. the non-masked tokens.
# The loss is not normalized, only potentially reduced via torch.distributed.ReduceOp.SUM
# across the context parallel process group, so you need to divide by the number
# of non-masked tokens (loss_mask.sum()) to compute the mean reduced loss per token.
cp_size = parallel_state.get_context_parallel_world_size()
if cp_size == 1:
# reduce the loss across the micro batch per valid token
loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"])
else:
# reduce the loss across the micro batch per valid token.
# TODO(@jomitchell): Figure out who defines "num_valid_tokens_in_ub" in the batch and document/understand this.
# This has something to do with context parallel, and there is probably a megatron or nemo function that adds this and
# other necessary keys to the batch. Thanks!
loss_for_microbatch = masked_token_loss_context_parallel(
unreduced_token_loss, batch["loss_mask"], batch["num_valid_tokens_in_ub"]
)
loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"], cp_size=cp_size)
num_valid_tokens_in_microbatch = batch["loss_mask"].sum()

# If we do not drop the last partial batch of validation, we need to do fancy reduction handling to support
# reducing the loss across the data parallel group.
if self.validation_step and not self.val_drop_last:
num_valid_tokens_in_microbatch = batch["loss_mask"].sum()
if loss_for_microbatch.isnan():
# TODO(@jomitchell): Add a unit test for this. This is the case where there are no valid tokens in the microbatch for the loss
# to be computed over, so we expect a NaN loss (divide by zero for a mean) but we make this an expected and non-breaking case,
Expand All @@ -205,9 +197,8 @@ def forward(
raise ValueError("Got NaN loss with non-empty input")
loss_sum_for_microbatch = torch.zeros_like(num_valid_tokens_in_microbatch)
else:
loss_sum_for_microbatch = (
num_valid_tokens_in_microbatch * loss_for_microbatch
) # sum over all valid tokens
# The reduced loss is already the sum of all losses from masked_token_loss().
loss_sum_for_microbatch = loss_for_microbatch

# In this case we need to store the loss sum as well as the number of valid tokens in the microbatch.
loss_sum_and_microbatch_size_all_gpu = torch.cat(
Expand All @@ -216,17 +207,28 @@ def forward(
Tensor([num_valid_tokens_in_microbatch]).cuda().clone().detach(),
]
)

# Reduce the loss sum across the data parallel group to get the total loss
# for all data parallel / distributed microbatches.
torch.distributed.all_reduce(
loss_sum_and_microbatch_size_all_gpu,
group=parallel_state.get_data_parallel_group(),
op=torch.distributed.ReduceOp.SUM,
)

# Return the loss tensor multiplied by the context parallel size,
# and the data & context parallel reduced loss sum.
return loss_for_microbatch * cp_size, {
"loss_sum_and_microbatch_size": loss_sum_and_microbatch_size_all_gpu
}

# average the losses across the data parallel group, but also return the unreduced loss
reduced_loss = average_losses_across_data_parallel_group([loss_for_microbatch])
# Return the loss tensor multiplied by the context parallel size, as well as
# the data-parallel averaged loss, i.e. the loss divided by the DP size.
# Normalize the loss by the number of "valid" tokens, because masked_token_loss
# no longer does this normalization, and BioNeMo losses expect this normalization.
reduced_loss = (
average_losses_across_data_parallel_group([loss_for_microbatch]) / num_valid_tokens_in_microbatch
)
return loss_for_microbatch * cp_size, {"avg": reduced_loss}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_loss_equivalency_nemo_vs_pytorch():
batch=batch_megatron,
forward_out=unreduced_megatron_loss, # wants the loss directly
)
final_nemo_loss = nemo_default_loss_fn.reduce([forward_nemo_loss[1]])
final_nemo_loss = nemo_default_loss_fn.reduce([forward_nemo_loss[2]])

# First check, nemo+megatron loss
torch.testing.assert_close(expected_loss, final_nemo_loss)
Expand Down