Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/mpid/ch4/netmod/ofi/ofi_events.c
Original file line number Diff line number Diff line change
Expand Up @@ -568,9 +568,8 @@ int MPIDI_OFI_handle_cq_error(int vci, int nic, ssize_t ret)
MPIR_STATUS_SET_CANCEL_BIT(req->status, TRUE);
MPIR_STATUS_SET_COUNT(req->status, 0);
MPIR_Datatype_release_if_not_builtin(MPIDI_OFI_REQUEST(req, datatype));
if ((event_id == MPIDI_OFI_EVENT_RECV_PACK) &&
MPIDI_OFI_REQUEST(req, noncontig.pack.pack_buffer)) {
MPL_free(MPIDI_OFI_REQUEST(req, noncontig.pack.pack_buffer));
if (event_id == MPIDI_OFI_EVENT_RECV_PACK) {
MPIDI_OFI_free_pack_buffer(req);
} else if (event_id == MPIDI_OFI_EVENT_RECV_NOPACK) {
MPL_free(MPIDI_OFI_REQUEST(req, noncontig.nopack.iovs));
}
Expand Down
7 changes: 3 additions & 4 deletions src/mpid/ch4/netmod/ofi/ofi_events.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,8 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_send_event(int vci,
MPIR_FUNC_ENTER;

/* free the packing buffers and datatype */
if ((event_id == MPIDI_OFI_EVENT_SEND_PACK) &&
(MPIDI_OFI_REQUEST(sreq, noncontig.pack.pack_buffer))) {
MPL_free(MPIDI_OFI_REQUEST(sreq, noncontig.pack.pack_buffer));
if (event_id == MPIDI_OFI_EVENT_SEND_PACK) {
MPIDI_OFI_free_pack_buffer(sreq);
} else if (MPIDI_OFI_ENABLE_PT2PT_NOPACK && (event_id == MPIDI_OFI_EVENT_SEND_NOPACK)) {
MPL_free(MPIDI_OFI_REQUEST(sreq, noncontig.nopack.iovs));
}
Expand Down Expand Up @@ -97,7 +96,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_recv_complete(MPIR_Request * rreq, int ev
if (mpi_errno) {
MPIR_ERR_SET(rreq->status.MPI_ERROR, MPI_ERR_TYPE, "**dtypemismatch");
}
MPL_free(MPIDI_OFI_REQUEST(rreq, noncontig.pack.pack_buffer));
MPIDI_OFI_free_pack_buffer(rreq);
} else if (event_id == MPIDI_OFI_EVENT_RECV_NOPACK) {
#ifdef HAVE_ERROR_CHECKING
MPI_Count elements;
Expand Down
43 changes: 26 additions & 17 deletions src/mpid/ch4/netmod/ofi/ofi_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -833,27 +833,36 @@ MPL_STATIC_INLINE_PREFIX void MPIDI_OFI_gpu_rma_register(const void *buffer, siz
#undef CQ_D_HEAD
#undef CQ_D_TAIL

MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_gpu_malloc_pack_buffer(void **ptr, size_t pack_size)
{
if (MPIDI_OFI_ENABLE_HMEM) {
return MPL_gpu_malloc_host(ptr, pack_size);
} else {
#ifdef MPL_DEFINE_ALIGNED_ALLOC
*ptr = MPL_aligned_alloc(256, pack_size, MPL_MEM_BUFFER);
#else
*ptr = MPL_malloc(pack_size, MPL_MEM_BUFFER);
#endif
return 0;
MPL_STATIC_INLINE_PREFIX void *MPIDI_OFI_malloc_pack_buffer(MPIR_Request * req, MPI_Aint pack_size)
{
void *pack_buf;
bool is_genq;
if (pack_size <= MPIR_CVAR_CH4_OFI_PIPELINE_CHUNK_SZ) {
int vci = MPIR_REQUEST_POOL_FROM_HANDLE(req->handle);
MPIDU_genq_private_pool_alloc_cell(MPIDI_OFI_global.per_vci[vci].pipeline_pool, &pack_buf);
is_genq = true;
}
if (!pack_buf) {
pack_buf = MPL_aligned_alloc(64, pack_size, MPL_MEM_OTHER);
is_genq = false;
}
if (pack_buf) {
MPIDI_OFI_REQUEST(req, noncontig.pack.pack_buffer) = pack_buf;
MPIDI_OFI_REQUEST(req, noncontig.pack.is_genq) = is_genq;
}
return pack_buf;
}

MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_gpu_free_pack_buffer(void *ptr)
MPL_STATIC_INLINE_PREFIX void MPIDI_OFI_free_pack_buffer(MPIR_Request * req)
{
if (MPIDI_OFI_ENABLE_HMEM) {
return MPL_gpu_free_host(ptr);
} else {
MPL_free(ptr);
return 0;
if (MPIDI_OFI_REQUEST(req, noncontig.pack.pack_buffer)) {
if (MPIDI_OFI_REQUEST(req, noncontig.pack.is_genq)) {
int vci = MPIR_REQUEST_POOL_FROM_HANDLE(req->handle);
MPIDU_genq_private_pool_free_cell(MPIDI_OFI_global.per_vci[vci].pipeline_pool,
MPIDI_OFI_REQUEST(req, noncontig.pack.pack_buffer));
} else {
MPL_free(MPIDI_OFI_REQUEST(req, noncontig.pack.pack_buffer));
}
}
}

Expand Down
1 change: 1 addition & 0 deletions src/mpid/ch4/netmod/ofi/ofi_pre.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ typedef struct {
enum MPIDI_OFI_req_kind kind;
union {
struct {
bool is_genq;
char *pack_buffer;
} pack;
struct {
Expand Down
4 changes: 1 addition & 3 deletions src/mpid/ch4/netmod/ofi/ofi_recv.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,9 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_do_irecv(void *buf,

/* Unpack */
MPIDI_OFI_REQUEST(rreq, event_id) = MPIDI_OFI_EVENT_RECV_PACK;
MPIDI_OFI_REQUEST(rreq, noncontig.pack.pack_buffer) =
MPL_aligned_alloc(64, data_sz, MPL_MEM_OTHER);
recv_buf = MPIDI_OFI_malloc_pack_buffer(rreq, data_sz);
MPIR_ERR_CHKANDJUMP1(MPIDI_OFI_REQUEST(rreq, noncontig.pack.pack_buffer) == NULL, mpi_errno,
MPI_ERR_OTHER, "**nomem", "**nomem %s", "Recv Pack Buffer alloc");
recv_buf = MPIDI_OFI_REQUEST(rreq, noncontig.pack.pack_buffer);
} else {
MPIDI_OFI_REQUEST(rreq, noncontig.pack.pack_buffer) = NULL;
}
Expand Down
2 changes: 1 addition & 1 deletion src/mpid/ch4/netmod/ofi/ofi_rndv.c
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ int MPIDI_OFI_recv_rndv_event(int vci, struct fi_cq_tagged_entry *wc, MPIR_Reque
/* if we were expecting an eager send, free the unneeded pack_buffer or iovs array */
switch (MPIDI_OFI_REQUEST(rreq, event_id)) {
case MPIDI_OFI_EVENT_RECV_PACK:
MPL_free(MPIDI_OFI_REQUEST(rreq, noncontig.pack.pack_buffer));
MPIDI_OFI_free_pack_buffer(rreq);
break;
case MPIDI_OFI_EVENT_RECV_NOPACK:
MPL_free(MPIDI_OFI_REQUEST(rreq, noncontig.nopack.iovs));
Expand Down
3 changes: 1 addition & 2 deletions src/mpid/ch4/netmod/ofi/ofi_send.h
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_send(const void *buf, MPI_Aint count, MPI

void *data = NULL;
if (need_pack) {
void *pack_buf = MPL_aligned_alloc(64, data_sz, MPL_MEM_OTHER);
void *pack_buf = MPIDI_OFI_malloc_pack_buffer(sreq, data_sz);
MPIR_ERR_CHKANDJUMP1(pack_buf == NULL, mpi_errno,
MPI_ERR_OTHER, "**nomem", "**nomem %s", "Send Pack buffer alloc");

Expand All @@ -475,7 +475,6 @@ MPL_STATIC_INLINE_PREFIX int MPIDI_OFI_send(const void *buf, MPI_Aint count, MPI
MPIR_ERR_CHECK(mpi_errno);

data = pack_buf;
MPIDI_OFI_REQUEST(sreq, noncontig.pack.pack_buffer) = pack_buf;
} else {
data = MPIR_get_contig_ptr(buf, dt_true_lb);
MPIDI_OFI_REQUEST(sreq, noncontig.pack.pack_buffer) = NULL;
Expand Down