Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions paddle/fluid/eager/api/utils/global_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,10 @@ bool Controller::UseLayoutAutoTune() {
return use_autotune;
}

void Controller::SetIsInBackward(bool is_in_backward) {
is_in_backward_ = is_in_backward;
}

bool Controller::GetIsInBackward() const { return is_in_backward_; }

} // namespace egr
13 changes: 13 additions & 0 deletions paddle/fluid/eager/api/utils/global_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ class Controller {
return force_sequential_nodes_;
}

TEST_API void SetIsInBackward(bool is_in_backward);
TEST_API bool GetIsInBackward() const;

private:
Controller() = default;
static Controller* controller_;
Expand All @@ -145,7 +148,17 @@ class Controller {
custom_edges_slot_map_;
std::vector<std::shared_ptr<VoidHook>> final_backward_hooks_;
std::queue<GradNodeBase*> force_sequential_nodes_;
bool is_in_backward_{false};
DISABLE_COPY_AND_ASSIGN(Controller);
};

class EagerBackwardStateGuard {
public:
EagerBackwardStateGuard() { Controller::GetInstance().SetIsInBackward(true); }

~EagerBackwardStateGuard() {
Controller::GetInstance().SetIsInBackward(false);
}
};

} // namespace egr
1 change: 1 addition & 0 deletions paddle/fluid/eager/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ std::vector<paddle::Tensor> RunBackward(
const std::vector<paddle::Tensor>& no_grad_vars = {}) {
VLOG(3) << "Start Backward";

egr::EagerBackwardStateGuard guard;
auto place = egr::Controller::Instance().GetExpectedPlace();

// *Gradient Hook should happen at node-level
Expand Down
14 changes: 14 additions & 0 deletions paddle/fluid/pybind/eager_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1350,6 +1350,16 @@ static PyObject* eager_api_set_master_grads(PyObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL
}

PyObject* eager__is_run_in_backward(PyObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY

return ToPyObject(egr::Controller::Instance().GetIsInBackward());

EAGER_CATCH_AND_THROW_RETURN_NULL
}

PyMethodDef variable_functions[] = { // NOLINT
// TODO(jiabin): Remove scale when we have final state tests
{"scale",
Expand Down Expand Up @@ -1423,6 +1433,10 @@ PyMethodDef variable_functions[] = { // NOLINT
(PyCFunction)(void (*)())eager_api_set_master_grads,
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_is_run_in_backward",
(PyCFunction)(void (*)())eager__is_run_in_backward,
METH_VARARGS | METH_KEYWORDS,
nullptr},
/**sparse functions**/
#if defined(PADDLE_WITH_CUDA)
{"async_read",
Expand Down