Skip to content

Commit 2c2fd43

Browse files
committed
Implement huggingface checkpoint loading
1 parent a1b067e commit 2c2fd43

1 file changed

Lines changed: 118 additions & 14 deletions

File tree

examples/pre-training/ernie/pretrain.py

Lines changed: 118 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@
6363

6464
from config import get_config
6565

66+
from safetensors import safe_open
67+
6668
try:
6769
from paddleformers.trainer.trainer_utils import log_trainer_start
6870
except ImportError:
@@ -202,6 +204,117 @@ def _collate_data(data, stack_fn=Stack()):
202204
return train_dataset, valid_dataset, test_dataset, _collate_data
203205

204206

207+
def load_huggingface_checkpoint(model, args):
208+
fused_rms_norm_replace = [
209+
("self_attn.fused_rms_norm_linear.rms_norm_weight", "input_layernorm.weight"),
210+
("self_attn.fused_rms_norm_linear.linear_weight", "self_attn.qkv_proj.weight"),
211+
]
212+
shared_layers_prefix = "shared_layers.embed_weight_share."
213+
unnamed_layers = ["ernie.norm.weight", "lm_head.weight"]
214+
215+
logger.info(f"Loading huggingface checkpoint from {args.model_name_or_path}")
216+
with open(
217+
os.path.join(args.model_name_or_path, "model.safetensors.index.json")
218+
) as f:
219+
weight_map = json.load(f)["weight_map"]
220+
221+
ep_degree = fleet.get_hybrid_communicate_group().get_expert_parallel_world_size()
222+
ep_rank = fleet.get_hybrid_communicate_group().get_expert_parallel_rank()
223+
expert_offset = (model.config.moe_num_experts // ep_degree) * ep_rank
224+
225+
def param_to_weight(name):
226+
# for PP=1, we only need to substitute the fused_rms_norm and expert_id
227+
for src, dst in fused_rms_norm_replace:
228+
name = name.replace(src, dst)
229+
if m := re.search(r"mlp\.experts\.(\d+)", name):
230+
expert_id = expert_offset + int(m.group(1))
231+
s, e = m.span()
232+
name = name[:s] + f"mlp.experts.{expert_id}" + name[e:]
233+
if isinstance(model, ErnieMoEForCausalLM):
234+
return name
235+
236+
# for PP>1, we also need to handle special layers and adjust layer_idx
237+
if name.startswith(shared_layers_prefix):
238+
return "ernie." + name[len(shared_layers_prefix) :]
239+
layer_idx, stem = name.split(".", maxsplit=1)
240+
if stem == "weight":
241+
return unnamed_layers.pop(0)
242+
if stem.startswith("mtp"):
243+
return f"ernie.{stem}"
244+
return f"ernie.layers.{int(layer_idx) - 1}.{stem}"
245+
246+
def try_torch_format(weight_key):
247+
if weight_key.startswith("ernie."):
248+
weight_key = "model." + weight_key[6:]
249+
250+
key_decompose = [weight_key]
251+
if ".up_gate_proj." in weight_key:
252+
key_decompose = [
253+
weight_key.replace(".up_gate_proj.", ".gate_proj."),
254+
weight_key.replace(".up_gate_proj.", ".up_proj."),
255+
]
256+
elif ".qkv_proj." in weight_key:
257+
key_decompose = [
258+
weight_key.replace(".qkv_proj.", ".q_proj."),
259+
weight_key.replace(".qkv_proj.", ".k_proj."),
260+
weight_key.replace(".qkv_proj.", ".v_proj."),
261+
]
262+
263+
tensor_decompose = []
264+
for key in key_decompose:
265+
if not (weight_file := weight_map.get(key)):
266+
return None
267+
with safe_open(
268+
os.path.join(args.model_name_or_path, weight_file),
269+
framework="numpy",
270+
) as f:
271+
tensor = paddle.to_tensor(f.get_tensor(key))
272+
if "_proj." in key or ".gate." in key:
273+
tensor = tensor.T.contiguous()
274+
tensor_decompose.append(tensor)
275+
276+
if len(tensor_decompose) == 1:
277+
return tensor_decompose[0]
278+
else:
279+
return paddle.concat(tensor_decompose, axis=-1)
280+
281+
def auto_fix_shape(param, weight):
282+
assert len(param.shape) == len(weight.shape), "rank not match"
283+
if (
284+
len(param.shape) == 2
285+
and param.shape[0] == weight.shape[1]
286+
and param.shape[1] == weight.shape[0]
287+
):
288+
return weight.T.contiguous()
289+
assert all(
290+
p_dim <= w_dim for p_dim, w_dim in zip(param.shape, weight.shape)
291+
), "weight too small"
292+
indices = tuple(slice(0, dim) for dim in param.shape)
293+
return weight[indices].contiguous()
294+
295+
for name, param in model.named_parameters():
296+
weight_key = param_to_weight(name)
297+
if weight_file := weight_map.get(weight_key):
298+
with safe_open(
299+
os.path.join(args.model_name_or_path, weight_file),
300+
framework="numpy",
301+
) as f:
302+
weight = paddle.to_tensor(f.get_tensor(weight_key))
303+
elif (weight := try_torch_format(weight_key)) is None:
304+
logger.warning(
305+
f"param `{name}`'s weight `{weight_key}` not found. "
306+
"Skip initializing."
307+
)
308+
continue
309+
if param.shape != weight.shape:
310+
logger.warning(
311+
f"param `{name}`'s shape doesn't match weight `{weight_key}`: "
312+
f"{param.shape} and {weight.shape}. Auto fixing."
313+
)
314+
weight = auto_fix_shape(param, weight)
315+
param.copy_(weight)
316+
317+
205318
def main():
206319
if set_affinity is not None:
207320
set_affinity_code = set_affinity()
@@ -520,21 +633,12 @@ def sname_to_tname(pp_model):
520633
cfg.enable_delay_scale_loss = args.enable_delay_scale_loss
521634
register_pp_reshard_information(cfg.num_hidden_layers)
522635

523-
if args.from_scratch:
524-
model = ErnieMoEForCausalLMPipe(cfg)
525-
else:
526-
model = ErnieMoEForCausalLMPipe.from_pretrained(
527-
args.model_name_or_path,
528-
config=cfg,
529-
)
636+
model = ErnieMoEForCausalLMPipe(cfg)
530637
else:
531-
if args.from_scratch:
532-
model = ErnieMoEForCausalLM(cfg)
533-
else:
534-
model = ErnieMoEForCausalLM.from_pretrained(
535-
args.model_name_or_path,
536-
config=cfg,
537-
)
638+
model = ErnieMoEForCausalLM(cfg)
639+
640+
if not args.from_scratch:
641+
load_huggingface_checkpoint(model, args)
538642

539643
cfg = model.config
540644
logger.info(f"using model type:{type(model)}")

0 commit comments

Comments
 (0)