forked from PaddlePaddle/FastDeploy
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdynamic_weight_manager.py
More file actions
278 lines (236 loc) · 11.5 KB
/
dynamic_weight_manager.py
File metadata and controls
278 lines (236 loc) · 11.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
"""
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import os
import time
from multiprocessing.shared_memory import SharedMemory
from typing import Any, Dict, List
import numpy as np
import paddle
from paddleformers.utils.log import logger
from fastdeploy.config import FDConfig
class DynamicWeightManager:
"""Manages model weights loading, updating and shared state across processes."""
def __init__(self, fd_config: FDConfig, models):
"""Initialize with config and model instances."""
self.fd_config = fd_config
self.load_config = fd_config.load_config
self.parallel_config = fd_config.parallel_config
self.state_dict: Dict[str, paddle.Tensor] = {}
self.rank = fd_config.parallel_config.tensor_parallel_rank
self.nranks = paddle.distributed.get_world_size()
self.meta_src_id = self._get_gpu_id()
self.first_load = True
self.ipc_path = f"/shared_ipc_meta/ipc_metas_{self.meta_src_id}"
if not isinstance(models, List):
self.model_list = [models]
else:
self.model_list = models
self._capture_model_state()
self.update_parameters()
self.finalize_update()
logger.info(
f"✅ DynamicLoad model built successfully by {self.load_config.load_strategy}, "
f" tp rank={self.rank}, dp rank={fd_config.parallel_config.local_data_parallel_id}, ep rank={fd_config.parallel_config.expert_parallel_rank}, ranks={self.nranks}, "
)
@paddle.no_grad()
def _capture_model_state(self):
"""Capture and store initial model parameters state."""
for model in self.model_list:
for name, param in model.state_dict().items():
logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}")
self.state_dict[name] = param
def update_parameters(self, pid: int = 0) -> None:
"""Core method to update model parameters based on strategy."""
start_time = time.perf_counter()
paddle.device.cuda.empty_cache()
# step1 : restart paddle process group
if not self.first_load:
paddle.distributed.restart_process_group(self.parallel_config.tp_group)
if self.parallel_config.enable_expert_parallel:
paddle.distributed.restart_process_group(self.parallel_config.ep_group)
# step2 : recreat deepep buffer when enable expert parallel
if self.parallel_config.enable_expert_parallel and not self.first_load:
from fastdeploy.model_executor.layers.moe.ep import DeepEPBufferManager
DeepEPBufferManager.recreate_buffer()
# ep barrier
paddle.distributed.barrier(self.parallel_config.ep_group)
# step3 : update model weight
strategy_handlers = {
"ipc_snapshot": self._update_ipc_snapshot,
"ipc": self._update_ipc,
}
if handler := strategy_handlers.get(self.load_config.load_strategy):
handler()
else:
raise ValueError(f"Unsupported strategy: {self.load_config.load_strategy}")
logger.info(f"Update parameters in {time.perf_counter()-start_time:.2f}s")
# steps in the runner
# step 4: reinitialze kv_cache
# step 5: recapture CUDAGraph
# step 6: update weight status signal
def _update_ipc_snapshot(self):
"""Update using IPC snapshot strategy for elastic recovery."""
model_path = os.path.join(
self.fd_config.model_config.model,
f"model_state.tp0{self.meta_src_id}.pdparams",
)
try:
ipc_state_dict = paddle.load(model_path)
except FileNotFoundError:
fallback_path = f"/shared_ipc_meta/model_state.tp0{self.meta_src_id}.pdparams"
ipc_state_dict = paddle.load(fallback_path)
self._update_model_from_state(ipc_state_dict, "snapshot")
logger.info(f"IPC snapshot update parameters completed from {model_path}")
def _update_ipc(self):
"""Update using standard IPC strategy (requires Training Worker)."""
ipc_meta = paddle.load(self.ipc_path)
state_dict = self._convert_ipc_meta_to_tensor(ipc_meta)
self._update_model_from_state(state_dict, "raw")
logger.info(f"IPC update parameters completed from file: {self.ipc_path}")
def clear_parameters(self, pid: int = 0) -> None:
"""Clear all model parameters and free memory."""
logger.info("start clear paramaters")
# step1: release deepep buffer
if self.parallel_config.enable_expert_parallel:
from fastdeploy.model_executor.layers.moe.ep import DeepEPBufferManager
DeepEPBufferManager.clear_buffer()
# ep barrier
paddle.distributed.barrier(self.parallel_config.ep_group)
# shutdown ep group
paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
paddle.device.cuda.empty_cache()
# step2: release model weight
for model in self.model_list:
for param in model.state_dict().values():
param._clear_data()
self._verify_parameters("clearance")
if self.parallel_config.tensor_parallel_size > 1:
# tp barrier
paddle.distributed.barrier(self.parallel_config.tp_group)
# shutdown tp group
paddle.distributed.shutdown_process_group(self.parallel_config.tp_group)
# step3: update model weight signal
# step4: release kv cache in the runner
self._update_shared_status(pid, -2)
def _update_model_from_state(self, state_dict: Dict[str, paddle.Tensor], src_type: str):
"""Update model parameters from given state dictionary."""
if len(state_dict) == 0:
raise ValueError(f"No parameter found in state dict {state_dict}")
update_count = 0
for name, new_param in state_dict.items():
if name not in self.state_dict:
logger.debug(f"Ignoring unmatched {src_type} param: {name}")
continue
target_param = self.state_dict[name]
self._validate_parameter_match(name, new_param, target_param)
new_param._share_buffer_to(target_param)
update_count += 1
logger.info(f"🆗 Updated {update_count}/{len(state_dict)} parameters from {src_type} source")
def _validate_parameter_match(self, name: str, src: paddle.Tensor, dst: paddle.Tensor):
"""验证参数一致性"""
if src.dtype != dst.dtype:
raise TypeError(f"Type mismatch for {name}: {src.dtype} vs {dst.dtype}")
if src.shape != dst.shape:
raise ValueError(f"Shape mismatch for {name}: {src.shape} vs {dst.shape}")
def finalize_update(self, pid: int = 0):
"""Finalize update process with verification."""
self._verify_parameters("update")
if self.parallel_config.tensor_parallel_size > 1:
paddle.distributed.barrier(self.parallel_config.tp_group)
if self.parallel_config.enable_expert_parallel:
paddle.distributed.barrier(self.parallel_config.ep_group)
if not self.first_load:
self._update_shared_status(pid, 0)
self.first_load = False
def _get_gpu_id(self) -> int:
"""Get current GPU device ID."""
visible_devices = os.getenv("CUDA_VISIBLE_DEVICES", "0").split(",")
return int(visible_devices[int(os.getenv("FLAGS_selected_gpus", "0"))])
def _verify_parameters(self, operation: str):
"""Verify parameters are in expected state after operation."""
expected_initialized = operation == "update"
all_valid = True
for name, param in self.state_dict.items():
is_initialized = param._is_initialized()
if is_initialized != expected_initialized:
logger.error(
f"Verification failed after {operation}: "
f"Param {name} initialized={is_initialized} (expected {expected_initialized})"
)
all_valid = False
if all_valid:
logger.info(f"💡 Model Parameter {operation} verified successfully")
else:
raise RuntimeError(f"❌ Model Parameter {operation} verification failed")
@staticmethod
def _convert_ipc_meta_to_tensor(
ipc_meta: Dict[str, Any],
) -> Dict[str, paddle.Tensor]:
"""Convert IPC metadata to tensor dictionary."""
converted = {}
for name, meta in ipc_meta.items():
meta[0] = meta[0].encode("latin-1")
meta[6] = int(os.getenv("FLAGS_selected_gpus", "0"))
tensor = paddle.base.core.LoDTensor._new_shared_cuda(tuple(meta))
converted[name] = paddle.to_tensor(tensor)
return converted
def _log_memory(self, context: str):
"""Log current GPU memory usage."""
max_alloc = paddle.device.cuda.max_memory_allocated() / (1024**3)
max_reserved = paddle.device.cuda.max_memory_reserved() / (1024**3)
curr_alloc = paddle.device.cuda.memory_allocated() / (1024**3)
curr_reserved = paddle.device.cuda.memory_reserved() / (1024**3)
logger.warning(
f"GPU memory usage {context}:"
f"max_allocated: {max_alloc:.2f}GB\n"
f"max_reserved: {max_reserved:.2f}GB\n"
f"current_allocated: {curr_alloc:.2f}GB\n"
f"current_reserved: {curr_reserved:.2f}GB"
)
def _update_shared_status(self, pid: int, status: int) -> None:
"""Update shared memory status flag for inter-process communication."""
array = np.zeros([1], dtype=np.int32)
shm = SharedMemory(create=False, size=array.nbytes, name=f"model_weights_status.{pid}")
value = np.ndarray(array.shape, dtype=array.dtype, buffer=shm.buf)
if self.rank == 0:
value[self.rank] = status
@staticmethod
def check_model_weights_status(model_weights_status, model_runner, pid):
"""
check model weights status
"""
is_stop = 0
while model_weights_status.value[0] != 0:
if model_weights_status.value[0] == 1:
logger.info("infer engine stopped! start to load new checkpoint...")
model_runner.clear_requests()
model_runner.update_parameters(pid)
elif model_weights_status.value[0] == -1:
logger.info("infer engine stopped! start to clear checkpoint...")
model_runner.clear_requests()
model_runner.clear_parameters(pid)
while True:
if model_weights_status.value[0] == 0:
logger.info("finished loading new checkpoint")
break
elif is_stop == 1 or (model_weights_status.value[0] == -2 and is_stop == 0):
if is_stop == 0:
logger.info("finished clearing checkpoint")
is_stop = 1
time.sleep(0.001)
break
else:
time.sleep(0.001)