Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion monai/engines/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,6 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten

# put iteration outputs into engine.state
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}

# execute forward computation
with engine.mode(engine.network):
if engine.amp:
Expand Down
18 changes: 11 additions & 7 deletions monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,17 +375,19 @@ def eval_mode(*nets: nn.Module):
print(p(t).sum().backward()) # will correctly raise an exception as gradients are calculated
"""

# Get original state of network(s)
training = [n for n in nets if n.training]
# Get original state of network(s).
# Check the training attribute in case it's TensorRT based models which don't have this attribute.
training = [n for n in nets if hasattr(n, "training") and n.training]

try:
# set to eval mode
with torch.no_grad():
yield [n.eval() for n in nets]
yield [n.eval() if hasattr(n, "eval") else n for n in nets]
finally:
# Return required networks to training
for n in training:
Comment thread
wyli marked this conversation as resolved.
n.train()
if hasattr(n, "train"):
n.train()


@contextmanager
Expand All @@ -410,16 +412,18 @@ def train_mode(*nets: nn.Module):
"""

# Get original state of network(s)
eval_list = [n for n in nets if not n.training]
# Check the training attribute in case it's TensorRT based models which don't have this attribute.
eval_list = [n for n in nets if hasattr(n, "training") and (not n.training)]

try:
# set to train mode
with torch.set_grad_enabled(True):
yield [n.train() for n in nets]
yield [n.train() if hasattr(n, "train") else n for n in nets if hasattr(n, "train")]
Comment thread
wyli marked this conversation as resolved.
Outdated
finally:
# Return required networks to eval_list
for n in eval_list:
n.eval()
if hasattr(n, "eval"):
n.eval()


def get_state_dict(obj: torch.nn.Module | Mapping):
Expand Down