Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion python/paddle/distributed/passes/cpp_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .pass_base import CPPPassWrapper, register_pass
from .pass_base import PassType, CPPPassWrapper, register_pass


@register_pass("fuse_elewise_add_act")
Expand All @@ -23,3 +23,6 @@ def __init__(self):
@property
def cpp_name(self):
return "fuse_elewise_add_act_pass"

def _type(self):
return PassType.FUSION_OPT
7 changes: 5 additions & 2 deletions python/paddle/distributed/passes/fuse_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from paddle.framework import core
from paddle.fluid import unique_name
from .pass_base import CommOptPass, register_pass
from .pass_base import PassBase, PassType, register_pass
from collections import OrderedDict
import numpy as np

Expand Down Expand Up @@ -329,7 +329,7 @@ def insert_fuse_all_reduce_by_memory_size(block, groups, max_memory_size):


@register_pass("fuse_all_reduce")
class FuseAllReducePass(CommOptPass):
class FuseAllReducePass(PassBase):
def __init__(self):
super(FuseAllReducePass, self).__init__()
self.set_attr("max_memory_size", -1)
Expand All @@ -341,6 +341,9 @@ def _check_self(self):
def _check_conflict(self, other_pass):
return True

def _type(self):
return PassType.COMM_OPT

# NOTE: why FuseAllReducePass can override apply_single_impl instead of
# apply_impl? AllReduce is a collective operation, so the program of each
# rank inside the same communication group should have the same
Expand Down
50 changes: 21 additions & 29 deletions python/paddle/distributed/passes/pass_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,20 @@ def _pop_pass(self):
del self._applied_passes[-1]


class PassType:
UNKNOWN = 0
COMM_OPT = 1
CALC_OPT = 2
PARALLEL_OPT = 3
FUSION_OPT = 4


class PassBase(ABC):
_REGISTERED_PASSES = {}
_COMMON_RULES = []
# TODO(zengjinle): add white/black list

name = None

@staticmethod
def _register(pass_name, pass_class):
Expand All @@ -67,6 +78,9 @@ def _check_self(self):
def _check_conflict(self, other_pass):
pass

def _type(self):
return PassType.UNKNOWN

def _check_conflict_including_common_rules(self, other_pass):
return self._check_conflict(other_pass) and all(
[r(other_pass, self) for r in PassBase._COMMON_RULES])
Expand Down Expand Up @@ -142,40 +156,18 @@ def _apply_single_impl(self, main_program, startup_program, context):
self._attrs, self.cpp_attr_types)


# Like AutoParallel/HybridParallel, etc.
class ParallelOptPass(PassBase):
def __init__(self):
super(ParallelOptPass, self).__init__()


# Like AMP, Recompute, etc.
class CalcOptPass(PassBase):
def __init__(self):
super(CalcOptPass, self).__init__()


# Like FuseAllReduce, FuseGradientMerge, etc.
class CommOptPass(PassBase):
def __init__(self):
super(CommOptPass, self).__init__()


def _make_pass_order_rule(pass_class_before, pass_class_after):
def impl(pass_obj_before, pass_obj_after):
if isinstance(pass_obj_before, pass_class_after) \
and isinstance(pass_obj_after, pass_class_before):
return False
def _fusion_opt_last_rule(pass_before, pass_after):
if pass_before._type() == PassType.FUSION_OPT and pass_after._type(
) != PassType.FUSION_OPT:
return False
else:
return True

return impl


PassBase._COMMON_RULES = [
_make_pass_order_rule(CalcOptPass, CommOptPass),
_make_pass_order_rule(ParallelOptPass, CPPPassWrapper),
_make_pass_order_rule(CalcOptPass, CPPPassWrapper),
_make_pass_order_rule(CommOptPass, CPPPassWrapper),
_fusion_opt_last_rule,
lambda pass_before, pass_after: type(pass_before) != type(pass_after),
# Add more common rules here
]


Expand Down