2323from nemo .collections .llm .fn .mixin import FNMixin
2424from nemo .collections .llm .peft .lora import LoRA , LoRALinear
2525from nemo .collections .nlp .modules .common .megatron .adapters .parallel_adapters import ParallelLinearAdapter
26- from nemo .collections .nlp .modules .common .megatron .utils import average_losses_across_data_parallel_group
2726from nemo .lightning .megatron_parallel import masked_token_loss
2827from torch import Tensor , nn
2928
3635 unreduced_token_loss_fn ,
3736)
3837from bionemo .llm .utils import iomixin_utils as iom
38+ from bionemo .llm .utils .megatron_utils import average_losses_across_data_parallel_group
3939
4040
4141# This package demonstrates how you can take a pretrained geneformer module and fine-tune the classifier
@@ -98,13 +98,19 @@ def forward(
9898
9999 # TODO(@jstjohn) also handle different output keys, like the sequence loss.
100100
101+ # Compute loss over "valid" tokens in the microbatch, i.e. the non-masked tokens.
102+ # The loss is not normalized, so you need to divide by the number of non-masked
103+ # tokens (loss_mask.sum()) to compute the mean loss per token.
104+ loss_for_microbatch , num_valid_tokens_in_microbatch = masked_token_loss (
105+ unreduced_token_loss , batch ["loss_mask" ]
106+ )
107+
108+ # Get the context parallel size for some normalizations and reductions.
101109 cp_size = parallel_state .get_context_parallel_world_size ()
102- loss_for_microbatch = masked_token_loss (unreduced_token_loss , batch ["loss_mask" ], cp_size )
103110
104111 # If we do not drop the last partial batch of validation, we need to do fancy reduction handling to support
105112 # reducing the loss across the data parallel group.
106113 if self .validation_step and not self .val_drop_last :
107- num_valid_tokens_in_microbatch = batch ["loss_mask" ].sum ()
108114 if loss_for_microbatch .isnan ():
109115 # TODO(@jomitchell): Add a unit test for this. This is the case where there are no valid tokens in the microbatch for the loss
110116 # to be computed over, so we expect a NaN loss (divide by zero for a mean) but we make this an expected and non-breaking case,
@@ -113,7 +119,7 @@ def forward(
113119 raise ValueError ("Got NaN loss with non-empty input" )
114120 loss_sum_for_microbatch = torch .zeros_like (num_valid_tokens_in_microbatch )
115121 else :
116- loss_sum_for_microbatch = num_valid_tokens_in_microbatch * loss_for_microbatch
122+ loss_sum_for_microbatch = loss_for_microbatch
117123
118124 # In this case we need to store the loss sum as well as the number of valid tokens in the microbatch.
119125 loss_sum_and_microbatch_size_all_gpu = torch .cat (
@@ -123,14 +129,20 @@ def forward(
123129 ]
124130 )
125131 torch .distributed .all_reduce (
126- loss_sum_and_microbatch_size_all_gpu , group = parallel_state .get_data_parallel_group ()
132+ loss_sum_and_microbatch_size_all_gpu ,
133+ group = parallel_state .get_data_parallel_group (with_context_parallel = True ),
127134 )
128135 return loss_for_microbatch * cp_size , {
129136 "loss_sum_and_microbatch_size" : loss_sum_and_microbatch_size_all_gpu
130137 }
138+
131139 loss_for_microbatch = loss_for_microbatch + rmse_loss # add in the RMSE loss after reducing the logit loss
140+
132141 # average the losses across the data parallel group, but also return the unreduced loss
133- reduced_loss : Tensor = average_losses_across_data_parallel_group ([loss_for_microbatch ])
142+ reduced_loss : Tensor = (
143+ average_losses_across_data_parallel_group ([loss_for_microbatch ], with_context_parallel = True )
144+ / num_valid_tokens_in_microbatch
145+ )
134146 return loss_for_microbatch * cp_size , {"avg" : reduced_loss }
135147
136148
0 commit comments