Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions paddle/framework/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,19 @@ limitations under the License. */

namespace paddle {

namespace framework {
namespace pybind {
namespace details {
template <bool less, size_t i, typename... args>
struct CastToPyBufferImpl;
}
} // namespace pybind

namespace framework {

class Tensor {
public:
template <bool less, size_t i, typename... args>
friend struct details::CastToPyBufferImpl;
friend struct pybind::details::CastToPyBufferImpl;

template <typename T, size_t D, int MajorType, typename IndexType>
friend struct EigenTensor;
Expand Down
15 changes: 7 additions & 8 deletions paddle/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,7 @@ limitations under the License. */
namespace py = pybind11;

namespace paddle {
namespace framework {

using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using LoD = framework::LoD;

namespace pybind {
static size_t UniqueIntegerGenerator() {
static std::atomic<size_t> generator;
return generator.fetch_add(1);
Expand All @@ -56,14 +51,18 @@ bool IsCompileGPU() {
PYBIND11_PLUGIN(core) {
py::module m("core", "C++ core of PaddlePaddle");

// using framework in this function. Since it is inside a function, it will
// not cause namespace pollution.
using namespace paddle::framework; // NOLINT

py::class_<Tensor>(m, "Tensor", py::buffer_protocol())
.def_buffer(
[](Tensor &self) -> py::buffer_info { return CastToPyBuffer(self); })
.def("get_dims",
[](const Tensor &self) { return vectorize(self.dims()); })
.def("set_dims",
[](Tensor &self, const std::vector<int64_t> &dim) {
self.Resize(make_ddim(dim));
self.Resize(framework::make_ddim(dim));
})
.def("alloc_float",
[](Tensor &self, paddle::platform::GPUPlace &place) {
Expand Down Expand Up @@ -317,5 +316,5 @@ All parameter, weight, gradient are variables in Paddle.

return m.ptr();
}
} // namespace framework
} // namespace pybind
} // namespace paddle
2 changes: 1 addition & 1 deletion paddle/pybind/tensor_py.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace py = pybind11;

namespace paddle {

namespace framework {
namespace pybind {

namespace details {

Expand Down