Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion paddlenlp/trainer/auto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,6 @@ def _save(
output_dir: Optional[str] = None,
state_dict=None,
merge_tensor_parallel=False,
signal_dir: Optional[str] = None,
):
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
Expand Down
18 changes: 10 additions & 8 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2292,7 +2292,6 @@ def save_model(
self,
output_dir: Optional[str] = None,
merge_tensor_parallel: Optional[bool] = False,
signal_dir: Optional[str] = None,
):
"""
Will save the model, so you can reload it using `from_pretrained()`.
Expand All @@ -2303,14 +2302,16 @@ def save_model(
if output_dir is None:
output_dir = self.args.output_dir

if signal_dir is None:
if PREFIX_CHECKPOINT_DIR in output_dir:
signal_dir = os.path.join(self.args.output_signal_dir, os.path.split(output_dir)[-1])
else:
signal_dir = self.args.output_signal_dir

if ShardingOption.FULL_SHARD in self.args.sharding:
self.model_wrapped.get_all_parameters(convert2cpu=True)

if self.args.should_save_model_state:
self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel, signal_dir=signal_dir)
self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel)
else:
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
os.makedirs(signal_dir, exist_ok=True)
Expand Down Expand Up @@ -2370,11 +2371,11 @@ def _save_checkpoint(self, model, metrics=None):
signal_dir = os.path.join(run_signal_dir, checkpoint_folder)

if isinstance(self.model, LoRAModel) and (self.model.quantized or self.args.pipeline_parallel_degree > 1):
self.save_model(output_dir, False, signal_dir)
self.save_model(output_dir)
elif isinstance(self.model, LoRAModel) or isinstance(self.model, PrefixModelForCausalLM):
self.save_model(output_dir, True, signal_dir)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

signal_dir = os.path.join(signal_dir, os.path.split(output_dir)[-1])

self.save_model(output_dir, True)
else:
self.save_model(output_dir, False, signal_dir)
self.save_model(output_dir)

# only save model state dict, ignore optimizer and scheduler
if not self.args.ignore_save_lr_and_optim:
Expand Down Expand Up @@ -2591,15 +2592,16 @@ def _save(
output_dir: Optional[str] = None,
state_dict=None,
merge_tensor_parallel=False,
signal_dir: Optional[str] = None,
):
output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}")

# signal_dir is used for asynchronous saving situations.
signal_dir = self.args.output_signal_dir
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
signal_dir = signal_dir if signal_dir is not None else self.args.output_signal_dir
if PREFIX_CHECKPOINT_DIR in output_dir:
signal_dir = os.path.join(signal_dir, os.path.split(output_dir)[-1])
os.makedirs(signal_dir, exist_ok=True)
logger.info(f"Saving model checkpoint finish signal to {signal_dir}")

Expand Down