diff --git a/verl/utils/fsdp_utils.py b/verl/utils/fsdp_utils.py index 8a649a12..1ca563a7 100644 --- a/verl/utils/fsdp_utils.py +++ b/verl/utils/fsdp_utils.py @@ -128,6 +128,7 @@ def offload_fsdp_optimizer(optimizer: Optimizer, empty_cache: bool = True): def load_fsdp_optimizer(optimizer: Optimizer, empty_cache: bool = True): if not optimizer.state: return + for param_group in optimizer.param_groups: for param in param_group["params"]: state = optimizer.state[param] diff --git a/verl/utils/model_utils.py b/verl/utils/model_utils.py index c0528f94..c5c51305 100644 --- a/verl/utils/model_utils.py +++ b/verl/utils/model_utils.py @@ -30,8 +30,8 @@ def is_rank0() -> int: def print_gpu_memory_usage(prefix: str) -> None: if is_rank0(): - memory_allocated = torch.cuda.memory_allocated() / 1024**3 - memory_reserved = torch.cuda.memory_reserved() / 1024**3 + memory_allocated = torch.cuda.memory_allocated() / (1024**3) + memory_reserved = torch.cuda.memory_reserved() / (1024**3) print(f"{prefix} memory allocated: {memory_allocated:.2f} GB, memory reserved: {memory_reserved:.2f} GB.") diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 9b5b5d2e..0b641f9c 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -261,7 +261,7 @@ def _build_model_optimizer( if self.rank == 0: print_model_size(model) - print_gpu_memory_usage("After init from huggingface model") + print_gpu_memory_usage("After huggingface model init") mixed_precision = MixedPrecision( param_dtype=PrecisionType.to_dtype(fsdp_config.mp_param_dtype), reduce_dtype=PrecisionType.to_dtype(fsdp_config.mp_reduce_dtype), @@ -306,7 +306,7 @@ def _build_model_optimizer( use_orig_params=fsdp_config.use_orig_params, device_mesh=self.device_mesh, ) - print_gpu_memory_usage("After actor FSDP init") + print_gpu_memory_usage("After FSDP module init") if self._is_actor or self._is_critic: self.optimizer = torch.optim.AdamW( @@ -322,30 +322,27 @@ def _build_model_optimizer( else: self.optimizer, self.lr_scheduler = None, None - print_gpu_memory_usage("After actor optimizer init") + print_gpu_memory_usage("After optimizer init") def _build_rollout(self) -> None: # TODO(sgm): support FSDP hybrid shard for larger model tp_size = self.config.rollout.tensor_parallel_size dp_size = self.world_size // tp_size assert self.world_size % tp_size == 0, ( - f"rollout world_size: {self.world_size} is not divisible by tp_size: {tp_size}" + f"rollout world size: {self.world_size} is not divisible by tp size: {tp_size}" ) - rollout_device_mesh = init_device_mesh("cuda", mesh_shape=(dp_size, tp_size), mesh_dim_names=["dp", "tp"]) - print_gpu_memory_usage("Before building vllm rollout") + rollout_device_mesh = init_device_mesh("cuda", mesh_shape=(dp_size, tp_size), mesh_dim_names=("dp", "tp")) self.rollout = vLLMRollout( model_path=self.config.actor.model.model_path, config=self.config.rollout, tokenizer=self.tokenizer, ) - print_gpu_memory_usage("After building vllm rollout") - self.rollout_sharding_manager = FSDPVLLMShardingManager( module=self.fsdp_module, inference_engine=self.rollout.inference_engine, device_mesh=rollout_device_mesh, ) - print_gpu_memory_usage("After building sharding manager") + print_gpu_memory_usage("After vllm init") @register(dispatch_mode=Dispatch.ONE_TO_ALL) def init_model(self): @@ -453,7 +450,6 @@ def update_actor(self, data: DataProto): if self._use_optimizer_offload: load_fsdp_optimizer(optimizer=self.optimizer) - print_gpu_memory_usage("Before update policy") with self.ulysses_sharding_manager: data = self.ulysses_sharding_manager.preprocess_data(data=data) with Timer(name="update_policy", logger=None) as timer: @@ -472,7 +468,6 @@ def update_actor(self, data: DataProto): self.lr_scheduler.step() lr = self.lr_scheduler.get_last_lr()[0] metrics["actor/lr"] = lr - print_gpu_memory_usage("After update policy") # TODO: here, we should return all metrics output = DataProto(meta_info={"metrics": metrics}) @@ -511,14 +506,11 @@ def generate_sequences(self, prompts: DataProto): if self._use_optimizer_offload: offload_fsdp_optimizer(optimizer=self.optimizer) - print_gpu_memory_usage("After entering rollout sharding manager") prompts = self.rollout_sharding_manager.preprocess_data(prompts) output = self.rollout.generate_sequences(prompts=prompts) output = self.rollout_sharding_manager.postprocess_data(output) - print_gpu_memory_usage("After rollout generation") output = output.to("cpu") - print_gpu_memory_usage("After recompute log prob") return output @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) @@ -548,7 +540,6 @@ def compute_log_prob(self, data: DataProto): offload_fsdp_model(self.fsdp_module) output = output.to("cpu") - print_gpu_memory_usage("After compute_log_prob") return output @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) @@ -574,7 +565,6 @@ def compute_ref_log_prob(self, data: DataProto): offload_fsdp_model(self.fsdp_module) output = output.to("cpu") - print_gpu_memory_usage("After compute_ref_log_prob") return output """CriticWorker"""