Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions paddle/fluid/imperative/layer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -336,11 +336,15 @@ void OpBase::InvokeBackwardHooks() {
}
}

void OpBase::RegisterBackwardHooks(const py::object& callable) {
void OpBase::RegisterBackwardHooks(const py::object& callable, bool front) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why need front?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trace.py append the release_op as the first grad hook, and we can get the op object in other hooks, maybe we can store the hooks using stack instead of vector ?

VLOG(3) << "Register backward hooks " << trace_id_;

// TODO(minqiyang): check the callable format
backward_hooks_.push_back(callable);
if (front) {
backward_hooks_.insert(backward_hooks_.begin(), callable);
} else {
backward_hooks_.push_back(callable);
}
}

void VarBase::RunBackward() {
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/imperative/layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ class PYBIND11_HIDDEN OpBase {
return grad_op_descs_[index]->Type();
}

void RegisterBackwardHooks(const py::object& callable);
void RegisterBackwardHooks(const py::object& callable, bool front = false);

void InvokeBackwardHooks();

Expand Down
111 changes: 24 additions & 87 deletions paddle/fluid/operators/distributed_ops/allreduce_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,91 +15,22 @@ limitations under the License. */
#include <future> // NOLINT
#include <ostream>

#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif
#include "paddle/fluid/operators/distributed_ops/allreduce_op.h"

namespace paddle {
namespace operators {

struct MutableDataFunctor {
MutableDataFunctor(void** data, framework::LoDTensor* tensor,
const platform::Place& place)
: data_(data), tensor_(tensor), place_(place) {}

template <typename T>
void apply() {
*data_ = tensor_->mutable_data<T>(place_);
}
class AllReduceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void** data_;
framework::LoDTensor* tensor_;
platform::Place place_;
};
void InferShape(framework::InferShapeContext* ctx) const override {}

class AllReduceOp : public framework::OperatorBase {
using OperatorBase::OperatorBase;

void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
PADDLE_ENFORCE(is_gpu_place(place),
"AllReduce op can run on gpu place only for now.");
#ifdef PADDLE_WITH_CUDA
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* ctx = pool.Get(place);
auto in_names = Inputs("X");
auto out_names = Outputs("Out");
PADDLE_ENFORCE_EQ(in_names.size(), 1, "Only support one input");
PADDLE_ENFORCE_EQ(out_names.size(), 1, "Only support one output");

auto* in = scope.FindVar(in_names[0]);
auto* out = scope.FindVar(out_names[0]);

PADDLE_ENFORCE(in->IsType<framework::LoDTensor>() ||
out->IsType<framework::LoDTensor>(),
"Only support allreduce LoDTensors");

int dtype = -1;
auto in_tensor = in->Get<framework::LoDTensor>();
dtype = platform::ToNCCLDataType(in_tensor.type());

int64_t numel = in_tensor.numel();
auto* sendbuff = in_tensor.data<void>();
auto* out_tensor = out->GetMutable<framework::LoDTensor>();
out_tensor->Resize(in_tensor.dims());
void* recvbuff = nullptr;
framework::VisitDataType(in_tensor.type(),
MutableDataFunctor(&recvbuff, out_tensor, place));

auto cuda_ctx = static_cast<platform::CUDADeviceContext*>(ctx);
auto* comm = cuda_ctx->nccl_comm();
// FIXME(typhoonzero): should use nccl stream here.
auto stream = cuda_ctx->stream();

int reduce_type = Attr<int>("reduce_type");
ncclRedOp_t red_type = ncclSum;
switch (reduce_type) {
case 0:
red_type = ncclSum;
break;
case 1:
red_type = ncclProd;
break;
case 2:
red_type = ncclMax;
break;
case 3:
red_type = ncclMin;
break;
}

PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, numel, static_cast<ncclDataType_t>(dtype), red_type,
comm, stream));
#endif
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
ctx.GetPlace());
}
};

Expand All @@ -110,6 +41,10 @@ class AllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("Out", "(Tensor) the result of allreduced.");
AddAttr<int>("reduce_type", "(int) determin the reduce type.")
.SetDefault(0);
AddAttr<bool>(
"sync_mode",
"(bool) whether to synchronize the CUDA stream after nccl call.")
.SetDefault(false);
AddComment(R"DOC(
***AllReduce Operator***

Expand All @@ -128,16 +63,18 @@ If input and output are the same variable, in-place allreduce will be used.
}
};

class AllReduceOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_WITHOUT_GRADIENT(allreduce, ops::AllReduceOp,
ops::AllReduceOpMaker);

REGISTER_OPERATOR(allreduce, ops::AllReduceOp,
paddle::framework::EmptyGradOpMaker, ops::AllReduceOpMaker,
ops::AllReduceOpShapeInference);
REGISTER_OP_CPU_KERNEL(
allreduce, ops::AllReduceOpKernel<plat::CPUDeviceContext, float>,
ops::AllReduceOpKernel<plat::CPUDeviceContext, double>,
ops::AllReduceOpKernel<plat::CPUDeviceContext, int>,
ops::AllReduceOpKernel<plat::CPUDeviceContext, int64_t>,
ops::AllReduceOpKernel<plat::CPUDeviceContext, plat::float16>);
25 changes: 25 additions & 0 deletions paddle/fluid/operators/distributed_ops/allreduce_op.cu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/distributed_ops/allreduce_op.h"

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_CUDA_KERNEL(
allreduce, ops::AllReduceOpKernel<plat::CUDADeviceContext, float>,
ops::AllReduceOpKernel<plat::CUDADeviceContext, double>,
ops::AllReduceOpKernel<plat::CUDADeviceContext, int>,
ops::AllReduceOpKernel<plat::CUDADeviceContext, int64_t>,
ops::AllReduceOpKernel<plat::CUDADeviceContext, plat::float16>);
87 changes: 87 additions & 0 deletions paddle/fluid/operators/distributed_ops/allreduce_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <algorithm>
#include <utility>
#include <vector>

#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"

#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
#include "paddle/fluid/platform/nccl_helper.h"
#endif

namespace paddle {
namespace operators {

template <typename DeviceContext, typename T>
class AllReduceOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto place = ctx.GetPlace();
PADDLE_ENFORCE(is_gpu_place(place),
"AllReduce op can run on gpu place only for now.");
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto in = ctx.Input<framework::Tensor>("X");
auto out = ctx.Output<framework::Tensor>("Out");
int dtype = platform::ToNCCLDataType(in->type());
int64_t numel = in->numel();
auto* sendbuff = in->data<void>();
out->Resize(in->dims());
void* recvbuff = out->mutable_data<T>(place);

auto* comm = dev_ctx.nccl_comm();
// FIXME(typhoonzero): should use nccl stream here.
auto stream = dev_ctx.stream();
PADDLE_ENFORCE_NOT_NULL(stream, "Should initialize NCCL firstly.");

int reduce_type = ctx.Attr<int>("reduce_type");
ncclRedOp_t red_type = ncclSum;
switch (reduce_type) {
case 0:
red_type = ncclSum;
break;
case 1:
red_type = ncclProd;
break;
case 2:
red_type = ncclMax;
break;
case 3:
red_type = ncclMin;
break;
}
VLOG(0) << "call allreduce with type: " << reduce_type;
PADDLE_ENFORCE(platform::dynload::ncclAllReduce(
sendbuff, recvbuff, numel, static_cast<ncclDataType_t>(dtype), red_type,
comm, stream));
if (ctx.Attr<bool>("sync_mode")) {
VLOG(0) << "sync allreduce...";
cudaError_t e_sync = cudaStreamSynchronize(stream);
if (e_sync != 0) {
LOG(FATAL) << "cudaStreamSynchronize " << cudaGetErrorString(e_sync);
}
}
#else
PADDLE_THROW("PaddlePaddle should compile with GPU.");
#endif
}
};

} // namespace operators
} // namespace paddle
8 changes: 5 additions & 3 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,11 @@ PYBIND11_MODULE(core, m) {
py::class_<imperative::OpBase, PyOpBase>(m, "OpBase", R"DOC()DOC")
.def(py::init<const std::string &>())
.def("register_backward_hooks",
[](imperative::OpBase &self, const py::object &callable) {
self.RegisterBackwardHooks(callable);
})
[](imperative::OpBase &self, const py::object &callable,
bool front = false) {
self.RegisterBackwardHooks(callable, front);
},
py::arg("callable"), py::arg("front") = false)
.def_property("_trace_id",
[](const imperative::OpBase &self) {
pybind11::gil_scoped_release release;
Expand Down
47 changes: 46 additions & 1 deletion python/paddle/fluid/dygraph/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import six

from .. import core
from . import layers
from .. import framework

from ..layers import collective

__all__ = ["prepare_context"]

Expand All @@ -21,9 +27,13 @@
__parallel_ctx__clz__ = None


def prepare_context(parallel_strategy, place):
def prepare_context(parallel_strategy):
global __parallel_ctx__clz__
assert __parallel_ctx__clz__ is None, "ParallelContext can only be initialized once."
assert framework.in_dygraph_mode(
) is True, "dygraph.parallel.prepare_context should be used with dygrahp mode."
place = framework._current_expected_place()
assert place is not None, "dygraph.parallel.prepare_context should be used in fluid.dygraph.guard(place) guard."

if isinstance(place, core.CUDAPlace):
__parallel_ctx__clz__ = core.NCCLParallelContext(parallel_strategy,
Expand Down Expand Up @@ -58,3 +68,38 @@ def dev_id(self):
@property
def current_endpoint(self):
return self._current_endpoint

@property
def trainer_endpoints(self):
return self._trainer_endpoints


class DataParallel(layers.Layer):
def __init__(self, layers):
super(DataParallel,
self).__init__(layers.full_name() + "_data_parallel")
self._layers = layers

def build_once(self, *inputs, **kwargs):
#TODO(Yancey1989): broadcast all the paramters
pass

def forward(self, *inputs, **kwargs):
def _collective_hook(iop):
op = framework._dygraph_tracer()._ops[iop._trace_id]
for k, v in six.iteritems(op.inputs):
for ivar in v:
g = ivar._grad_ivar()
if g:
g_var = framework.Variable(
block=self._helper.main_program.current_block(),
name=ivar._grad_name(),
stop_gradient=True,
ivar=g)
collective._allreduce(g_var, g_var, sync_mode=True)

outs = self._layers(*inputs, **kwargs)
for _, op in six.iteritems(framework._dygraph_tracer()._ops):
# hook collective ops
op.iop.register_backward_hooks(_collective_hook, front=True)
return outs
5 changes: 3 additions & 2 deletions python/paddle/fluid/layers/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ..layer_helper import LayerHelper, unique_name


def _allreduce(x, out=None, reduce_type="sum"):
def _allreduce(x, out=None, reduce_type="sum", sync_mode=False):
helper = LayerHelper("allreduce", **locals())
# Convert string reduce type to op int type
red_typ_int = 0
Expand All @@ -43,5 +43,6 @@ def _allreduce(x, out=None, reduce_type="sum"):
type='allreduce',
inputs={'X': [x]},
outputs={'Out': [out]},
attrs={"reduce_type": red_typ_int})
attrs={"reduce_type": red_typ_int,
"sync_mode": sync_mode})
return out
1 change: 1 addition & 0 deletions python/paddle/fluid/tests/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ endif(NOT WITH_DISTRIBUTE)

if (NOT ${WITH_GPU})
LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op)
LIST(REMOVE_ITEM TEST_OPS test_parallel_dygraph_mnist) # TODO(Yancey1989): parallel dygraph support CPU device in future
elseif(${CUDNN_VERSION} VERSION_LESS 7100)
LIST(REMOVE_ITEM TEST_OPS test_conv2d_fusion_op)
endif()
Expand Down
Loading