Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion common/configdb.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,11 @@ class ConfigDBConnector_Native : public SonicV2Connector_Native
self.pubsub = self.get_redis_client(self.db_name).pubsub()
self.pubsub.psubscribe("__keyspace@{}__:*".format(self.get_dbid(self.db_name)))
while True:
item = self.pubsub.listen_message()
item = self.pubsub.listen_message(interrupt_on_signal=True)
if 'type' not in item:
# When timeout or interrupted, item will not contains 'type'
continue

if item['type'] == 'pmessage':
key = item['channel'].split(':', 1)[1]
try:
Expand Down
34 changes: 23 additions & 11 deletions common/pubsub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,22 +77,31 @@ bool PubSub::hasCachedData()
return m_keyspace_event_buffer.size() > 1;
}

map<string, string> PubSub::get_message(double timeout)
map<string, string> PubSub::get_message(double timeout, bool interrupt_on_signal)
{
map<string, string> ret;
return get_message_internal(timeout, interrupt_on_signal).second;
}

MessageResultPair PubSub::get_message_internal(double timeout, bool interrupt_on_signal)
{
MessageResultPair ret;

if (!m_subscribe)
{
ret.first = Select::ERROR;
return ret;
}

Selectable *selected;
int rc = m_select.select(&selected, int(timeout));
int rc = m_select.select(&selected, int(timeout), interrupt_on_signal);
ret.first = rc;
switch (rc)
{
case Select::ERROR:
throw RedisError("Failed to select", m_subscribe->getContext());

case Select::TIMEOUT:
case Select::SIGNALINT:
return ret;

case Select::OBJECT:
Expand All @@ -110,26 +119,29 @@ map<string, string> PubSub::get_message(double timeout)
}

auto message = event->getReply<RedisMessage>();
ret["type"] = message.type;
ret["pattern"] = message.pattern;
ret["channel"] = message.channel;
ret["data"] = message.data;
ret.second["type"] = message.type;
ret.second["pattern"] = message.pattern;
ret.second["channel"] = message.channel;
ret.second["data"] = message.data;
return ret;
}

// Note: it is not straightforward to implement redis-py PubSub.listen() directly in c++
// due to the `yield` syntax, so we implement this function for blocking listen one message
std::map<std::string, std::string> PubSub::listen_message()
std::map<std::string, std::string> PubSub::listen_message(bool interrupt_on_signal)
{
const double GET_MESSAGE_INTERVAL = 600.0; // in seconds
MessageResultPair ret;
for (;;)
{
auto ret = get_message(GET_MESSAGE_INTERVAL);
if (!ret.empty())
ret = get_message_internal(GET_MESSAGE_INTERVAL, interrupt_on_signal);
if (!ret.second.empty() || ret.first == Select::SIGNALINT)
{
return ret;
break;
}
}

return ret.second;
}

shared_ptr<RedisReply> PubSub::popEventBuffer()
Expand Down
8 changes: 6 additions & 2 deletions common/pubsub.h
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
#pragma once
#include <map>
#include <deque>
#include <utility>

#include "dbconnector.h"
#include "select.h"
#include "redisselect.h"

namespace swss {

typedef std::pair<int, std::map<std::string, std::string> > MessageResultPair;

// This class is to emulate python redis-py class PubSub
// After SWIG wrapping, it should be used in the same way
class PubSub : protected RedisSelect
{
public:
explicit PubSub(DBConnector *other);

std::map<std::string, std::string> get_message(double timeout = 0.0);
std::map<std::string, std::string> listen_message();
std::map<std::string, std::string> get_message(double timeout = 0.0, bool interrupt_on_signal = false);
std::map<std::string, std::string> listen_message(bool interrupt_on_signal = false);

void psubscribe(const std::string &pattern);
void punsubscribe(const std::string &pattern);
Expand All @@ -29,6 +32,7 @@ class PubSub : protected RedisSelect
private:
/* Pop keyspace event from event buffer. Caller should free resources. */
std::shared_ptr<RedisReply> popEventBuffer();
MessageResultPair get_message_internal(double timeout = 0.0, bool interrupt_on_signal = false);

DBConnector *m_parentConnector;
Select m_select;
Expand Down
28 changes: 23 additions & 5 deletions common/select.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <unistd.h>
#include <string.h>


using namespace std;

namespace swss {
Expand Down Expand Up @@ -87,20 +88,34 @@ void Select::addSelectables(vector<Selectable *> selectables)
}
}

int Select::poll_descriptors(Selectable **c, unsigned int timeout)
int Select::poll_descriptors(Selectable **c, unsigned int timeout, bool interrupt_on_signal = false)
{
int sz_selectables = static_cast<int>(m_objects.size());
std::vector<struct epoll_event> events(sz_selectables);
int ret;

do
while(true)
{
ret = ::epoll_wait(m_epoll_fd, events.data(), sz_selectables, timeout);
// on signal interrupt check if we need to return
if (ret == -1 && errno == EINTR)
{
if (interrupt_on_signal)
{
return Select::SIGNALINT;
}
}
// on all other errors break the loop
else
{
break;
}
}
while(ret == -1 && errno == EINTR); // Retry the select if the process was interrupted by a signal

if (ret < 0)
{
return Select::ERROR;
}

for (int i = 0; i < ret; ++i)
{
Expand Down Expand Up @@ -148,7 +163,7 @@ int Select::poll_descriptors(Selectable **c, unsigned int timeout)
return Select::TIMEOUT;
}

int Select::select(Selectable **c, int timeout)
int Select::select(Selectable **c, int timeout, bool interrupt_on_signal)
{
SWSS_LOG_ENTER();

Expand All @@ -164,7 +179,7 @@ int Select::select(Selectable **c, int timeout)
return ret;

/* wait for data */
ret = poll_descriptors(c, timeout);
ret = poll_descriptors(c, timeout, interrupt_on_signal);

return ret;

Expand All @@ -190,6 +205,9 @@ std::string Select::resultToString(int result)
case swss::Select::TIMEOUT:
return "TIMEOUT";

case swss::Select::SIGNALINT:
return "SIGNALINT";

default:
SWSS_LOG_WARN("unknown select result: %d", result);
return "UNKNOWN";
Expand Down
5 changes: 3 additions & 2 deletions common/select.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ class Select
OBJECT = 0,
ERROR = 1,
TIMEOUT = 2,
SIGNALINT = 3,// Read operation interrupted by a signal
};

int select(Selectable **c, int timeout = -1);
int select(Selectable **c, int timeout = -1, bool interrupt_on_signal = false);
bool isQueueEmpty();

/**
Expand Down Expand Up @@ -65,7 +66,7 @@ class Select
}
};

int poll_descriptors(Selectable **c, unsigned int timeout);
int poll_descriptors(Selectable **c, unsigned int timeout, bool interrupt_on_signal);

int m_epoll_fd;
std::unordered_map<int, Selectable *> m_objects;
Expand Down
34 changes: 34 additions & 0 deletions tests/test_signalhandler_ut.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import signal
import os
import pytest
import multiprocessing
import time
from swsscommon import swsscommon

def test_config_db_listen_while_signal_received():
""" Test performs ConfigDBConnector.listen() while signal is received,
checks that the listen() call is interrupted and the regular KeyboardInterrupt is raised.
"""
c=swsscommon.ConfigDBConnector()
c.subscribe('A', lambda a: None)
c.connect(wait_for_init=False)
event = multiprocessing.Event()

def signal_handler(signum, frame):
event.set()
sys.exit(0)

signal.signal(signal.SIGUSR1, signal_handler)

def listen():
c.listen()

thr = multiprocessing.Process(target=listen)
thr.start()

time.sleep(5)
os.kill(thr.pid, signal.SIGUSR1)

thr.join()

assert event.is_set()