Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 111 additions & 1 deletion python/paddle/hapi/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@
# limitations under the License.

import os
import numbers

from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.utils import try_import

from .progressbar import ProgressBar

__all__ = ['Callback', 'ProgBarLogger', 'ModelCheckpoint']
__all__ = ['Callback', 'ProgBarLogger', 'ModelCheckpoint', 'VisualDL']


def config_callbacks(callbacks=None,
Expand Down Expand Up @@ -469,3 +471,111 @@ def on_train_end(self, logs=None):
path = '{}/final'.format(self.save_dir)
print('save checkpoint at {}'.format(os.path.abspath(path)))
self.model.save(path)


class VisualDL(Callback):
"""VisualDL callback function
Args:
log_dir (str): The directory to save visualdl log file.

Examples:
.. code-block:: python

import paddle
from paddle.static import InputSpec

inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')]
labels = [InputSpec([None, 1], 'int64', 'label')]

train_dataset = paddle.vision.datasets.MNIST(mode='train')
eval_dataset = paddle.vision.datasets.MNIST(mode='test')

net = paddle.vision.LeNet()
model = paddle.Model(net, inputs, labels)

optim = paddle.optimizer.Adam(0.001, parameters=net.parameters())
model.prepare(optimizer=optim,
loss=paddle.nn.CrossEntropyLoss(),
metrics=paddle.metric.Accuracy())

## uncomment following lines to fit model with visualdl callback function
# callback = paddle.callbacks.VisualDL(log_dir='visualdl_log_dir')
# model.fit(train_dataset, eval_dataset, batch_size=64, callbacks=callback)

"""

def __init__(self, log_dir):
self.log_dir = log_dir
self.epochs = None
self.steps = None
self.epoch = 0

def _is_write(self):
return ParallelEnv().local_rank == 0

def on_train_begin(self, logs=None):
self.epochs = self.params['epochs']
assert self.epochs
self.train_metrics = self.params['metrics']
assert self.train_metrics
self._is_fit = True
self.train_step = 0

def on_epoch_begin(self, epoch=None, logs=None):
self.steps = self.params['steps']
self.epoch = epoch

def _updates(self, logs, mode):
if not self._is_write():
return
if not hasattr(self, 'writer'):
visualdl = try_import('visualdl')
self.writer = visualdl.LogWriter(self.log_dir)

metrics = getattr(self, '%s_metrics' % (mode))
current_step = getattr(self, '%s_step' % (mode))

if mode == 'train':
total_step = current_step
else:
total_step = self.epoch

for k in metrics:
if k in logs:
temp_tag = mode + '/' + k

if isinstance(logs[k], (list, tuple)):
temp_value = logs[k][0]
elif isinstance(logs[k], numbers.Number):
temp_value = logs[k]
else:
continue

self.writer.add_scalar(
tag=temp_tag, step=total_step, value=temp_value)

def on_train_batch_end(self, step, logs=None):
logs = logs or {}
self.train_step += 1

if self._is_write():
self._updates(logs, 'train')

def on_eval_begin(self, logs=None):
self.eval_steps = logs.get('steps', None)
self.eval_metrics = logs.get('metrics', [])
self.eval_step = 0
self.evaled_samples = 0

def on_train_end(self, logs=None):
if hasattr(self, 'writer'):
self.writer.close()
delattr(self, 'writer')

def on_eval_end(self, logs=None):
if self._is_write():
self._updates(logs, 'eval')

if (not hasattr(self, '_is_fit')) and hasattr(self, 'writer'):
self.writer.close()
delattr(self, 'writer')
28 changes: 28 additions & 0 deletions python/paddle/tests/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import unittest
import time
import random
import tempfile
import shutil
import paddle

from paddle import Model
from paddle.static import InputSpec
Expand Down Expand Up @@ -102,6 +104,32 @@ def test_callback_verbose_2(self):
self.verbose = 2
self.run_callback()

def test_visualdl_callback(self):
# visualdl not support python3
if sys.version_info < (3, ):
return

inputs = [InputSpec([-1, 1, 28, 28], 'float32', 'image')]
labels = [InputSpec([None, 1], 'int64', 'label')]

train_dataset = paddle.vision.datasets.MNIST(mode='train')
eval_dataset = paddle.vision.datasets.MNIST(mode='test')

net = paddle.vision.LeNet()
model = paddle.Model(net, inputs, labels)

optim = paddle.optimizer.Adam(0.001, parameters=net.parameters())
model.prepare(
optimizer=optim,
loss=paddle.nn.CrossEntropyLoss(),
metrics=paddle.metric.Accuracy())

callback = paddle.callbacks.VisualDL(log_dir='visualdl_log_dir')
model.fit(train_dataset,
eval_dataset,
batch_size=64,
callbacks=callback)


if __name__ == '__main__':
unittest.main()
1 change: 1 addition & 0 deletions python/unittest_py/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ PyGithub
coverage
pycrypto
mock
visualdl ; python_version>="3.5"