Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 26 additions & 3 deletions ignite/handlers/param_scheduler.py
Comment thread
sihyeong671 marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ def __init__(
cycle_mult: float = 1.0,
start_value_mult: float = 1.0,
end_value_mult: float = 1.0,
warmup_each_cycle: bool = False,
warmup_duration: Optional[int] = None,
save_history: bool = False,
param_group_index: Optional[int] = None,
):
Expand All @@ -313,6 +315,18 @@ def __init__(
self.cycle = 0
self.start_value_mult = start_value_mult
self.end_value_mult = end_value_mult
self.warmup_each_cycle = warmup_each_cycle
if not self.warmup_each_cycle:
if warmup_duration is not None:
warnings.warn(
Comment thread
sihyeong671 marked this conversation as resolved.
Outdated
Comment thread
sihyeong671 marked this conversation as resolved.
Outdated
f"warmup_each_cycle=False but your warmup_duration is {warmup_duration}. "
f"so warmup_duration will be set to 0. "
f"If you want to use warmup each cycle, pleas set warmup_each_cycle=True"
Comment thread
sihyeong671 marked this conversation as resolved.
Outdated
)
self.warmup_duration = 0
else:
self.warmup_duration = warmup_duration
Comment thread
sihyeong671 marked this conversation as resolved.
Outdated
self.total_cycle_size = self.warmup_duration + self.cycle_size

if self.cycle_size < 2:
raise ValueError(f"Argument cycle_size should be positive and larger than 1, but given {cycle_size}")
Expand All @@ -325,18 +339,26 @@ def __init__(
"cycle",
"start_value_mult",
"end_value_mult",
"warmup_duration",
]

def __call__(self, engine: Optional[Engine], name: Optional[str] = None) -> None:
if self.event_index != 0 and self.event_index % self.cycle_size == 0:
if self.event_index != 0 and self.event_index % self.total_cycle_size == 0:
self.event_index = 0
self.cycle_size = int(self.cycle_size * self.cycle_mult)
self.warmup_duration = int(self.warmup_duration * self.cycle_mult)
self.total_cycle_size = int(self.warmup_duration + self.cycle_size)
self.cycle += 1
self.start_value *= self.start_value_mult
if self.event_index != 0 and self.event_index == self.warmup_duration:
self.end_value *= self.end_value_mult

return super(CyclicalScheduler, self).__call__(engine, name)

def _get_cycle_param(self):
Comment thread
sihyeong671 marked this conversation as resolved.
Outdated
cycle_progress = (self.event_index - self.warmup_duration) / self.cycle_size
return self.start_value + ((self.end_value - self.start_value) / 2) * (1 - math.cos(math.pi * cycle_progress))


class LinearCyclicalScheduler(CyclicalScheduler):
"""Linearly adjusts param value to 'end_value' for a half-cycle, then linearly
Expand Down Expand Up @@ -538,8 +560,9 @@ def print_lr():

def get_param(self) -> float:
"""Method to get current optimizer's parameter value"""
Comment thread
sihyeong671 marked this conversation as resolved.
cycle_progress = self.event_index / self.cycle_size
return self.start_value + ((self.end_value - self.start_value) / 2) * (1 - math.cos(math.pi * cycle_progress))
if self.warmup_each_cycle and self.event_index < self.warmup_duration:
return self.end_value + (self.start_value - self.end_value) * self.event_index / self.warmup_duration
return self._get_cycle_param()


class ConcatScheduler(ParamScheduler):
Expand Down