-
Notifications
You must be signed in to change notification settings - Fork 2.4k
[misc] fix: fix DataProto __getstate__ bug #2962
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request aims to fix a bug in DataProto.__getstate__ where in-place consolidation of self.batch caused issues with subsequent modifications. The approach of creating a temporary consolidated variable is correct. However, the current implementation introduces a critical bug: an UnboundLocalError will occur if the consolidation logic is skipped, because the variable passed to torch.save would not be defined. I've provided a comment with a suggested fix to handle this case correctly.
verl/protocol.py
Outdated
| if version.parse(tensordict.__version__) >= version.parse("0.5.0") and self.batch is not None: | ||
| self.batch = self.batch.contiguous() | ||
| self.batch = self.batch.consolidate() | ||
| torch.save(self.batch, buffer) | ||
| batch_consolidated = self.batch.consolidate() | ||
| torch.save(batch_consolidated, buffer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change introduces a potential UnboundLocalError. If the condition on line 334 is false (e.g., self.batch is None or the tensordict version is older), the if block is skipped. In that case, batch_consolidated is never defined, but it's used in torch.save on line 337, which will cause a crash.
To fix this, you should handle the case where the condition is false. A good approach is to move the torch.save call into an if/else block to ensure a value is always saved, preserving the original behavior for the case where consolidation doesn't happen.
| if version.parse(tensordict.__version__) >= version.parse("0.5.0") and self.batch is not None: | |
| self.batch = self.batch.contiguous() | |
| self.batch = self.batch.consolidate() | |
| torch.save(self.batch, buffer) | |
| batch_consolidated = self.batch.consolidate() | |
| torch.save(batch_consolidated, buffer) | |
| if version.parse(tensordict.__version__) >= version.parse("0.5.0") and self.batch is not None: | |
| self.batch = self.batch.contiguous() | |
| batch_consolidated = self.batch.consolidate() | |
| torch.save(batch_consolidated, buffer) | |
| else: | |
| torch.save(self.batch, buffer) |
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses a significant bug in DataProto.__getstate__ by preventing the in-place modification of self.batch with a consolidated version, which previously caused issues with subsequent operations. The fix correctly uses a temporary variable for the consolidated batch during serialization. However, this change introduces a critical UnboundLocalError because the temporary variable is not defined if the consolidation logic is skipped. I have provided a code suggestion to resolve this by ensuring the variable passed to torch.save is always initialized.
| buffer = io.BytesIO() | ||
| if version.parse(tensordict.__version__) >= version.parse("0.5.0") and self.batch is not None: | ||
| self.batch = self.batch.contiguous() | ||
| self.batch = self.batch.consolidate() | ||
| torch.save(self.batch, buffer) | ||
| batch_consolidated = self.batch.consolidate() | ||
| torch.save(batch_consolidated, buffer) | ||
| buffer_bytes = buffer.getvalue() | ||
| return buffer_bytes, self.non_tensor_batch, self.meta_info |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change introduces a potential UnboundLocalError. The variable batch_consolidated is only defined within the if block on line 336. If the condition on line 334 is false (e.g., self.batch is None or an older tensordict version is used), torch.save on line 337 will be called with an undefined variable, causing a crash.
To fix this, you should ensure the variable passed to torch.save is always defined, regardless of the if condition's outcome.
| buffer = io.BytesIO() | |
| if version.parse(tensordict.__version__) >= version.parse("0.5.0") and self.batch is not None: | |
| self.batch = self.batch.contiguous() | |
| self.batch = self.batch.consolidate() | |
| torch.save(self.batch, buffer) | |
| batch_consolidated = self.batch.consolidate() | |
| torch.save(batch_consolidated, buffer) | |
| buffer_bytes = buffer.getvalue() | |
| return buffer_bytes, self.non_tensor_batch, self.meta_info | |
| buffer = io.BytesIO() | |
| batch_to_save = self.batch | |
| if version.parse(tensordict.__version__) >= version.parse("0.5.0") and self.batch is not None: | |
| self.batch = self.batch.contiguous() | |
| batch_to_save = self.batch.consolidate() | |
| torch.save(batch_to_save, buffer) | |
| buffer_bytes = buffer.getvalue() | |
| return buffer_bytes, self.non_tensor_batch, self.meta_info |
|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
The code changes fix a severe bug in DataProto that caused modification to be incorrect. The changes isolate the tensordict consolidation to a local variable, ensuring the DataProto instance's state remains unchanged.
verl/protocol.py
Outdated
| batch = self.batch.contiguous() | ||
| batch_consolidated = batch.consolidate() | ||
| else: | ||
| batch_consolidated = self.batch | ||
| torch.save(batch_consolidated, buffer) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable name batch_consolidated is slightly misleading, as the else branch assigns self.batch to it, which is not consolidated. This could be confusing for future maintainers. For improved clarity, use a more neutral name like batch_to_save. Additionally, the intermediate batch variable can be removed by chaining the .contiguous().consolidate() calls.
batch_to_save = self.batch.contiguous().consolidate()
else:
batch_to_save = self.batch
torch.save(batch_to_save, buffer)|
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses a critical bug in DataProto.__getstate__ where the object's state was mutated during serialization. The fix correctly avoids this side effect by using a temporary variable for the batch to be saved, ensuring self.batch is not modified. This prevents potential hard-to-debug issues with object state. I've also suggested a performance improvement to cache the tensordict version check, as it's currently performed on every call to __getstate__.
| self.batch = self.batch.contiguous() | ||
| self.batch = self.batch.consolidate() | ||
| torch.save(self.batch, buffer) | ||
| batch_to_save = self.batch.contiguous().consolidate() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line is executed based on a condition on line 334 that repeatedly parses the tensordict version on every __getstate__ call. This is inefficient as __getstate__ can be a hot path during serialization. For better performance, the version check should be performed once at module startup and the result cached in a constant. For example:
_TENSORDICT_V0_5_0_OR_GREATER = version.parse(tensordict.__version__) >= version.parse("0.5.0")This constant can then be used in the if condition, avoiding repeated parsing.
What does this PR do?
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)