diff --git a/common/zmqserver.cpp b/common/zmqserver.cpp index a6383866a..353ac509d 100644 --- a/common/zmqserver.cpp +++ b/common/zmqserver.cpp @@ -13,18 +13,26 @@ using namespace std; namespace swss { ZmqServer::ZmqServer(const std::string& endpoint) - : ZmqServer(endpoint, "") + : ZmqServer(endpoint, "", false) { } ZmqServer::ZmqServer(const std::string& endpoint, const std::string& vrf) - : m_endpoint(endpoint), - m_vrf(vrf) + : ZmqServer(endpoint, vrf, false) { - connect(); - m_buffer.resize(MQ_RESPONSE_MAX_COUNT); - m_runThread = true; - m_mqPollThread = std::make_shared(&ZmqServer::mqPollThread, this); +} + +ZmqServer::ZmqServer(const std::string& endpoint, const std::string& vrf, bool lazyBind) + : m_mqPollThread(nullptr), + m_endpoint(endpoint), + m_vrf(vrf), + m_context(nullptr), + m_socket(nullptr) +{ + if (!lazyBind) + { + bind(); + } SWSS_LOG_DEBUG("ZmqServer ctor endpoint: %s", endpoint.c_str()); } @@ -32,15 +40,30 @@ ZmqServer::ZmqServer(const std::string& endpoint, const std::string& vrf) ZmqServer::~ZmqServer() { m_runThread = false; - m_mqPollThread->join(); + if (m_mqPollThread) + { + m_mqPollThread->join(); + } + + if (m_socket) + { + zmq_close(m_socket); + } - zmq_close(m_socket); - zmq_ctx_destroy(m_context); + if (m_context) + { + zmq_ctx_destroy(m_context); + } } -void ZmqServer::connect() +void ZmqServer::bind() { SWSS_LOG_ENTER(); + if (m_socket) + { + SWSS_LOG_THROW("ZmqServer has already been bound to the endpoint: %s", m_endpoint.c_str()); + } + m_context = zmq_ctx_new(); m_socket = zmq_socket(m_context, ZMQ_PULL); @@ -60,6 +83,10 @@ void ZmqServer::connect() m_endpoint.c_str(), zmq_errno()); } + + SWSS_LOG_DEBUG("ZmqServer bind to endpoint: %s", m_endpoint.c_str()); + + startMqPollThread(); } void ZmqServer::registerMessageHandler( @@ -114,6 +141,13 @@ void ZmqServer::handleReceivedData(const char* buffer, const size_t size) handler->handleReceivedData(kcos); } +void ZmqServer::startMqPollThread() +{ + m_buffer.resize(MQ_RESPONSE_MAX_COUNT); + m_runThread = true; + m_mqPollThread = std::make_shared(&ZmqServer::mqPollThread, this); +} + void ZmqServer::mqPollThread() { SWSS_LOG_ENTER(); diff --git a/common/zmqserver.h b/common/zmqserver.h index 2b8d1bac0..9a65e0212 100644 --- a/common/zmqserver.h +++ b/common/zmqserver.h @@ -32,6 +32,7 @@ class ZmqServer ZmqServer(const std::string& endpoint); ZmqServer(const std::string& endpoint, const std::string& vrf); + ZmqServer(const std::string& endpoint, const std::string& vrf, bool lazyBind); ~ZmqServer(); void registerMessageHandler( @@ -42,12 +43,13 @@ class ZmqServer void sendMsg(const std::string& dbName, const std::string& tableName, const std::vector& values); -private: - - void connect(); + void bind(); +private: void handleReceivedData(const char* buffer, const size_t size); + void startMqPollThread(); + void mqPollThread(); ZmqMessageHandler* findMessageHandler(const std::string dbName, const std::string tableName); diff --git a/tests/zmq_state_ut.cpp b/tests/zmq_state_ut.cpp index 2b0b60d73..6b2d9a215 100644 --- a/tests/zmq_state_ut.cpp +++ b/tests/zmq_state_ut.cpp @@ -569,3 +569,46 @@ TEST(ZmqWithResponseClientError, test) // Wait will timeout without server reply. EXPECT_FALSE(p.wait(dbName, tableName, kcosPtr)); } + +TEST(ZmqServerLazzyBind, test) +{ + std::string testTableName = "ZMQ_PROD_CONS_UT"; + std::string pushEndpoint = "tcp://localhost:1234"; + std::string pullEndpoint = "tcp://*:1234"; + DBConnector db(TEST_DB, 0, true); + ZmqClient client(pushEndpoint, 3000); + ZmqProducerStateTable p(&db, testTableName, client, true); + std::vector kcos; + auto testKey = "testkey"; + kcos.push_back(KeyOpFieldsValuesTuple{testKey, SET_COMMAND, std::vector{}}); + std::vector> kcosPtr; + p.send(kcos); + + // initialize ZMQ server with lazzy bind + DBConnector server_db(TEST_DB, 0, true); + ZmqServer server(pullEndpoint, "", true); + ZmqConsumerStateTable c(&db, testTableName, server, 128, 0, false); + server.bind(); + + std::deque vkco; + int received = 0; + while (received < 1) + { + c.pops(vkco); + while (!vkco.empty()) + { + auto &kco = vkco.front(); + auto key = kfvKey(kco); + auto op = kfvOp(kco); + auto fvs = kfvFieldsValues(kco); + + EXPECT_EQ(key, testKey); + + received += 1; + vkco.pop_front(); + } + } + + EXPECT_EQ(received, 1); +} +