1- from typing import Callable , cast , Optional , Sequence , Union
1+ import warnings
2+ from typing import Any , Callable , cast , Optional , Union
23
34import torch
45
56import ignite .distributed as idist
67from ignite .engine import Engine , Events
7- from ignite .metrics .metric import EpochWise , Metric , MetricUsage , reinit__is_reduced , sync_all_reduce
8+ from ignite .metrics .metric import Metric , MetricUsage , reinit__is_reduced , RunningBatchWise , SingleEpochRunningBatchWise
89
910__all__ = ["RunningAverage" ]
1011
@@ -18,8 +19,10 @@ class RunningAverage(Metric):
1819 alpha: running average decay factor, default 0.98
1920 output_transform: a function to use to transform the output if `src` is None and
2021 corresponds the output of process function. Otherwise it should be None.
21- epoch_bound: whether the running average should be reset after each epoch (defaults
22- to True).
22+ epoch_bound: whether the running average should be reset after each epoch. It is depracated in favor of
23+ ``usage`` argument in :meth:`attach` method. Setting ``epoch_bound`` to ``False`` is equivalent to
24+ ``usage=SingleEpochRunningBatchWise()`` and setting it to ``True`` is equivalent to
25+ ``usage=RunningBatchWise()`` in the :meth:`attach` method. Default None.
2326 device: specifies which device updates are accumulated on. Should be
2427 None when ``src`` is an instance of :class:`~ignite.metrics.metric.Metric`, as the running average will
2528 use the ``src``'s device. Otherwise, defaults to CPU. Only applicable when the computed value
@@ -90,7 +93,7 @@ def __init__(
9093 src : Optional [Metric ] = None ,
9194 alpha : float = 0.98 ,
9295 output_transform : Optional [Callable ] = None ,
93- epoch_bound : bool = True ,
96+ epoch_bound : Optional [ bool ] = None ,
9497 device : Optional [Union [str , torch .device ]] = None ,
9598 ):
9699 if not (isinstance (src , Metric ) or src is None ):
@@ -101,70 +104,119 @@ def __init__(
101104 if isinstance (src , Metric ):
102105 if output_transform is not None :
103106 raise ValueError ("Argument output_transform should be None if src is a Metric." )
107+
108+ def output_transform (x : Any ) -> Any :
109+ return x
110+
104111 if device is not None :
105112 raise ValueError ("Argument device should be None if src is a Metric." )
106- self .src = src
107- self ._get_src_value = self ._get_metric_value
108- setattr (self , "iteration_completed" , self ._metric_iteration_completed )
113+ self .src : Union [Metric , None ] = src
109114 device = src ._device
110115 else :
111116 if output_transform is None :
112117 raise ValueError (
113118 "Argument output_transform should not be None if src corresponds "
114119 "to the output of process function."
115120 )
116- self ._get_src_value = self ._get_output_value
117- setattr (self , "update" , self ._output_update )
121+ self .src = None
118122 if device is None :
119123 device = torch .device ("cpu" )
120124
121- self .alpha = alpha
125+ if epoch_bound is not None :
126+ warnings .warn (
127+ "`epoch_bound` is deprecated and will be removed in the future. Consider using `usage` argument of"
128+ "`attach` method instead. `epoch_bound=True` is equivalent with `usage=SingleEpochRunningBatchWise()`"
129+ " and `epoch_bound=False` is equivalent with `usage=RunningBatchWise()`."
130+ )
122131 self .epoch_bound = epoch_bound
123- super (RunningAverage , self ).__init__ (output_transform = output_transform , device = device ) # type: ignore[arg-type]
132+ self .alpha = alpha
133+ super (RunningAverage , self ).__init__ (output_transform = output_transform , device = device )
124134
125135 @reinit__is_reduced
126136 def reset (self ) -> None :
127137 self ._value : Optional [Union [float , torch .Tensor ]] = None
138+ if isinstance (self .src , Metric ):
139+ self .src .reset ()
128140
129141 @reinit__is_reduced
130- def update (self , output : Sequence ) -> None :
131- # Implement abstract method
132- pass
133-
134- def compute (self ) -> Union [torch .Tensor , float ]:
135- if self ._value is None :
136- self ._value = self ._get_src_value ()
142+ def update (self , output : Union [torch .Tensor , float ]) -> None :
143+ if self .src is None :
144+ output = output .detach ().to (self ._device , copy = True ) if isinstance (output , torch .Tensor ) else output
145+ value = idist .all_reduce (output ) / idist .get_world_size ()
137146 else :
138- self ._value = self ._value * self .alpha + (1.0 - self .alpha ) * self ._get_src_value ()
147+ value = self .src .compute ()
148+ self .src .reset ()
139149
140- return self ._value
141-
142- def attach (self , engine : Engine , name : str , _usage : Union [str , MetricUsage ] = EpochWise ()) -> None :
143- if self .epoch_bound :
144- # restart average every epoch
145- engine .add_event_handler (Events .EPOCH_STARTED , self .started )
150+ if self ._value is None :
151+ self ._value = value
146152 else :
147- engine .add_event_handler (Events .STARTED , self .started )
148- # compute metric
149- engine .add_event_handler (Events .ITERATION_COMPLETED , self .iteration_completed )
150- # apply running average
151- engine .add_event_handler (Events .ITERATION_COMPLETED , self .completed , name )
152-
153- def _get_metric_value (self ) -> Union [torch .Tensor , float ]:
154- return self .src .compute ()
155-
156- @sync_all_reduce ("src" )
157- def _get_output_value (self ) -> Union [torch .Tensor , float ]:
158- # we need to compute average instead of sum produced by @sync_all_reduce("src")
159- output = cast (Union [torch .Tensor , float ], self .src ) / idist .get_world_size ()
160- return output
153+ self ._value = self ._value * self .alpha + (1.0 - self .alpha ) * value
161154
162- def _metric_iteration_completed (self , engine : Engine ) -> None :
163- self .src .started (engine )
164- self .src .iteration_completed (engine )
165-
166- @reinit__is_reduced
167- def _output_update (self , output : Union [torch .Tensor , float ]) -> None :
168- if isinstance (output , torch .Tensor ):
169- output = output .detach ().to (self ._device , copy = True )
170- self .src = output # type: ignore[assignment]
155+ def compute (self ) -> Union [torch .Tensor , float ]:
156+ return cast (Union [torch .Tensor , float ], self ._value )
157+
158+ def attach (self , engine : Engine , name : str , usage : Union [str , MetricUsage ] = RunningBatchWise ()) -> None :
159+ r"""
160+ Attach the metric to the ``engine`` using the events determined by the ``usage``.
161+
162+ Args:
163+ engine: the engine to get attached to.
164+ name: by which, the metric is inserted into ``engine.state.metrics`` dictionary.
165+ usage: the usage determining on which events the metric is reset, updated and computed. It should be an
166+ instance of the :class:`~ignite.metrics.metric.MetricUsage`\ s in the following table.
167+
168+ ======================================================= ===========================================
169+ ``usage`` **class** **Description**
170+ ======================================================= ===========================================
171+ :class:`~.metrics.metric.RunningBatchWise` Running average of the ``src`` metric or
172+ ``engine.state.output`` is computed across
173+ batches. In the former case, on each batch,
174+ ``src`` is reset, updated and computed then
175+ its value is retrieved. Default.
176+ :class:`~.metrics.metric.SingleEpochRunningBatchWise` Same as above but the running average is
177+ computed across batches in an epoch so it
178+ is reset at the end of the epoch.
179+ :class:`~.metrics.metric.RunningEpochWise` Running average of the ``src`` metric or
180+ ``engine.state.output`` is computed across
181+ epochs. In the former case, ``src`` works
182+ as if it was attached in a
183+ :class:`~ignite.metrics.metric.EpochWise`
184+ manner and its computed value is retrieved
185+ at the end of the epoch. The latter case
186+ doesn't make much sense for this usage as
187+ the ``engine.state.output`` of the last
188+ batch is retrieved then.
189+ ======================================================= ===========================================
190+
191+ ``RunningAverage`` retrieves ``engine.state.output`` at ``usage.ITERATION_COMPLETED`` if the ``src`` is not
192+ given and it's computed and updated using ``src``, by manually calling its ``compute`` method, or
193+ ``engine.state.output`` at ``usage.COMPLETED`` event.
194+ Also if ``src`` is given, it is updated at ``usage.ITERATION_COMPLETED``, but its reset event is determined by
195+ ``usage`` type. If ``isinstance(usage, BatchWise)`` holds true, ``src`` is reset on ``BatchWise().STARTED``,
196+ otherwise on ``EpochWise().STARTED`` if ``isinstance(usage, EpochWise)``.
197+
198+ .. versionchanged:: 0.5.1
199+ Added `usage` argument
200+ """
201+ usage = self ._check_usage (usage )
202+ if self .epoch_bound is not None :
203+ usage = SingleEpochRunningBatchWise () if self .epoch_bound else RunningBatchWise ()
204+
205+ if isinstance (self .src , Metric ) and not engine .has_event_handler (
206+ self .src .iteration_completed , Events .ITERATION_COMPLETED
207+ ):
208+ engine .add_event_handler (Events .ITERATION_COMPLETED , self .src .iteration_completed )
209+
210+ super ().attach (engine , name , usage )
211+
212+ def detach (self , engine : Engine , usage : Union [str , MetricUsage ] = RunningBatchWise ()) -> None :
213+ usage = self ._check_usage (usage )
214+ if self .epoch_bound is not None :
215+ usage = SingleEpochRunningBatchWise () if self .epoch_bound else RunningBatchWise ()
216+
217+ if isinstance (self .src , Metric ) and engine .has_event_handler (
218+ self .src .iteration_completed , Events .ITERATION_COMPLETED
219+ ):
220+ engine .remove_event_handler (self .src .iteration_completed , Events .ITERATION_COMPLETED )
221+
222+ super ().detach (engine , usage )
0 commit comments