Skip to content

Commit 0a904fb

Browse files
authored
6124-add-training-attribute-check (#6132)
Fixes #6124 . ### Description When running the inference with torchscript wrapped TensorRT models, the evaluator would give an error. This is caused by the `with engine.mode()` code run the `training` method of `engine.network` without checking. In this PR, an attribute check has been added to cover this issue. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: binliu <[email protected]>
1 parent a8302ec commit 0a904fb

2 files changed

Lines changed: 11 additions & 8 deletions

File tree

monai/engines/evaluator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,6 @@ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Ten
295295

296296
# put iteration outputs into engine.state
297297
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}
298-
299298
# execute forward computation
300299
with engine.mode(engine.network):
301300
if engine.amp:

monai/networks/utils.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -375,17 +375,19 @@ def eval_mode(*nets: nn.Module):
375375
print(p(t).sum().backward()) # will correctly raise an exception as gradients are calculated
376376
"""
377377

378-
# Get original state of network(s)
379-
training = [n for n in nets if n.training]
378+
# Get original state of network(s).
379+
# Check the training attribute in case it's TensorRT based models which don't have this attribute.
380+
training = [n for n in nets if hasattr(n, "training") and n.training]
380381

381382
try:
382383
# set to eval mode
383384
with torch.no_grad():
384-
yield [n.eval() for n in nets]
385+
yield [n.eval() if hasattr(n, "eval") else n for n in nets]
385386
finally:
386387
# Return required networks to training
387388
for n in training:
388-
n.train()
389+
if hasattr(n, "train"):
390+
n.train()
389391

390392

391393
@contextmanager
@@ -410,16 +412,18 @@ def train_mode(*nets: nn.Module):
410412
"""
411413

412414
# Get original state of network(s)
413-
eval_list = [n for n in nets if not n.training]
415+
# Check the training attribute in case it's TensorRT based models which don't have this attribute.
416+
eval_list = [n for n in nets if hasattr(n, "training") and (not n.training)]
414417

415418
try:
416419
# set to train mode
417420
with torch.set_grad_enabled(True):
418-
yield [n.train() for n in nets]
421+
yield [n.train() if hasattr(n, "train") else n for n in nets]
419422
finally:
420423
# Return required networks to eval_list
421424
for n in eval_list:
422-
n.eval()
425+
if hasattr(n, "eval"):
426+
n.eval()
423427

424428

425429
def get_state_dict(obj: torch.nn.Module | Mapping):

0 commit comments

Comments
 (0)