Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,9 @@ def fn(layer):

# very last
self._memory_tracker.stop_and_update_metrics()
if self.args.count_trained_tokens:
self.trained_effective_tokens = 0
self.trained_tokens = 0

def _wrap_amp_model(self, args, model):
logger.info("Using half precision")
Expand Down Expand Up @@ -1122,6 +1125,9 @@ def _inner_training_loop(
is_no_sync = True

sync_context = model.no_sync() if is_no_sync else contextlib.nullcontext()
if self.args.count_trained_tokens:
self.trained_effective_tokens += (inputs["input_ids"] != self.args.pad_token_id).sum()
self.trained_tokens += inputs["input_ids"].numel()
with sync_context:
if "step_control" in inspect.signature(self.training_step).parameters:
tr_loss_step = self.training_step(model, inputs, step_control=step_control)
Expand Down Expand Up @@ -1570,6 +1576,27 @@ def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval,
self._save_checkpoint(model, metrics=metrics)
logger.info(f"{self.runtime_timer.log()}")
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
self.log_trained_tokens()

def log_trained_tokens(self):
if self.args.count_trained_tokens:
token_list = []
for token_num in [self.trained_effective_tokens, self.trained_tokens]:
tensors = token_num.reshape([1])
if self.hcg._sharding_degree > 1:
output_tensors = []
paddle.distributed.all_gather(output_tensors, tensors, group=self.hcg._sharding_comm_group)
tensors = paddle.concat(output_tensors).sum().reshape([1])
if self.hcg._dp_degree > 1:
output_tensors = []
paddle.distributed.all_gather(output_tensors, tensors, group=self.hcg._dp_comm_group)
tensors = paddle.concat(output_tensors).sum().reshape([1])
token_list.append(tensors.item())
if self.is_local_process_zero():

logger.info(
f"Update to now, trained_effective_tokens: {token_list[0]}, trained_tokens: {token_list[1]}."
)

def _get_learning_rate(self):
return self.optimizer.get_lr()
Expand Down
8 changes: 8 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,14 @@ class TrainingArguments:
default=300,
metadata={"help": "Timeout seconds for downloading checkpoint from remote cluster."},
)
count_trained_tokens: bool = field(
default=False,
metadata={"help": "Whether to count trained tokens."},
)
pad_token_id: int = field(
default=0,
metadata={"help": "The id of the padding token."},
)

def __post_init__(self):
if in_auto_parallel_align_mode():
Expand Down