Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
57b22cd
clean loss compute and reduction, make it consistent to megatron
xrennvidia Apr 27, 2025
1090156
fix loss func
xrennvidia Apr 27, 2025
36be710
Merge branch 'main' into xren/loss_func
xrennvidia Apr 27, 2025
f1fe4ab
minor fix
xrennvidia Apr 27, 2025
54aa270
revert two file changes
xrennvidia Apr 27, 2025
5156e6f
Merge branch 'main' into xren/loss_func
xrennvidia Apr 27, 2025
9f9d761
Apply isort and black reformatting
xrennvidia Apr 27, 2025
e0ced98
Merge branch 'main' into xren/loss_func
xrennvidia Apr 30, 2025
67dc76e
import get_batch_on_this_cp_rank from mcore.utils
xrennvidia May 1, 2025
7edd983
Apply isort and black reformatting
xrennvidia May 1, 2025
e3f6243
more cleaning
xrennvidia May 1, 2025
30e844a
Merge branch 'xren/loss_func' of github.com:xrennvidia/NeMo into xren…
xrennvidia May 1, 2025
ff116dd
Apply isort and black reformatting
xrennvidia May 1, 2025
d1bc6ad
more cleaning
xrennvidia May 1, 2025
5d305fc
Merge branch 'main' into xren/loss_func
ko3n1g May 2, 2025
c43aae0
remove redundant loss reduce in BertLossReduction
xrennvidia May 2, 2025
b107c5b
Merge branch 'xren/loss_func' of github.com:xrennvidia/NeMo into xren…
xrennvidia May 2, 2025
a39f6ae
Merge branch 'main' into xren/loss_func
xrennvidia May 2, 2025
6ea30c0
more cleaning
xrennvidia May 2, 2025
e897791
add fixed loss reduce func
xrennvidia May 2, 2025
8ceca57
Apply isort and black reformatting
xrennvidia May 2, 2025
97beb35
minor change
xrennvidia May 2, 2025
ca7551e
Merge branch 'xren/loss_func' of github.com:xrennvidia/NeMo into xren…
xrennvidia May 2, 2025
fc80af9
remove unused import
xrennvidia May 2, 2025
cd70095
Merge branch 'main' into xren/loss_func
xrennvidia May 2, 2025
acd4043
git
xrennvidia May 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 53 additions & 38 deletions nemo/collections/llm/bert/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Literal, Tuple
from typing import Dict, Literal, Tuple

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -68,30 +68,35 @@ def forward(

loss_for_ub = sop_loss_for_ub + lm_loss_for_ub
reduced_loss = average_losses_across_data_parallel_group([loss_for_ub])
return loss_for_ub * cp_size, {"avg": reduced_loss}
return loss_for_ub, {"avg": reduced_loss}

def reduce(self, losses_reduced_per_micro_batch) -> torch.Tensor:
"""Taken from: https://github.com/NVIDIA/NeMo/blob/main
/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L535-L552 ."""
if losses_reduced_per_micro_batch:
if "avg" in losses_reduced_per_micro_batch[0]:
loss_tensors_list = [loss_reduced["avg"] for loss_reduced in losses_reduced_per_micro_batch]
loss_tensor = torch.concat(loss_tensors_list)
# legacy behavior, average over the number of microbatches
avg = [x["avg"] for x in losses_reduced_per_micro_batch]
loss = torch.cat(avg).mean()
return loss

return loss_tensor.mean()
from megatron.core import parallel_state

# Get the total loss since micro batches sizes are not uniform
loss_sum_tensors_list: List[torch.Tensor] = [
loss_sum["loss_sum_and_ub_size"]
for loss_sum in losses_reduced_per_micro_batch
if loss_sum["loss_sum_and_ub_size"][1] > 0
loss_sum_and_ub_size = [
x["loss_sum_and_ub_size"] for x in losses_reduced_per_micro_batch if x["loss_sum_and_ub_size"][1] > 0
]
loss_sum = (
torch.vstack(loss_sum_tensors_list).sum(dim=0)
if len(loss_sum_tensors_list) > 0
loss = (
torch.vstack(loss_sum_and_ub_size).sum(dim=0)
if len(loss_sum_and_ub_size) > 0
else torch.tensor([0.0, 0.0], device=torch.cuda.current_device())
)
return loss_sum
torch.distributed.all_reduce(
loss,
group=parallel_state.get_data_parallel_group(with_context_parallel=True),
)
# average over the total number of tokens across the global batch.
loss = loss[0] / loss[1]
return loss

return torch.tensor(0.0, device=torch.cuda.current_device())

Expand Down Expand Up @@ -158,23 +163,28 @@ def reduce(self, losses_reduced_per_micro_batch) -> torch.Tensor:
/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L535-L552 ."""
if losses_reduced_per_micro_batch:
if "avg" in losses_reduced_per_micro_batch[0]:
loss_tensors_list = [loss_reduced["avg"] for loss_reduced in losses_reduced_per_micro_batch]
loss_tensor = torch.concat(loss_tensors_list)
# legacy behavior, average over the number of microbatches
avg = [x["avg"] for x in losses_reduced_per_micro_batch]
loss = torch.cat(avg).mean()
return loss

return loss_tensor.mean()
from megatron.core import parallel_state

# Get the total loss since micro batches sizes are not uniform
loss_sum_tensors_list: List[torch.Tensor] = [
loss_sum["loss_sum_and_ub_size"]
for loss_sum in losses_reduced_per_micro_batch
if loss_sum["loss_sum_and_ub_size"][1] > 0
loss_sum_and_ub_size = [
x["loss_sum_and_ub_size"] for x in losses_reduced_per_micro_batch if x["loss_sum_and_ub_size"][1] > 0
]
loss_sum = (
torch.vstack(loss_sum_tensors_list).sum(dim=0)
if len(loss_sum_tensors_list) > 0
loss = (
torch.vstack(loss_sum_and_ub_size).sum(dim=0)
if len(loss_sum_and_ub_size) > 0
else torch.tensor([0.0, 0.0], device=torch.cuda.current_device())
)
return loss_sum
torch.distributed.all_reduce(
loss,
group=parallel_state.get_data_parallel_group(with_context_parallel=True),
)
# average over the total number of tokens across the global batch.
loss = loss[0] / loss[1]
return loss

return torch.tensor(0.0, device=torch.cuda.current_device())

Expand Down Expand Up @@ -277,23 +287,28 @@ def reduce(self, losses_reduced_per_micro_batch) -> torch.Tensor:
/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L535-L552 ."""
if losses_reduced_per_micro_batch:
if "avg" in losses_reduced_per_micro_batch[0]:
loss_tensors_list = [loss_reduced["avg"] for loss_reduced in losses_reduced_per_micro_batch]
loss_tensor = torch.concat(loss_tensors_list)
# legacy behavior, average over the number of microbatches
avg = [x["avg"] for x in losses_reduced_per_micro_batch]
loss = torch.cat(avg).mean()
return loss

return loss_tensor.mean()
from megatron.core import parallel_state

# Get the total loss since micro batches sizes are not uniform
loss_sum_tensors_list: List[torch.Tensor] = [
loss_sum["loss_sum_and_ub_size"]
for loss_sum in losses_reduced_per_micro_batch
if loss_sum["loss_sum_and_ub_size"][1] > 0
loss_sum_and_ub_size = [
x["loss_sum_and_ub_size"] for x in losses_reduced_per_micro_batch if x["loss_sum_and_ub_size"][1] > 0
]
loss_sum = (
torch.vstack(loss_sum_tensors_list).sum(dim=0)
if len(loss_sum_tensors_list) > 0
loss = (
torch.vstack(loss_sum_and_ub_size).sum(dim=0)
if len(loss_sum_and_ub_size) > 0
else torch.tensor([0.0, 0.0], device=torch.cuda.current_device())
)
return loss_sum
torch.distributed.all_reduce(
loss,
group=parallel_state.get_data_parallel_group(with_context_parallel=True),
)
# average over the total number of tokens across the global batch.
loss = loss[0] / loss[1]
return loss

return torch.tensor(0.0, device=torch.cuda.current_device())

Expand Down
40 changes: 2 additions & 38 deletions nemo/collections/llm/bert/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules
from megatron.core.transformer.utils import get_linear_layer as mcore_get_linear_layer
from megatron.core.utils import make_viewless_tensor
from megatron.core.utils import get_batch_on_this_cp_rank, make_viewless_tensor
from torch import Tensor, nn

from nemo.collections.llm import fn
Expand Down Expand Up @@ -73,7 +73,7 @@ def bert_data_step(dataloder_iter) -> Dict[str, torch.Tensor]:

_batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in _batch.items()}
# slice batch along sequence dimension for context parallelism
output = get_batch_on_this_context_parallel_rank(_batch)
output = get_batch_on_this_cp_rank(_batch)

return output

Expand Down Expand Up @@ -628,42 +628,6 @@ def validation_loss_reduction(self) -> BERTLossReduction: # pylint: disable=C01
return self._validation_loss_reduction


def get_batch_on_this_context_parallel_rank(batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Modifies the batch data based on the context parallel rank,
if the context parallel world size is greater than 1. Otherwise the batch is returned as-is.

Args:
batch (dict): The input batch data.

Returns:
dict: The modified batch data based on the context parallel rank.
"""
if cp_size := parallel_state.get_context_parallel_world_size() > 1:
num_valid_tokens_in_ub = None
if "loss_mask" in batch and batch["loss_mask"] is not None:
num_valid_tokens_in_ub = batch["loss_mask"].sum()

cp_rank = parallel_state.get_context_parallel_rank()
for key, val in batch.items():
if val is not None:
seq_dim = 1 if key != "attention_mask" else 2
_val = val.view(
*val.shape[0:seq_dim],
2 * cp_size,
val.shape[seq_dim] // (2 * cp_size),
*val.shape[(seq_dim + 1) :],
)
index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).cuda(
non_blocking=True
)
_val = _val.index_select(seq_dim, index)
_val = _val.view(*val.shape[0:seq_dim], -1, *_val.shape[(seq_dim + 2) :])
batch[key] = _val
batch["num_valid_tokens_in_ub"] = num_valid_tokens_in_ub
return batch


def get_packed_seq_params(batch: Dict[str, torch.Tensor]) -> PackedSeqParams:
"""
Get the packed sequence parameters for the given batch.
Expand Down
5 changes: 3 additions & 2 deletions nemo/collections/llm/bert/model/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@
import torch
import torch.nn.functional as F
from megatron.core import parallel_state
from megatron.core.utils import get_batch_on_this_cp_rank
from torch import Tensor, nn

from nemo.collections.common.tokenizers import TokenizerSpec
from nemo.collections.llm.bert.loss import BERTInBatchExclusiveHardNegativesRankingLoss
from nemo.collections.llm.bert.model import BertConfig, BertModel
from nemo.collections.llm.bert.model.base import get_batch_on_this_context_parallel_rank, get_packed_seq_params
from nemo.collections.llm.bert.model.base import get_packed_seq_params
from nemo.collections.llm.bert.model.bert import HuggingFaceBertImporter
from nemo.lightning import io
from nemo.lightning.pytorch.optim import OptimizerModule
Expand All @@ -49,7 +50,7 @@ def bert_embedding_data_step(dataloder_iter) -> Dict[str, torch.Tensor]:

_batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in _batch.items()}
# slice batch along sequence dimension for context parallelism
output = get_batch_on_this_context_parallel_rank(_batch)
output = get_batch_on_this_cp_rank(_batch)

return output

Expand Down
41 changes: 2 additions & 39 deletions nemo/collections/llm/gpt/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from megatron.core.optimizer import OptimizerConfig
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import get_batch_on_this_cp_rank
from torch import nn

from nemo.collections.llm import fn
Expand Down Expand Up @@ -103,7 +104,7 @@ def gpt_data_step(dataloader_iter, use_mtp=False) -> dict[str, torch.Tensor]:
_batch_required_keys[key] = None

# slice batch along sequence dimension for context parallelism
output = get_batch_on_this_context_parallel_rank(_batch_required_keys)
output = get_batch_on_this_cp_rank(_batch_required_keys)

return output

Expand Down Expand Up @@ -731,44 +732,6 @@ def validation_loss_reduction(self) -> MaskedTokenLossReduction:
return self._validation_loss_reduction


def get_batch_on_this_context_parallel_rank(batch) -> dict[str, torch.Tensor]:
"""Process batch data for the current context parallel rank.

Handles the slicing of batch data across context parallel dimensions.

Args:
batch: Input batch

Returns:
dict[str, torch.Tensor]: Processed batch for the current context parallel rank
"""
from megatron.core import parallel_state

if (cp_size := parallel_state.get_context_parallel_world_size()) > 1:
num_valid_tokens_in_ub = None
if "loss_mask" in batch and batch["loss_mask"] is not None:
num_valid_tokens_in_ub = batch["loss_mask"].sum()

cp_rank = parallel_state.get_context_parallel_rank()
for key, val in batch.items():
if val is not None:
seq_dim = 1 if key != "attention_mask" else 2
_val = val.view(
*val.shape[0:seq_dim],
2 * cp_size,
val.shape[seq_dim] // (2 * cp_size),
*val.shape[(seq_dim + 1) :],
)
index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).to(
_val.device, non_blocking=True
)
_val = _val.index_select(seq_dim, index)
_val = _val.view(*val.shape[0:seq_dim], -1, *_val.shape[(seq_dim + 2) :])
batch[key] = _val
batch["num_valid_tokens_in_ub"] = num_valid_tokens_in_ub
return batch


def get_packed_seq_params(batch):
"""Extract packed sequence parameters from the batch.

Expand Down
3 changes: 2 additions & 1 deletion nemo/collections/llm/gpt/model/llama_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from megatron.core import parallel_state
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.utils import get_batch_on_this_cp_rank
from torch import Tensor, nn

import nemo.collections.llm.gpt.model.base as GPTBase
Expand Down Expand Up @@ -91,7 +92,7 @@ def nv_embedding_data_step(dataloder_iter) -> Dict[str, torch.Tensor]:

_batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in _batch.items()}
# slice batch along sequence dimension for context parallelism
output = GPTBase.get_batch_on_this_context_parallel_rank(_batch)
output = get_batch_on_this_cp_rank(_batch)

return output

Expand Down
26 changes: 12 additions & 14 deletions nemo/collections/llm/modelopt/distill/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

import torch
from megatron.core import parallel_state
from megatron.core.utils import get_batch_on_this_cp_rank
from torch import Tensor, nn

from nemo.collections import llm
from nemo.collections.llm.gpt.model.base import get_batch_on_this_context_parallel_rank
from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group
from nemo.lightning.megatron_parallel import MaskedTokenLossReduction
from nemo.utils.import_utils import safe_import
Expand Down Expand Up @@ -75,7 +75,7 @@ def gpt_distillation_data_step(dataloader_iter, attn_mask_cpu=False) -> Dict[str
batch_required_keys[key] = None

# slice batch along sequence dimension for context parallelism
output = get_batch_on_this_context_parallel_rank(batch_required_keys)
output = get_batch_on_this_cp_rank(batch_required_keys)

return output

Expand All @@ -95,33 +95,31 @@ def forward(self, batch: Dict[str, Tensor], forward_out: Tensor) -> Tuple[Tensor

# [ModelOpt]: KD loss calculation.
loss_for_ub = self._distillation_loss_fn(
loss_reduction_fn=lambda x: self._masked_token_loss(
x, batch["loss_mask"], batch.get("num_valid_tokens_in_ub")
)
loss_reduction_fn=lambda x: self._masked_token_loss(x, batch["loss_mask"])
)

reduced_loss = average_losses_across_data_parallel_group([loss_for_ub])
return loss_for_ub * self._cp_size, {"avg": reduced_loss}
return loss_for_ub, {"avg": reduced_loss}

def _masked_token_loss(self, loss_output: Tensor, mask: Tensor, num_valid_tokens_in_ub: Optional[int] = None):
def _masked_token_loss(self, loss_output: Tensor, mask: Tensor):
"""The function takes as input per-token loss and masks non-required values."""
if isinstance(loss_output, tuple):
# [ModelOpt]: Losses can return extra flag to indicate additional TP-reduction (often required)
loss_output, tp_reduce = loss_output
else:
tp_reduce = False
losses = loss_output.float()

losses = loss_output.view(-1).float()
loss_mask = mask.view(-1).float()
loss_sum = torch.sum(losses * loss_mask)
num_valid_tokens = loss_mask.sum()

if self._cp_size > 1:
if num_valid_tokens_in_ub is None:
num_valid_tokens_in_ub = loss_mask.sum()
if num_valid_tokens_in_ub < 0.5: # no valid tokens
num_valid_tokens_in_ub += 1.0
loss = torch.sum(losses.view(-1) * loss_mask) / num_valid_tokens_in_ub # sequence level nll
loss = torch.cat([loss_sum.view(1), num_valid_tokens.view(1)])
torch.distributed.all_reduce(loss, group=parallel_state.get_context_parallel_group())
loss = loss[0] / loss[1] # sequence level nll
else:
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() # sequence level nll
loss = loss_sum / num_valid_tokens # sequence level nll

if tp_reduce is True:
torch.distributed.all_reduce(loss, group=parallel_state.get_tensor_model_parallel_group())
Expand Down
Loading
Loading