Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@


class AscendIRParser(object):
def __init__(self):
def __init__(self, auto_dp=False, world_rank_size=1):
self.graph_idx = 0
self.hcom_endpoints = {}
self.groups_to_create = []
self._auto_dp = auto_dp
self._world_rank_size = world_rank_size

def _construct_input_map(self, input_varlist):
ret_map = {}
Expand Down Expand Up @@ -91,13 +93,12 @@ def parse_op(self, op):
print("append to create group: %s, with rank_ids: %s" %
(group_name, global_rank_ids))
elif op.type in ascend_parser.registerd_op:
print("Op[%s] has been registered, begin to parse it" % (op.type))
op_parser = self.parser_factory.create_parse(
ascend_parser.registerd_op[op.type])
op_parser.apply(op)
else:
print("Op[%s] has not been registered, so we have to skip it" %
(op.type))
assert False, "Op[%s] has not been registered, so we have to skip it" % (
op.type)

def _parse_program(self,
graph_name,
Expand Down Expand Up @@ -161,6 +162,17 @@ def parse_program(self, startup_program, main_program, input_varlist,
startup_graph = self._parse_program("startup", startup_program)
main_graph = self._parse_program("main", main_program, input_varlist,
fetch_list)
if self._auto_dp and self._world_rank_size > 1:
assert len(self.groups_to_create
) == 0, "can't parse program under auto_dp mode"

from paddle.distributed import fleet
self.groups_to_create.append(
HcomGroupConfig(
name="hcom_group_0",
nranks=fleet.world_size(),
rank_ids=[x for x in range(fleet.world_size())]))

return startup_graph, main_graph


Expand Down Expand Up @@ -196,7 +208,8 @@ def minimize(self,
startup_program=None,
parameter_list=None,
no_grad_set=None,
auto_dp=False):
auto_dp=False,
rank_table_file=None):
minimized = None
if self.inner_opt:
minimized = self.inner_opt.minimize(
Expand All @@ -205,42 +218,44 @@ def minimize(self,
self.ascend_instance = core.AscendInstance()

from paddle.distributed import fleet
if auto_dp and fleet.worker_num() > 1:
if auto_dp and fleet.world_size() > 1:
from paddle.fluid.transpiler import ascend_transpiler
t = ascend_transpiler.AscendTranspiler(startup_program,
loss.block.program)
t.transpile()
print(loss.block.program)
#print(loss.block.program)

# Config about Graph Engine can be found in https://support.huaweicloud.com/
config = {
"ge.exec.deviceId": str(fleet.local_device_ids()),
"ge.graphRunMode": "1",
"ge.exec.precision_mode": "must_keep_origin_dtype",
# if multi mode
"ge.exec.rankTableFile": os.getenv("RANK_TABLE_FILE"),
"ge.exec.rankId": str(fleet.worker_index()),
"ge.exec.isUseHcom": "1",
"ge.exec.deployMode": "0",
}
# if multi trainers
if rank_table_file and fleet.world_size() > 1:
config["ge.exec.rankTableFile"] = rank_table_file
config["ge.exec.rankId"] = str(fleet.worker_index())
config["ge.exec.isUseHcom"] = "1"
config["ge.exec.deployMode"] = "0"
print("ge_initialize config:", config)
core.ge_initialize(config)

# Init Session
self.ascend_instance.init_global_resources()

main_block = loss.block
self.parser = AscendIRParser()
self.parser = AscendIRParser(
auto_dp=auto_dp, world_rank_size=fleet.world_size())

input_varlist = self._get_input_varlist(main_block.program)

startup_graph, main_graph = self.parser.parse_program(
startup_program, main_block.program, input_varlist, self.fetch_list)

for cfg in self.parser.groups_to_create:
hccl.create_group(cfg.name, cfg.nranks, cfg.rank_ids)
print("create group (%s), nranks: %d, rank_ids: %s" %
(cfg.name, cfg.nranks, cfg.rank_ids))
hccl.create_group(cfg.name, cfg.nranks, cfg.rank_ids)

self.ascend_instance.add_ascend_subgraph(0, startup_graph)
self.ascend_instance.add_ascend_subgraph(1, main_graph)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,16 @@ def update_output(self, geop_list, index_list):
self.parser_name, len(index_list), output_num)
for output_id in range(output_num):
arguments = self.op.output(self.op.output_names[output_id])
#print("%d argument: %s" % (output_id, str(arguments)))
if len(arguments) > 0:
assert len(arguments) == len(
index_list[output_id]
), "Parser[%s]'s %dth argument number[%d] is not equal to paddle's number[%d]" % (
self.parser_name, output_id, len(index_list[output_id]),
len(arguments))
for i in range(len(arguments)):
#print("assgin index_list[%d][%d] to %s" %
# (output_id, i, arguments[i]))
self.var2geop[arguments[i]] = geop_list[index_list[
output_id][i]]

Expand Down Expand Up @@ -789,6 +792,8 @@ def _apply(self):
"Const").set_attr_tensor("value", tensor)
self._mark_as_input(const)
if self.op.block.var(self.op.output('Out')[0]).persistable:
#print("%s is Persistable in fill_constant" %
# (self.op.output('Out')[0]))
var = core.GEOperatorFactory.create_operator(
self.op.output('Out')[0], "Variable")
var.update_output_desc("y",
Expand All @@ -800,6 +805,10 @@ def _apply(self):
"assign" + self._accumulated_op_id(), "Assign").set_input(
"value", const).set_input("ref", var)
return [const], [[0]]
#else:
# print(
# "self.op.output('Out')[0]: %s is not persistable in fill_constant"
# % (self.op.output('Out')[0]))
return [const], [[0]]


Expand Down Expand Up @@ -853,6 +862,8 @@ def _apply(self):

## wirte the output of truncatedNormal from startup_program to main_program
if self.op.block.var(self.op.output('Out')[0]).persistable:
#print("%s is Persistable in truncated_normal" %
# (self.op.output('Out')[0]))
var = core.GEOperatorFactory.create_operator(
self.op.output('Out')[0], "Variable")
var.update_output_desc("y",
Expand All @@ -867,6 +878,10 @@ def _apply(self):
shape_tensor, mean_tensor, std_tensor, min_tensor, max_tensor,
truncated_normal
], [[-1]]
#else:
# print(
# "self.op.output('Out')[0] is not persistable in truncated_noraml"
# )
return [truncated_normal], [[0]]


Expand Down Expand Up @@ -1366,7 +1381,7 @@ def _apply(self):

tensor1 = self._create_ge_tensor([len(shape)], 2, shape)
shape_tensor = core.GEOperatorFactory.create_operator(
"const" + self._accumulated_op_id(),
"const" + self._accumulated_op_id(),
"Const").set_attr_tensor("value", tensor1)

ge_ur = core.GEOperatorFactory.create_operator(
Expand All @@ -1379,9 +1394,9 @@ def _apply(self):
scale = max_v - min_v

scale_value = core.GEOperatorFactory.create_operator(
"scale" + self._accumulated_op_id(), "Power").set_input(
"x", ge_ur).set_attr_float("power", 1.0).set_attr_float(
"scale", scale).set_attr_float("shift", min_v)
"scale" + self._accumulated_op_id(), "Power").set_input(
"x", ge_ur).set_attr_float("power", 1.0).set_attr_float(
"scale", scale).set_attr_float("shift", min_v)

return [scale_value], [[0]]

Expand Down Expand Up @@ -1429,14 +1444,15 @@ def __init__(self, graph, var2geop):

def _apply(self):
tensor = self._get_ge_input(self.op.input_arg_names[0])
axes = self.op.attr("axes")
axes = self.op.attr("axes")

data_squeezed = core.GEOperatorFactory\
.create_operator("squeeze" + self._accumulated_op_id(), "Squeeze")\
.set_input("x", tensor)\
.set_attr_vec_int32("axes", axes)
shape = core.GEOperatorFactory.create_operator(
"shape" + self._accumulated_op_id(), "Shape").set_input("x", data_squeezed)
"shape" + self._accumulated_op_id(),
"Shape").set_input("x", data_squeezed)
return [shape, data_squeezed], [[1], [0]]


Expand Down Expand Up @@ -2172,4 +2188,3 @@ def _apply(self):
"epsilon", epsilon).set_input("grad", grad)

return [adam], [[0]]