Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 45 additions & 6 deletions ignite/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def supervised_training_step(
model_transform: Callable[[Any], Any] = lambda output: output,
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
gradient_accumulation_steps: int = 1,
model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
) -> Callable:
"""Factory function for supervised training.

Expand All @@ -71,6 +72,8 @@ def supervised_training_step(
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
gradient_accumulation_steps: Number of steps the gradients should be accumulated across.
(default: 1 (means no gradient accumulation))
model_fn: the model function that receives `model` and `x`, and returns `y_pred`.

Returns:
Callable: update function.

Expand All @@ -91,6 +94,8 @@ def supervised_training_step(
Added Gradient Accumulation.
.. versionchanged:: 0.4.11
Added `model_transform` to transform model's output
.. versionchanged:: 0.4.13
Added `model_fn` to customize model's application on the sample
"""

if gradient_accumulation_steps <= 0:
Expand All @@ -104,7 +109,7 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
optimizer.zero_grad()
model.train()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
output = model(x)
output = model_fn(model, x)
y_pred = model_transform(output)
loss = loss_fn(y_pred, y)
if gradient_accumulation_steps > 1:
Expand All @@ -128,6 +133,7 @@ def supervised_training_step_amp(
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
scaler: Optional["torch.cuda.amp.GradScaler"] = None,
gradient_accumulation_steps: int = 1,
model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
) -> Callable:
"""Factory function for supervised training using ``torch.cuda.amp``.

Expand All @@ -149,6 +155,7 @@ def supervised_training_step_amp(
scaler: GradScaler instance for gradient scaling. (default: None)
gradient_accumulation_steps: Number of steps the gradients should be accumulated across.
(default: 1 (means no gradient accumulation))
model_fn: the model function that receives `model` and `x`, and returns `y_pred`.

Returns:
Callable: update function
Expand All @@ -171,6 +178,8 @@ def supervised_training_step_amp(
Added Gradient Accumulation.
.. versionchanged:: 0.4.11
Added `model_transform` to transform model's output
.. versionchanged:: 0.4.13
Added `model_fn` to customize model's application on the sample
"""

try:
Expand All @@ -190,7 +199,7 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
model.train()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
with autocast(enabled=True):
output = model(x)
output = model_fn(model, x)
y_pred = model_transform(output)
loss = loss_fn(y_pred, y)
if gradient_accumulation_steps > 1:
Expand Down Expand Up @@ -219,6 +228,7 @@ def supervised_training_step_apex(
model_transform: Callable[[Any], Any] = lambda output: output,
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
gradient_accumulation_steps: int = 1,
model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
) -> Callable:
"""Factory function for supervised training using apex.

Expand All @@ -239,6 +249,7 @@ def supervised_training_step_apex(
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
gradient_accumulation_steps: Number of steps the gradients should be accumulated across.
(default: 1 (means no gradient accumulation))
model_fn: the model function that receives `model` and `x`, and returns `y_pred`.

Returns:
Callable: update function.
Expand All @@ -260,6 +271,8 @@ def supervised_training_step_apex(
Added Gradient Accumulation.
.. versionchanged:: 0.4.11
Added `model_transform` to transform model's output
.. versionchanged:: 0.4.13
Added `model_fn` to customize model's application on the sample
"""

try:
Expand All @@ -278,7 +291,7 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
optimizer.zero_grad()
model.train()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
output = model(x)
output = model_fn(model, x)
y_pred = model_transform(output)
loss = loss_fn(y_pred, y)
if gradient_accumulation_steps > 1:
Expand All @@ -302,6 +315,7 @@ def supervised_training_step_tpu(
model_transform: Callable[[Any], Any] = lambda output: output,
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
gradient_accumulation_steps: int = 1,
model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
) -> Callable:
"""Factory function for supervised training using ``torch_xla``.

Expand All @@ -322,6 +336,7 @@ def supervised_training_step_tpu(
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
gradient_accumulation_steps: Number of steps the gradients should be accumulated across.
(default: 1 (means no gradient accumulation))
model_fn: the model function that receives `model` and `x`, and returns `y_pred`.

Returns:
Callable: update function.
Expand All @@ -343,6 +358,8 @@ def supervised_training_step_tpu(
Added Gradient Accumulation argument for all supervised training methods.
.. versionchanged:: 0.4.11
Added `model_transform` to transform model's output
.. versionchanged:: 0.4.13
Added `model_fn` to customize model's application on the sample
"""
try:
import torch_xla.core.xla_model as xm
Expand All @@ -360,7 +377,7 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
optimizer.zero_grad()
model.train()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
output = model(x)
output = model_fn(model, x)
y_pred = model_transform(output)
loss = loss_fn(y_pred, y)
if gradient_accumulation_steps > 1:
Expand Down Expand Up @@ -414,6 +431,7 @@ def create_supervised_trainer(
amp_mode: Optional[str] = None,
scaler: Union[bool, "torch.cuda.amp.GradScaler"] = False,
gradient_accumulation_steps: int = 1,
model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
) -> Engine:
"""Factory function for creating a trainer for supervised models.

Expand Down Expand Up @@ -444,6 +462,7 @@ def create_supervised_trainer(
(default: False)
gradient_accumulation_steps: Number of steps the gradients should be accumulated across.
(default: 1 (means no gradient accumulation))
model_fn: the model function that receives `model` and `x`, and returns `y_pred`.

Returns:
a trainer engine with supervised update function.
Expand Down Expand Up @@ -525,6 +544,8 @@ def output_transform_fn(x, y, y_pred, loss):
Added Gradient Accumulation argument for all supervised training methods.
.. versionchanged:: 0.4.11
Added ``model_transform`` to transform model's output
.. versionchanged:: 0.4.13
Added `model_fn` to customize model's application on the sample
"""

device_type = device.type if isinstance(device, torch.device) else device
Expand All @@ -543,6 +564,7 @@ def output_transform_fn(x, y, y_pred, loss):
output_transform,
_scaler,
gradient_accumulation_steps,
model_fn,
)
elif mode == "apex":
_update = supervised_training_step_apex(
Expand All @@ -555,6 +577,7 @@ def output_transform_fn(x, y, y_pred, loss):
model_transform,
output_transform,
gradient_accumulation_steps,
model_fn,
)
elif mode == "tpu":
_update = supervised_training_step_tpu(
Expand All @@ -567,6 +590,7 @@ def output_transform_fn(x, y, y_pred, loss):
model_transform,
output_transform,
gradient_accumulation_steps,
model_fn,
)
else:
_update = supervised_training_step(
Expand All @@ -579,6 +603,7 @@ def output_transform_fn(x, y, y_pred, loss):
model_transform,
output_transform,
gradient_accumulation_steps,
model_fn,
)

trainer = Engine(_update) if not deterministic else DeterministicEngine(_update)
Expand All @@ -595,6 +620,7 @@ def supervised_evaluation_step(
prepare_batch: Callable = _prepare_batch,
model_transform: Callable[[Any], Any] = lambda output: output,
output_transform: Callable[[Any, Any, Any], Any] = lambda x, y, y_pred: (y_pred, y),
model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
) -> Callable:
"""
Factory function for supervised evaluation.
Expand All @@ -612,6 +638,7 @@ def supervised_evaluation_step(
output_transform: function that receives 'x', 'y', 'y_pred' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `(y_pred, y,)` which fits
output expected by metrics. If you change it you should use `output_transform` in metrics.
model_fn: the model function that receives `model` and `x`, and returns `y_pred`.

Returns:
Inference function.
Expand All @@ -629,13 +656,15 @@ def supervised_evaluation_step(
.. versionadded:: 0.4.5
.. versionchanged:: 0.4.12
Added ``model_transform`` to transform model's output
.. versionchanged:: 0.4.13
Added `model_fn` to customize model's application on the sample
"""

def evaluate_step(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
model.eval()
with torch.no_grad():
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
output = model(x)
output = model_fn(model, x)
y_pred = model_transform(output)
return output_transform(x, y, y_pred)

Expand All @@ -649,6 +678,7 @@ def supervised_evaluation_step_amp(
prepare_batch: Callable = _prepare_batch,
model_transform: Callable[[Any], Any] = lambda output: output,
output_transform: Callable[[Any, Any, Any], Any] = lambda x, y, y_pred: (y_pred, y),
model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
) -> Callable:
"""
Factory function for supervised evaluation using ``torch.cuda.amp``.
Expand All @@ -666,6 +696,7 @@ def supervised_evaluation_step_amp(
output_transform: function that receives 'x', 'y', 'y_pred' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `(y_pred, y,)` which fits
output expected by metrics. If you change it you should use `output_transform` in metrics.
model_fn: the model function that receives `model` and `x`, and returns `y_pred`.

Returns:
Inference function.
Expand All @@ -683,6 +714,8 @@ def supervised_evaluation_step_amp(
.. versionadded:: 0.4.5
.. versionchanged:: 0.4.12
Added ``model_transform`` to transform model's output
.. versionchanged:: 0.4.13
Added `model_fn` to customize model's application on the sample
"""
try:
from torch.cuda.amp import autocast
Expand All @@ -694,7 +727,7 @@ def evaluate_step(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, T
with torch.no_grad():
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
with autocast(enabled=True):
output = model(x)
output = model_fn(model, x)
y_pred = model_transform(output)
return output_transform(x, y, y_pred)

Expand All @@ -710,6 +743,7 @@ def create_supervised_evaluator(
model_transform: Callable[[Any], Any] = lambda output: output,
output_transform: Callable[[Any, Any, Any], Any] = lambda x, y, y_pred: (y_pred, y),
amp_mode: Optional[str] = None,
model_fn: Callable[[torch.nn.Module, Any], Any] = lambda model, x: model(x),
) -> Engine:
"""
Factory function for creating an evaluator for supervised models.
Expand All @@ -730,6 +764,7 @@ def create_supervised_evaluator(
output expected by metrics. If you change it you should use `output_transform` in metrics.
amp_mode: can be ``amp``, model will be casted to float16 using
`torch.cuda.amp <https://pytorch.org/docs/stable/amp.html>`_
model_fn: the model function that receives `model` and `x`, and returns `y_pred`.

Returns:
an evaluator engine with supervised inference function.
Expand All @@ -754,6 +789,8 @@ def create_supervised_evaluator(
Added ``amp_mode`` argument for automatic mixed precision.
.. versionchanged:: 0.4.12
Added ``model_transform`` to transform model's output
.. versionchanged:: 0.4.13
Added `model_fn` to customize model's application on the sample
"""
device_type = device.type if isinstance(device, torch.device) else device
on_tpu = "xla" in device_type if device_type is not None else False
Expand All @@ -768,6 +805,7 @@ def create_supervised_evaluator(
prepare_batch=prepare_batch,
model_transform=model_transform,
output_transform=output_transform,
model_fn=model_fn,
)
else:
evaluate_step = supervised_evaluation_step(
Expand All @@ -777,6 +815,7 @@ def create_supervised_evaluator(
prepare_batch=prepare_batch,
model_transform=model_transform,
output_transform=output_transform,
model_fn=model_fn,
)

evaluator = Engine(evaluate_step)
Expand Down
Loading