diff --git a/CHANGELOG.md b/CHANGELOG.md index ce2b6a63b..a6b9f13ee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ # Confluent Python Client for Apache Kafka - CHANGELOG + +### Fixes + +- Fixed `Consumer.poll()`, `Consumer.consume()`, `Producer.poll()`, and `Producer.flush()` blocking indefinitely and not responding to Ctrl+C (KeyboardInterrupt) signals. The implementation now uses a "wakeable poll" pattern that breaks long blocking calls into smaller chunks (200ms) and periodically re-acquires the Python GIL to check for pending signals. This allows Ctrl+C to properly interrupt blocking operations. Fixes Issues [#209](https://github.com/confluentinc/confluent-kafka-python/issues/209) and [#807](https://github.com/confluentinc/confluent-kafka-python/issues/807). + + ## v2.12.1 - 2025-10-21 v2.12.1 is a maintenance release with the following fixes: diff --git a/src/confluent_kafka/src/Consumer.c b/src/confluent_kafka/src/Consumer.c index 2a6ec2433..8773bfe48 100644 --- a/src/confluent_kafka/src/Consumer.c +++ b/src/confluent_kafka/src/Consumer.c @@ -105,7 +105,6 @@ static int Consumer_traverse(Handle *self, visitproc visit, void *arg) { } - static PyObject * Consumer_subscribe(Handle *self, PyObject *args, PyObject *kwargs) { @@ -958,13 +957,36 @@ Consumer_offsets_for_times(Handle *self, PyObject *args, PyObject *kwargs) { #endif } - +/** + * @brief Poll for a single message from the subscribed topics. + * + * Instead of a single blocking call to rd_kafka_consumer_poll() with the + * full timeout, this function: + * 1. Splits the timeout into 200ms chunks + * 2. Calls rd_kafka_consumer_poll() with chunk timeout + * 3. Between chunks, re-acquires GIL and calls PyErr_CheckSignals() + * 4. If signal detected, returns NULL (raises KeyboardInterrupt) + * 5. Continues until message received, timeout expired, or signal detected + * + * + * @param self Consumer handle + * @param args Positional arguments (unused) + * @param kwargs Keyword arguments: + * - timeout (float, optional): Timeout in seconds. + * Default: -1.0 (infinite timeout) + * @return PyObject* Message object, None if timeout, or NULL on error + * (raises KeyboardInterrupt if signal detected) + */ static PyObject *Consumer_poll(Handle *self, PyObject *args, PyObject *kwargs) { - double tmout = -1.0f; - static char *kws[] = {"timeout", NULL}; - rd_kafka_message_t *rkm; + double tmout = -1.0f; + static char *kws[] = {"timeout", NULL}; + rd_kafka_message_t *rkm = NULL; PyObject *msgobj; CallState cs; + const int CHUNK_TIMEOUT_MS = 200; /* 200ms chunks for signal checking */ + int total_timeout_ms; + int chunk_timeout_ms; + int chunk_count = 0; if (!self->rk) { PyErr_SetString(PyExc_RuntimeError, ERR_MSG_CONSUMER_CLOSED); @@ -974,16 +996,53 @@ static PyObject *Consumer_poll(Handle *self, PyObject *args, PyObject *kwargs) { if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|d", kws, &tmout)) return NULL; + total_timeout_ms = cfl_timeout_ms(tmout); + CallState_begin(self, &cs); - rkm = rd_kafka_consumer_poll(self->rk, cfl_timeout_ms(tmout)); + /* Skip wakeable poll pattern for non-blocking or very short timeouts. + * This avoids unnecessary GIL re-acquisition that can interfere with + * ThreadPool. Only use wakeable poll for + * blocking calls that need to be interruptible. */ + if (total_timeout_ms >= 0 && total_timeout_ms < CHUNK_TIMEOUT_MS) { + rkm = rd_kafka_consumer_poll(self->rk, total_timeout_ms); + } else { + while (1) { + /* Calculate timeout for this chunk */ + chunk_timeout_ms = calculate_chunk_timeout( + total_timeout_ms, chunk_count, CHUNK_TIMEOUT_MS); + if (chunk_timeout_ms == 0) { + /* Timeout expired */ + break; + } + + /* Poll with chunk timeout */ + rkm = + rd_kafka_consumer_poll(self->rk, chunk_timeout_ms); + /* If we got a message, exit the loop */ + if (rkm) { + break; + } + + chunk_count++; + + /* Check for signals between chunks */ + if (check_signals_between_chunks(self, &cs)) { + return NULL; + } + } + } + + /* Final GIL restore and signal check */ if (!CallState_end(self, &cs)) { - if (rkm) + if (rkm) { rd_kafka_message_destroy(rkm); + } return NULL; } + /* Handle the message */ if (!rkm) Py_RETURN_NONE; @@ -1024,7 +1083,27 @@ Consumer_memberid(Handle *self, PyObject *args, PyObject *kwargs) { return memberidobj; } - +/** + * @brief Consume a batch of messages from the subscribed topics. + * + * Instead of a single blocking call to rd_kafka_consume_batch_queue() with the + * full timeout, this function: + * 1. Splits the timeout into 200ms chunks + * 2. Calls rd_kafka_consume_batch_queue() with chunk timeout + * 3. Between chunks, re-acquires GIL and calls PyErr_CheckSignals() + * 4. If signal detected, returns NULL (raises KeyboardInterrupt) + * 5. Continues until messages received, timeout expired, or signal detected. + * + * @param self Consumer handle + * @param args Positional arguments (unused) + * @param kwargs Keyword arguments: + * - num_messages (int, optional): Maximum number of messages to + * consume per call. Default: 1. Maximum: 1000000. + * - timeout (float, optional): Timeout in seconds. + * Default: -1.0 (infinite timeout) + * @return PyObject* List of Message objects, empty list if timeout, or NULL on + * error (raises KeyboardInterrupt if signal detected) + */ static PyObject * Consumer_consume(Handle *self, PyObject *args, PyObject *kwargs) { unsigned int num_messages = 1; @@ -1034,7 +1113,11 @@ Consumer_consume(Handle *self, PyObject *args, PyObject *kwargs) { PyObject *msglist; rd_kafka_queue_t *rkqu = self->u.Consumer.rkqu; CallState cs; - Py_ssize_t i, n; + Py_ssize_t i, n = 0; + const int CHUNK_TIMEOUT_MS = 200; /* 200ms chunks for signal checking */ + int total_timeout_ms; + int chunk_timeout_ms; + int chunk_count = 0; if (!self->rk) { PyErr_SetString(PyExc_RuntimeError, ERR_MSG_CONSUMER_CLOSED); @@ -1052,13 +1135,74 @@ Consumer_consume(Handle *self, PyObject *args, PyObject *kwargs) { return NULL; } - CallState_begin(self, &cs); + total_timeout_ms = cfl_timeout_ms(tmout); rkmessages = malloc(num_messages * sizeof(rd_kafka_message_t *)); + if (!rkmessages) { + PyErr_NoMemory(); + return NULL; + } + + CallState_begin(self, &cs); + + /* Skip wakeable poll pattern for non-blocking or very short timeouts. + * This avoids unnecessary GIL re-acquisition that can interfere with + * ThreadPool. Only use wakeable poll for + * blocking calls that need to be interruptible. */ + if (total_timeout_ms >= 0 && total_timeout_ms < CHUNK_TIMEOUT_MS) { + n = (Py_ssize_t)rd_kafka_consume_batch_queue( + rkqu, total_timeout_ms, rkmessages, num_messages); + + if (n < 0) { + /* Error - need to restore GIL before setting error */ + PyEval_RestoreThread(cs.thread_state); + free(rkmessages); + cfl_PyErr_Format( + rd_kafka_last_error(), "%s", + rd_kafka_err2str(rd_kafka_last_error())); + return NULL; + } + } else { + while (1) { + /* Calculate timeout for this chunk */ + chunk_timeout_ms = calculate_chunk_timeout( + total_timeout_ms, chunk_count, CHUNK_TIMEOUT_MS); + if (chunk_timeout_ms == 0) { + /* Timeout expired */ + break; + } - n = (Py_ssize_t)rd_kafka_consume_batch_queue( - rkqu, cfl_timeout_ms(tmout), rkmessages, num_messages); + /* Consume with chunk timeout */ + n = (Py_ssize_t)rd_kafka_consume_batch_queue( + rkqu, chunk_timeout_ms, rkmessages, num_messages); + + if (n < 0) { + /* Error - need to restore GIL before setting + * error */ + PyEval_RestoreThread(cs.thread_state); + free(rkmessages); + cfl_PyErr_Format( + rd_kafka_last_error(), "%s", + rd_kafka_err2str(rd_kafka_last_error())); + return NULL; + } + + /* If we got messages, exit the loop */ + if (n > 0) { + break; + } + + chunk_count++; + + /* Check for signals between chunks */ + if (check_signals_between_chunks(self, &cs)) { + free(rkmessages); + return NULL; + } + } + } + /* Final GIL restore and signal check */ if (!CallState_end(self, &cs)) { for (i = 0; i < n; i++) { rd_kafka_message_destroy(rkmessages[i]); @@ -1067,13 +1211,7 @@ Consumer_consume(Handle *self, PyObject *args, PyObject *kwargs) { return NULL; } - if (n < 0) { - free(rkmessages); - cfl_PyErr_Format(rd_kafka_last_error(), "%s", - rd_kafka_err2str(rd_kafka_last_error())); - return NULL; - } - + /* Create Python list from messages */ msglist = PyList_New(n); for (i = 0; i < n; i++) { diff --git a/src/confluent_kafka/src/Producer.c b/src/confluent_kafka/src/Producer.c index b09bad47a..365ca7c94 100644 --- a/src/confluent_kafka/src/Producer.c +++ b/src/confluent_kafka/src/Producer.c @@ -338,17 +338,64 @@ Producer_produce(Handle *self, PyObject *args, PyObject *kwargs) { /** - * @brief Call rd_kafka_poll() and keep track of crashing callbacks. - * @returns -1 if callback crashed (or poll() failed), else the number - * of events served. + * @brief Poll for producer events with wakeable pattern for interruptibility. + * + * This function: + * 1. Splits the timeout into 200ms chunks + * 2. Calls rd_kafka_poll() with chunk timeout + * 3. Between chunks, re-acquires GIL and calls PyErr_CheckSignals() + * 4. If signal detected, returns -1 (raises KeyboardInterrupt) + * 5. Continues until events processed, timeout expired, or signal detected + * + * @param self Producer handle + * @param tmout Timeout in milliseconds (-1 for infinite) + * @returns -1 if callback crashed, signal detected, or poll() failed, else the + * number of events served. */ static int Producer_poll0(Handle *self, int tmout) { - int r; + int r = 0; CallState cs; + const int CHUNK_TIMEOUT_MS = 200; /* 200ms chunks for signal checking */ + int total_timeout_ms = tmout; + int chunk_timeout_ms; + int chunk_count = 0; CallState_begin(self, &cs); - r = rd_kafka_poll(self->rk, tmout); + /* Skip wakeable poll pattern for non-blocking or very short timeouts. + * This avoids unnecessary GIL re-acquisition that can interfere with + * ThreadPool. Only use wakeable poll for + * blocking calls that need to be interruptible. */ + if (total_timeout_ms >= 0 && total_timeout_ms < CHUNK_TIMEOUT_MS) { + r = rd_kafka_poll(self->rk, total_timeout_ms); + } else { + while (1) { + /* Calculate timeout for this chunk */ + chunk_timeout_ms = calculate_chunk_timeout( + total_timeout_ms, chunk_count, CHUNK_TIMEOUT_MS); + if (chunk_timeout_ms == 0) { + /* Timeout expired */ + break; + } + + /* Poll with chunk timeout */ + int chunk_result = + rd_kafka_poll(self->rk, chunk_timeout_ms); + /* Error from poll */ + if (chunk_result < 0) { + r = chunk_result; + break; + } + r += chunk_result; /* Accumulate events processed */ + + chunk_count++; + + /* Check for signals between chunks */ + if (check_signals_between_chunks(self, &cs)) { + return -1; /* Signal detected */ + } + } + } if (!CallState_end(self, &cs)) { return -1; @@ -379,6 +426,26 @@ static PyObject *Producer_poll(Handle *self, PyObject *args, PyObject *kwargs) { } +/** + * @brief Flush all messages in the producer queue with wakeable pattern for + * interruptibility. + * + * Instead of a single blocking call to rd_kafka_flush() with the + * full timeout, this function: + * 1. Splits the timeout into 200ms chunks + * 2. Calls rd_kafka_flush() with chunk timeout + * 3. Between chunks, re-acquires GIL and calls PyErr_CheckSignals() + * 4. If signal detected, returns NULL (raises KeyboardInterrupt) + * 5. Continues until all messages flushed, timeout expired, or signal detected + * + * @param self Producer handle + * @param args Positional arguments (unused) + * @param kwargs Keyword arguments: + * - timeout (float, optional): Timeout in seconds. + * Default: -1.0 (infinite timeout) + * @return PyObject* Number of messages remaining in queue, or NULL on error + * (raises KeyboardInterrupt if signal detected) + */ static PyObject * Producer_flush(Handle *self, PyObject *args, PyObject *kwargs) { double tmout = -1; @@ -386,6 +453,10 @@ Producer_flush(Handle *self, PyObject *args, PyObject *kwargs) { static char *kws[] = {"timeout", NULL}; rd_kafka_resp_err_t err; CallState cs; + const int CHUNK_TIMEOUT_MS = 200; /* 200ms chunks for signal checking */ + int total_timeout_ms; + int chunk_timeout_ms; + int chunk_count = 0; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|d", kws, &tmout)) return NULL; @@ -395,8 +466,56 @@ Producer_flush(Handle *self, PyObject *args, PyObject *kwargs) { return NULL; } + total_timeout_ms = cfl_timeout_ms(tmout); CallState_begin(self, &cs); - err = rd_kafka_flush(self->rk, cfl_timeout_ms(tmout)); + + /* Skip wakeable poll pattern for non-blocking or very short timeouts. + * This avoids unnecessary GIL re-acquisition that can interfere with + * ThreadPool. Only use wakeable poll for + * blocking calls that need to be interruptible. */ + if (total_timeout_ms >= 0 && total_timeout_ms < CHUNK_TIMEOUT_MS) { + err = rd_kafka_flush(self->rk, total_timeout_ms); + } else { + /* For infinite timeout, we need to keep looping and checking + * for signals. rd_kafka_flush() waits for messages that were in + * the queue when it's called. When flush() returns NO_ERROR, it + * means all messages that were queued at that point have been + * delivered. Note: Messages produced after flush() starts are + * not included in the current flush. */ + while (1) { + /* Calculate timeout for this chunk */ + chunk_timeout_ms = calculate_chunk_timeout( + total_timeout_ms, chunk_count, CHUNK_TIMEOUT_MS); + if (chunk_timeout_ms == 0) { + /* Timeout expired */ + err = RD_KAFKA_RESP_ERR__TIMED_OUT; + break; + } + + /* Flush with chunk timeout */ + err = rd_kafka_flush(self->rk, chunk_timeout_ms); + + /* Always check for signals between chunks (critical for + * interruptibility) */ + chunk_count++; + if (check_signals_between_chunks(self, &cs)) { + return NULL; /* Signal detected */ + } + + if (err == RD_KAFKA_RESP_ERR_NO_ERROR) { + break; + } + + /* If timeout error, continue to next chunk */ + if (err == RD_KAFKA_RESP_ERR__TIMED_OUT) { + continue; + } + + /* Other error - break and return it */ + break; + } + } + if (!CallState_end(self, &cs)) return NULL; diff --git a/src/confluent_kafka/src/confluent_kafka.h b/src/confluent_kafka/src/confluent_kafka.h index 32414c1d8..2053f8d7c 100644 --- a/src/confluent_kafka/src/confluent_kafka.h +++ b/src/confluent_kafka/src/confluent_kafka.h @@ -536,6 +536,83 @@ static CFL_UNUSED CFL_INLINE int cfl_timeout_ms(double tmout) { return -1; return (int)(tmout * 1000); } + + +/** + * @brief Calculate the timeout for the current chunk in wakeable poll pattern. + * + * This function calculates how long each chunk should wait, ensuring: + * - Infinite timeouts (-1) use the chunk size repeatedly + * - Finite timeouts are properly divided and don't exceed the total + * - The final chunk uses any remaining time (may be < chunk_size) + + * + * @param total_timeout_ms Total timeout in milliseconds (-1 for infinite) + * @param chunk_count Current chunk iteration count (0-based) + * @param chunk_timeout_ms Chunk size in milliseconds (200ms by default) + * @return int Chunk timeout in milliseconds, or 0 if total timeout expired + */ +static CFL_UNUSED CFL_INLINE int calculate_chunk_timeout(int total_timeout_ms, + int chunk_count, + int chunk_timeout_ms) { + if (total_timeout_ms < 0) { + /* Infinite timeout - use chunk size */ + return chunk_timeout_ms; + } else { + /* Finite timeout - calculate remaining */ + int remaining_ms = + total_timeout_ms - (chunk_count * chunk_timeout_ms); + if (remaining_ms <= 0) { + /* Timeout expired */ + return 0; + } + return (remaining_ms < chunk_timeout_ms) ? remaining_ms + : chunk_timeout_ms; + } +} + +/** + * @brief Check for pending signals between poll chunks. + * + * Re-acquires GIL, checks for signals, and handles cleanup if signal detected. + * + * Signal Handling Details: + * ----------------------- + * This function uses PyErr_CheckSignals(), which checks for ALL signals that + * Python has registered handlers for. By default, + * Python registers handlers for: + * - SIGINT (Ctrl+C): Raises KeyboardInterrupt exception + * - SIGTERM: Can raise SystemExit or be handled by user code + * - Other signals: If the user has registered handlers via Python's `signal` + * module, those will also be checked (e.g., signal.signal(signal.SIGUSR1, + * handler)). User code will need to handle these signals accordingly. + * + * + * @param self Handle (Producer or Consumer) + * @param cs CallState structure (thread state will be updated) + * @return int 0 if no signal detected (continue), 1 if signal detected (should + * return NULL) + */ +static CFL_UNUSED CFL_INLINE int check_signals_between_chunks(Handle *self, + CallState *cs) { + /* Re-acquire GIL */ + PyEval_RestoreThread(cs->thread_state); + + /* Check for pending signals (KeyboardInterrupt, etc.) */ + if (PyErr_CheckSignals() == -1) { + /* Signal detected - end the call state (cleanup TLS, etc.) + * Note: GIL is already held, but CallState_end expects to + * restore it, so save thread state again */ + cs->thread_state = PyEval_SaveThread(); + CallState_end(self, cs); + return 1; + } else { + /* No signal detected - re-release GIL for next iteration */ + cs->thread_state = PyEval_SaveThread(); + return 0; + } +} + /**************************************************************************** * * diff --git a/tests/common/__init__.py b/tests/common/__init__.py index 5d3f5b614..2c9119fd4 100644 --- a/tests/common/__init__.py +++ b/tests/common/__init__.py @@ -17,6 +17,8 @@ # import os +import signal +import time from confluent_kafka import Consumer @@ -29,6 +31,19 @@ def _trivup_cluster_type_kraft(): class TestUtils: + @staticmethod + def send_sigint_after_delay(delay_seconds): + """Send SIGINT to current process after delay. + + Utility function for testing interruptible poll/flush/consume operations. + Used to simulate Ctrl+C in automated tests. + + Args: + delay_seconds: Delay in seconds before sending SIGINT + """ + time.sleep(delay_seconds) + os.kill(os.getpid(), signal.SIGINT) + @staticmethod def broker_version(): return '4.0.0' if TestUtils.use_group_protocol_consumer() else '3.9.0' diff --git a/tests/integration/consumer/test_consumer_wakeable_poll_consume.py b/tests/integration/consumer/test_consumer_wakeable_poll_consume.py new file mode 100644 index 000000000..8ca5aea86 --- /dev/null +++ b/tests/integration/consumer/test_consumer_wakeable_poll_consume.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2024 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import time + +from tests.common import TestConsumer + +# ============================================================================ +# Consumer Wakeability Integration Testing +# ============================================================================ +# +# These integration tests verify that the wakeable pattern works correctly +# with actual Kafka clusters and real message delivery scenarios. +# +# How We Test Consumer Wakeability in Integration: +# ----------------------------------------------- +# 1. Message Availability Testing: +# - Produce messages to Kafka topics using a producer +# - Create consumers with wakeable pattern settings (timeouts >= 200ms) +# - Call poll()/consume() with timeouts that trigger chunking +# - Verify messages are returned correctly despite chunking +# - Measure elapsed time to ensure wakeable pattern doesn't delay delivery +# +# 2. Testing Methodology: +# - Setup: Create topics, produce messages, create consumers with proper config +# - Execution: Call poll()/consume() with timeouts >= 200ms (triggers chunking) +# - Verification: Check messages are returned, values are correct, timing is reasonable +# - Cleanup: Close consumers and verify no resource leaks +# +# 3. What We Verify: +# - Messages are correctly returned when available (wakeable pattern doesn't block delivery) +# - Message values and metadata are preserved through chunking +# - Timing remains reasonable (messages return quickly when available) +# - Consumer state remains consistent after operations complete + + +def test_poll_message_delivery_with_wakeable_pattern(kafka_cluster): + """Test that poll() correctly returns messages when available. + + This integration test verifies that the wakeable poll pattern doesn't + interfere with normal message delivery. + """ + topic = kafka_cluster.create_topic_and_wait_propogation('test-poll-message-delivery') + + # Produce a test message (use cimpl_producer for raw bytes) + producer = kafka_cluster.cimpl_producer() + producer.produce(topic, value=b'test-message') + producer.flush(timeout=1.0) + + # Create consumer with wakeable poll pattern settings + consumer_conf = kafka_cluster.client_conf( + { + 'group.id': 'test-poll-message-available', + 'socket.timeout.ms': 100, + 'session.timeout.ms': 6000, + 'auto.offset.reset': 'earliest', + } + ) + consumer = TestConsumer(consumer_conf) + consumer.subscribe([topic]) + + # Wait for subscription and message availability + time.sleep(2.0) + + # Poll for message - should return immediately when available + start = time.time() + msg = consumer.poll(timeout=2.0) + elapsed = time.time() - start + + # Verify message was returned correctly + assert msg is not None, "Expected message, got None" + assert not msg.error(), f"Message has error: {msg.error()}" + # Allow more time for initial consumer setup, but once ready, should return quickly + assert elapsed < 2.5, f"Message available but took {elapsed:.2f}s, expected < 2.5s" + assert msg.value() == b'test-message', "Message value mismatch" + + consumer.close() + + +def test_consume_message_delivery_with_wakeable_pattern(kafka_cluster): + """Test that consume() correctly returns messages when available. + + This integration test verifies that the wakeable poll pattern doesn't + interfere with normal batch message delivery. + """ + topic = kafka_cluster.create_topic_and_wait_propogation('test-consume-message-delivery') + + # Produce multiple test messages (use cimpl_producer for raw bytes) + producer = kafka_cluster.cimpl_producer() + for i in range(3): + producer.produce(topic, value=f'test-message-{i}'.encode()) + producer.flush(timeout=1.0) + + # Create consumer with wakeable poll pattern settings + consumer_conf = kafka_cluster.client_conf( + { + 'group.id': 'test-consume-messages-available', + 'socket.timeout.ms': 100, + 'session.timeout.ms': 6000, + 'auto.offset.reset': 'earliest', + } + ) + consumer = TestConsumer(consumer_conf) + consumer.subscribe([topic]) + + # Wait for subscription and message availability + time.sleep(2.0) + + # Consume messages - should return immediately when available + start = time.time() + msglist = consumer.consume(num_messages=5, timeout=2.0) + elapsed = time.time() - start + + # Verify messages were returned correctly + assert len(msglist) > 0, "Expected messages, got empty list" + assert len(msglist) <= 5, f"Should return at most 5 messages, got {len(msglist)}" + # Allow more time for initial consumer setup, but once ready, should return quickly + assert elapsed < 2.5, f"Messages available but took {elapsed:.2f}s, expected < 2.5s" + + # Verify message values + for i, msg in enumerate(msglist): + assert not msg.error(), f"Message {i} has error: {msg.error()}" + assert msg.value() is not None, f"Message {i} has no value" + # Verify we got the expected messages + expected_value = f'test-message-{i}'.encode() + expected_msg = f"Message {i} value mismatch: expected {expected_value}, " f"got {msg.value()}" + assert msg.value() == expected_value, expected_msg + + consumer.close() diff --git a/tests/integration/producer/test_producer_wakeable_poll_flush.py b/tests/integration/producer/test_producer_wakeable_poll_flush.py new file mode 100644 index 000000000..49d1d47d0 --- /dev/null +++ b/tests/integration/producer/test_producer_wakeable_poll_flush.py @@ -0,0 +1,201 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright 2024 Confluent Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import time + +from tests.common import TestConsumer + +# ============================================================================ +# Approach to Producer Wakeability Integration Testing +# ============================================================================ +# +# These integration tests verify that the wakeable pattern works correctly +# with actual Kafka clusters and real message delivery scenarios. +# +# How We Test Producer Wakeability in Integration: +# ------------------------------------------------ +# 1. Message Delivery Testing: +# - Create producers with wakeable pattern settings (timeouts >= 200ms) +# - Produce messages with delivery callbacks +# - Call poll()/flush() with timeouts that trigger chunking +# - Verify delivery callbacks are invoked correctly despite chunking +# - Measure elapsed time to ensure wakeable pattern doesn't delay delivery +# +# 2. Testing Methodology: +# - Setup: Create topics, configure producers with delivery callbacks +# - Execution: Produce messages, call poll()/flush() with timeouts >= 200ms (triggers chunking) +# - Verification: Check callbacks are called, messages are delivered to Kafka, timing is reasonable +# - End-to-End: Consume messages to verify they were actually committed to Kafka +# - Cleanup: Close producers/consumers and verify no resource leaks +# +# 3. What We Verify: +# - Delivery callbacks are correctly invoked (wakeable pattern doesn't block callbacks) +# - Messages are successfully delivered to Kafka brokers +# - Timing remains reasonable (delivery completes quickly when possible) +# - Producer state remains consistent after operations complete + + +def test_poll_message_delivery_with_wakeable_pattern(kafka_cluster): + """Test that poll() correctly delivers messages when using wakeable pattern. + + This integration test verifies that the wakeable poll pattern doesn't + interfere with normal message delivery callbacks. + """ + topic = kafka_cluster.create_topic_and_wait_propogation('test-poll-message-delivery') + + # Track delivery callbacks + delivery_called = [] + delivery_errors = [] + + def delivery_callback(err, msg): + if err: + delivery_errors.append(err) + else: + delivery_called.append(msg) + + # Create producer with wakeable poll pattern settings + producer_conf = kafka_cluster.client_conf( + { + 'socket.timeout.ms': 100, + 'message.timeout.ms': 10000, + } + ) + producer = kafka_cluster.cimpl_producer(producer_conf) + + # Produce a test message with delivery callback + producer.produce(topic, value=b'test-message', on_delivery=delivery_callback) + + # Poll with wakeable pattern - should trigger delivery callback + start = time.time() + events_handled = producer.poll(timeout=2.0) + elapsed = time.time() - start + + # Verify delivery callback was called + assert len(delivery_called) > 0, "Expected delivery callback to be called" + assert len(delivery_errors) == 0, f"Unexpected delivery errors: {delivery_errors}" + assert events_handled >= 0, "poll() should return non-negative int" + # Allow time for delivery callback, but should complete reasonably quickly + assert elapsed < 2.5, f"Poll took {elapsed:.2f}s, expected < 2.5s" + + # Flush to ensure message is committed to Kafka + producer.flush(timeout=1.0) + + # Verify message was actually delivered by consuming it + consumer_conf = kafka_cluster.client_conf( + { + 'group.id': 'test-poll-verify-delivery', + 'socket.timeout.ms': 100, + 'session.timeout.ms': 6000, + 'auto.offset.reset': 'earliest', + } + ) + consumer = TestConsumer(consumer_conf) + consumer.subscribe([topic]) + + # Wait for subscription and message availability + time.sleep(2.0) + + msg = consumer.poll(timeout=2.0) + assert msg is not None, "Expected message to be delivered" + assert not msg.error(), f"Message has error: {msg.error()}" + assert msg.value() == b'test-message', "Message value mismatch" + + producer.close() + consumer.close() + + +def test_flush_message_delivery_with_wakeable_pattern(kafka_cluster): + """Test that flush() correctly delivers messages when using wakeable pattern. + + This integration test verifies that the wakeable flush pattern doesn't + interfere with normal message delivery. + """ + topic = kafka_cluster.create_topic_and_wait_propogation('test-flush-message-delivery') + + # Track delivery callbacks + delivery_called = [] + delivery_errors = [] + + def delivery_callback(err, msg): + if err: + delivery_errors.append(err) + else: + delivery_called.append(msg) + + # Create producer with wakeable flush pattern settings + producer_conf = kafka_cluster.client_conf( + { + 'socket.timeout.ms': 100, + 'message.timeout.ms': 10000, + } + ) + producer = kafka_cluster.cimpl_producer(producer_conf) + + # Produce multiple test messages with delivery callbacks + num_messages = 5 + for i in range(num_messages): + producer.produce(topic, value=f'test-message-{i}'.encode(), on_delivery=delivery_callback) + + # Flush with wakeable pattern - should trigger all delivery callbacks + start = time.time() + remaining = producer.flush(timeout=2.0) + elapsed = time.time() - start + + # Verify all delivery callbacks were called + assert ( + len(delivery_called) == num_messages + ), f"Expected {num_messages} delivery callbacks, got {len(delivery_called)}" + assert len(delivery_errors) == 0, f"Unexpected delivery errors: {delivery_errors}" + assert remaining == 0, f"Expected 0 remaining messages after flush, got {remaining}" + # Allow time for flush, but should complete reasonably quickly + assert elapsed < 2.5, f"Flush took {elapsed:.2f}s, expected < 2.5s" + + # Verify messages were actually delivered by consuming them + consumer_conf = kafka_cluster.client_conf( + { + 'group.id': 'test-flush-verify-delivery', + 'socket.timeout.ms': 100, + 'session.timeout.ms': 6000, + 'auto.offset.reset': 'earliest', + } + ) + consumer = TestConsumer(consumer_conf) + consumer.subscribe([topic]) + + # Wait for subscription and message availability + time.sleep(2.0) + + # Consume all messages + msglist = [] + start = time.time() + while len(msglist) < num_messages and (time.time() - start) < 5.0: + msg = consumer.poll(timeout=1.0) + if msg is not None and not msg.error(): + msglist.append(msg) + + assert len(msglist) == num_messages, f"Expected {num_messages} messages, got {len(msglist)}" + + # Verify message values + for i, msg in enumerate(msglist): + expected_value = f'test-message-{i}'.encode() + assert ( + msg.value() == expected_value + ), f"Message {i} value mismatch: expected {expected_value}, got {msg.value()}" + + producer.close() + consumer.close() diff --git a/tests/test_Wakeable.py b/tests/test_Wakeable.py new file mode 100644 index 000000000..87cf05ddc --- /dev/null +++ b/tests/test_Wakeable.py @@ -0,0 +1,1439 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +Tests for wakeable poll/flush/consume functionality. + +These tests verify the interruptibility of blocking operations (poll, flush, consume) +using the wakeable pattern with signal checking between chunks. + +Includes: +- Utility function tests (calculate_chunk_timeout, check_signals_between_chunks) +- Producer wakeable tests (poll, flush) +- Consumer wakeable tests (poll, consume) +""" +import threading +import time + +import pytest + +from confluent_kafka import Producer +from tests.common import TestConsumer, TestUtils + +# Timing constants for wakeable poll/flush/consume pattern tests +# For timeouts < 200ms, the wakeable pattern is NOT used (see Producer.c/Consumer.c), +# so those timeouts can complete faster. For timeouts >= 200ms, chunking is used. +CHUNK_TIMEOUT_MS = 200 # Chunk size in milliseconds +WAKEABLE_POLL_TIMEOUT_MIN = 0.2 # Minimum timeout for chunked operations (seconds) +WAKEABLE_POLL_TIMEOUT_MAX = 2.0 # Maximum timeout (seconds) + + +# ============================================================================ +# Approach to Wakeability Testing +# ============================================================================ +# +# The wakeable pattern is implemented using shared C utility functions that are +# used by both Producer and Consumer. Our testing strategy mirrors this architecture: +# +# High level Wakeability Implementation: +# ------------ +# Shared Utilities (confluent_kafka.h): +# - calculate_chunk_timeout(): Splits long timeouts into 200ms chunks +# - check_signals_between_chunks(): Re-acquires GIL, checks signals, handles cleanup +# +# Producer Implementation (Producer.c): +# - Producer.poll() uses wakeable pattern for timeouts >= 200ms +# - Producer.flush() uses wakeable pattern for timeouts >= 200ms +# +# Consumer Implementation (Consumer.c): +# - Consumer.poll() uses wakeable pattern for timeouts >= 200ms +# - Consumer.consume() uses wakeable pattern for timeouts >= 200ms +# +# How We Test Wakeability: +# ------------------------ +# Since Producer and Consumer share the same C utility functions but have different +# Python APIs, we test them using a layered approach: +# +# 1. Testing Producer Wakeability: +# - Create Producer instances and call poll()/flush() with various timeouts +# - Inject signals at different times (immediate, after chunks, during finite timeout) +# - Verify KeyboardInterrupt is raised and Producer-specific behavior (return types, cleanup) +# - Test both infinite and finite timeouts to cover all code paths +# +# 2. Testing Consumer Wakeability: +# - Create Consumer instances and call poll()/consume() with various timeouts +# - Inject signals at different times using the same pattern as Producer tests +# - Verify KeyboardInterrupt is raised and Consumer-specific behavior (return types, message handling) +# - Test both infinite and finite timeouts, including edge cases like num_messages=0 +# +# 3. Testing Methodology: +# - Signal Injection: Use TestUtils.send_sigint_after_delay() in a background thread +# to simulate KeyboardInterrupt at specific times during blocking operations +# - Chunking Verification: Measure elapsed time to verify >= 200ms timeouts use chunking +# (multiple 200ms intervals) while < 200ms timeouts bypass chunking entirely +# - Interruptibility Verification: Wrap blocking calls in try/except to catch +# KeyboardInterrupt and verify operations abort cleanly +# - State Verification: Check that objects are properly cleaned up (closed state, +# no resource leaks) after interrupted operations + + +# ============================================================================ +# Producer wakeable tests +# ============================================================================ + + +def test_producer_wakeable_poll_utility_functions_interaction(): + """Test interaction between calculate_chunk_timeout() and check_signals_between_chunks().""" + # Assert: Chunk calculation and signal check work together + producer1 = Producer({'socket.timeout.ms': 100, 'message.timeout.ms': 10}) + + interrupt_thread = threading.Thread(target=lambda: TestUtils.send_sigint_after_delay(0.4)) + interrupt_thread.daemon = True + interrupt_thread.start() + + interrupted = False + try: + producer1.poll(timeout=1.0) + except KeyboardInterrupt: + interrupted = True + finally: + producer1.close() + + assert interrupted, "Should have raised KeyboardInterrupt" + + # Assert: Multiple chunks before signal detection + producer2 = Producer({'socket.timeout.ms': 100, 'message.timeout.ms': 10}) + + interrupt_thread = threading.Thread(target=lambda: TestUtils.send_sigint_after_delay(0.6)) + interrupt_thread.daemon = True + interrupt_thread.start() + + interrupted = False + try: + producer2.poll(timeout=WAKEABLE_POLL_TIMEOUT_MAX) + except KeyboardInterrupt: + interrupted = True + finally: + producer2.close() + + assert interrupted, "Should have raised KeyboardInterrupt" + + +def test_producer_wakeable_poll_interruptibility_and_messages(): + """Test poll() interruptibility and message handling.""" + # Assert: Infinite timeout can be interrupted + producer1 = Producer({'socket.timeout.ms': 100, 'message.timeout.ms': 10}) + + interrupt_thread = threading.Thread(target=lambda: TestUtils.send_sigint_after_delay(0.1)) + interrupt_thread.daemon = True + interrupt_thread.start() + + interrupted = False + try: + producer1.poll(timeout=WAKEABLE_POLL_TIMEOUT_MAX) + except KeyboardInterrupt: + interrupted = True + finally: + producer1.close() + + assert interrupted, "Should have raised KeyboardInterrupt" + + # Assert: Finite timeout can be interrupted before timeout expires + producer2 = Producer({'socket.timeout.ms': 100, 'message.timeout.ms': 10}) + + interrupt_thread = threading.Thread(target=lambda: TestUtils.send_sigint_after_delay(0.3)) + interrupt_thread.daemon = True + interrupt_thread.start() + + interrupted = False + try: + producer2.poll(timeout=WAKEABLE_POLL_TIMEOUT_MAX) + except KeyboardInterrupt: + interrupted = True + finally: + producer2.close() + + assert interrupted, "Should have raised KeyboardInterrupt" + + # Assert: Signal sent after multiple chunks still interrupts + producer3 = Producer({'socket.timeout.ms': 100, 'message.timeout.ms': 10}) + + interrupt_thread = threading.Thread(target=lambda: TestUtils.send_sigint_after_delay(0.6)) + interrupt_thread.daemon = True + interrupt_thread.start() + + interrupted = False + try: + producer3.poll(timeout=WAKEABLE_POLL_TIMEOUT_MAX) + except KeyboardInterrupt: + interrupted = True + finally: + producer3.close() + + assert interrupted, "Should have raised KeyboardInterrupt" + + # Assert: No signal - timeout works normally + producer4 = Producer({'socket.timeout.ms': 100, 'message.timeout.ms': 10}) + + start = time.time() + result = producer4.poll(timeout=0.5) + elapsed = time.time() - start + + assert isinstance(result, int), "poll() should return int" + assert ( + WAKEABLE_POLL_TIMEOUT_MIN <= elapsed <= WAKEABLE_POLL_TIMEOUT_MAX + ), f"Timeout took {elapsed:.2f}s, expected ~0.5s" + producer4.close() + + +def test_producer_wakeable_poll_edge_cases(): + """Test poll() edge cases.""" + # Assert: Zero timeout returns immediately + producer1 = Producer({'socket.timeout.ms': 100, 'message.timeout.ms': 10}) + + start = time.time() + result = producer1.poll(timeout=0.0) + elapsed = time.time() - start + + assert elapsed < WAKEABLE_POLL_TIMEOUT_MAX, f"Zero timeout took {elapsed:.2f}s" + assert isinstance(result, int) + producer1.close() + + # Assert: Closed producer raises RuntimeError + producer2 = Producer({'socket.timeout.ms': 100, 'message.timeout.ms': 10}) + producer2.close() + + with pytest.raises(RuntimeError) as exc_info: + producer2.poll(timeout=0.1) + assert 'Producer has been closed' in str(exc_info.value) + + # Assert: Short timeout works correctly + producer3 = Producer({'socket.timeout.ms': 100, 'message.timeout.ms': 10}) + + start = time.time() + result = producer3.poll(timeout=0.1) + elapsed = time.time() - start + + assert isinstance(result, int) + # Short timeouts don't use chunking + assert elapsed <= WAKEABLE_POLL_TIMEOUT_MAX, f"Short timeout took {elapsed:.2f}s" + producer3.close() + + # Assert: Very short timeout works + producer4 = Producer({'socket.timeout.ms': 100, 'message.timeout.ms': 10}) + + start = time.time() + result = producer4.poll(timeout=0.05) + elapsed = time.time() - start + + assert isinstance(result, int) + assert elapsed < WAKEABLE_POLL_TIMEOUT_MAX, f"Very short timeout took {elapsed:.2f}s" + producer4.close() + + +def test_producer_wakeable_flush_interruptibility_and_messages(): + """Test flush() interruptibility and message handling.""" + # Assert: Infinite timeout can be interrupted + producer1 = Producer( + { + 'bootstrap.servers': 'localhost:9092', + 'socket.timeout.ms': 60000, + 'message.timeout.ms': 30000, + 'acks': 'all', + 'batch.num.messages': 100, + 'linger.ms': 100, + 'queue.buffering.max.messages': 100000, + 'queue.buffering.max.kbytes': 104857600, + 'max.in.flight.requests.per.connection': 1, + 'request.timeout.ms': 30000, + 'delivery.timeout.ms': 30000, + } + ) + + messages_produced = False + stop_producing = threading.Event() + production_stats = {'count': 0, 'errors': 0} + + def continuous_producer(): + message_num = 0 + while not stop_producing.is_set(): + try: + producer1.produce( + 'test-topic', value=f'continuous-{message_num}'.encode(), key=f'key-{message_num}'.encode() + ) + production_stats['count'] += 1 + message_num += 1 + except Exception as e: + production_stats['errors'] += 1 + if "QUEUE_FULL" in str(e): + time.sleep(0.001) + else: + time.sleep(0.01) + + try: + for i in range(1000): + try: + producer1.produce('test-topic', value=f'initial-{i}'.encode()) + messages_produced = True + except Exception as e: + if "QUEUE_FULL" in str(e): + time.sleep(0.01) + continue + break + + if not messages_produced: + producer1.close() + pytest.skip("Broker not available, cannot test flush() interruptibility") + + poll_start = time.time() + while time.time() - poll_start < 0.5: + producer1.poll(timeout=0.1) + + producer_thread = threading.Thread(target=continuous_producer, daemon=True) + producer_thread.start() + time.sleep(0.1) + + interrupt_thread = threading.Thread(target=lambda: TestUtils.send_sigint_after_delay(0.1)) + interrupt_thread.daemon = True + interrupt_thread.start() + + interrupted = False + try: + producer1.flush() + except KeyboardInterrupt: + interrupted = True + finally: + stop_producing.set() + time.sleep(0.1) + producer1.close() + + assert interrupted, "Should have raised KeyboardInterrupt" + except Exception: + stop_producing.set() + producer1.close() + raise + + # Assert: Finite timeout can be interrupted before timeout expires + producer2 = Producer( + { + 'bootstrap.servers': 'localhost:9092', + 'socket.timeout.ms': 60000, + 'message.timeout.ms': 30000, + 'acks': 'all', + 'batch.num.messages': 100, + 'linger.ms': 100, + 'queue.buffering.max.messages': 100000, + 'queue.buffering.max.kbytes': 104857600, + 'max.in.flight.requests.per.connection': 1, + 'request.timeout.ms': 30000, + 'delivery.timeout.ms': 30000, + } + ) + + stop_producing2 = threading.Event() + production_stats2 = {'count': 0, 'errors': 0} + + def continuous_producer2(): + message_num = 0 + while not stop_producing2.is_set(): + try: + producer2.produce( + 'test-topic', value=f'continuous2-{message_num}'.encode(), key=f'key2-{message_num}'.encode() + ) + production_stats2['count'] += 1 + message_num += 1 + except Exception as e: + production_stats2['errors'] += 1 + if "QUEUE_FULL" in str(e): + time.sleep(0.001) + else: + time.sleep(0.01) + + try: + for i in range(1000): + try: + producer2.produce('test-topic', value=f'initial2-{i}'.encode()) + except Exception as e: + if "QUEUE_FULL" in str(e): + time.sleep(0.01) + continue + break + + poll_start = time.time() + while time.time() - poll_start < 0.5: + producer2.poll(timeout=0.1) + + producer_thread2 = threading.Thread(target=continuous_producer2, daemon=True) + producer_thread2.start() + time.sleep(0.1) + + interrupt_thread = threading.Thread(target=lambda: TestUtils.send_sigint_after_delay(0.3)) + interrupt_thread.daemon = True + interrupt_thread.start() + + interrupted = False + try: + producer2.flush(timeout=WAKEABLE_POLL_TIMEOUT_MAX) + except KeyboardInterrupt: + interrupted = True + finally: + stop_producing2.set() + time.sleep(0.1) + producer2.close() + + assert interrupted, "Should have raised KeyboardInterrupt" + except Exception: + stop_producing2.set() + producer2.close() + raise + + # Assert: Signal sent after multiple chunks still interrupts + producer3 = Producer( + { + 'bootstrap.servers': 'localhost:9092', + 'socket.timeout.ms': 60000, + 'message.timeout.ms': 30000, + 'acks': 'all', + 'batch.num.messages': 100, + 'linger.ms': 100, + 'queue.buffering.max.messages': 100000, + 'queue.buffering.max.kbytes': 104857600, + 'max.in.flight.requests.per.connection': 1, + 'request.timeout.ms': 30000, + 'delivery.timeout.ms': 30000, + } + ) + + stop_producing3 = threading.Event() + production_stats3 = {'count': 0, 'errors': 0} + + def continuous_producer3(): + message_num = 0 + while not stop_producing3.is_set(): + try: + producer3.produce( + 'test-topic', value=f'continuous3-{message_num}'.encode(), key=f'key3-{message_num}'.encode() + ) + production_stats3['count'] += 1 + message_num += 1 + except Exception as e: + production_stats3['errors'] += 1 + if "QUEUE_FULL" in str(e): + time.sleep(0.001) + else: + time.sleep(0.01) + + try: + for i in range(1000): + try: + producer3.produce('test-topic', value=f'initial3-{i}'.encode()) + except Exception as e: + if "QUEUE_FULL" in str(e): + time.sleep(0.01) + continue + break + + poll_start = time.time() + while time.time() - poll_start < 0.5: + producer3.poll(timeout=0.1) + + producer_thread3 = threading.Thread(target=continuous_producer3, daemon=True) + producer_thread3.start() + time.sleep(0.1) + + interrupt_thread = threading.Thread(target=lambda: TestUtils.send_sigint_after_delay(0.6)) + interrupt_thread.daemon = True + interrupt_thread.start() + + interrupted = False + try: + producer3.flush() + except KeyboardInterrupt: + interrupted = True + finally: + stop_producing3.set() + time.sleep(0.1) + producer3.close() + + assert interrupted, "Should have raised KeyboardInterrupt" + except Exception: + stop_producing3.set() + producer3.close() + raise + + # Assert: No signal - timeout works normally + producer4 = Producer( + { + 'bootstrap.servers': 'localhost:9092', + 'socket.timeout.ms': 100, + 'message.timeout.ms': 10, + 'acks': 'all', + 'max.in.flight.requests.per.connection': 1, + } + ) + + try: + for i in range(100): + producer4.produce('test-topic', value=f'timeout-test-{i}'.encode()) + except Exception: + pass + + start = time.time() + qlen = producer4.flush(timeout=0.5) + elapsed = time.time() - start + + assert isinstance(qlen, int), "flush() should return int" + assert elapsed <= WAKEABLE_POLL_TIMEOUT_MAX, f"Timeout took {elapsed:.2f}s" + producer4.close() + + +def test_producer_wakeable_flush_edge_cases(): + """Test flush() edge cases.""" + # Assert: Zero timeout returns immediately + producer1 = Producer({'socket.timeout.ms': 100, 'message.timeout.ms': 10}) + + start = time.time() + qlen = producer1.flush(timeout=0.0) + elapsed = time.time() - start + + assert elapsed < WAKEABLE_POLL_TIMEOUT_MAX, f"Zero timeout took {elapsed:.2f}s" + assert isinstance(qlen, int) + producer1.close() + + # Assert: Closed producer raises RuntimeError + producer2 = Producer({'socket.timeout.ms': 100, 'message.timeout.ms': 10}) + producer2.close() + + with pytest.raises(RuntimeError) as exc_info: + producer2.flush(timeout=0.1) + assert 'Producer has been closed' in str(exc_info.value) + + # Assert: Short timeout works correctly + producer3 = Producer({'socket.timeout.ms': 100, 'message.timeout.ms': 10}) + + start = time.time() + qlen = producer3.flush(timeout=0.1) + elapsed = time.time() - start + + assert isinstance(qlen, int) + # Short timeouts don't use chunking + assert elapsed <= WAKEABLE_POLL_TIMEOUT_MAX, f"Short timeout took {elapsed:.2f}s" + producer3.close() + + # Assert: Very short timeout works + producer4 = Producer({'socket.timeout.ms': 100, 'message.timeout.ms': 10}) + + start = time.time() + qlen = producer4.flush(timeout=0.05) + elapsed = time.time() - start + + assert isinstance(qlen, int) + assert elapsed < WAKEABLE_POLL_TIMEOUT_MAX, f"Very short timeout took {elapsed:.2f}s" + producer4.close() + + # Assert: Empty queue flush returns immediately + producer5 = Producer({'socket.timeout.ms': 100, 'message.timeout.ms': 10}) + + start = time.time() + qlen = producer5.flush(timeout=1.0) + elapsed = time.time() - start + + assert qlen == 0 + assert elapsed < WAKEABLE_POLL_TIMEOUT_MAX, f"Empty flush took {elapsed:.2f}s" + producer5.close() + + +# ============================================================================ +# Consumer wakeable tests +# ============================================================================ + + +def test_consumer_wakeable_poll_utility_functions_interaction(): + """Test interaction between calculate_chunk_timeout() and check_signals_between_chunks().""" + # Assertion 1: Both functions work together - chunk calculation + signal check + consumer1 = TestConsumer( + { + 'group.id': 'test-interaction-chunk-signal', + 'socket.timeout.ms': 100, + 'session.timeout.ms': 1000, + 'auto.offset.reset': 'latest', + } + ) + consumer1.subscribe(['test-topic']) + + interrupt_thread = threading.Thread(target=lambda: TestUtils.send_sigint_after_delay(0.4)) + interrupt_thread.daemon = True + interrupt_thread.start() + + interrupted = False + try: + consumer1.poll(timeout=1.0) # 1 second timeout, interrupt after 0.4s + except KeyboardInterrupt: + interrupted = True + finally: + consumer1.close() + + assert interrupted, "Assertion 1 failed: Should have raised KeyboardInterrupt" + + # Assertion 2: Multiple chunks before signal - both functions work over multiple iterations + consumer2 = TestConsumer( + { + 'group.id': 'test-interaction-multiple-chunks', + 'socket.timeout.ms': 100, + 'session.timeout.ms': 1000, + 'auto.offset.reset': 'latest', + } + ) + consumer2.subscribe(['test-topic']) + + # Send signal after 0.6 seconds (3 chunks should have passed: 0.2s, 0.4s, 0.6s) + interrupt_thread = threading.Thread(target=lambda: TestUtils.send_sigint_after_delay(0.6)) + interrupt_thread.daemon = True + interrupt_thread.start() + + interrupted = False + try: + consumer2.poll() # Infinite timeout + except KeyboardInterrupt: + interrupted = True + finally: + consumer2.close() + + assert interrupted, "Assertion 2 failed: Should have raised KeyboardInterrupt" + + +def test_consumer_wakeable_poll_interruptibility_and_messages(): + """Test poll() interruptibility (main fix) and message handling.""" + topic = 'test-poll-interrupt-topic' + + # Assertion 1: Infinite timeout can be interrupted immediately + consumer1 = TestConsumer( + { + 'group.id': 'test-poll-infinite-immediate', + 'socket.timeout.ms': 100, + 'session.timeout.ms': 1000, + 'auto.offset.reset': 'latest', + } + ) + consumer1.subscribe([topic]) + + interrupt_thread = threading.Thread(target=lambda: TestUtils.send_sigint_after_delay(0.1)) + interrupt_thread.daemon = True + interrupt_thread.start() + + interrupted = False + try: + consumer1.poll() # Infinite timeout + except KeyboardInterrupt: + interrupted = True + finally: + consumer1.close() + + assert interrupted, "Assertion 1 failed: Should have raised KeyboardInterrupt" + + # Assertion 2: Finite timeout can be interrupted before timeout expires + consumer2 = TestConsumer( + { + 'group.id': 'test-poll-finite-interrupt', + 'socket.timeout.ms': 100, + 'session.timeout.ms': 1000, + 'auto.offset.reset': 'latest', + } + ) + consumer2.subscribe([topic]) + + interrupt_thread = threading.Thread(target=lambda: TestUtils.send_sigint_after_delay(0.3)) + interrupt_thread.daemon = True + interrupt_thread.start() + + interrupted = False + timeout_value = WAKEABLE_POLL_TIMEOUT_MAX # Use constant instead of hardcoded 2.0 + try: + consumer2.poll(timeout=timeout_value) # Use constant for timeout + except KeyboardInterrupt: + interrupted = True + finally: + consumer2.close() + + assert interrupted, "Assertion 2 failed: Should have raised KeyboardInterrupt" + + # Assertion 3: Signal sent after multiple chunks still interrupts quickly + consumer3 = TestConsumer( + { + 'group.id': 'test-poll-multiple-chunks', + 'socket.timeout.ms': 100, + 'session.timeout.ms': 1000, + 'auto.offset.reset': 'latest', + } + ) + consumer3.subscribe([topic]) + + interrupt_thread = threading.Thread(target=lambda: TestUtils.send_sigint_after_delay(0.6)) + interrupt_thread.daemon = True + interrupt_thread.start() + + interrupted = False + try: + consumer3.poll() # Infinite timeout + except KeyboardInterrupt: + interrupted = True + finally: + consumer3.close() + + assert interrupted, "Assertion 3 failed: Should have raised KeyboardInterrupt" + + # Assertion 4: No signal - timeout works normally + consumer4 = TestConsumer( + { + 'group.id': 'test-poll-timeout-normal', + 'socket.timeout.ms': 100, + 'session.timeout.ms': 1000, + 'auto.offset.reset': 'latest', + } + ) + consumer4.subscribe([topic]) + + start = time.time() + msg = consumer4.poll(timeout=0.5) # 500ms, no signal + elapsed = time.time() - start + + assert msg is None, "Assertion 4 failed: Expected None (timeout), no signal should not interrupt" + assert ( + WAKEABLE_POLL_TIMEOUT_MIN <= elapsed <= WAKEABLE_POLL_TIMEOUT_MAX + ), f"Assertion 4 failed: Normal timeout took {elapsed:.2f}s, expected ~0.5s" + consumer4.close() + + +def test_consumer_wakeable_poll_edge_cases(): + """Test poll() edge cases.""" + topic = 'test-poll-edge-topic' + + # Assertion 1: Zero timeout returns immediately (non-blocking) + consumer1 = TestConsumer( + { + 'group.id': 'test-poll-zero-timeout', + 'socket.timeout.ms': 100, + 'session.timeout.ms': 1000, + 'auto.offset.reset': 'latest', + } + ) + consumer1.subscribe([topic]) + + start = time.time() + msg = consumer1.poll(timeout=0.0) # Zero timeout + elapsed = time.time() - start + + assert ( + elapsed < WAKEABLE_POLL_TIMEOUT_MAX + ), f"Assertion 1 failed: Zero timeout took {elapsed:.2f}s, expected < {WAKEABLE_POLL_TIMEOUT_MAX}s" + assert msg is None, "Assertion 1 failed: Zero timeout with no messages should return None" + consumer1.close() + + # Assertion 2: Closed consumer raises RuntimeError + consumer2 = TestConsumer({'group.id': 'test-poll-closed', 'socket.timeout.ms': 100, 'session.timeout.ms': 1000}) + consumer2.close() + + with pytest.raises(RuntimeError) as exc_info: + consumer2.poll(timeout=0.1) + msg = f"Assertion 2 failed: Expected 'Consumer closed' error, " f"got: {exc_info.value}" + assert 'Consumer closed' in str(exc_info.value), msg + + # Assertion 3: Short timeout works correctly (no signal) + consumer3 = TestConsumer( + { + 'group.id': 'test-poll-short-timeout', + 'socket.timeout.ms': 100, + 'session.timeout.ms': 1000, + 'auto.offset.reset': 'latest', + } + ) + consumer3.subscribe([topic]) + + start = time.time() + msg = consumer3.poll(timeout=0.1) # 100ms timeout + elapsed = time.time() - start + + assert msg is None, "Assertion 3 failed: Short timeout with no messages should return None" + # Short timeouts (< 200ms) don't use chunking, so they can complete faster than WAKEABLE_POLL_TIMEOUT_MIN + # Only check upper bound to allow for actual timeout duration + assert ( + elapsed <= WAKEABLE_POLL_TIMEOUT_MAX + ), f"Assertion 3 failed: Short timeout took {elapsed:.2f}s, expected <= {WAKEABLE_POLL_TIMEOUT_MAX}s" + consumer3.close() + + # Assertion 4: Very short timeout (less than chunk size) works + consumer4 = TestConsumer( + { + 'group.id': 'test-poll-very-short', + 'socket.timeout.ms': 100, + 'session.timeout.ms': 1000, + 'auto.offset.reset': 'latest', + } + ) + consumer4.subscribe([topic]) + + start = time.time() + msg = consumer4.poll(timeout=0.05) # 50ms timeout (less than 200ms chunk) + elapsed = time.time() - start + + assert msg is None, "Assertion 4 failed: Very short timeout should return None" + assert ( + elapsed < WAKEABLE_POLL_TIMEOUT_MAX + ), f"Assertion 4 failed: Very short timeout took {elapsed:.2f}s, expected < {WAKEABLE_POLL_TIMEOUT_MAX}s" + consumer4.close() + + +def test_consumer_wakeable_consume_interruptibility_and_messages(): + """Test consume() interruptibility (main fix) and message handling.""" + topic = 'test-consume-interrupt-topic' + + # Assertion 1: Infinite timeout can be interrupted immediately + consumer1 = TestConsumer( + { + 'group.id': 'test-consume-infinite-immediate', + 'socket.timeout.ms': 100, + 'session.timeout.ms': 1000, + 'auto.offset.reset': 'latest', + } + ) + consumer1.subscribe([topic]) + + interrupt_thread = threading.Thread(target=lambda: TestUtils.send_sigint_after_delay(0.1)) + interrupt_thread.daemon = True + interrupt_thread.start() + + interrupted = False + try: + consumer1.consume() # Infinite timeout, default num_messages=1 + except KeyboardInterrupt: + interrupted = True + finally: + consumer1.close() + + assert interrupted, "Assertion 1 failed: Should have raised KeyboardInterrupt" + + # Assertion 2: Finite timeout can be interrupted before timeout expires + consumer2 = TestConsumer( + { + 'group.id': 'test-consume-finite-interrupt', + 'socket.timeout.ms': 100, + 'session.timeout.ms': 1000, + 'auto.offset.reset': 'latest', + } + ) + consumer2.subscribe([topic]) + + interrupt_thread = threading.Thread(target=lambda: TestUtils.send_sigint_after_delay(0.3)) + interrupt_thread.daemon = True + interrupt_thread.start() + + interrupted = False + timeout_value = WAKEABLE_POLL_TIMEOUT_MAX # Use constant instead of hardcoded 2.0 + try: + consumer2.consume(num_messages=10, timeout=timeout_value) # Use constant for timeout + except KeyboardInterrupt: + interrupted = True + finally: + consumer2.close() + + assert interrupted, "Assertion 2 failed: Should have raised KeyboardInterrupt" + + # Assertion 3: Signal sent after multiple chunks still interrupts quickly + consumer3 = TestConsumer( + { + 'group.id': 'test-consume-multiple-chunks', + 'socket.timeout.ms': 100, + 'session.timeout.ms': 1000, + 'auto.offset.reset': 'latest', + } + ) + consumer3.subscribe([topic]) + + interrupt_thread = threading.Thread(target=lambda: TestUtils.send_sigint_after_delay(0.6)) + interrupt_thread.daemon = True + interrupt_thread.start() + + interrupted = False + try: + consumer3.consume(num_messages=5) # Infinite timeout + except KeyboardInterrupt: + interrupted = True + finally: + consumer3.close() + + assert interrupted, "Assertion 3 failed: Should have raised KeyboardInterrupt" + + # Assertion 4: No signal - timeout works normally, returns empty list + consumer4 = TestConsumer( + { + 'group.id': 'test-consume-timeout-normal', + 'socket.timeout.ms': 100, + 'session.timeout.ms': 1000, + 'auto.offset.reset': 'latest', + } + ) + consumer4.subscribe([topic]) + + start = time.time() + msglist = consumer4.consume(num_messages=10, timeout=0.5) # 500ms, no signal + elapsed = time.time() - start + + assert isinstance(msglist, list), "Assertion 4 failed: consume() should return a list" + assert len(msglist) == 0, f"Assertion 4 failed: Expected empty list (timeout), got {len(msglist)} messages" + assert ( + WAKEABLE_POLL_TIMEOUT_MIN <= elapsed <= WAKEABLE_POLL_TIMEOUT_MAX + ), f"Assertion 4 failed: Normal timeout took {elapsed:.2f}s, expected ~0.5s" + consumer4.close() + + # Assertion 5: num_messages=0 returns empty list immediately + consumer5 = TestConsumer( + { + 'group.id': 'test-consume-zero-messages', + 'socket.timeout.ms': 100, + 'session.timeout.ms': 1000, + 'auto.offset.reset': 'latest', + } + ) + consumer5.subscribe([topic]) + + start = time.time() + msglist = consumer5.consume(num_messages=0, timeout=1.0) + elapsed = time.time() - start + + assert isinstance(msglist, list), "Assertion 5 failed: consume() should return a list" + assert len(msglist) == 0, "Assertion 5 failed: num_messages=0 should return empty list" + assert ( + elapsed < WAKEABLE_POLL_TIMEOUT_MAX + ), f"Assertion 5 failed: num_messages=0 took {elapsed:.2f}s, expected < {WAKEABLE_POLL_TIMEOUT_MAX}s" + consumer5.close() + + +def test_consumer_wakeable_consume_edge_cases(): + """Test consume() wakeable edge cases.""" + topic = 'test-consume-edge-topic' + + # Assertion 1: Zero timeout returns immediately (non-blocking) + consumer1 = TestConsumer( + { + 'group.id': 'test-consume-zero-timeout', + 'socket.timeout.ms': 100, + 'session.timeout.ms': 1000, + 'auto.offset.reset': 'latest', + } + ) + consumer1.subscribe([topic]) + + start = time.time() + msglist = consumer1.consume(num_messages=10, timeout=0.0) # Zero timeout + elapsed = time.time() - start + + assert ( + elapsed < WAKEABLE_POLL_TIMEOUT_MAX + ), f"Assertion 1 failed: Zero timeout took {elapsed:.2f}s, expected < {WAKEABLE_POLL_TIMEOUT_MAX}s" + assert isinstance(msglist, list), "Assertion 1 failed: consume() should return a list" + assert len(msglist) == 0, "Assertion 1 failed: Zero timeout with no messages should return empty list" + consumer1.close() + + # Assertion 2: Closed consumer raises RuntimeError + consumer2 = TestConsumer({'group.id': 'test-consume-closed', 'socket.timeout.ms': 100, 'session.timeout.ms': 1000}) + consumer2.close() + + with pytest.raises(RuntimeError) as exc_info: + consumer2.consume(num_messages=10, timeout=0.1) + msg = f"Assertion 2 failed: Expected 'Consumer closed' error, " f"got: {exc_info.value}" + assert 'Consumer closed' in str(exc_info.value), msg + + # Assertion 3: Invalid num_messages (negative) raises ValueError + consumer3 = TestConsumer( + { + 'group.id': 'test-consume-invalid-negative', + 'socket.timeout.ms': 100, + 'session.timeout.ms': 1000, + 'auto.offset.reset': 'latest', + } + ) + consumer3.subscribe([topic]) + + with pytest.raises(ValueError) as exc_info: + consumer3.consume(num_messages=-1, timeout=0.1) + msg = f"Assertion 3 failed: Expected num_messages range error, " f"got: {exc_info.value}" + assert 'num_messages must be between 0 and 1000000' in str(exc_info.value), msg + consumer3.close() + + # Assertion 4: Invalid num_messages (too large) raises ValueError + consumer4 = TestConsumer( + { + 'group.id': 'test-consume-invalid-large', + 'socket.timeout.ms': 100, + 'session.timeout.ms': 1000, + 'auto.offset.reset': 'latest', + } + ) + consumer4.subscribe([topic]) + + with pytest.raises(ValueError) as exc_info: + consumer4.consume(num_messages=1000001, timeout=0.1) + msg = f"Assertion 4 failed: Expected num_messages range error, " f"got: {exc_info.value}" + assert 'num_messages must be between 0 and 1000000' in str(exc_info.value), msg + consumer4.close() + + # Assertion 5: Short timeout works correctly (no signal) + consumer5 = TestConsumer( + { + 'group.id': 'test-consume-short-timeout', + 'socket.timeout.ms': 100, + 'session.timeout.ms': 1000, + 'auto.offset.reset': 'latest', + } + ) + consumer5.subscribe([topic]) + + start = time.time() + msglist = consumer5.consume(num_messages=10, timeout=0.1) # 100ms timeout + elapsed = time.time() - start + + assert isinstance(msglist, list), "Assertion 5 failed: consume() should return a list" + assert len(msglist) == 0, "Assertion 5 failed: Short timeout with no messages should return empty list" + # Only check upper bound to allow for actual timeout duration + assert ( + elapsed <= WAKEABLE_POLL_TIMEOUT_MAX + ), f"Assertion 5 failed: Short timeout took {elapsed:.2f}s, expected <= {WAKEABLE_POLL_TIMEOUT_MAX}s" + consumer5.close() + + # Assertion 6: Very short timeout (less than chunk size) works + consumer6 = TestConsumer( + { + 'group.id': 'test-consume-very-short', + 'socket.timeout.ms': 100, + 'session.timeout.ms': 1000, + 'auto.offset.reset': 'latest', + } + ) + consumer6.subscribe([topic]) + + start = time.time() + msglist = consumer6.consume(num_messages=5, timeout=0.05) # 50ms timeout (less than 200ms chunk) + elapsed = time.time() - start + + assert isinstance(msglist, list), "Assertion 6 failed: consume() should return a list" + assert len(msglist) == 0, "Assertion 6 failed: Very short timeout should return empty list" + assert ( + elapsed < WAKEABLE_POLL_TIMEOUT_MAX + ), f"Assertion 6 failed: Very short timeout took {elapsed:.2f}s, expected < {WAKEABLE_POLL_TIMEOUT_MAX}s" + consumer6.close() + + +# ============================================================================ +# Utility function tests +# ============================================================================ + + +@pytest.mark.parametrize("api_type", ["producer", "consumer"]) +def test_calculate_chunk_timeout_utility_function(api_type): + """Comprehensive test of calculate_chunk_timeout() utility function through poll() API. + + Tests all timeout scenarios: infinite, exact multiple, not multiple, very short, + zero timeout, large timeout, and interruption during finite timeout. + """ + + # Helper to create API object and blocking call + def create_api_obj(group_id_suffix=""): + if api_type == "producer": + obj = Producer({'socket.timeout.ms': 100, 'message.timeout.ms': 10}) + return obj, lambda t=None: obj.poll() if t is None else obj.poll(timeout=t) + else: + group_id = f'test-chunk-{group_id_suffix}' if group_id_suffix else 'test-chunk' + obj = TestConsumer( + { + 'group.id': group_id, + 'socket.timeout.ms': 100, + 'session.timeout.ms': 1000, + 'auto.offset.reset': 'latest', + } + ) + obj.subscribe(['test-topic']) + return obj, lambda t=None: obj.poll() if t is None else obj.poll(timeout=t) + + # Assertion 1: Infinite timeout chunks forever with 200ms intervals + obj1, blocking_call1 = create_api_obj("infinite") + interrupt_thread = threading.Thread(target=lambda: TestUtils.send_sigint_after_delay(0.3)) + interrupt_thread.daemon = True + interrupt_thread.start() + + interrupted = False + try: + blocking_call1() # Infinite timeout - should chunk every 200ms + except KeyboardInterrupt: + interrupted = True + finally: + obj1.close() + + assert interrupted, "Assertion 1 failed: Should have raised KeyboardInterrupt" + + # Assertion 2: Finite timeout exact multiple (1.0s = 5 chunks of 200ms) + obj2, blocking_call2 = create_api_obj("exact-multiple") + start = time.time() + result = blocking_call2(1.0) # Exactly 1000ms (5 chunks) + elapsed = time.time() - start + + if api_type == "producer": + assert isinstance(result, int), "Assertion 2 failed: poll() should return int" + else: + assert result is None, "Assertion 2 failed: Expected None (timeout)" + assert ( + WAKEABLE_POLL_TIMEOUT_MIN <= elapsed <= WAKEABLE_POLL_TIMEOUT_MAX + ), f"Assertion 2 failed: Timeout took {elapsed:.2f}s, expected ~1.0s" + obj2.close() + + # Assertion 3: Finite timeout not multiple (0.35s = 1 chunk + 150ms partial) + obj3, blocking_call3 = create_api_obj("not-multiple") + start = time.time() + result = blocking_call3(0.35) # 350ms (1 full chunk + 150ms partial) + elapsed = time.time() - start + + if api_type == "producer": + assert isinstance(result, int), "Assertion 3 failed: poll() should return int" + else: + assert result is None, "Assertion 3 failed: Expected None (timeout)" + assert ( + WAKEABLE_POLL_TIMEOUT_MIN <= elapsed <= WAKEABLE_POLL_TIMEOUT_MAX + ), f"Assertion 3 failed: Timeout took {elapsed:.2f}s, expected ~0.35s" + obj3.close() + + # Assertion 4: Very short timeout (< 200ms chunk size) + obj4, blocking_call4 = create_api_obj("very-short") + start = time.time() + result = blocking_call4(0.05) # 50ms (less than 200ms chunk) + elapsed = time.time() - start + + if api_type == "producer": + assert isinstance(result, int), "Assertion 4 failed: poll() should return int" + else: + assert result is None, "Assertion 4 failed: Expected None (timeout)" + # Short timeouts don't use chunking, so only check upper bound + assert elapsed < WAKEABLE_POLL_TIMEOUT_MAX, ( + f"Assertion 4 failed: Very short timeout took {elapsed:.2f}s, " f"expected < {WAKEABLE_POLL_TIMEOUT_MAX}s" + ) + obj4.close() + + # Assertion 5: Zero timeout (non-blocking) + obj5, blocking_call5 = create_api_obj("zero") + start = time.time() + result = blocking_call5(0.0) # Non-blocking + elapsed = time.time() - start + + assert elapsed < WAKEABLE_POLL_TIMEOUT_MAX, ( + f"Assertion 5 failed: Zero timeout took {elapsed:.2f}s, " f"expected < {WAKEABLE_POLL_TIMEOUT_MAX}s" + ) + obj5.close() + + # Assertion 6: Large finite timeout (10s = 50 chunks) + obj6, blocking_call6 = create_api_obj("large") + start = time.time() + result = blocking_call6(10.0) # 10 seconds (50 chunks) + elapsed = time.time() - start + + if api_type == "producer": + assert isinstance(result, int), "Assertion 6 failed: poll() should return int" + else: + assert result is None, "Assertion 6 failed: Expected None (timeout)" + # Use constants for bounds - verify timeout happened (loose bounds for large timeout) + assert elapsed >= WAKEABLE_POLL_TIMEOUT_MIN, ( + f"Assertion 6 failed: Timeout took {elapsed:.2f}s, " f"expected >= {WAKEABLE_POLL_TIMEOUT_MIN}s" + ) + assert elapsed <= 10.0 * 2.0, ( + f"Assertion 6 failed: Timeout took {elapsed:.2f}s, " f"expected <= {10.0 * 2.0}s (relative check)" + ) + obj6.close() + + # Assertion 7: Finite timeout with interruption (chunk calculation continues correctly) + obj7, blocking_call7 = create_api_obj("interrupt") + interrupt_thread = threading.Thread(target=lambda: TestUtils.send_sigint_after_delay(0.4)) + interrupt_thread.daemon = True + interrupt_thread.start() + + interrupted = False + timeout_value = WAKEABLE_POLL_TIMEOUT_MAX # Use constant instead of hardcoded 1.0 + try: + blocking_call7(timeout_value) # Use constant for timeout + except KeyboardInterrupt: + interrupted = True + finally: + obj7.close() + + assert interrupted, "Assertion 7 failed: Should have raised KeyboardInterrupt" + + +@pytest.mark.parametrize("api_type", ["producer", "consumer"]) +def test_check_signals_between_chunks_utility_function(api_type): + """Comprehensive test of check_signals_between_chunks() utility function through poll() API. + + Tests signal detection on first chunk, later chunk, no signal case, every chunk check, + and signal during finite timeout. + """ + + # Helper to create API object and blocking call + def create_api_obj(group_id_suffix=""): + if api_type == "producer": + obj = Producer({'socket.timeout.ms': 100, 'message.timeout.ms': 10}) + return obj, lambda t=None: obj.poll() if t is None else obj.poll(timeout=t) + else: + group_id = f'test-signal-{group_id_suffix}' if group_id_suffix else 'test-signal' + obj = TestConsumer( + { + 'group.id': group_id, + 'socket.timeout.ms': 100, + 'session.timeout.ms': 1000, + 'auto.offset.reset': 'latest', + } + ) + obj.subscribe(['test-topic']) + return obj, lambda t=None: obj.poll() if t is None else obj.poll(timeout=t) + + # Assertion 1: Signal detected on first chunk check + obj1, blocking_call1 = create_api_obj("first-chunk") + interrupt_thread = threading.Thread(target=lambda: TestUtils.send_sigint_after_delay(0.05)) + interrupt_thread.daemon = True + interrupt_thread.start() + + interrupted = False + try: + blocking_call1() # Infinite timeout + except KeyboardInterrupt: + interrupted = True + finally: + obj1.close() + + assert interrupted, "Assertion 1 failed: Should have raised KeyboardInterrupt" + + # Assertion 2: Signal detected on later chunk check + obj2, blocking_call2 = create_api_obj("later-chunk") + interrupt_thread = threading.Thread(target=lambda: TestUtils.send_sigint_after_delay(0.5)) + interrupt_thread.daemon = True + interrupt_thread.start() + + interrupted = False + try: + blocking_call2() # Infinite timeout + except KeyboardInterrupt: + interrupted = True + finally: + obj2.close() + + assert interrupted, "Assertion 2 failed: Should have raised KeyboardInterrupt" + + # Assertion 3: No signal - continues polling + obj3, blocking_call3 = create_api_obj("no-signal") + start = time.time() + result = blocking_call3(0.5) # 500ms, no signal + elapsed = time.time() - start + + if api_type == "producer": + assert isinstance(result, int), "Assertion 3 failed: poll() should return int" + else: + assert result is None, "Assertion 3 failed: Expected None (timeout), no signal should not interrupt" + assert ( + WAKEABLE_POLL_TIMEOUT_MIN <= elapsed <= WAKEABLE_POLL_TIMEOUT_MAX + ), f"Assertion 3 failed: No signal timeout took {elapsed:.2f}s, expected ~0.5s" + obj3.close() + + # Assertion 4: Signal checked every chunk (not just once) + obj4, blocking_call4 = create_api_obj("every-chunk") + # Send signal after 0.6 seconds (3 chunks should have passed) + interrupt_thread = threading.Thread(target=lambda: TestUtils.send_sigint_after_delay(0.6)) + interrupt_thread.daemon = True + interrupt_thread.start() + + interrupted = False + try: + blocking_call4() # Infinite timeout + except KeyboardInterrupt: + interrupted = True + finally: + obj4.close() + + assert interrupted, "Assertion 4 failed: Should have raised KeyboardInterrupt" + + # Assertion 5: Signal check works during finite timeout + obj5, blocking_call5 = create_api_obj("finite-timeout") + interrupt_thread = threading.Thread(target=lambda: TestUtils.send_sigint_after_delay(0.3)) + interrupt_thread.daemon = True + interrupt_thread.start() + + interrupted = False + timeout_value = WAKEABLE_POLL_TIMEOUT_MAX # Use constant instead of hardcoded 2.0 + try: + blocking_call5(timeout_value) # Use constant for timeout + except KeyboardInterrupt: + interrupted = True + finally: + obj5.close() + + assert interrupted, "Assertion 5 failed: Should have raised KeyboardInterrupt" + + +@pytest.mark.parametrize("api_type", ["producer", "consumer"]) +def test_utilities_interaction(api_type): + """Test that chunking and signal checking work together.""" + if api_type == "producer": + obj = Producer({'socket.timeout.ms': 100, 'message.timeout.ms': 10}) + + def blocking_call(t): + return obj.poll(timeout=t) + + else: + obj = TestConsumer( + { + 'group.id': 'test-interaction', + 'socket.timeout.ms': 100, + 'session.timeout.ms': 1000, + 'auto.offset.reset': 'latest', + } + ) + obj.subscribe(['test-topic']) + + def blocking_call(t): + return obj.poll(timeout=t) + + interrupt_thread = threading.Thread(target=lambda: TestUtils.send_sigint_after_delay(0.4)) + interrupt_thread.daemon = True + interrupt_thread.start() + + interrupted = False + timeout_value = WAKEABLE_POLL_TIMEOUT_MAX # Use constant instead of hardcoded 1.0 + try: + blocking_call(timeout_value) # Use constant for timeout + except KeyboardInterrupt: + interrupted = True + finally: + time.sleep(0.5) # Wait for signal thread + obj.close() + + # Key assertion: interrupted before full timeout + assert interrupted, "Should have been interrupted" + + +@pytest.mark.parametrize( + "api_type,method", + [ + ("producer", "poll"), + ("producer", "flush"), + ("consumer", "poll"), + ("consumer", "consume"), + ], +) +def test_can_be_interrupted(api_type, method): + """Test that blocking operations can be interrupted.""" + if api_type == "producer": + obj = Producer({'bootstrap.servers': 'localhost:9092', 'socket.timeout.ms': 100, 'message.timeout.ms': 10}) + if method == "poll": + + def blocking_call(): + return obj.poll() + + else: # flush + obj.produce('test-topic', value='test', callback=lambda err, msg: None) + + def blocking_call(): + return obj.flush() + + else: # consumer + obj = TestConsumer( + { + 'group.id': 'test-interrupt', + 'socket.timeout.ms': 100, + 'session.timeout.ms': 1000, + 'auto.offset.reset': 'latest', + } + ) + obj.subscribe(['test-topic']) + if method == "poll": + + def blocking_call(): + return obj.poll() + + else: # consume + + def blocking_call(): + return obj.consume() + + interrupt_thread = threading.Thread(target=lambda: TestUtils.send_sigint_after_delay(0.1)) + interrupt_thread.daemon = True + interrupt_thread.start() + + interrupted = False + try: + blocking_call() + except KeyboardInterrupt: + interrupted = True + finally: + # Wait for signal thread to complete + time.sleep(0.2) + obj.close() + + # Key assertion: operation was interruptible + assert interrupted, f"{api_type}.{method}() should be interruptible" + + +@pytest.mark.parametrize( + "api_type,method", + [ + ("producer", "poll"), + ("consumer", "poll"), + ("consumer", "consume"), + ], +) +def test_short_timeout_not_chunked(api_type, method): + """Test that short timeouts use non-chunked path.""" + if api_type == "producer": + obj = Producer({'socket.timeout.ms': 100, 'message.timeout.ms': 10}) + + def blocking_call(t): + return obj.poll(timeout=t) + + else: + obj = TestConsumer( + { + 'group.id': 'test-short', + 'socket.timeout.ms': 100, + 'session.timeout.ms': 1000, + 'auto.offset.reset': 'latest', + } + ) + obj.subscribe(['test-topic']) + if method == "poll": + + def blocking_call(t): + return obj.poll(timeout=t) + + else: # consume + + def blocking_call(t): + return obj.consume(timeout=t) + + start = time.time() + if method == "consume": + result = blocking_call(0.1) + assert isinstance(result, list) + else: + result = blocking_call(0.1) + elapsed = time.time() - start + + obj.close() + + # Key assertion: short timeout completes quickly (use constant) + assert elapsed < WAKEABLE_POLL_TIMEOUT_MAX, f"Short timeout should complete quickly, took {elapsed:.2f}s" + + +def test_flush_empty_queue_returns_immediately(): + """Test that flush() with no messages returns immediately.""" + producer = Producer({'bootstrap.servers': 'localhost:9092', 'socket.timeout.ms': 100, 'message.timeout.ms': 10}) + + start = time.time() + qlen = producer.flush(timeout=0.5) + elapsed = time.time() - start + + producer.close() + + # Key assertion: empty flush is fast + assert qlen == 0, "Empty queue should return 0" + assert elapsed < WAKEABLE_POLL_TIMEOUT_MAX, f"Empty flush should return quickly, took {elapsed:.2f}s"