Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions trl/models/activation_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ def unpack_tensor_single_stream(unpack_tensor_id: int) -> torch.Tensor:
verify_sufficient_virtual_memory()

self.is_first_backward_call = False
self.is_first_forward_call = True

if unpack_tensor_id not in self.tracker:
raise ValueError(f"Untracked tensor with id {unpack_tensor_id}")
Expand All @@ -231,6 +230,8 @@ def unpack_tensor_single_stream(unpack_tensor_id: int) -> torch.Tensor:

# clear tensor from tracking
del self.tracker[unpack_tensor_id]
if len(self.tracker) == 0:
self.is_first_forward_call = True
return maybe_accelerator_tensor

def unpack_tensor_with_streams(unpack_tensor_id: int) -> torch.Tensor:
Expand All @@ -254,7 +255,6 @@ def wait_and_del_remaining_references() -> None:
verify_sufficient_virtual_memory()

self.is_first_backward_call = False
self.is_first_forward_call = True

if unpack_tensor_id not in self.tracker:
raise ValueError(f"untracked tensor with id {unpack_tensor_id}")
Expand Down Expand Up @@ -359,6 +359,8 @@ def hook(outputs, inputs):

# clear tensor from tracking
del self.tracker[unpack_tensor_id]
if len(self.tracker) == 0:
self.is_first_forward_call = True
return maybe_accelerator_tensor

unpack_tensor = unpack_tensor_with_streams if self.use_streams else unpack_tensor_single_stream
Expand Down
Loading