@@ -21,6 +21,7 @@ limitations under the License. */
2121#include " paddle/pten/core/utils/intrusive_ref_counter.h"
2222#include " paddle/pten/core/utils/type_info.h"
2323
24+ #include " paddle/fluid/memory/memory.h"
2425#include " paddle/fluid/platform/place.h"
2526#include " paddle/pten/core/allocator.h"
2627
@@ -35,14 +36,32 @@ class Storage : public intrusive_ref_counter<Storage> {
3536 Storage () = default ;
3637 Storage (const Storage&) = delete ;
3738
38- explicit Storage (Allocation&& data) : data_(std::move(data)) {}
39+ /* --------- shared_ptr<Allocation> -------- */
40+ // Initialize a Storage with unique Allocation
41+ explicit Storage (std::shared_ptr<paddle::memory::Allocation>&& data)
42+ : data_(std::move(data)) {}
3943
40- virtual ~Storage () = default ;
44+ // Initialize a Storage shareing Allocation with another storage
45+ explicit Storage (const std::shared_ptr<paddle::memory::Allocation>& data)
46+ : data_(data) {}
47+
48+ void * data () const {
49+ return data_ ? reinterpret_cast <void *>(
50+ reinterpret_cast <uintptr_t >(data_->ptr ()) + offset_)
51+ : nullptr ;
52+ }
53+
54+ const std::shared_ptr<paddle::memory::Allocation> data_shared () const {
55+ return data_;
56+ }
4157
42- // / \brief Get the mutable data pointer of the storage.
43- // / This function is set to inline to improve performance.
44- // / \return The mutable data pointer of the storage.
45- void * data () const noexcept { return data_.operator ->(); }
58+ virtual void ReallocShared (size_t n) {
59+ PADDLE_THROW (paddle::platform::errors::Unimplemented (
60+ " ReallocShared has not been overrided by the current Storage" ));
61+ }
62+ /* --------- shared_ptr<Allocation> -------- */
63+
64+ virtual ~Storage () = default ;
4665
4766 virtual void Clear () = 0;
4867
@@ -52,31 +71,47 @@ class Storage : public intrusive_ref_counter<Storage> {
5271 virtual void Realloc (size_t n) = 0;
5372
5473 protected:
55- Allocation data_;
74+ size_t offset_{0 };
75+ std::shared_ptr<paddle::memory::Allocation> data_;
5676};
5777
5878class TensorStorage : public Storage {
5979 public:
6080 using Place = paddle::platform::Place;
6181
6282 explicit TensorStorage (const std::shared_ptr<Allocator>& a) : alloc_(a) {}
83+
6384 TensorStorage (const std::shared_ptr<Allocator>& a, size_t size)
64- : Storage(Allocate(a, size)), alloc_(a), size_(size) {}
85+ : Storage(paddle::memory::AllocShared(a->place (), size)), alloc_(a) {
86+ size_ = data_->size ();
87+ }
88+
89+ void Clear () override {
90+ data_ = nullptr ;
91+ size_ = 0 ;
92+ offset_ = 0 ;
93+ }
94+
95+ void Realloc (size_t size) override ;
6596
6697 ~TensorStorage () = default ;
6798
6899 static const char * name () { return " TensorStorage" ; }
69100
70- void Realloc (size_t size) override ;
71-
72101 size_t size () const noexcept override { return size_; }
73102
74- void Clear () override {
75- data_.Clear ();
76- size_ = 0 ;
103+ const Place& place () const override {
104+ if (!data_ && !alloc_) {
105+ PADDLE_THROW (paddle::platform::errors::Unimplemented (
106+ " Unable to visit place: either data_ or alloc_ has to be initialized "
107+ " first." ));
108+ }
109+ if (data_) {
110+ return data_->place ();
111+ }
112+ return alloc_->place ();
77113 }
78114
79- const Place& place () const override { return data_.place (); }
80115 bool OwnsMemory () const noexcept override { return true ; }
81116 const std::shared_ptr<Allocator>& allocator () const noexcept {
82117 return alloc_;
0 commit comments