Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions src/rdma.c
Original file line number Diff line number Diff line change
Expand Up @@ -143,12 +143,34 @@ static void serverRdmaError(char *err, const char *fmt, ...) {
va_end(ap);
}

static inline int connRdmaAllowCommand(void) {
/* RDMA MR is not accessible in a child process, avoid segment fault due to
* invalid MR access, close it rather than server random crash */
if (server.in_fork_child != CHILD_TYPE_NONE) {
return C_ERR;
}

return C_OK;
}

static inline int connRdmaAllowRW(connection *conn) {
if (conn->state == CONN_STATE_ERROR || conn->state == CONN_STATE_CLOSED) {
return C_ERR;
}

return connRdmaAllowCommand();
}

static int rdmaPostRecv(RdmaContext *ctx, struct rdma_cm_id *cm_id, ValkeyRdmaCmd *cmd) {
struct ibv_sge sge;
size_t length = sizeof(ValkeyRdmaCmd);
struct ibv_recv_wr recv_wr, *bad_wr;
int ret;

if (connRdmaAllowCommand()) {
return C_ERR;
}

sge.addr = (uint64_t)cmd;
sge.length = length;
sge.lkey = ctx->cmd_mr->lkey;
Expand Down Expand Up @@ -1214,6 +1236,10 @@ static size_t connRdmaSend(connection *conn, const void *data, size_t data_len)
char *remote_addr = ctx->tx_addr + ctx->tx.offset;
int ret;

if (connRdmaAllowCommand()) {
return C_ERR;
}

memcpy(addr, data, data_len);

sge.addr = (uint64_t)addr;
Expand Down Expand Up @@ -1247,7 +1273,7 @@ static int connRdmaWrite(connection *conn, const void *data, size_t data_len) {
RdmaContext *ctx = cm_id->context;
uint32_t towrite;

if (conn->state == CONN_STATE_ERROR || conn->state == CONN_STATE_CLOSED) {
if (connRdmaAllowRW(conn)) {
return C_ERR;
}

Expand Down Expand Up @@ -1290,7 +1316,7 @@ static int connRdmaRead(connection *conn, void *buf, size_t buf_len) {
struct rdma_cm_id *cm_id = rdma_conn->cm_id;
RdmaContext *ctx = cm_id->context;

if (conn->state == CONN_STATE_ERROR || conn->state == CONN_STATE_CLOSED) {
if (connRdmaAllowRW(conn)) {
return C_ERR;
}

Expand All @@ -1312,7 +1338,7 @@ static ssize_t connRdmaSyncWrite(connection *conn, char *ptr, ssize_t size, long
long long start = mstime();
uint32_t towrite;

if (conn->state == CONN_STATE_ERROR || conn->state == CONN_STATE_CLOSED) {
if (connRdmaAllowRW(conn)) {
return C_ERR;
}

Expand Down Expand Up @@ -1355,7 +1381,7 @@ static ssize_t connRdmaSyncRead(connection *conn, char *ptr, ssize_t size, long
long long start = mstime();
uint32_t toread;

if (conn->state == CONN_STATE_ERROR || conn->state == CONN_STATE_CLOSED) {
if (connRdmaAllowRW(conn)) {
return C_ERR;
}

Expand Down Expand Up @@ -1390,7 +1416,7 @@ static ssize_t connRdmaSyncReadLine(connection *conn, char *ptr, ssize_t size, l
char *c;
char nl = 0;

if (conn->state == CONN_STATE_ERROR || conn->state == CONN_STATE_CLOSED) {
if (connRdmaAllowRW(conn)) {
return C_ERR;
}

Expand Down
Loading