Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
bd084cc
Add model_transform in create supervised trainer
guptaaryan16 Feb 4, 2023
35b5f72
autopep8 fix
guptaaryan16 Feb 4, 2023
6e2116d
Made changes in the model_transform
guptaaryan16 Feb 4, 2023
febd297
Merge branch 'master' of github.com:guptaaryan16/ignite
guptaaryan16 Feb 4, 2023
1597942
autopep8 fix
guptaaryan16 Feb 4, 2023
c330ca9
Add test for Supervised trainer output transform
guptaaryan16 Feb 7, 2023
ddf7741
autopep8 fix
guptaaryan16 Feb 7, 2023
0cfef6a
changed code formatting
guptaaryan16 Feb 7, 2023
53c770f
Merge branch 'master' of https://github.com/guptaaryan16/ignite
guptaaryan16 Feb 7, 2023
91414a3
Add necessary changes to tests for model transform
guptaaryan16 Feb 14, 2023
b18a4fa
autopep8 fix
guptaaryan16 Feb 14, 2023
6e920bc
Some code formatting changes
guptaaryan16 Feb 14, 2023
5fab439
Merged conflict changes
guptaaryan16 Feb 14, 2023
408e456
autopep8 fix
guptaaryan16 Feb 14, 2023
c6b674d
Made code formatting changes
guptaaryan16 Feb 14, 2023
eeb24e4
Merge changes
guptaaryan16 Feb 14, 2023
7fb561d
autopep8 fix
guptaaryan16 Feb 14, 2023
02b9549
Code formatting changes
guptaaryan16 Feb 14, 2023
ac22c38
Merge branch 'master' of https://github.com/guptaaryan16/ignite
guptaaryan16 Feb 14, 2023
8ff15db
Added test for model_output_transform
guptaaryan16 Feb 15, 2023
c1072c3
autopep8 fix
guptaaryan16 Feb 15, 2023
64abf7f
Changed somethng in the test
guptaaryan16 Feb 15, 2023
4e66c46
Merge branch 'master' of https://github.com/guptaaryan16/ignite
guptaaryan16 Feb 15, 2023
35b1bc7
Updated tests
vfdev-5 Feb 17, 2023
b549d18
Merge branch 'master' into model-transform
vfdev-5 Feb 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions ignite/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def supervised_training_step(
device: Optional[Union[str, torch.device]] = None,
non_blocking: bool = False,
prepare_batch: Callable = _prepare_batch,
model_transform: Callable[[Any], Any] = lambda output: output,
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
gradient_accumulation_steps: int = 1,
) -> Callable:
Expand All @@ -64,6 +65,8 @@ def supervised_training_step(
with respect to the host. For other cases, this argument has no effect.
prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs
tuple of tensors `(batch_x, batch_y)`.
model_transform: function that receives the output from the model and convert it into the form as required
by the loss function
output_transform: function that receives 'x', 'y', 'y_pred', 'loss' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
gradient_accumulation_steps: Number of steps the gradients should be accumulated across.
Expand Down Expand Up @@ -99,7 +102,8 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
optimizer.zero_grad()
model.train()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
output = model(x)
y_pred = model_transform(output)
loss = loss_fn(y_pred, y)
if gradient_accumulation_steps > 1:
loss = loss / gradient_accumulation_steps
Expand All @@ -118,6 +122,7 @@ def supervised_training_step_amp(
device: Optional[Union[str, torch.device]] = None,
non_blocking: bool = False,
prepare_batch: Callable = _prepare_batch,
model_transform: Callable[[Any], Any] = lambda output: output,
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
scaler: Optional["torch.cuda.amp.GradScaler"] = None,
gradient_accumulation_steps: int = 1,
Expand All @@ -135,6 +140,8 @@ def supervised_training_step_amp(
with respect to the host. For other cases, this argument has no effect.
prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs
tuple of tensors `(batch_x, batch_y)`.
model_transform: function that receives the output from the model and convert it into the form as required
by the loss function
output_transform: function that receives 'x', 'y', 'y_pred', 'loss' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
scaler: GradScaler instance for gradient scaling. (default: None)
Expand Down Expand Up @@ -179,7 +186,8 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
model.train()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
with autocast(enabled=True):
y_pred = model(x)
output = model(x)
y_pred = model_transform(output)
loss = loss_fn(y_pred, y)
if gradient_accumulation_steps > 1:
loss = loss / gradient_accumulation_steps
Expand All @@ -204,6 +212,7 @@ def supervised_training_step_apex(
device: Optional[Union[str, torch.device]] = None,
non_blocking: bool = False,
prepare_batch: Callable = _prepare_batch,
model_transform: Callable[[Any], Any] = lambda output: output,
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
gradient_accumulation_steps: int = 1,
) -> Callable:
Expand All @@ -220,6 +229,8 @@ def supervised_training_step_apex(
with respect to the host. For other cases, this argument has no effect.
prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs
tuple of tensors `(batch_x, batch_y)`.
model_transform: function that receives the output from the model and convert it into the form as required
by the loss function
output_transform: function that receives 'x', 'y', 'y_pred', 'loss' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
gradient_accumulation_steps: Number of steps the gradients should be accumulated across.
Expand Down Expand Up @@ -261,7 +272,8 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
optimizer.zero_grad()
model.train()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
output = model(x)
y_pred = model_transform(output)
loss = loss_fn(y_pred, y)
if gradient_accumulation_steps > 1:
loss = loss / gradient_accumulation_steps
Expand All @@ -281,6 +293,7 @@ def supervised_training_step_tpu(
device: Optional[Union[str, torch.device]] = None,
non_blocking: bool = False,
prepare_batch: Callable = _prepare_batch,
model_transform: Callable[[Any], Any] = lambda output: output,
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
gradient_accumulation_steps: int = 1,
) -> Callable:
Expand All @@ -297,6 +310,8 @@ def supervised_training_step_tpu(
with respect to the host. For other cases, this argument has no effect.
prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs
tuple of tensors `(batch_x, batch_y)`.
model_transform: function that receives the output from the model and convert it into the form as required
by the loss function
output_transform: function that receives 'x', 'y', 'y_pred', 'loss' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
gradient_accumulation_steps: Number of steps the gradients should be accumulated across.
Expand Down Expand Up @@ -337,7 +352,8 @@ def update(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[to
optimizer.zero_grad()
model.train()
x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
y_pred = model(x)
output = model(x)
y_pred = model_transform(output)
loss = loss_fn(y_pred, y)
if gradient_accumulation_steps > 1:
loss = loss / gradient_accumulation_steps
Expand Down Expand Up @@ -384,6 +400,7 @@ def create_supervised_trainer(
device: Optional[Union[str, torch.device]] = None,
non_blocking: bool = False,
prepare_batch: Callable = _prepare_batch,
model_transform: Callable[[Any], Any] = lambda output: output,
output_transform: Callable[[Any, Any, Any, torch.Tensor], Any] = lambda x, y, y_pred, loss: loss.item(),
deterministic: bool = False,
amp_mode: Optional[str] = None,
Expand All @@ -403,6 +420,8 @@ def create_supervised_trainer(
with respect to the host. For other cases, this argument has no effect.
prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs
tuple of tensors `(batch_x, batch_y)`.
model_transform: function that receives the output from the model and convert it into the form as required
by the loss function
output_transform: function that receives 'x', 'y', 'y_pred', 'loss' and returns value
to be assigned to engine's state.output after each iteration. Default is returning `loss.item()`.
deterministic: if True, returns deterministic engine of type
Expand Down Expand Up @@ -510,6 +529,7 @@ def output_transform_fn(x, y, y_pred, loss):
device,
non_blocking,
prepare_batch,
model_transform,
output_transform,
_scaler,
gradient_accumulation_steps,
Expand All @@ -522,6 +542,7 @@ def output_transform_fn(x, y, y_pred, loss):
device,
non_blocking,
prepare_batch,
model_transform,
output_transform,
gradient_accumulation_steps,
)
Expand All @@ -533,6 +554,7 @@ def output_transform_fn(x, y, y_pred, loss):
device,
non_blocking,
prepare_batch,
model_transform,
output_transform,
gradient_accumulation_steps,
)
Expand All @@ -544,6 +566,7 @@ def output_transform_fn(x, y, y_pred, loss):
device,
non_blocking,
prepare_batch,
model_transform,
output_transform,
gradient_accumulation_steps,
)
Expand Down Expand Up @@ -662,6 +685,7 @@ def create_supervised_evaluator(
device: Optional[Union[str, torch.device]] = None,
non_blocking: bool = False,
prepare_batch: Callable = _prepare_batch,
model_transform: Callable[[Any], Any] = lambda output: output,
output_transform: Callable[[Any, Any, Any], Any] = lambda x, y, y_pred: (y_pred, y),
amp_mode: Optional[str] = None,
) -> Engine:
Expand Down
Loading