Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 40 additions & 5 deletions maint/gen_coll.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def dump_coll(name, blocking_type):
dump_allcomm_sched_auto(name)
dump_sched_impl(name)
dump_mpir_impl_nonblocking(name)
dump_mpir_nonblocking_tag(name)
elif blocking_type == "persistent":
dump_mpir_impl_persistent(name)
else:
Expand Down Expand Up @@ -426,6 +427,36 @@ def dump_mpir_impl_nonblocking(name):
G.out.append("goto fn_exit;")
dump_close("}")

def dump_mpir_nonblocking_tag(name):
blocking_type = "nonblocking"
func = G.FUNCS["mpi_" + name]
params, args = get_params_and_args(func)
func_params = get_func_params(params, name, "tag")

func_name = get_func_name(name, blocking_type)
Name = func_name.capitalize()
NAME = func_name.upper()

G.out.append("")
add_prototype("int MPIR_%s_tag(%s)" % (Name, func_params))
dump_split(0, "int MPIR_%s_tag(%s)" % (Name, func_params))
dump_open('{')
G.out.append("int mpi_errno = MPI_SUCCESS;")
G.out.append("enum MPIR_sched_type sched_type;")
G.out.append("void *sched;")
G.out.append("")
G.out.append("*request = NULL;")
func_args = get_func_args(args, name, "mpir_impl_tag")
dump_split(1, "mpi_errno = MPIR_%s_sched_impl(%s);" % (Name, func_args))
G.out.append("MPIR_ERR_CHECK(mpi_errno);")
G.out.append("MPII_SCHED_START(sched_type, sched, comm_ptr, request);")
G.out.append("")
G.out.append("fn_exit:")
G.out.append("return mpi_errno;")
G.out.append("fn_fail:")
G.out.append("goto fn_exit;")
dump_close("}")

def dump_mpir_impl_persistent(name):
blocking_type = "persistent"
func = G.FUNCS["mpi_" + name]
Expand Down Expand Up @@ -693,14 +724,16 @@ def get_func_params(params, name, kind):
func_params += ", int coll_attr"
elif kind == "nonblocking":
func_params += ", MPIR_Request ** request"
elif kind == "tag":
func_params += ", int tag, MPIR_Request ** request"
elif kind == "persistent":
func_params += ", MPIR_Info * info_ptr, MPIR_Request ** request"
elif kind == "sched_auto":
func_params += ", MPIR_Sched_t s"
elif kind == "allcomm_sched_auto":
func_params += ", bool is_persistent, void **sched_p, enum MPIR_sched_type *sched_type_p"
func_params += ", int tag, bool is_persistent, void **sched_p, enum MPIR_sched_type *sched_type_p"
elif kind == "sched_impl":
func_params += ", bool is_persistent, void **sched_p, enum MPIR_sched_type *sched_type_p"
func_params += ", int tag, bool is_persistent, void **sched_p, enum MPIR_sched_type *sched_type_p"
else:
raise Exception("get_func_params - unexpected kind = %s" % kind)

Expand All @@ -716,11 +749,13 @@ def get_func_args(args, name, kind):
elif kind == "persistent":
func_args += ", info_ptr, request"
elif kind == "allcomm_sched_auto":
func_args += ", is_persistent, sched_p, sched_type_p"
func_args += ", tag, is_persistent, sched_p, sched_type_p"
elif kind == "mpir_impl_nonblocking":
func_args += ", false, &sched, &sched_type"
func_args += ", 0, false, &sched, &sched_type"
elif kind == "mpir_impl_tag":
func_args += ", tag, false, &sched, &sched_type"
elif kind == "mpir_impl_persistent":
func_args += ", true, &req->u.persist_coll.sched, &req->u.persist_coll.sched_type"
func_args += ", 0, true, &req->u.persist_coll.sched, &req->u.persist_coll.sched_type"
else:
raise Exception("get_func_args - unexpected kind = %s" % kind)

Expand Down
3 changes: 2 additions & 1 deletion src/include/mpir_tags.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
#define MPIR_REDUCE_SCATTER_BLOCK_TAG 28
#define MPIR_SHRINK_TAG 29
#define MPIR_AGREE_TAG 30
#define MPIR_FIRST_HCOLL_TAG 31
#define MPIR_CTXID_TAG 31
#define MPIR_FIRST_HCOLL_TAG 32
#define MPIR_LAST_HCOLL_TAG (MPIR_FIRST_HCOLL_TAG + 255)
#define MPIR_FIRST_NBC_TAG (MPIR_LAST_HCOLL_TAG + 1)

Expand Down
7 changes: 4 additions & 3 deletions src/mpi/coll/include/coll_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,10 @@ int MPII_Coll_finalize(void);
} \
mpi_errno = MPIR_Sched_create(&s, sched_kind); \
MPIR_ERR_CHECK(mpi_errno); \
int tag = -1; \
mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); \
MPIR_ERR_CHECK(mpi_errno); \
if (!tag) { \
mpi_errno = MPIR_Sched_next_tag(comm_ptr, &tag); \
MPIR_ERR_CHECK(mpi_errno); \
} \
MPIR_Sched_set_tag(s, tag); \
*sched_type_p = MPIR_SCHED_NORMAL; \
*sched_p = s; \
Expand Down
Loading