Skip to content

Commit ac004e8

Browse files
authored
Bump NeMo to use a trunk commit instead of a branch for Evo2 fixes and inference. (#861)
### Description <!-- Provide a detailed description of the changes in this PR --> - #798 (#855) depends on a NeMo branch, which has been merged into NeMo `main`: NVIDIA-NeMo/NeMo#13436. Update to point to this trunk commit. ### Details - NeMo ToT reverted the `cp_size` argument for `masked_token_loss` (NVIDIA-NeMo/NeMo#13295), so we do the CP reduction on our side now... - Future Megatron bump will add the `* cp_size` multiplier to the loss, and break our inference unit tests due to `torch.inference_mode()` usage in Megatron. --------- Signed-off-by: Cory Ye <[email protected]>
1 parent 635a8f4 commit ac004e8

File tree

5 files changed

+50
-18
lines changed

5 files changed

+50
-18
lines changed

3rdparty/NeMo

Submodule NeMo updated from 42d2b55 to 6a78ab8

sub-packages/bionemo-evo2/src/bionemo/evo2/run/predict.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
from lightning.pytorch import LightningDataModule
2828
from megatron.core import parallel_state
2929
from megatron.core.tensor_parallel.mappings import _gather_along_last_dim
30-
from nemo.collections.llm.gpt.model.base import get_batch_on_this_context_parallel_rank, get_packed_seq_params
30+
from megatron.core.utils import get_batch_on_this_cp_rank
31+
from nemo.collections.llm.gpt.model.base import get_packed_seq_params
3132
from nemo.collections.llm.gpt.model.hyena import HYENA_MODEL_OPTIONS, HyenaModel
3233
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
3334
from nemo.lightning import NeMoLogger
@@ -254,7 +255,7 @@ def hyena_predict_data_step(dataloader_iter) -> dict[str, torch.Tensor]:
254255
_batch_required_keys[key] = None
255256

256257
# slice batch along sequence dimension for context parallelism
257-
output = get_batch_on_this_context_parallel_rank(_batch_required_keys)
258+
output = get_batch_on_this_cp_rank(_batch_required_keys)
258259

259260
return output
260261

sub-packages/bionemo-geneformer/src/bionemo/geneformer/model/finetune_token_regressor.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from nemo.collections.llm.fn.mixin import FNMixin
2424
from nemo.collections.llm.peft.lora import LoRA, LoRALinear
2525
from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ParallelLinearAdapter
26-
from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group
2726
from nemo.lightning.megatron_parallel import masked_token_loss
2827
from torch import Tensor, nn
2928

@@ -36,6 +35,7 @@
3635
unreduced_token_loss_fn,
3736
)
3837
from bionemo.llm.utils import iomixin_utils as iom
38+
from bionemo.llm.utils.megatron_utils import average_losses_across_data_parallel_group
3939

4040

4141
# This package demonstrates how you can take a pretrained geneformer module and fine-tune the classifier
@@ -98,13 +98,19 @@ def forward(
9898

9999
# TODO(@jstjohn) also handle different output keys, like the sequence loss.
100100

101+
# Compute loss over "valid" tokens in the microbatch, i.e. the non-masked tokens.
102+
# The loss is not normalized, so you need to divide by the number of non-masked
103+
# tokens (loss_mask.sum()) to compute the mean loss per token.
104+
loss_for_microbatch, num_valid_tokens_in_microbatch = masked_token_loss(
105+
unreduced_token_loss, batch["loss_mask"]
106+
)
107+
108+
# Get the context parallel size for some normalizations and reductions.
101109
cp_size = parallel_state.get_context_parallel_world_size()
102-
loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"], cp_size)
103110

104111
# If we do not drop the last partial batch of validation, we need to do fancy reduction handling to support
105112
# reducing the loss across the data parallel group.
106113
if self.validation_step and not self.val_drop_last:
107-
num_valid_tokens_in_microbatch = batch["loss_mask"].sum()
108114
if loss_for_microbatch.isnan():
109115
# TODO(@jomitchell): Add a unit test for this. This is the case where there are no valid tokens in the microbatch for the loss
110116
# to be computed over, so we expect a NaN loss (divide by zero for a mean) but we make this an expected and non-breaking case,
@@ -113,7 +119,7 @@ def forward(
113119
raise ValueError("Got NaN loss with non-empty input")
114120
loss_sum_for_microbatch = torch.zeros_like(num_valid_tokens_in_microbatch)
115121
else:
116-
loss_sum_for_microbatch = num_valid_tokens_in_microbatch * loss_for_microbatch
122+
loss_sum_for_microbatch = loss_for_microbatch
117123

118124
# In this case we need to store the loss sum as well as the number of valid tokens in the microbatch.
119125
loss_sum_and_microbatch_size_all_gpu = torch.cat(
@@ -123,14 +129,20 @@ def forward(
123129
]
124130
)
125131
torch.distributed.all_reduce(
126-
loss_sum_and_microbatch_size_all_gpu, group=parallel_state.get_data_parallel_group()
132+
loss_sum_and_microbatch_size_all_gpu,
133+
group=parallel_state.get_data_parallel_group(with_context_parallel=True),
127134
)
128135
return loss_for_microbatch * cp_size, {
129136
"loss_sum_and_microbatch_size": loss_sum_and_microbatch_size_all_gpu
130137
}
138+
131139
loss_for_microbatch = loss_for_microbatch + rmse_loss # add in the RMSE loss after reducing the logit loss
140+
132141
# average the losses across the data parallel group, but also return the unreduced loss
133-
reduced_loss: Tensor = average_losses_across_data_parallel_group([loss_for_microbatch])
142+
reduced_loss: Tensor = (
143+
average_losses_across_data_parallel_group([loss_for_microbatch], with_context_parallel=True)
144+
/ num_valid_tokens_in_microbatch
145+
)
134146
return loss_for_microbatch * cp_size, {"avg": reduced_loss}
135147

136148

sub-packages/bionemo-llm/src/bionemo/llm/model/loss.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@
1818
import torch
1919
from megatron.core import parallel_state, tensor_parallel
2020
from megatron.core.fusions.fused_cross_entropy import fused_vocab_parallel_cross_entropy
21-
from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group
2221
from nemo.lightning.megatron_parallel import (
2322
MegatronLossReduction,
2423
masked_token_loss,
2524
)
2625
from torch import Tensor
2726

27+
from bionemo.llm.utils.megatron_utils import average_losses_across_data_parallel_group
28+
2829

2930
__all__: Sequence[str] = (
3031
"BERTMLMLossWithReduction",
@@ -179,12 +180,14 @@ def forward(
179180
# TODO(@jstjohn) also handle different output keys, like the sequence loss.
180181

181182
# Compute loss over "valid" tokens in the microbatch, i.e. the non-masked tokens.
182-
# The loss is not normalized, only potentially reduced via torch.distributed.ReduceOp.SUM
183-
# across the context parallel process group, so you need to divide by the number
184-
# of non-masked tokens (loss_mask.sum()) to compute the mean reduced loss per token.
183+
# The loss is not normalized, so you need to divide by the number of non-masked
184+
# tokens (loss_mask.sum()) to compute the mean loss per token.
185+
loss_for_microbatch, num_valid_tokens_in_microbatch = masked_token_loss(
186+
unreduced_token_loss, batch["loss_mask"]
187+
)
188+
189+
# Get the context parallel size for some normalizations and reductions.
185190
cp_size = parallel_state.get_context_parallel_world_size()
186-
loss_for_microbatch = masked_token_loss(unreduced_token_loss, batch["loss_mask"], cp_size=cp_size)
187-
num_valid_tokens_in_microbatch = batch["loss_mask"].sum()
188191

189192
# If we do not drop the last partial batch of validation, we need to do fancy reduction handling to support
190193
# reducing the loss across the data parallel group.
@@ -197,7 +200,7 @@ def forward(
197200
raise ValueError("Got NaN loss with non-empty input")
198201
loss_sum_for_microbatch = torch.zeros_like(num_valid_tokens_in_microbatch)
199202
else:
200-
# The reduced loss is already the sum of all losses from masked_token_loss().
203+
# The loss is already the sum of all losses from masked_token_loss().
201204
loss_sum_for_microbatch = loss_for_microbatch
202205

203206
# In this case we need to store the loss sum as well as the number of valid tokens in the microbatch.
@@ -212,7 +215,7 @@ def forward(
212215
# for all data parallel / distributed microbatches.
213216
torch.distributed.all_reduce(
214217
loss_sum_and_microbatch_size_all_gpu,
215-
group=parallel_state.get_data_parallel_group(),
218+
group=parallel_state.get_data_parallel_group(with_context_parallel=True),
216219
op=torch.distributed.ReduceOp.SUM,
217220
)
218221

@@ -227,7 +230,8 @@ def forward(
227230
# Normalize the loss by the number of "valid" tokens, because masked_token_loss
228231
# no longer does this normalization, and BioNeMo losses expect this normalization.
229232
reduced_loss = (
230-
average_losses_across_data_parallel_group([loss_for_microbatch]) / num_valid_tokens_in_microbatch
233+
average_losses_across_data_parallel_group([loss_for_microbatch], with_context_parallel=True)
234+
/ num_valid_tokens_in_microbatch
231235
)
232236
return loss_for_microbatch * cp_size, {"avg": reduced_loss}
233237

sub-packages/bionemo-llm/src/bionemo/llm/utils/megatron_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,18 @@ def is_only_data_parallel() -> bool:
3434
world_size: int = torch.distributed.get_world_size()
3535
dp_world_size: int = parallel_state.get_data_parallel_world_size()
3636
return world_size == dp_world_size
37+
38+
39+
def average_losses_across_data_parallel_group(losses, with_context_parallel: bool = False):
40+
"""Reduce a tensor of losses across all GPUs."""
41+
averaged_losses = torch.cat([loss.clone().detach().view(1) for loss in losses])
42+
# Reduce across the DP (or optionally, the flattened DP + CP) group.
43+
# Refer to the ring attention algorithm on why we always must reduce across the CP group.
44+
torch.distributed.all_reduce(
45+
averaged_losses, group=parallel_state.get_data_parallel_group(with_context_parallel=with_context_parallel)
46+
)
47+
averaged_losses = averaged_losses / torch.distributed.get_world_size(
48+
# Only average losses across the data parallel group, not the context parallel group!
49+
group=parallel_state.get_data_parallel_group()
50+
)
51+
return averaged_losses

0 commit comments

Comments
 (0)