@@ -126,14 +126,6 @@ async def release_memory(self):
126126 async def wake_up (self ):
127127 get_torch_device ().empty_cache ()
128128
129- if self .device_mesh ["infer_tp" ].get_local_rank () == 0 and self .rollout_config .free_cache_engine :
130- if self .multi_stage_wake_up :
131- await self .inference_engine .resume_memory_occupation (tags = ["weights" ])
132- log_gpu_memory_usage ("Before resume SGLang weights in sharding manager" , logger = logger )
133- else :
134- await self .inference_engine .resume_memory_occupation ()
135- log_gpu_memory_usage ("Before resume SGLang weights + kv_cache in sharding manager" , logger = logger )
136-
137129 log_gpu_memory_usage ("Before state_dict() in sharding manager memory" , logger = logger )
138130 if self .offload_param :
139131 load_fsdp_model_to_gpu (self .module )
@@ -147,13 +139,24 @@ async def wake_up(self):
147139 # convert weight keys to match the model config
148140 params = convert_weight_keys (params , getattr (self .module , "_fsdp_wrapped_module" , self .module ))
149141
142+ if self .offload_param :
143+ offload_fsdp_model_to_cpu (self .module )
144+
145+ log_gpu_memory_usage ("After offload_param in sharding manager memory" , logger = logger )
146+
147+ if self .device_mesh ["infer_tp" ].get_local_rank () == 0 and self .rollout_config .free_cache_engine :
148+ if self .multi_stage_wake_up :
149+ await self .inference_engine .resume_memory_occupation (tags = ["weights" ])
150+ log_gpu_memory_usage ("Before resume SGLang weights in sharding manager" , logger = logger )
151+ else :
152+ await self .inference_engine .resume_memory_occupation ()
153+ log_gpu_memory_usage ("Before resume SGLang weights + kv_cache in sharding manager" , logger = logger )
154+
150155 # Copy, not share memory
151156 await self .update_weights (params )
152157 log_gpu_memory_usage ("After sync model weights in sharding manager" , logger = logger )
153158
154159 del params
155- if self .offload_param :
156- offload_fsdp_model_to_cpu (self .module )
157160 get_torch_device ().empty_cache ()
158161 log_gpu_memory_usage ("After del state_dict and empty_cache in sharding manager" , logger = logger )
159162
0 commit comments