-
Notifications
You must be signed in to change notification settings - Fork 5.9k
Add COWPtr and its unittest #7240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,94 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #pragma once | ||
| #include <memory> | ||
| #include <thread> | ||
|
|
||
| namespace paddle { | ||
| namespace framework { | ||
| namespace details { | ||
|
|
||
| // Change it to thread safe flags if needed. | ||
| class ThreadUnsafeOwnershipFlags { | ||
| public: | ||
| ThreadUnsafeOwnershipFlags(bool flag) : flag_(flag) {} | ||
|
|
||
| ThreadUnsafeOwnershipFlags(const ThreadUnsafeOwnershipFlags& o) = delete; | ||
| ThreadUnsafeOwnershipFlags& operator=(const ThreadUnsafeOwnershipFlags& o) = | ||
| delete; | ||
| ThreadUnsafeOwnershipFlags(ThreadUnsafeOwnershipFlags&& o) = default; | ||
|
|
||
| void SetOwnership(bool flag) { flag_ = flag; } | ||
|
|
||
| template <typename Callback> | ||
| void AcquireOwnershipOnce(Callback acquire) { | ||
| if (!flag_) { | ||
| acquire(); | ||
| flag_ = true; | ||
| } | ||
| } | ||
|
|
||
| private: | ||
| bool flag_; | ||
| }; | ||
|
|
||
| // Copy On Write pointer. | ||
|
||
| // It will hold a T* pointer, and only copy once when `MutableData` is invoked. | ||
| // | ||
| // The template parameter OwnershipFlags should have: | ||
| // * a constructor takes a bool. True if own. | ||
| // * SetOwnership(bool flag). | ||
| // * AcquireOwnershipOnce(Callback). It will invoke the callback if it is not | ||
| // owned. | ||
| template <typename T, typename OwnershipFlags = ThreadUnsafeOwnershipFlags> | ||
| class COWPtr { | ||
| public: | ||
| // Ctor from raw pointer. | ||
| explicit COWPtr(T* ptr) : payload_(ptr), ownership_{true} {} | ||
|
|
||
| // Move methods. Steal ownership from origin | ||
| COWPtr(COWPtr&& o) | ||
|
||
| : payload_(o.payload_), ownership_{std::move(o.ownership_)} {} | ||
| COWPtr& operator=(COWPtr&& origin) = default; | ||
|
|
||
| // Copy methods. Not own payload | ||
| COWPtr(const COWPtr& o) : payload_(o.payload_), ownership_{false} {} | ||
| COWPtr& operator=(const COWPtr& o) { | ||
| payload_ = o.payload_; | ||
| ownership_.SetOwnership(false); | ||
| return *this; | ||
| } | ||
|
|
||
| const T& Data() const { return *payload_; } | ||
|
|
||
| T* MutableData() { | ||
| ownership_.AcquireOwnershipOnce( | ||
| [this] { payload_.reset(new T(*payload_)); }); | ||
| return payload_.get(); | ||
| } | ||
|
|
||
| void Reset() { | ||
| ownership_.AcquireOwnershipOnce([this] { payload_.reset(); }); | ||
| payload_.reset(new T()); | ||
| } | ||
|
||
|
|
||
| private: | ||
| std::shared_ptr<T> payload_; | ||
| OwnershipFlags ownership_; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add comments for
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
| }; | ||
|
|
||
| } // namespace details | ||
| } // namespace framework | ||
| } // namespace paddle | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); | ||
| you may not use this file except in compliance with the License. | ||
| You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software | ||
| distributed under the License is distributed on an "AS IS" BASIS, | ||
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| See the License for the specific language governing permissions and | ||
| limitations under the License. */ | ||
|
|
||
| #include "paddle/framework/details/cow_ptr.h" | ||
| #include "gtest/gtest.h" | ||
|
|
||
| namespace paddle { | ||
| namespace framework { | ||
| namespace details { | ||
|
|
||
| TEST(COWPtr, all) { | ||
| COWPtr<int> ptr(new int{0}); | ||
| ASSERT_EQ(ptr.Data(), 0); | ||
| COWPtr<int> ptr2 = ptr; | ||
| ASSERT_EQ(ptr2.Data(), 0); | ||
| ASSERT_EQ(&ptr2.Data(), &ptr.Data()); | ||
| *ptr2.MutableData() = 10; | ||
| ASSERT_EQ(ptr.Data(), 0); | ||
| ASSERT_EQ(ptr2.Data(), 10); | ||
|
|
||
| auto ptr_before = ptr2.MutableData(); | ||
| ptr2.Reset(); | ||
| ASSERT_NE(ptr2.MutableData(), ptr_before); | ||
| } | ||
|
|
||
| } // namespace details | ||
| } // namespace framework | ||
| } // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use the Maroc
DISABLE_COPY_AND_ASSIGNin https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/platform/macros.h#L19There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are not exactly same. OwnershipFlag just disable copy constructor, but enable move constructor.