diff --git a/python/paddle/distributed/passes/allreduce_matmul_grad_overlapping.py b/python/paddle/distributed/passes/allreduce_matmul_grad_overlapping.py index 89e6c20ad03c97..e1e4514b60d24d 100644 --- a/python/paddle/distributed/passes/allreduce_matmul_grad_overlapping.py +++ b/python/paddle/distributed/passes/allreduce_matmul_grad_overlapping.py @@ -17,10 +17,9 @@ from ..auto_parallel.static.utils import ( get_logger, - naive_set_dist_op_attr_for_program_by_mesh_and_mapping, ) from .pass_base import PassBase, register_pass -from .pass_utils import AutoParallelStreamType +from .pass_utils import split_matmul_grad_to_matmul logger = get_logger(logging.INFO) @@ -84,44 +83,6 @@ def _get_all_matmul_grad_and_allreduce_pairs(self, block): matmul_grad_id_to_allreduce_id[i] = j return matmul_grad_id_to_allreduce_id - def _insert_reshape_op(self, block, index, x, shape, op_role, out=None): - var_x = block.var(x[0]) - x_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(var_x) - - if out is None: - out = block.create_var( - name=f"{x[0]}@reshape.out", - dtype=var_x.dtype, - persistable=False, - ) - self.dist_context.set_tensor_dist_attr_for_program(out, x_dist_attr) - - x_shape = block.create_var( - name=f"{x[0]}@reshape.xshape", dtype=var_x.dtype - ) - self.dist_context.set_tensor_dist_attr_for_program(x_shape, x_dist_attr) - - reshape_op = block._insert_op_without_sync( - index=index, - type="reshape2", - inputs={"X": x}, - outputs={"Out": out, "XShape": x_shape}, - attrs={ - "shape": shape, - "op_role": op_role, - 'op_namescope': self.op_namescope, - }, - ) - naive_set_dist_op_attr_for_program_by_mesh_and_mapping( - reshape_op, - process_mesh=x_dist_attr.process_mesh, - ref_mapping=x_dist_attr.dims_mapping, - ctx=self.dist_context, - chunk_id=x_dist_attr.chunk_id, - ) - - return out - def _split_matmul_grad_and_multi_streaming_allreduce( self, block, matmul_grad_id_to_allreduce_id ): @@ -133,20 +94,15 @@ def _split_matmul_grad_and_multi_streaming_allreduce( matmul_grad_op = ops[matmul_grad_id] allreduce_op = ops[allreduce_id] - # NOTE(Sonder): Why move those operations to the back of matmul_v2? - # When using amp_master_grad, the cast operation is inserted after matmul_grad. - # However, when employing allreduce_matmul_grad_overlapping, the matmul_grad is - # split into two matmul operations. In this case, some operations would access - # uninitialized tensors. Therefore, we move the cast operation to the back of the - # second matmul operation to avoid this problem. + # NOTE(Sonder): When there are ops between matmul_grad and allreduce, we should check whether + # these ops rely on the output of the intermediate ops. If so, we should not split the matmul_grad. + # Otherwise, the output of the intermediate ops will get wrong results. skip_overlapping = False - moved_ops_idx = [] moved_ops_output = [] matmul_grad_output = matmul_grad_op.output('Y@GRAD')[0] for idx in range(matmul_grad_id + 1, allreduce_id): if matmul_grad_output in ops[idx].desc.input_arg_names(): - moved_ops_idx.append(idx) moved_ops_output.extend(ops[idx].desc.output_arg_names()) else: for input_name in ops[idx].desc.input_arg_names(): @@ -156,137 +112,40 @@ def _split_matmul_grad_and_multi_streaming_allreduce( if skip_overlapping: continue - for i, idx in enumerate(moved_ops_idx): - op = ops[idx] - dist_attr = self.dist_context.get_op_dist_attr_for_program(op) - - op_inputs = op.desc.input_names() - op_outputs = op.desc.output_names() - - op_inputs = {name: op.input(name) for name in op_inputs} - op_outputs = {name: op.output(name) for name in op_outputs} - - op = block._insert_op_without_sync( - index=allreduce_id + 1 + i, - type=op.type, - inputs=op_inputs, - outputs=op_outputs, - attrs=op.all_attrs(), - ) - - self.dist_context.set_op_dist_attr_for_program(op, dist_attr) - - for i, idx in enumerate(moved_ops_idx): - block._remove_op(idx - i, sync=False) - allreduce_id -= 1 - - tran_x = matmul_grad_op.attr("trans_x") - assert ( - not tran_x - ), f"matmul_grad(id={matmul_grad_id}) with tran_x == True is not supported for column parallel linear backward overlapping" - tran_y = matmul_grad_op.attr("trans_y") - assert ( - not tran_y - ), f"matmul_grad(id={matmul_grad_id}) with tran_y == True is not supported for column parallel linear backward overlapping" - - allreduce_op.dist_attr.execution_stream = ( - AutoParallelStreamType.MP_STREAM.value + # matmul_grad_op => matmul_v2 + reshape + reshape + matmul_v2 + reshape + split_matmul_grad_to_matmul( + block, matmul_grad_id, self.dist_context, self.op_namescope ) - x = matmul_grad_op.input("X") - y = matmul_grad_op.input("Y") - out_grad = matmul_grad_op.input("Out@GRAD") - x_grad = matmul_grad_op.output("X@GRAD") - y_grad = matmul_grad_op.output("Y@GRAD") - op_role = matmul_grad_op.attr("op_role") - # NOTE(Ruibiao): Required OP scheduling order: matmul(dOut, Y^T) -> c_allreduce_sum(dX) -> matmul(X^T, dOut). # c_allreduce_sum(dX) and matmul(X^T, dOut) cannot be swapped. Otherwise, after buffer_shared_inplace_pass # adding share_buffer OP before c_allreduce_sum, c_allreduce_sum will synchronous with comp-stream, and then # the matmul op before it cannot be overlapped. - var_x = block.var(x[0]) - var_out_grad = block.var(out_grad[0]) - var_y_grad = block.var(y_grad[0]) - - x_dims = var_x.shape - out_grad_dims = var_out_grad.shape - y_grad_dims = var_y_grad.shape - - assert len(x_dims) == len( - out_grad_dims - ), f"The rank of x must be equal to that of out_grad, but got x rank = {len(x_dims)} and out_grad rank = {len(out_grad_dims)}." - if len(x_dims) > 2: - assert ( - x_dims[0:2] == out_grad_dims[0:2] - ), f"The first two dimensions of x must be equal to that of out_grad, but got x_dims:{x_dims} and out_grad_dims:{out_grad_dims}." - new_x_dims = [x_dims[0] * x_dims[1]] + list(x_dims[2:]) - new_out_grad_dims = [ - out_grad_dims[0] * out_grad_dims[1] - ] + list(out_grad_dims[2:]) - - # NOTE(Ruibiao): Why insert reshape op here? - # When the rank of input matrix is 3, MatmulGradKernel use reshape to fold the first two dimensions of x and out_grad (see FoldInitDims in matmul_grad_kernel_impl.h), and then calls blas.Matmul to calculate y_grad. - # If we directly append matmul op to calculate y_grad without FoldInitDims, blas.BatchedGEMM is actually called in MatmulKernel, which has a larger cost than using blas.Matmul after dimension folding. - # Therefore, we imitate MatmulGradKernel here by inserting reshape op before matmul. - new_x = self._insert_reshape_op( - block, allreduce_id + 1, x, new_x_dims, op_role - ) - new_out_grad = self._insert_reshape_op( - block, allreduce_id + 2, out_grad, new_out_grad_dims, op_role - ) - new_y_grad = block.create_var( - name=f"{y_grad[0]}@reshape.out", - dtype=var_y_grad.dtype, - persistable=False, - ) - self.dist_context.set_tensor_dist_attr_for_program( - new_y_grad, - self.dist_context.get_tensor_dist_attr_for_program(var_y_grad), - ) - - matmul_grad_dist_attr = ( - self.dist_context.get_op_dist_attr_for_program(matmul_grad_op) - ) - matmul_op = block._insert_op_without_sync( - index=allreduce_id + 3, - type="matmul_v2", - inputs={"X": new_x, "Y": new_out_grad}, - outputs={"Out": new_y_grad}, - attrs={ - "trans_x": True, - "trans_y": False, - "op_role": op_role, - 'op_namescope': self.op_namescope, - }, - ) - self.dist_context.set_op_dist_attr_for_program( - matmul_op, matmul_grad_dist_attr - ) - - self._insert_reshape_op( - block, - allreduce_id + 4, - [new_y_grad.name], - y_grad_dims, - op_role, - y_grad, + allreduce_op_dist_attr = ( + self.dist_context.get_op_dist_attr_for_program(allreduce_op) ) - matmul_op = block._insert_op_without_sync( - index=matmul_grad_id + 1, - type="matmul_v2", - inputs={"X": out_grad, "Y": y}, - outputs={"Out": x_grad}, - attrs={ - "trans_x": False, - "trans_y": True, - "op_role": op_role, - 'op_namescope': self.op_namescope, - }, + allreduce_op_inputs = allreduce_op.desc.input_names() + allreduce_op_outputs = allreduce_op.desc.output_names() + + allreduce_op_inputs = { + name: allreduce_op.input(name) for name in allreduce_op_inputs + } + allreduce_op_outputs = { + name: allreduce_op.output(name) for name in allreduce_op_outputs + } + + allreduce_op = block._insert_op_without_sync( + index=allreduce_id + 1, + type=allreduce_op.type, + inputs=allreduce_op_inputs, + outputs=allreduce_op_outputs, + attrs=allreduce_op.all_attrs(), ) self.dist_context.set_op_dist_attr_for_program( - matmul_op, matmul_grad_dist_attr + allreduce_op, allreduce_op_dist_attr ) + # Remove the original allreduce op + block._remove_op(allreduce_id + 5, sync=False) - block._remove_op(matmul_grad_id, sync=False) block._sync_with_cpp() diff --git a/python/paddle/distributed/passes/pass_utils.py b/python/paddle/distributed/passes/pass_utils.py index f1dcc8a7ffd797..a8064e90535203 100644 --- a/python/paddle/distributed/passes/pass_utils.py +++ b/python/paddle/distributed/passes/pass_utils.py @@ -26,6 +26,7 @@ is_backward_op, is_forward_op, is_optimize_op, + naive_set_dist_op_attr_for_program_by_mesh_and_mapping, use_new_executor, ) from paddle.distributed.fleet.meta_optimizers.common import OpRole @@ -785,3 +786,172 @@ def _add_event_dependency(recorder_op, waiter_op): if recorder_op.dist_attr.event_to_record not in waiter_wait_list: waiter_wait_list.append(recorder_op.dist_attr.event_to_record) waiter_op.dist_attr.events_to_wait = waiter_wait_list + + +def _insert_reshape_op( + block, + index, + x, + shape, + op_role, + dist_context, + out=None, + op_namescope="/", +): + var_x = block.var(x[0]) + x_dist_attr = dist_context.get_tensor_dist_attr_for_program(var_x) + + if out is None: + out = block.create_var( + name=f"{x[0]}@reshape.out", + dtype=var_x.dtype, + persistable=False, + ) + dist_context.set_tensor_dist_attr_for_program(out, x_dist_attr) + + x_shape = block.create_var(name=f"{x[0]}@reshape.xshape", dtype=var_x.dtype) + dist_context.set_tensor_dist_attr_for_program(x_shape, x_dist_attr) + + reshape_op = block._insert_op_without_sync( + index=index, + type="reshape2", + inputs={"X": x}, + outputs={"Out": out, "XShape": x_shape}, + attrs={ + "shape": shape, + "op_role": op_role, + 'op_namescope': op_namescope, + }, + ) + + naive_set_dist_op_attr_for_program_by_mesh_and_mapping( + reshape_op, + process_mesh=x_dist_attr.process_mesh, + ref_mapping=x_dist_attr.dims_mapping, + ctx=dist_context, + chunk_id=x_dist_attr.chunk_id, + ) + + return out + + +def split_matmul_grad_to_matmul( + block, matmul_grad_id, dist_context, op_namescope="/" +): + ops = block.ops + matmul_grad_op = ops[matmul_grad_id] + + tran_x = matmul_grad_op.attr("trans_x") + assert ( + not tran_x + ), f"matmul_grad(id={matmul_grad_id}) with tran_x == True is not supported for spliting matmul_grad to matmul" + tran_y = matmul_grad_op.attr("trans_y") + assert ( + not tran_y + ), f"matmul_grad(id={matmul_grad_id}) with tran_y == True is not supported for spliting matmul_grad to matmul" + + x = matmul_grad_op.input("X") + y = matmul_grad_op.input("Y") + out_grad = matmul_grad_op.input("Out@GRAD") + x_grad = matmul_grad_op.output("X@GRAD") + y_grad = matmul_grad_op.output("Y@GRAD") + op_role = matmul_grad_op.attr("op_role") + + var_x = block.var(x[0]) + var_out_grad = block.var(out_grad[0]) + var_y_grad = block.var(y_grad[0]) + + x_dims = var_x.shape + out_grad_dims = var_out_grad.shape + y_grad_dims = var_y_grad.shape + + assert len(x_dims) == len( + out_grad_dims + ), f"The rank of x must be equal to that of out_grad, but got x rank = {len(x_dims)} and out_grad rank = {len(out_grad_dims)}." + if len(x_dims) > 2: + assert ( + x_dims[0:2] == out_grad_dims[0:2] + ), f"The first two dimensions of x must be equal to that of out_grad, but got x_dims:{x_dims} and out_grad_dims:{out_grad_dims}." + new_x_dims = [x_dims[0] * x_dims[1]] + list(x_dims[2:]) + new_out_grad_dims = [out_grad_dims[0] * out_grad_dims[1]] + list( + out_grad_dims[2:] + ) + + # NOTE(Ruibiao): Why insert reshape op here? + # When the rank of input matrix is 3, MatmulGradKernel use reshape to fold the first two dimensions of x and out_grad (see FoldInitDims in matmul_grad_kernel_impl.h), and then calls blas.Matmul to calculate y_grad. + # If we directly append matmul op to calculate y_grad without FoldInitDims, blas.BatchedGEMM is actually called in MatmulKernel, which has a larger cost than using blas.Matmul after dimension folding. + # Therefore, we imitate MatmulGradKernel here by inserting reshape op before matmul. + new_x = _insert_reshape_op( + block, + matmul_grad_id + 1, + x, + new_x_dims, + op_role, + dist_context=dist_context, + op_namescope=op_namescope, + ) + new_out_grad = _insert_reshape_op( + block, + matmul_grad_id + 2, + out_grad, + new_out_grad_dims, + op_role, + dist_context=dist_context, + op_namescope=op_namescope, + ) + new_y_grad = block.create_var( + name=f"{y_grad[0]}@reshape.out", + dtype=var_y_grad.dtype, + persistable=False, + ) + + dist_context.set_tensor_dist_attr_for_program( + new_y_grad, + dist_context.get_tensor_dist_attr_for_program(var_y_grad), + ) + + matmul_grad_dist_attr = dist_context.get_op_dist_attr_for_program( + matmul_grad_op + ) + + matmul_op = block._insert_op_without_sync( + index=matmul_grad_id + 3, + type="matmul_v2", + inputs={"X": new_x, "Y": new_out_grad}, + outputs={"Out": new_y_grad}, + attrs={ + "trans_x": True, + "trans_y": False, + "op_role": op_role, + 'op_namescope': op_namescope, + }, + ) + + dist_context.set_op_dist_attr_for_program(matmul_op, matmul_grad_dist_attr) + _insert_reshape_op( + block, + matmul_grad_id + 4, + [new_y_grad.name], + y_grad_dims, + op_role, + dist_context=dist_context, + out=y_grad, + op_namescope=op_namescope, + ) + + matmul_op = block._insert_op_without_sync( + index=matmul_grad_id + 1, + type="matmul_v2", + inputs={"X": out_grad, "Y": y}, + outputs={"Out": x_grad}, + attrs={ + "trans_x": False, + "trans_y": True, + "op_role": op_role, + 'op_namescope': op_namescope, + }, + ) + + dist_context.set_op_dist_attr_for_program(matmul_op, matmul_grad_dist_attr) + + block._remove_op(matmul_grad_id, sync=False)