Skip to content

Commit 2458d10

Browse files
committed
support backward for distribute pir.
1 parent 8cae55d commit 2458d10

8 files changed

Lines changed: 122 additions & 16 deletions

File tree

paddle/fluid/pir/dialect/distributed/ir/dist_interface.h

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414
#pragma once
1515

16+
#include "paddle/fluid/pir/dialect/distributed/ir/dist_attribute.h"
1617
#include "paddle/pir/include/core/cast_utils.h"
1718
#include "paddle/pir/include/core/dll_decl.h"
1819
#include "paddle/pir/include/core/type.h"
@@ -25,24 +26,46 @@ class IR_API DistTypeInterface
2526
public:
2627
struct Concept {
2728
/// Defined these methods with the interface.
28-
explicit Concept(pir::Type (*local_type)(pir::Type))
29-
: local_type(local_type) {}
29+
explicit Concept(pir::Type (*local_type)(pir::Type),
30+
ProcessMeshAttribute (*process_mesh_attr)(pir::Type),
31+
TensorDistAttribute (*tensor_dist_attr)(pir::Type))
32+
: local_type(local_type),
33+
process_mesh_attr(process_mesh_attr),
34+
tensor_dist_attr(tensor_dist_attr) {}
3035
pir::Type (*local_type)(pir::Type);
36+
ProcessMeshAttribute (*process_mesh_attr)(pir::Type);
37+
TensorDistAttribute (*tensor_dist_attr)(pir::Type);
3138
};
3239

3340
template <class ConcreteType>
3441
struct Model : public Concept {
3542
static Type local_type(Type type) {
3643
return pir::cast<ConcreteType>(type).local_type();
3744
}
38-
Model() : Concept(local_type) {}
45+
static ProcessMeshAttribute process_mesh_attr(Type type) {
46+
return pir::cast<ConcreteType>(type).process_mesh_attr();
47+
}
48+
49+
static TensorDistAttribute tensor_dist_attr(Type type) {
50+
return pir::cast<ConcreteType>(type).tensor_dist_attr();
51+
}
52+
53+
Model() : Concept(local_type, process_mesh_attr, tensor_dist_attr) {}
3954
};
4055

4156
DistTypeInterface(pir::Type type, Concept *impl)
4257
: pir::TypeInterfaceBase<DistTypeInterface>(type), impl_(impl) {}
4358

4459
pir::Type local_type() { return impl_->local_type(*this); }
4560

61+
ProcessMeshAttribute process_mesh_attr() {
62+
return impl_->process_mesh_attr(*this);
63+
}
64+
65+
TensorDistAttribute tensor_dist_attr() {
66+
return impl_->tensor_dist_attr(*this);
67+
}
68+
4669
private:
4770
Concept *impl_;
4871
};

paddle/fluid/pir/dialect/distributed/ir/dist_op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "paddle/phi/core/dense_tensor.h"
2222
#include "paddle/phi/core/enforce.h"
2323
#include "paddle/pir/include/core/builtin_attribute.h"
24+
#include "paddle/pir/include/core/builtin_op.h"
2425
#include "paddle/pir/include/core/ir_context.h"
2526

2627
namespace paddle {
@@ -155,6 +156,7 @@ void ShardTensorOp::Build(pir::Builder& builder,
155156
tensor_dist_attr,
156157
local_shape);
157158
argument.AddOutput(out_dist_tensor_type);
159+
::pir::PassStopGradientsDefaultly(argument);
158160
}
159161

160162
void ReShardOp::VerifySig() {

paddle/fluid/pir/dialect/op_generator/op_infermeta_func_gen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -613,8 +613,8 @@ def GenDistBranch(args, op_info):
613613
ProcessMeshAttribute op_mesh;
614614
auto ctx = pir::IrContext::Instance();
615615
for(auto value : input_values) {{
616-
if (auto dist_type = value.type().dyn_cast<DistDenseTensorType>()) {{
617-
op_mesh = dist_type.process_mesh_attr();
616+
if (auto dist_interface = value.type().dyn_cast<DistTypeInterface>()) {{
617+
op_mesh = dist_interface.process_mesh_attr();
618618
break;
619619
}}
620620
}}"""

paddle/fluid/pir/dialect/operator/ir/manual_api.cc

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/pir/dialect/operator/ir/manual_api.h"
16+
#include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h"
1617
#include "paddle/fluid/pir/dialect/operator/ir/api_builder.h"
1718
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
1819
#include "paddle/fluid/pir/dialect/operator/ir/pd_api.h"
@@ -63,8 +64,17 @@ void set_parameter(const pir::Value& parameter, const std::string& name) {
6364
}
6465

6566
void shadow_output(const pir::Value& persist_value, const std::string& name) {
66-
ApiBuilder::Instance().GetBuilder()->Build<pir::ShadowOutputOp>(persist_value,
67-
name);
67+
auto& builder = ApiBuilder::Instance().GetBuilder();
68+
auto op = builder->Build<pir::ShadowOutputOp>(persist_value, name);
69+
if (auto dist_interface =
70+
persist_value.type().dyn_cast<DistTypeInterface>()) {
71+
op->set_attribute(
72+
kAttrOpDistAttr,
73+
OperationDistAttribute::get(builder->ir_context(),
74+
dist_interface.process_mesh_attr(),
75+
{dist_interface.tensor_dist_attr()},
76+
{}));
77+
}
6878
}
6979

7080
pir::Value embedding_grad(const pir::Value& x,

paddle/fluid/pybind/pir.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ using pir::Block;
118118
using pir::BlockArgument;
119119
using pir::BoolAttribute;
120120
using pir::CloneOptions;
121+
using pir::IrContext;
121122
using pir::IrMapping;
122123
using pir::IrParser;
123124
using pir::Operation;
@@ -223,6 +224,20 @@ std::string GetValueInfo(Value v) {
223224
return ss.str();
224225
}
225226

227+
Value GetOutputValueByName(const Program &program, const std::string &name) {
228+
auto &block = *program.block();
229+
pir::StrAttribute name_attr =
230+
pir::StrAttribute::get(IrContext::Instance(), name);
231+
for (auto &op : block) {
232+
if (op.isa<pir::ShadowOutputOp>()) {
233+
if (op.attribute("output_name") == name_attr) {
234+
return op.operand_source(0);
235+
}
236+
}
237+
}
238+
return nullptr;
239+
}
240+
226241
void BindProgram(py::module *m) {
227242
py::class_<Program, std::shared_ptr<Program>> program(
228243
*m, "Program", py::dynamic_attr(), R"DOC(
@@ -334,6 +349,10 @@ void BindProgram(py::module *m) {
334349
[](std::shared_ptr<Program> self, int64_t random_seed) {
335350
SetProgramInt64Attr(self, "random_seed", random_seed);
336351
})
352+
.def("get_output_value_by_name",
353+
[](Program &self, const std::string &name) {
354+
return GetOutputValueByName(self, name);
355+
})
337356
.def("num_ops", [](Program &self) { return self.num_ops(); });
338357
}
339358

python/paddle/distributed/auto_parallel/static/engine.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -638,11 +638,10 @@ def _parallel_pir(self, mode):
638638
dist_program = paddle.base.libpaddle.pir.apply_mix2dist_pass(
639639
mix_fw_program
640640
)
641-
642-
# TODO(winter-wang) Step 1.2: pir backward
643-
# with program_guard(dist_program):
644-
# params_grads = append_backward_pir(self._loss, parameter_list=self._parameter_list)
645-
641+
# Step 1.2: pir backward
642+
if mode != "predict" and self._loss:
643+
loss = dist_program.get_output_value_by_name(self._loss_names[0])
644+
paddle.autograd.ir_backward.append_backward(loss)
646645
# TODO(winter-wang) Step 1.3: adapot opt.minimize() for pir-auto-parallel
647646
# with program_guard(dist_program):
648647
# ptimizer_ops = self._optimizer.apply_gradients(params_grads)
@@ -767,6 +766,7 @@ def _build(self, mode):
767766
# self._process_dist_input_specs()
768767
outputs = self.program_helper.output_vars
769768
self._losses = self.program_helper.loss_vars
769+
self._loss_names = self.program_helper.loss_names
770770
metrics = self.program_helper.metric_vars
771771

772772
paddle.enable_static()

python/paddle/distributed/auto_parallel/static/helper.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __init__(self, layer, loss_func, metrics):
5858
self._label_vars = defaultdict(list)
5959
self._output_vars = defaultdict(list)
6060
self._loss_vars = defaultdict(list)
61+
self._loss_names = defaultdict(list)
6162
self._metric_vars = defaultdict(list)
6263

6364
# Consider ProxyLayer as not Paddle inner function because it contains
@@ -66,6 +67,12 @@ def __init__(self, layer, loss_func, metrics):
6667
inspect.getmodule(ProxyLayer).__name__ + ".ProxyLayer"
6768
)
6869

70+
@paddle.jit.not_to_static
71+
def append_loss_to_shadow_output(self, mode):
72+
name = paddle.utils.unique_name.generate('loss')
73+
paddle._pir_ops.set_persistable_value(self._loss_vars[mode], name)
74+
self._loss_names[mode] = name
75+
6976
def _train(self, inputs, labels):
7077
"""
7178
Train process of inner_layer with forward/loss/metric logic.
@@ -81,6 +88,10 @@ def _train(self, inputs, labels):
8188
# step 3. calculate loss if needed
8289
new_inputs = self._prepare(self.output_vars, labels)
8390
self._loss_vars[mode] = self.call_loss(new_inputs)
91+
if paddle.base.framework.get_flags("FLAGS_enable_pir_api")[
92+
"FLAGS_enable_pir_api"
93+
]:
94+
self.append_loss_to_shadow_output(mode)
8495

8596
# step 4. calculate metrics if needed
8697
self._metric_vars[mode] = self.call_metrics(new_inputs)
@@ -103,6 +114,10 @@ def _eval(self, inputs, labels):
103114
# step 3. calculate loss if needed
104115
new_inputs = self._prepare(self.output_vars, labels)
105116
self._loss_vars[mode] = self.call_loss(new_inputs)
117+
if paddle.base.framework.get_flags("FLAGS_enable_pir_api")[
118+
"FLAGS_enable_pir_api"
119+
]:
120+
self.append_loss_to_shadow_output(mode)
106121

107122
# step 4. calculate metrics if needed
108123
self._metric_vars[mode] = self.call_metrics(new_inputs)
@@ -180,6 +195,10 @@ def output_vars(self):
180195
def loss_vars(self):
181196
return self._loss_vars[self.mode]
182197

198+
@property
199+
def loss_names(self):
200+
return self._loss_names[self.mode]
201+
183202
@property
184203
def metric_vars(self):
185204
return self._metric_vars[self.mode]
@@ -521,6 +540,10 @@ def label_vars(self):
521540
def loss_vars(self):
522541
return to_list(self.proxy_layer.loss_vars)
523542

543+
@property
544+
def loss_names(self):
545+
return to_list(self.proxy_layer.loss_names)
546+
524547
@property
525548
def metric_vars(self):
526549
return to_list(self.proxy_layer.metric_vars)

test/auto_parallel/pir/test_to_static_pir_program.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ def test_to_static_program(self):
9797
main_program = dist_model._engine._pir_main_progs["eval"]
9898

9999
for op in main_program.global_block().ops:
100+
if op.num_results() == 0:
101+
continue
100102
tensor = op.result(0)
101103
if op.name() == 'pd_op.data':
102104
self.assertTrue(tensor.is_dist_dense_tensor_type())
@@ -128,9 +130,24 @@ def test_to_static_program(self):
128130

129131
relu_idx = 0
130132
matmul_idx = 0
131-
132-
for op in main_program.global_block().ops:
133+
matmul_grad_idx = 0
134+
ops = main_program.global_block().ops
135+
self.assertEqual(ops[-1].name(), "pd_op.matmul_grad")
136+
self.assertEqual(ops[-2].name(), "pd_op.relu_grad")
137+
self.assertEqual(ops[-3].name(), "pd_op.matmul_grad")
138+
self.assertEqual(ops[-4].name(), "pd_op.relu_grad")
139+
self.assertEqual(ops[-5].name(), "pd_op.subtract_grad")
140+
self.assertEqual(ops[-6].name(), "pd_op.square_grad")
141+
self.assertEqual(ops[-7].name(), "pd_op.mean_grad")
142+
143+
for op in ops:
144+
# skip shadow_output
145+
if op.num_results() == 0:
146+
continue
133147
tensor = op.result(0)
148+
# while tensor's stop_gradient is true, the corresponding grad tensor is initialized.
149+
if not tensor.initialized():
150+
continue
134151
self.assertTrue(tensor.is_dist_dense_tensor_type())
135152
self.assertEqual(tensor.dist_attr().process_mesh.shape, [2])
136153
self.assertEqual(
@@ -143,8 +160,6 @@ def test_to_static_program(self):
143160
elif op.name() == 'builtin.parameter':
144161
self.assertTrue(tensor.is_dense_tensor_type())
145162
self.assertTrue(tensor.is_dist_dense_tensor_type())
146-
self.assertTrue(tensor.has_one_use())
147-
148163
self.assertTrue(tensor.is_dist_dense_tensor_type())
149164
self.assertEqual(tensor.dist_attr().process_mesh.shape, [2])
150165
self.assertEqual(
@@ -189,6 +204,20 @@ def test_to_static_program(self):
189204
tensor._local_shape, [BATCH_SIZE, CLASS_NUM]
190205
)
191206
matmul_idx += 1
207+
if op.name() == 'pd_op.matmul_grad':
208+
if matmul_grad_idx == 0:
209+
self.assertEqual(tensor.dist_attr().dims_mapping, [-1, 0])
210+
self.assertEqual(tensor.dist_attr().partial_dims, set())
211+
self.assertEqual(
212+
tensor._local_shape, [BATCH_SIZE, CLASS_NUM]
213+
)
214+
elif matmul_grad_idx == 1:
215+
self.assertEqual(tensor.dist_attr().dims_mapping, [-1, 0])
216+
self.assertEqual(tensor.dist_attr().partial_dims, set())
217+
self.assertEqual(
218+
tensor._local_shape, [BATCH_SIZE, IMAGE_SIZE // 2]
219+
)
220+
matmul_grad_idx += 1
192221

193222
# dist_model.train()
194223
# for batch_id, (image, label) in enumerate(dist_loader()):

0 commit comments

Comments
 (0)