Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions verl/utils/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def offload_fsdp_optimizer(optimizer: Optimizer, empty_cache: bool = True):
def load_fsdp_optimizer(optimizer: Optimizer, empty_cache: bool = True):
if not optimizer.state:
return

for param_group in optimizer.param_groups:
for param in param_group["params"]:
state = optimizer.state[param]
Expand Down
4 changes: 2 additions & 2 deletions verl/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def is_rank0() -> int:

def print_gpu_memory_usage(prefix: str) -> None:
if is_rank0():
memory_allocated = torch.cuda.memory_allocated() / 1024**3
memory_reserved = torch.cuda.memory_reserved() / 1024**3
memory_allocated = torch.cuda.memory_allocated() / (1024**3)
memory_reserved = torch.cuda.memory_reserved() / (1024**3)
print(f"{prefix} memory allocated: {memory_allocated:.2f} GB, memory reserved: {memory_reserved:.2f} GB.")


Expand Down
22 changes: 6 additions & 16 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def _build_model_optimizer(
if self.rank == 0:
print_model_size(model)

print_gpu_memory_usage("After init from huggingface model")
print_gpu_memory_usage("After huggingface model init")
mixed_precision = MixedPrecision(
param_dtype=PrecisionType.to_dtype(fsdp_config.mp_param_dtype),
reduce_dtype=PrecisionType.to_dtype(fsdp_config.mp_reduce_dtype),
Expand Down Expand Up @@ -306,7 +306,7 @@ def _build_model_optimizer(
use_orig_params=fsdp_config.use_orig_params,
device_mesh=self.device_mesh,
)
print_gpu_memory_usage("After actor FSDP init")
print_gpu_memory_usage("After FSDP module init")

if self._is_actor or self._is_critic:
self.optimizer = torch.optim.AdamW(
Expand All @@ -322,30 +322,27 @@ def _build_model_optimizer(
else:
self.optimizer, self.lr_scheduler = None, None

print_gpu_memory_usage("After actor optimizer init")
print_gpu_memory_usage("After optimizer init")

def _build_rollout(self) -> None:
# TODO(sgm): support FSDP hybrid shard for larger model
tp_size = self.config.rollout.tensor_parallel_size
dp_size = self.world_size // tp_size
assert self.world_size % tp_size == 0, (
f"rollout world_size: {self.world_size} is not divisible by tp_size: {tp_size}"
f"rollout world size: {self.world_size} is not divisible by tp size: {tp_size}"
)
rollout_device_mesh = init_device_mesh("cuda", mesh_shape=(dp_size, tp_size), mesh_dim_names=["dp", "tp"])
print_gpu_memory_usage("Before building vllm rollout")
rollout_device_mesh = init_device_mesh("cuda", mesh_shape=(dp_size, tp_size), mesh_dim_names=("dp", "tp"))
self.rollout = vLLMRollout(
model_path=self.config.actor.model.model_path,
config=self.config.rollout,
tokenizer=self.tokenizer,
)
print_gpu_memory_usage("After building vllm rollout")

self.rollout_sharding_manager = FSDPVLLMShardingManager(
module=self.fsdp_module,
inference_engine=self.rollout.inference_engine,
device_mesh=rollout_device_mesh,
)
print_gpu_memory_usage("After building sharding manager")
print_gpu_memory_usage("After vllm init")

@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
Expand Down Expand Up @@ -453,7 +450,6 @@ def update_actor(self, data: DataProto):
if self._use_optimizer_offload:
load_fsdp_optimizer(optimizer=self.optimizer)

print_gpu_memory_usage("Before update policy")
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data=data)
with Timer(name="update_policy", logger=None) as timer:
Expand All @@ -472,7 +468,6 @@ def update_actor(self, data: DataProto):
self.lr_scheduler.step()
lr = self.lr_scheduler.get_last_lr()[0]
metrics["actor/lr"] = lr
print_gpu_memory_usage("After update policy")

# TODO: here, we should return all metrics
output = DataProto(meta_info={"metrics": metrics})
Expand Down Expand Up @@ -511,14 +506,11 @@ def generate_sequences(self, prompts: DataProto):
if self._use_optimizer_offload:
offload_fsdp_optimizer(optimizer=self.optimizer)

print_gpu_memory_usage("After entering rollout sharding manager")
prompts = self.rollout_sharding_manager.preprocess_data(prompts)
output = self.rollout.generate_sequences(prompts=prompts)
output = self.rollout_sharding_manager.postprocess_data(output)
print_gpu_memory_usage("After rollout generation")

output = output.to("cpu")
print_gpu_memory_usage("After recompute log prob")
return output

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
Expand Down Expand Up @@ -548,7 +540,6 @@ def compute_log_prob(self, data: DataProto):
offload_fsdp_model(self.fsdp_module)

output = output.to("cpu")
print_gpu_memory_usage("After compute_log_prob")
return output

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
Expand All @@ -574,7 +565,6 @@ def compute_ref_log_prob(self, data: DataProto):
offload_fsdp_model(self.fsdp_module)

output = output.to("cpu")
print_gpu_memory_usage("After compute_ref_log_prob")
return output

"""CriticWorker"""
Expand Down