2727from vllm .distributed import parallel_state as vllm_ps
2828
2929from ...protocol import DataProto , all_gather_data_proto
30- from ...utils .model_utils import print_gpu_memory_usage
30+ from ...utils .fsdp_utils import load_fsdp_model , offload_fsdp_model
31+ from ...utils .model_utils import is_rank0 , print_gpu_memory_usage
3132from .base import BaseShardingManager
3233
3334
@@ -37,10 +38,13 @@ def __init__(
3738 module : FSDP ,
3839 inference_engine : LLM ,
3940 device_mesh : DeviceMesh ,
41+ use_param_offload : bool ,
4042 ):
4143 self .module = module
4244 self .inference_engine = inference_engine
4345 self .device_mesh = device_mesh
46+ self .use_param_offload = use_param_offload
47+ self .skip_vllm_sync_once = False
4448
4549 self .world_size = dist .get_world_size ()
4650 self .tp_size = vllm_ps .get_tensor_model_parallel_world_size ()
@@ -85,6 +89,24 @@ def _make_weight_iterator(
8589 for name , tensor in actor_weights .items ():
8690 yield name , tensor .full_tensor () if self .world_size != 1 else tensor
8791
92+ def _sync_weight_to_vllm (self ):
93+ if self .use_param_offload :
94+ load_fsdp_model (self .module )
95+
96+ actor_weights = get_model_state_dict (self .module )
97+ actor_weights = self ._rename_weight_keys (actor_weights , self .module ._fsdp_wrapped_module )
98+ print_gpu_memory_usage ("After gather model weights in sharding manager" )
99+
100+ model = self .inference_engine .llm_engine .model_executor .driver_worker .worker .model_runner .model
101+ model .load_weights (self ._make_weight_iterator (actor_weights ))
102+
103+ del actor_weights
104+ if self .use_param_offload :
105+ offload_fsdp_model (self .module )
106+
107+ torch .cuda .empty_cache ()
108+ print_gpu_memory_usage ("After sync model weights in sharding manager" )
109+
88110 def __enter__ (self ):
89111 # NOTE: Basically, we only need `torch.cuda.empty_cache()` before vllm wake_up and
90112 # after vllm sleep, since vllm has its own caching memory allocator CuMemAllocator.
@@ -94,27 +116,23 @@ def __enter__(self):
94116 # pytorch: https://pytorch.org/docs/stable/notes/cuda.html#memory-management
95117 # vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/device_allocator/cumem.py#L103
96118 torch .cuda .empty_cache ()
97- print_gpu_memory_usage ("Before state_dict() in sharding manager" )
98- actor_weights = get_model_state_dict (self .module )
99- actor_weights = self ._rename_weight_keys (actor_weights , self .module ._fsdp_wrapped_module )
100- print_gpu_memory_usage ("After state_dict() in sharding manager" )
101-
119+ print_gpu_memory_usage ("Before vllm wake up in sharding manager" )
102120 if "tags" in inspect .signature (self .inference_engine .wake_up ).parameters :
103121 self .inference_engine .wake_up (tags = ["weights" ])
104122 else :
105123 self .inference_engine .wake_up ()
106124
107- model = self .inference_engine . llm_engine . model_executor . driver_worker . worker . model_runner . model
108- model . load_weights ( self ._make_weight_iterator ( actor_weights ))
109- print_gpu_memory_usage ( "After sync model weights in sharding manager" )
110-
111- del actor_weights
112- torch . cuda . empty_cache ()
125+ if self .skip_vllm_sync_once :
126+ self .skip_vllm_sync_once = False # reset the flag
127+ if is_rank0 ():
128+ print ( "Skip vllm weight sync in sharding manager once." )
129+ else :
130+ self . _sync_weight_to_vllm ()
113131
114132 if "tags" in inspect .signature (self .inference_engine .wake_up ).parameters :
115133 self .inference_engine .wake_up (tags = ["kv_cache" ])
116134
117- print_gpu_memory_usage ("After del state_dict and empty_cache in sharding manager" )
135+ print_gpu_memory_usage ("After vllm wake up in sharding manager" )
118136 # important: need to manually set the random states of each tp to be identical.
119137 if self .device_mesh is not None :
120138 self .torch_random_states = torch .cuda .get_rng_state ()
0 commit comments