Skip to content

Commit c1603ce

Browse files
committed
Fix error when loading models with enable_ddp=True
1 parent 47ca73f commit c1603ce

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

d3rlpy/torch_utility.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,8 @@ def save(self, f: BinaryIO) -> None:
430430
def load(self, f: BinaryIO) -> None:
431431
chkpt = torch.load(f, map_location=map_location(self._device))
432432
for k, v in self._modules.items():
433+
if isinstance(v, nn.Module):
434+
v = unwrap_ddp_model(v)
433435
v.load_state_dict(chkpt[k])
434436

435437
@property

0 commit comments

Comments
 (0)