diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index ce52956d042..abff9e1690c 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -530,9 +530,17 @@ def logits_processor(logits, label, label_mask): assert label.shape == label_mask.shape ret = {} if calculate_entropy: + logits_bak = logits.clone() + logger.warning_once( + "For memory-efficient computation, enable fused kernels via " + "`actor_rollout_ref.model.use_fused_kernels=True`. " + "The current `clone()` operation ensures correctness but increases memory usage." + ) entropy = vocab_parallel_entropy(logits) ret["entropy"] = entropy - log_probs = vocab_parallel_log_probs_from_logits(logits, label) + else: + logits_bak = logits + log_probs = vocab_parallel_log_probs_from_logits(logits_bak, label) log_probs = log_probs.masked_fill(~label_mask, 0.0) ret["log_probs"] = log_probs return ret