Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions trl/models/activation_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ def unpack_tensor_single_stream(unpack_tensor_id: int) -> torch.Tensor:
verify_sufficient_virtual_memory()

self.is_first_backward_call = False
self.is_first_forward_call = True

if unpack_tensor_id not in self.tracker:
raise ValueError(f"Untracked tensor with id {unpack_tensor_id}")
Expand All @@ -231,6 +230,9 @@ def unpack_tensor_single_stream(unpack_tensor_id: int) -> torch.Tensor:

# clear tensor from tracking
del self.tracker[unpack_tensor_id]
# Only set is_first_forward_call to True when all tensors have been unpacked
if len(self.tracker) == 0:
self.is_first_forward_call = True
return maybe_accelerator_tensor

def unpack_tensor_with_streams(unpack_tensor_id: int) -> torch.Tensor:
Expand All @@ -254,7 +256,6 @@ def wait_and_del_remaining_references() -> None:
verify_sufficient_virtual_memory()

self.is_first_backward_call = False
self.is_first_forward_call = True

if unpack_tensor_id not in self.tracker:
raise ValueError(f"untracked tensor with id {unpack_tensor_id}")
Expand Down Expand Up @@ -359,6 +360,9 @@ def hook(outputs, inputs):

# clear tensor from tracking
del self.tracker[unpack_tensor_id]
# Only set is_first_forward_call to True when all tensors have been unpacked
if len(self.tracker) == 0:
self.is_first_forward_call = True
return maybe_accelerator_tensor

unpack_tensor = unpack_tensor_with_streams if self.use_streams else unpack_tensor_single_stream
Expand Down
Loading