Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,24 @@ def _load_micro_batch_impl(self, inputs, micro_step):
else:
output.append(None)
return tuple(output)

elif isinstance(inputs, dict):
output_dict = {}
for key, data in inputs.items():
if isinstance(data, list):
assert len(data) == self._acc_steps, (
f"length of data should be {self._acc_steps}, but it is {len(data)}"
)
output_dict[key] = (
data[micro_step].detach()
if data[micro_step] is not None
else None
)
elif data is not None:
self._check_data_valid(data)
output_dict[key] = data[begin:end, :].detach()
else:
output_dict[key] = None
return output_dict
elif isinstance(inputs, list):
assert len(inputs) == self._acc_steps, (
f"length of data should be {self._acc_steps}, but it is {len(inputs)}"
Expand Down Expand Up @@ -264,6 +281,8 @@ def __init__(self, layers, hcg, strategy):
self._hcg.get_moe_sharding_parallel_world_size() > 1
)

self.use_dict_in_pp = True

self.total_loss = None

self.micro_batch_size = self._strategy.pipeline_configs[
Expand Down Expand Up @@ -1306,6 +1325,9 @@ def _check_micro_batch_data_valid(self, micro_batch_data):
if isinstance(micro_batch_data, (tuple, list)):
for data in micro_batch_data:
self._check_micro_batch_data_valid(data)
elif isinstance(micro_batch_data, dict):
for value in micro_batch_data.values():
self._check_micro_batch_data_valid(value)
elif micro_batch_data is not None:
assert isinstance(micro_batch_data, paddle.Tensor)

Expand Down Expand Up @@ -3482,16 +3504,30 @@ def dict_to_tuple_helper(output_tensor):


def convert_tensor_dict_to_tuple(output_tensor_dict):
output_tensor = []
for key, tensor in output_tensor_dict.items():
tensor.key = key
if isinstance(tensor, (list, tuple)):
for idx, t in enumerate(tensor):
t.key = key + " " + str(idx)
output_tensor.append(t)
else: # single tensor
tensor.key = key
output_tensor.append(tensor)

return tuple(output_tensor_dict.values())
return tuple(output_tensor)


def convert_tensor_tuple_to_dict(input_tensor_tuple):
input_tensor_dict = {}
for tensor in input_tensor_tuple:
key = tensor.key
input_tensor_dict[key] = tensor
if " " in key:
real_key, _ = key.split(" ")
if real_key in input_tensor_dict.keys():
input_tensor_dict[real_key].append(tensor)
else:
input_tensor_dict[real_key] = [tensor]
else:
input_tensor_dict[key] = tensor
delattr(tensor, "key")
return input_tensor_dict
116 changes: 95 additions & 21 deletions test/collective/fleet/hybrid_parallel_pp_send_recv_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __len__(self):
return self.num_samples


class LinearPipe(nn.Linear):
class FirstLinearPipe(nn.Linear):
def __init__(
self,
in_features,
Expand All @@ -70,26 +70,93 @@ def __init__(
self.use_dict = use_dict

def forward(self, input):
if isinstance(input, list):
input = input[0]
if self.use_dict:
if isinstance(input, dict):
input = input['x']
x = paddle.matmul(input, self.weight)
return {"x": x}
y0 = 2 * x
y1 = 2 * x
return {"x": x, "y": [y0, y1]}
else:
return paddle.matmul(input, self.weight)
x = paddle.matmul(input, self.weight)
y0 = 2 * x
y1 = 2 * x
return (x, y0, y1)

def build_schedule_node(self):
return ScheduleNode(self.forward)


class SecondLinearPipe(nn.Linear):
def __init__(
self,
in_features,
out_features,
weight_attr=None,
bias_attr=None,
name=None,
use_dict=False,
):
super().__init__(
in_features, out_features, weight_attr, bias_attr, name
)
self.use_dict = use_dict

def forward(self, input):
if self.use_dict:
if isinstance(input, dict):
y0 = input['y'][0]
y1 = input['y'][1]
input = input['x']
x = paddle.matmul(input, self.weight)
return {"x": x, "y": [y0, y1]}
else:
x = paddle.matmul(input[0], self.weight)
y0 = input[1]
y1 = input[2]
return (x, y0, y1)

def build_schedule_node(self):
return ScheduleNode(self.forward)


class ThirdLinearPipe(nn.Linear):
def __init__(
self,
in_features,
out_features,
weight_attr=None,
bias_attr=None,
name=None,
use_dict=False,
):
super().__init__(
in_features, out_features, weight_attr, bias_attr, name
)
self.use_dict = use_dict

def forward(self, input):
if self.use_dict:
if isinstance(input, dict):
x = input['x']
y0, y1 = input['y']
out = paddle.matmul(x + y0 + y1, self.weight)
return {"out": out}
else:
x = input[0]
y0, y1 = input[1], input[2]
return paddle.matmul(x + y0 + y1, self.weight)

def build_schedule_node(self):
return ScheduleNode(self.forward)


class CrossEntropyLossPipe(nn.loss.CrossEntropyLoss):
def forward(self, logits, label):
if isinstance(logits, list):
logits = logits[0]
if isinstance(logits, dict):
logits = logits["x"]
logits = logits["out"]
if isinstance(label, dict):
label = label["label"]
return super().forward(logits, label)

def build_schedule_node(self):
Expand All @@ -115,13 +182,25 @@ class SimpleNetPipeDesc(PipelineLayer):
def __init__(self, **kwargs):
decs = [
LayerDesc(
LinearPipe, 5, 5, bias_attr=False, use_dict=kwargs["use_dict"]
FirstLinearPipe,
5,
5,
bias_attr=False,
use_dict=kwargs["use_dict"],
),
LayerDesc(
LinearPipe, 5, 5, bias_attr=False, use_dict=kwargs["use_dict"]
SecondLinearPipe,
5,
5,
bias_attr=False,
use_dict=kwargs["use_dict"],
),
LayerDesc(
LinearPipe, 5, 5, bias_attr=False, use_dict=kwargs["use_dict"]
ThirdLinearPipe,
5,
5,
bias_attr=False,
use_dict=kwargs["use_dict"],
),
]
kwargs.pop("use_dict")
Expand Down Expand Up @@ -219,19 +298,14 @@ def test_pp_model(self):
if i >= 5:
return True

loss_a = model_a(img, label)
loss_a.backward()
optimizer_a.step()
optimizer_a.clear_grad()
scheduler_a.step()

loss_b = model_b.train_batch([img, label], optimizer_b, scheduler_b)

loss_c = model_c.train_batch([img, label], optimizer_c, scheduler_c)

np.testing.assert_allclose(
loss_a.numpy(), loss_b.numpy(), rtol=5e-5
loss_c = model_c.train_batch(
[{"x": img, "z": None}, {"label": label}],
optimizer_c,
scheduler_c,
)

np.testing.assert_equal(loss_b.numpy(), loss_c.numpy())


Expand Down