Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,12 @@
from paddle.fluid import layers

import logging
logging.basicConfig(
format='%(asctime)s %(levelname)-8s %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
logger = logging.getLogger(__name__)
formatter = logging.Formatter(
fmt='%(asctime)s %(levelname)-8s %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)
from functools import reduce

__all__ = ["ShardingOptimizer"]
Expand Down Expand Up @@ -136,7 +139,7 @@ def minimize_impl(self,

# FIXME (JZ-LIANG) deprecated hybrid_dp
if self.user_defined_strategy.sharding_configs["hybrid_dp"]:
logging.warning(
logger.warning(
"[hybrid_dp] API setting is deprecated. Now when dp_degree >= 2, its will be in hybrid dp mode automatically"
)
assert self.dp_degree >= 1
Expand Down Expand Up @@ -174,7 +177,7 @@ def minimize_impl(self,
self._gradient_merge_acc_step = self.user_defined_strategy.pipeline_configs[
'accumulate_steps']
if self._gradient_merge_acc_step > 1:
logging.info("Gradient merge in [{}], acc step = [{}]".format(
logger.info("Gradient merge in [{}], acc step = [{}]".format(
self.gradient_merge_mode, self._gradient_merge_acc_step))

# optimize offload
Expand Down Expand Up @@ -338,7 +341,7 @@ def minimize_impl(self,
# opt offload should be enable while gradient merge is enable && acc_step is quite large (e.g. >> 100)
# sync its memcpy could not be overlap with calc, otherwise it will slower down training severely.
if self.optimize_offload:
logging.info("Sharding with optimize offload !")
logger.info("Sharding with optimize offload !")
offload_helper = OffloadHelper()
offload_helper.offload(main_block, startup_block)
offload_helper.offload_fp32param(main_block, startup_block)
Expand Down Expand Up @@ -641,15 +644,15 @@ def _split_program(self, block):
for varname in sorted(
var2broadcast_time, key=var2broadcast_time.get,
reverse=True):
logging.info("Sharding broadcast: [{}] times [{}]".format(
logger.info("Sharding broadcast: [{}] times [{}]".format(
var2broadcast_time[varname], varname))
for idx_ in range(len(self._segments)):
logging.info("segment [{}] :".format(idx_))
logging.info("start op: [{}] [{}]".format(block.ops[
logger.info("segment [{}] :".format(idx_))
logger.info("start op: [{}] [{}]".format(block.ops[
self._segments[idx_]._start_idx].desc.type(), block.ops[
self._segments[idx_]._start_idx].desc.input_arg_names(
)))
logging.info("end op: [{}] [{}]".format(block.ops[
logger.info("end op: [{}] [{}]".format(block.ops[
self._segments[idx_]._end_idx].desc.type(), block.ops[
self._segments[idx_]._end_idx].desc.input_arg_names()))
return
Expand Down Expand Up @@ -1108,7 +1111,7 @@ def _build_groups(self):
self.dp_group_endpoints.append(self.global_endpoints[
dp_first_rank_idx + dp_offset * i])
assert self.current_endpoint in self.dp_group_endpoints
logging.info("Hybrid DP mode turn on !")
logger.info("Hybrid DP mode turn on !")
else:
self.dp_ring_id = -1
self.dp_rank = -1
Expand All @@ -1119,40 +1122,40 @@ def _build_groups(self):
# NOTE (JZ-LIANG) when use global ring for calc global norm and dp_degree > 1, the allreduce result should be devided by dp_degree
self.global_ring_id = 3

logging.info("global word size: {}".format(self.global_word_size))
logging.info("global rank: {}".format(self.global_rank))
logging.info("global endpoints: {}".format(self.global_endpoints))
logging.info("global ring id: {}".format(self.global_ring_id))
logging.info("#####" * 6)

logging.info("mp group size: {}".format(self.mp_degree))
logging.info("mp rank: {}".format(self.mp_rank))
logging.info("mp group id: {}".format(self.mp_group_id))
logging.info("mp group endpoints: {}".format(self.mp_group_endpoints))
logging.info("mp ring id: {}".format(self.mp_ring_id))
logging.info("#####" * 6)

logging.info("sharding group size: {}".format(self.sharding_degree))
logging.info("sharding rank: {}".format(self.sharding_rank))
logging.info("sharding group id: {}".format(self.sharding_group_id))
logging.info("sharding group endpoints: {}".format(
logger.info("global word size: {}".format(self.global_word_size))
logger.info("global rank: {}".format(self.global_rank))
logger.info("global endpoints: {}".format(self.global_endpoints))
logger.info("global ring id: {}".format(self.global_ring_id))
logger.info("#####" * 6)

logger.info("mp group size: {}".format(self.mp_degree))
logger.info("mp rank: {}".format(self.mp_rank))
logger.info("mp group id: {}".format(self.mp_group_id))
logger.info("mp group endpoints: {}".format(self.mp_group_endpoints))
logger.info("mp ring id: {}".format(self.mp_ring_id))
logger.info("#####" * 6)

logger.info("sharding group size: {}".format(self.sharding_degree))
logger.info("sharding rank: {}".format(self.sharding_rank))
logger.info("sharding group id: {}".format(self.sharding_group_id))
logger.info("sharding group endpoints: {}".format(
self.sharding_group_endpoints))
logging.info("sharding ring id: {}".format(self.sharding_ring_id))
logging.info("#####" * 6)

logging.info("pp group size: {}".format(self.pp_degree))
logging.info("pp rank: {}".format(self.pp_rank))
logging.info("pp group id: {}".format(self.pp_group_id))
logging.info("pp group endpoints: {}".format(self.pp_group_endpoints))
logging.info("pp ring id: {}".format(self.pp_ring_id))
logging.info("#####" * 6)

logging.info("pure dp group size: {}".format(self.dp_degree))
logging.info("pure dp rank: {}".format(self.dp_rank))
logging.info("pure dp group endpoints: {}".format(
logger.info("sharding ring id: {}".format(self.sharding_ring_id))
logger.info("#####" * 6)

logger.info("pp group size: {}".format(self.pp_degree))
logger.info("pp rank: {}".format(self.pp_rank))
logger.info("pp group id: {}".format(self.pp_group_id))
logger.info("pp group endpoints: {}".format(self.pp_group_endpoints))
logger.info("pp ring id: {}".format(self.pp_ring_id))
logger.info("#####" * 6)

logger.info("pure dp group size: {}".format(self.dp_degree))
logger.info("pure dp rank: {}".format(self.dp_rank))
logger.info("pure dp group endpoints: {}".format(
self.dp_group_endpoints))
logging.info("pure dp ring id: {}".format(self.dp_ring_id))
logging.info("#####" * 6)
logger.info("pure dp ring id: {}".format(self.dp_ring_id))
logger.info("#####" * 6)

return

Expand Down
11 changes: 7 additions & 4 deletions python/paddle/distributed/fleet/utils/recompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
import contextlib

import logging
logging.basicConfig(
format='%(asctime)s %(levelname)-8s %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
logger = logging.getLogger(__name__)
formatter = logging.Formatter(
fmt='%(asctime)s %(levelname)-8s %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)


def detach_variable(inputs):
Expand All @@ -40,7 +43,7 @@ def detach_variable(inputs):
def check_recompute_necessary(inputs):
if not any(input_.stop_gradient == False for input_ in inputs
if isinstance(input_, paddle.Tensor)):
logging.warn(
logger.warn(
"[Recompute]: None of the inputs to current recompute block need grad, "
"therefore there is NO need to recompute this block in backward !")

Expand Down
7 changes: 5 additions & 2 deletions python/paddle/fluid/incubate/fleet/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,12 @@
"graphviz"
]

logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
formatter = logging.Formatter(fmt='%(asctime)s - %(levelname)s - %(message)s')
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)

persistable_vars_out_fn = "vars_persistable.log"
all_vars_out_fn = "vars_all.log"
Expand Down
9 changes: 6 additions & 3 deletions python/paddle/utils/cpp_extension/extension_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@
from ...fluid.framework import OpProtoHolder
from ...sysconfig import get_include, get_lib

logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(message)s', level=logging.INFO)
logger = logging.getLogger("utils.cpp_extension")
logger.setLevel(logging.INFO)
formatter = logging.Formatter(fmt='%(asctime)s - %(levelname)s - %(message)s')
ch = logging.StreamHandler()
ch.setFormatter(formatter)
logger.addHandler(ch)

OS_NAME = sys.platform
IS_WINDOWS = OS_NAME.startswith('win')
Expand Down Expand Up @@ -1125,4 +1128,4 @@ def log_v(info, verbose=True):
Print log information on stdout.
"""
if verbose:
logging.info(info)
logger.info(info)