@@ -524,7 +524,7 @@ def output_transform_fn(x, y, y_pred, loss):
524524 .. versionchanged:: 0.4.7
525525 Added Gradient Accumulation argument for all supervised training methods.
526526 .. versionchanged:: 0.4.11
527- Added `model_transform` to transform model's output
527+ Added `` model_transform` ` to transform model's output
528528 """
529529
530530 device_type = device .type if isinstance (device , torch .device ) else device
@@ -593,6 +593,7 @@ def supervised_evaluation_step(
593593 device : Optional [Union [str , torch .device ]] = None ,
594594 non_blocking : bool = False ,
595595 prepare_batch : Callable = _prepare_batch ,
596+ model_transform : Callable [[Any ], Any ] = lambda output : output ,
596597 output_transform : Callable [[Any , Any , Any ], Any ] = lambda x , y , y_pred : (y_pred , y ),
597598) -> Callable :
598599 """
@@ -606,6 +607,8 @@ def supervised_evaluation_step(
606607 with respect to the host. For other cases, this argument has no effect.
607608 prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs
608609 tuple of tensors `(batch_x, batch_y)`.
610+ model_transform: function that receives the output from the model and convert it into the predictions:
611+ ``y_pred = model_transform(model(x))``.
609612 output_transform: function that receives 'x', 'y', 'y_pred' and returns value
610613 to be assigned to engine's state.output after each iteration. Default is returning `(y_pred, y,)` which fits
611614 output expected by metrics. If you change it you should use `output_transform` in metrics.
@@ -624,13 +627,16 @@ def supervised_evaluation_step(
624627 The `model` should be moved by the user before creating an optimizer.
625628
626629 .. versionadded:: 0.4.5
630+ .. versionchanged:: 0.4.12
631+ Added ``model_transform`` to transform model's output
627632 """
628633
629634 def evaluate_step (engine : Engine , batch : Sequence [torch .Tensor ]) -> Union [Any , Tuple [torch .Tensor ]]:
630635 model .eval ()
631636 with torch .no_grad ():
632637 x , y = prepare_batch (batch , device = device , non_blocking = non_blocking )
633- y_pred = model (x )
638+ output = model (x )
639+ y_pred = model_transform (output )
634640 return output_transform (x , y , y_pred )
635641
636642 return evaluate_step
@@ -641,6 +647,7 @@ def supervised_evaluation_step_amp(
641647 device : Optional [Union [str , torch .device ]] = None ,
642648 non_blocking : bool = False ,
643649 prepare_batch : Callable = _prepare_batch ,
650+ model_transform : Callable [[Any ], Any ] = lambda output : output ,
644651 output_transform : Callable [[Any , Any , Any ], Any ] = lambda x , y , y_pred : (y_pred , y ),
645652) -> Callable :
646653 """
@@ -654,6 +661,8 @@ def supervised_evaluation_step_amp(
654661 with respect to the host. For other cases, this argument has no effect.
655662 prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs
656663 tuple of tensors `(batch_x, batch_y)`.
664+ model_transform: function that receives the output from the model and convert it into the predictions:
665+ ``y_pred = model_transform(model(x))``.
657666 output_transform: function that receives 'x', 'y', 'y_pred' and returns value
658667 to be assigned to engine's state.output after each iteration. Default is returning `(y_pred, y,)` which fits
659668 output expected by metrics. If you change it you should use `output_transform` in metrics.
@@ -672,6 +681,8 @@ def supervised_evaluation_step_amp(
672681 The `model` should be moved by the user before creating an optimizer.
673682
674683 .. versionadded:: 0.4.5
684+ .. versionchanged:: 0.4.12
685+ Added ``model_transform`` to transform model's output
675686 """
676687 try :
677688 from torch .cuda .amp import autocast
@@ -683,7 +694,8 @@ def evaluate_step(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, T
683694 with torch .no_grad ():
684695 x , y = prepare_batch (batch , device = device , non_blocking = non_blocking )
685696 with autocast (enabled = True ):
686- y_pred = model (x )
697+ output = model (x )
698+ y_pred = model_transform (output )
687699 return output_transform (x , y , y_pred )
688700
689701 return evaluate_step
@@ -711,6 +723,8 @@ def create_supervised_evaluator(
711723 with respect to the host. For other cases, this argument has no effect.
712724 prepare_batch: function that receives `batch`, `device`, `non_blocking` and outputs
713725 tuple of tensors `(batch_x, batch_y)`.
726+ model_transform: function that receives the output from the model and convert it into the predictions:
727+ ``y_pred = model_transform(model(x))``.
714728 output_transform: function that receives 'x', 'y', 'y_pred' and returns value
715729 to be assigned to engine's state.output after each iteration. Default is returning `(y_pred, y,)` which fits
716730 output expected by metrics. If you change it you should use `output_transform` in metrics.
@@ -737,17 +751,33 @@ def create_supervised_evaluator(
737751 - `PyTorch's Explanation <https://github.com/pytorch/pytorch/issues/7844#issuecomment-503713840>`_
738752
739753 .. versionchanged:: 0.4.5
740- - Added ``amp_mode`` argument for automatic mixed precision.
754+ Added ``amp_mode`` argument for automatic mixed precision.
755+ .. versionchanged:: 0.4.12
756+ Added ``model_transform`` to transform model's output
741757 """
742758 device_type = device .type if isinstance (device , torch .device ) else device
743759 on_tpu = "xla" in device_type if device_type is not None else False
744760 mode , _ = _check_arg (on_tpu , amp_mode , None )
745761
746762 metrics = metrics or {}
747763 if mode == "amp" :
748- evaluate_step = supervised_evaluation_step_amp (model , device , non_blocking , prepare_batch , output_transform )
764+ evaluate_step = supervised_evaluation_step_amp (
765+ model ,
766+ device ,
767+ non_blocking = non_blocking ,
768+ prepare_batch = prepare_batch ,
769+ model_transform = model_transform ,
770+ output_transform = output_transform ,
771+ )
749772 else :
750- evaluate_step = supervised_evaluation_step (model , device , non_blocking , prepare_batch , output_transform )
773+ evaluate_step = supervised_evaluation_step (
774+ model ,
775+ device ,
776+ non_blocking = non_blocking ,
777+ prepare_batch = prepare_batch ,
778+ model_transform = model_transform ,
779+ output_transform = output_transform ,
780+ )
751781
752782 evaluator = Engine (evaluate_step )
753783
0 commit comments