Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 32 additions & 10 deletions ignite/handlers/state_param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,11 @@ class LambdaStateScheduler(StateParamScheduler):

Examples:

.. code-block:: python
.. testsetup::

...
engine = Engine(train_step)
default_trainer = get_default_trainer()

.. testcode::

class LambdaState:
def __init__(self, initial_value, gamma):
Expand All @@ -139,18 +140,39 @@ def __init__(self, initial_value, gamma):
def __call__(self, event_index):
return self.initial_value * self.gamma ** (event_index % 9)


param_scheduler = LambdaStateScheduler(
param_name="param",
lambda_obj=LambdaState(10, 0.99),
param_name="param", lambda_obj=LambdaState(1, 0.9), create_new=True
)

param_scheduler.attach(engine, Events.EPOCH_COMPLETED)
# parameter is param, initial_value sets param to 1 and in this example gamma = 1
# using class 'LambdaState' user defined callable object can be created
# update a parameter during training by using a user defined callable object
# user defined callable object is taking an event index as input and returns parameter value
# in this example, we update as initial_value * gamma ** (event_endex % 9)
# in every Epoch the parameter is updated as 1 * 0.9 ** (Epoch % 9)
# In Epoch 3, parameter param = 1 * 0.9 ** (3 % 9) = 0.729
# In Epoch 10, parameter param = 1 * 0.9 ** (10 % 9) = 0.9

# basic handler to print scheduled state parameter
engine.add_event_handler(Events.EPOCH_COMPLETED, lambda _ : print(engine.state.param))
param_scheduler.attach(default_trainer, Events.EPOCH_COMPLETED)

@default_trainer.on(Events.EPOCH_COMPLETED)
def print_param():
print(default_trainer.state.param)

default_trainer.run([0], max_epochs=10)

engine.run([0] * 8, max_epochs=2)
.. testoutput::

0.9
0.81
0.7290...
0.6561
0.5904...
0.5314...
0.4782...
0.4304...
1.0
0.9

.. versionadded:: 0.5.0

Expand Down