Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ typedef enum {
CONN_STATE_ERROR
} ConnectionState;

#define CONN_FLAG_CLOSE_SCHEDULED (1 << 0) /* Closed scheduled by a handler */
#define CONN_FLAG_WRITE_BARRIER (1 << 1) /* Write barrier requested */
#define CONN_FLAG_CLOSE_SCHEDULED (1 << 0) /* Closed scheduled by a handler */
#define CONN_FLAG_WRITE_BARRIER (1 << 1) /* Write barrier requested */
#define CONN_FLAG_ALLOW_ACCEPT_OFFLOAD (1 << 2) /* Connection accept can be offloaded to IO threads. */

#define CONN_TYPE_SOCKET "tcp"
#define CONN_TYPE_UNIX "unix"
Expand Down
52 changes: 52 additions & 0 deletions src/io_threads.c
Original file line number Diff line number Diff line change
Expand Up @@ -554,3 +554,55 @@ void trySendPollJobToIOThreads(void) {
aeSetPollProtect(server.el, 1);
IOJobQueue_push(jq, IOThreadPoll, server.el);
}

static void ioThreadAccept(void *data) {
client *c = (client *)data;
connAccept(c->conn, NULL);
c->io_read_state = CLIENT_COMPLETED_IO;
}

/*
* Attempts to offload an Accept operation (currently used for TLS accept) for a client
* connection to I/O threads.
*
* Returns:
* C_OK - If the accept operation was successfully queued for processing
* C_ERR - If the connection is not eligible for offloading
*
* Parameters:
* conn - The connection object to perform the accept operation on
*/
int trySendAcceptToIOThreads(connection *conn) {
if (server.io_threads_num <= 1) {
return C_ERR;
}

if (!(conn->flags & CONN_FLAG_ALLOW_ACCEPT_OFFLOAD)) {
return C_ERR;
}

client *c = connGetPrivateData(conn);
if (c->io_read_state != CLIENT_IDLE) {
return C_OK;
}

if (server.active_io_threads_num <= 1) {
return C_ERR;
}

size_t thread_id = (c->id % (server.active_io_threads_num - 1)) + 1;
IOJobQueue *job_queue = &io_jobs[thread_id];

if (IOJobQueue_isFull(job_queue)) {
return C_ERR;
}

c->io_read_state = CLIENT_PENDING_IO;
c->flag.pending_read = 1;
listLinkNodeTail(server.clients_pending_io_read, &c->pending_read_list_node);
connSetPostponeUpdateState(c->conn, 1);
server.stat_io_accept_offloaded++;
IOJobQueue_push(job_queue, ioThreadAccept, c);

return C_OK;
}
1 change: 1 addition & 0 deletions src/io_threads.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ int tryOffloadFreeArgvToIOThreads(client *c);
void adjustIOThreadsByEventLoad(int numevents, int increase_only);
void drainIOThreadsQueue(void);
void trySendPollJobToIOThreads(void);
int trySendAcceptToIOThreads(connection *conn);

#endif /* IO_THREADS_H */
6 changes: 6 additions & 0 deletions src/networking.c
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ client *createClient(connection *conn) {
if (server.tcpkeepalive) connKeepAlive(conn, server.tcpkeepalive);
connSetReadHandler(conn, readQueryFromClient);
connSetPrivateData(conn, c);
conn->flags |= CONN_FLAG_ALLOW_ACCEPT_OFFLOAD;
}
c->buf = zmalloc_usable(PROTO_REPLY_CHUNK_BYTES, &c->buf_usable_size);
selectDb(c, 0);
Expand Down Expand Up @@ -4722,9 +4723,14 @@ int processIOThreadsReadDone(void) {
processed++;
server.stat_io_reads_processed++;

/* Save the current conn state, as connUpdateState may modify it */
int in_accept_state = (connGetState(c->conn) == CONN_STATE_ACCEPTING);
connSetPostponeUpdateState(c->conn, 0);
connUpdateState(c->conn);

/* In accept state, no client's data was read - stop here. */
if (in_accept_state) continue;

/* On read error - stop here. */
if (handleReadResult(c) == C_ERR) {
continue;
Expand Down
2 changes: 2 additions & 0 deletions src/server.c
Original file line number Diff line number Diff line change
Expand Up @@ -2604,6 +2604,7 @@ void resetServerStats(void) {
server.stat_total_reads_processed = 0;
server.stat_io_writes_processed = 0;
server.stat_io_freed_objects = 0;
server.stat_io_accept_offloaded = 0;
server.stat_poll_processed_by_io_threads = 0;
server.stat_total_writes_processed = 0;
server.stat_client_qbuf_limit_disconnections = 0;
Expand Down Expand Up @@ -5862,6 +5863,7 @@ sds genValkeyInfoString(dict *section_dict, int all_sections, int everything) {
"io_threaded_reads_processed:%lld\r\n", server.stat_io_reads_processed,
"io_threaded_writes_processed:%lld\r\n", server.stat_io_writes_processed,
"io_threaded_freed_objects:%lld\r\n", server.stat_io_freed_objects,
"io_threaded_accept:%lld\r\n", server.stat_io_accept_offloaded,
"io_threaded_poll_processed:%lld\r\n", server.stat_poll_processed_by_io_threads,
"io_threaded_total_prefetch_batches:%lld\r\n", server.stat_total_prefetch_batches,
"io_threaded_total_prefetch_entries:%lld\r\n", server.stat_total_prefetch_entries,
Expand Down
1 change: 1 addition & 0 deletions src/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -1841,6 +1841,7 @@ struct valkeyServer {
long long stat_io_reads_processed; /* Number of read events processed by IO threads */
long long stat_io_writes_processed; /* Number of write events processed by IO threads */
long long stat_io_freed_objects; /* Number of objects freed by IO threads */
long long stat_io_accept_offloaded; /* Number of offloaded accepts */
long long stat_poll_processed_by_io_threads; /* Total number of poll jobs processed by IO */
long long stat_total_reads_processed; /* Total number of read events processed */
long long stat_total_writes_processed; /* Total number of write events processed */
Expand Down
141 changes: 72 additions & 69 deletions src/tls.c
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "server.h"
#include "connhelpers.h"
#include "adlist.h"
#include "io_threads.h"

#if (USE_OPENSSL == 1 /* BUILD_YES */) || ((USE_OPENSSL == 2 /* BUILD_MODULE */) && (BUILD_TLS_MODULE == 2))

Expand Down Expand Up @@ -437,16 +438,13 @@ static ConnectionType CT_TLS;
*
*/

typedef enum {
WANT_READ = 1,
WANT_WRITE
} WantIOType;

#define TLS_CONN_FLAG_READ_WANT_WRITE (1 << 0)
#define TLS_CONN_FLAG_WRITE_WANT_READ (1 << 1)
#define TLS_CONN_FLAG_FD_SET (1 << 2)
#define TLS_CONN_FLAG_POSTPONE_UPDATE_STATE (1 << 3)
#define TLS_CONN_FLAG_HAS_PENDING (1 << 4)
#define TLS_CONN_FLAG_ACCEPT_ERROR (1 << 5)
#define TLS_CONN_FLAG_ACCEPT_SUCCESS (1 << 6)

typedef struct tls_connection {
connection c;
Expand Down Expand Up @@ -514,20 +512,26 @@ static connection *connCreateAcceptedTLS(int fd, void *priv) {
return (connection *)conn;
}

static int connTLSAccept(connection *_conn, ConnectionCallbackFunc accept_handler);
static void tlsEventHandler(struct aeEventLoop *el, int fd, void *clientData, int mask);
static void updateSSLEvent(tls_connection *conn);

static void clearTLSWantFlags(tls_connection *conn) {
conn->flags &= ~(TLS_CONN_FLAG_WRITE_WANT_READ | TLS_CONN_FLAG_READ_WANT_WRITE);
}

/* Process the return code received from OpenSSL>
* Update the want parameter with expected I/O.
* Update the conn flags with the WANT_READ/WANT_WRITE flags.
* Update the connection's error state if a real error has occurred.
* Returns an SSL error code, or 0 if no further handling is required.
*/
static int handleSSLReturnCode(tls_connection *conn, int ret_value, WantIOType *want) {
static int handleSSLReturnCode(tls_connection *conn, int ret_value) {
clearTLSWantFlags(conn);
if (ret_value <= 0) {
int ssl_err = SSL_get_error(conn->ssl, ret_value);
switch (ssl_err) {
case SSL_ERROR_WANT_WRITE: *want = WANT_WRITE; return 0;
case SSL_ERROR_WANT_READ: *want = WANT_READ; return 0;
case SSL_ERROR_WANT_WRITE: conn->flags |= TLS_CONN_FLAG_READ_WANT_WRITE; return 0;
case SSL_ERROR_WANT_READ: conn->flags |= TLS_CONN_FLAG_WRITE_WANT_READ; return 0;
case SSL_ERROR_SYSCALL:
conn->c.last_errno = errno;
if (conn->ssl_error) zfree(conn->ssl_error);
Expand Down Expand Up @@ -563,11 +567,8 @@ static int updateStateAfterSSLIO(tls_connection *conn, int ret_value, int update
}

if (ret_value <= 0) {
WantIOType want = 0;
int ssl_err;
if (!(ssl_err = handleSSLReturnCode(conn, ret_value, &want))) {
if (want == WANT_READ) conn->flags |= TLS_CONN_FLAG_WRITE_WANT_READ;
if (want == WANT_WRITE) conn->flags |= TLS_CONN_FLAG_READ_WANT_WRITE;
if (!(ssl_err = handleSSLReturnCode(conn, ret_value))) {
if (update_event) updateSSLEvent(conn);
errno = EAGAIN;
return -1;
Expand All @@ -585,19 +586,17 @@ static int updateStateAfterSSLIO(tls_connection *conn, int ret_value, int update
return ret_value;
}

static void registerSSLEvent(tls_connection *conn, WantIOType want) {
static void registerSSLEvent(tls_connection *conn) {
int mask = aeGetFileEvents(server.el, conn->c.fd);

switch (want) {
case WANT_READ:
if (conn->flags & TLS_CONN_FLAG_WRITE_WANT_READ) {
if (mask & AE_WRITABLE) aeDeleteFileEvent(server.el, conn->c.fd, AE_WRITABLE);
if (!(mask & AE_READABLE)) aeCreateFileEvent(server.el, conn->c.fd, AE_READABLE, tlsEventHandler, conn);
break;
case WANT_WRITE:
} else if (conn->flags & TLS_CONN_FLAG_READ_WANT_WRITE) {
if (mask & AE_READABLE) aeDeleteFileEvent(server.el, conn->c.fd, AE_READABLE);
if (!(mask & AE_WRITABLE)) aeCreateFileEvent(server.el, conn->c.fd, AE_WRITABLE, tlsEventHandler, conn);
break;
default: serverAssert(0); break;
} else {
serverAssert(0);
}
}

Expand Down Expand Up @@ -650,12 +649,47 @@ static void updateSSLEvent(tls_connection *conn) {
if (!need_write && (mask & AE_WRITABLE)) aeDeleteFileEvent(server.el, conn->c.fd, AE_WRITABLE);
}

static int TLSHandleAcceptResult(tls_connection *conn, int call_handler_on_error) {
if (conn->flags & TLS_CONN_FLAG_ACCEPT_SUCCESS) {
conn->c.state = CONN_STATE_CONNECTED;
} else if (conn->flags & TLS_CONN_FLAG_ACCEPT_ERROR) {
conn->c.state = CONN_STATE_ERROR;
if (!call_handler_on_error) return C_ERR;
} else {
/* Still pending accept */
registerSSLEvent(conn);
return C_OK;
}

/* call accept handler */
if (!callHandler((connection *)conn, conn->c.conn_handler)) return C_ERR;
conn->c.conn_handler = NULL;
return C_OK;
}

static void updateSSLState(connection *conn_) {
tls_connection *conn = (tls_connection *)conn_;

if (conn->c.state == CONN_STATE_ACCEPTING) {
TLSHandleAcceptResult(conn, 1);
return;
}

updateSSLEvent(conn);
updatePendingData(conn);
}

static void TLSAccept(void *_conn) {
tls_connection *conn = (tls_connection *)_conn;
ERR_clear_error();
int ret = SSL_accept(conn->ssl);
if (ret > 0) {
conn->flags |= TLS_CONN_FLAG_ACCEPT_SUCCESS;
} else if (handleSSLReturnCode(conn, ret)) {
conn->flags |= TLS_CONN_FLAG_ACCEPT_ERROR;
}
}

static void tlsHandleEvent(tls_connection *conn, int mask) {
int ret, conn_error;

Expand All @@ -676,10 +710,8 @@ static void tlsHandleEvent(tls_connection *conn, int mask) {
}
ret = SSL_connect(conn->ssl);
if (ret <= 0) {
WantIOType want = 0;
if (!handleSSLReturnCode(conn, ret, &want)) {
registerSSLEvent(conn, want);

if (!handleSSLReturnCode(conn, ret)) {
registerSSLEvent(conn);
/* Avoid hitting UpdateSSLEvent, which knows nothing
* of what SSL_connect() wants and instead looks at our
* R/W handlers.
Expand All @@ -698,28 +730,8 @@ static void tlsHandleEvent(tls_connection *conn, int mask) {
conn->c.conn_handler = NULL;
break;
case CONN_STATE_ACCEPTING:
ERR_clear_error();
ret = SSL_accept(conn->ssl);
if (ret <= 0) {
WantIOType want = 0;
if (!handleSSLReturnCode(conn, ret, &want)) {
/* Avoid hitting UpdateSSLEvent, which knows nothing
* of what SSL_connect() wants and instead looks at our
* R/W handlers.
*/
registerSSLEvent(conn, want);
return;
}

/* If not handled, it's an error */
conn->c.state = CONN_STATE_ERROR;
} else {
conn->c.state = CONN_STATE_CONNECTED;
}

if (!callHandler((connection *)conn, conn->c.conn_handler)) return;
conn->c.conn_handler = NULL;
break;
connTLSAccept((connection *)conn, NULL);
return;
case CONN_STATE_CONNECTED: {
int call_read = ((mask & AE_READABLE) && conn->c.read_handler) ||
((mask & AE_WRITABLE) && (conn->flags & TLS_CONN_FLAG_READ_WANT_WRITE));
Expand All @@ -740,20 +752,17 @@ static void tlsHandleEvent(tls_connection *conn, int mask) {
int invert = conn->c.flags & CONN_FLAG_WRITE_BARRIER;

if (!invert && call_read) {
conn->flags &= ~TLS_CONN_FLAG_READ_WANT_WRITE;
if (!callHandler((connection *)conn, conn->c.read_handler)) return;
}

/* Fire the writable event. */
if (call_write) {
conn->flags &= ~TLS_CONN_FLAG_WRITE_WANT_READ;
if (!callHandler((connection *)conn, conn->c.write_handler)) return;
}

/* If we have to invert the call, fire the readable event now
* after the writable one. */
if (invert && call_read) {
conn->flags &= ~TLS_CONN_FLAG_READ_WANT_WRITE;
if (!callHandler((connection *)conn, conn->c.read_handler)) return;
}
updatePendingData(conn);
Expand Down Expand Up @@ -841,31 +850,25 @@ static void connTLSClose(connection *conn_) {

static int connTLSAccept(connection *_conn, ConnectionCallbackFunc accept_handler) {
tls_connection *conn = (tls_connection *)_conn;
int ret;

if (conn->c.state != CONN_STATE_ACCEPTING) return C_ERR;
ERR_clear_error();

int call_handler_on_error = 1;
/* Try to accept */
conn->c.conn_handler = accept_handler;
ret = SSL_accept(conn->ssl);

if (ret <= 0) {
WantIOType want = 0;
if (!handleSSLReturnCode(conn, ret, &want)) {
registerSSLEvent(conn, want); /* We'll fire back */
return C_OK;
} else {
conn->c.state = CONN_STATE_ERROR;
return C_ERR;
}
if (accept_handler) {
conn->c.conn_handler = accept_handler;
call_handler_on_error = 0;
}

conn->c.state = CONN_STATE_CONNECTED;
if (!callHandler((connection *)conn, conn->c.conn_handler)) return C_OK;
conn->c.conn_handler = NULL;
/* We're in IO thread - just call accept and return, the main thread will handle the rest */
if (!inMainThread()) {
TLSAccept(conn);
return C_OK;
}

return C_OK;
/* Try to offload accept to IO threads */
if (trySendAcceptToIOThreads(_conn) == C_OK) return C_OK;

TLSAccept(conn);
return TLSHandleAcceptResult(conn, call_handler_on_error);
}

static int connTLSConnect(connection *conn_,
Expand Down
Loading