Skip to content
Merged
4 changes: 3 additions & 1 deletion rstream/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,9 @@ async def store_offset(self, stream: str, reference: str, offset: int) -> None:
)
)

async def declare_publisher(self, stream: str, reference: str, publisher_id: int) -> None:
async def declare_publisher(self, stream: str, reference: Optional[str], publisher_id: int) -> None:
if reference is None:
reference = ""
await self.sync_request(
schema.DeclarePublisher(
self._corr_id_seq.next(),
Expand Down
19 changes: 11 additions & 8 deletions rstream/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,12 @@ async def _create_subscriber(
logger.debug("_create_subscriber(): Create subscriber")
# need to check if the current subscribers for this stream reached the max limit

client = await self._get_or_create_client(stream)
await client.inc_available_id()
# We can have multiple subscribers sharing same connection, so their ids must be distinct
subscription_id = await self.get_available_id()

client = await self._get_or_create_client(stream)
await client.inc_available_id()

decoder = decoder or (lambda x: x)
# the ID is unique per connection
subscriber = self._subscribers[subscription_id] = _Subscriber(
Expand Down Expand Up @@ -320,12 +322,13 @@ async def subscribe(
return subscriber.subscription_id

async def get_available_id(self) -> int:
# ok = True
# for _client in self._clients.keys():
# ok = ok and await self._clients[_client].get_count_available_ids()> 0
#
# if not ok:
# raise exceptions.MaxConsumersPerConnectionReached("Max consumers per connection reached")
ok = True
for _client in self._clients.keys():
v = await self._clients[_client].get_count_available_ids()
ok = ok and v > 0

if not ok:
raise exceptions.MaxConsumersPerConnectionReached("Max consumers per connection reached")

for subscribing_id in range(0, self._max_subscribers_by_connection):
if subscribing_id not in self._subscribers:
Expand Down
81 changes: 47 additions & 34 deletions rstream/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import inspect
import logging
import ssl
import uuid
from collections import defaultdict
from dataclasses import dataclass
from functools import partial
Expand All @@ -33,8 +32,11 @@
CompressionType,
ICompressionCodec,
)
from .constants import Key, SlasMechanism
from .exceptions import StreamDoesNotExist
from .constants import MAX_ITEM_ALLOWED, Key, SlasMechanism
from .exceptions import (
MaxPublishersPerConnectionReached,
StreamDoesNotExist,
)
from .utils import OnClosedErrorInfo, RawMessage

MessageT = TypeVar("MessageT", _MessageProtocol, bytes)
Expand All @@ -47,7 +49,7 @@
@dataclass
class _Publisher:
id: int
reference: str
reference: Optional[str]
stream: str
sequence: utils.MonotonicSeq
client: Client
Expand Down Expand Up @@ -84,7 +86,7 @@ def __init__(
heartbeat: int = 60,
load_balancer_mode: bool = False,
max_retries: int = 20,
max_publishers_by_connection=256,
max_publishers_by_connection=MAX_ITEM_ALLOWED,
default_batch_publishing_delay: float = 3,
default_context_switch_value: int = 1000,
connection_name: str = "",
Expand All @@ -108,7 +110,7 @@ def __init__(
self._clients: dict[str, Client] = {}
self._publishers: dict[str, _Publisher] = {}
self._waiting_for_confirm: dict[
str, dict[asyncio.Future[None] | CB[ConfirmationStatus], set[int]]
int, dict[asyncio.Future[None] | CB[ConfirmationStatus], set[int]]
] = defaultdict(dict)
self._lock = asyncio.Lock()
# dictionary [stream][list] of buffered messages to send asynchronously
Expand All @@ -124,7 +126,7 @@ def __init__(
self._close_called = False
self._connection_name = connection_name
self._filter_value_extractor: Optional[CB_F[Any]] = filter_value_extractor
self.publisher_id = 0
# self.publisher_id = 0
Comment thread
Gsantomaggio marked this conversation as resolved.
Outdated
self._max_publishers_by_connection = max_publishers_by_connection

if self._connection_name is None or self._connection_name == "":
Expand Down Expand Up @@ -185,9 +187,9 @@ async def close(self) -> None:
logger.warning("timeout when closing producer and deleting publisher")
except BaseException:
logger.exception("exception in delete_publisher in Producer.close")
publisher.client.remove_handler(schema.PublishConfirm, publisher.reference)
publisher.client.remove_handler(schema.PublishError, publisher.reference)
publisher.client.remove_handler(schema.MetadataUpdate, publisher.reference)
publisher.client.remove_handler(schema.PublishConfirm, str(publisher.id))
publisher.client.remove_handler(schema.PublishError, str(publisher.id))
publisher.client.remove_handler(schema.MetadataUpdate, str(publisher.id))

self._publishers.clear()

Expand Down Expand Up @@ -230,17 +232,17 @@ async def _get_or_create_publisher(
) -> _Publisher:
if stream in self._publishers:
publisher = self._publishers[stream]
if publisher_name is not None:
assert publisher.reference == publisher_name
return publisher
try:
logger.debug("_get_or_create_publisher(): Getting/Creating new publisher")
client = await self._get_or_create_client(stream)

# We can have multiple publishers sharing same connection, so their ids must be distinct
publisher_id = await client.inc_available_id()
publisher_id = self._get_next_available_id()
await client.inc_available_id()

reference = publisher_name or f"{stream}_publisher_{publisher_id}_{str(uuid.uuid4())}"
reference = publisher_name
# reference = publisher_name or f"{stream}_publisher_{publisher_id}_{str(uuid.uuid4())}"
Comment thread
Gsantomaggio marked this conversation as resolved.
Outdated
Comment thread
Gsantomaggio marked this conversation as resolved.
Outdated
publisher = self._publishers[stream] = _Publisher(
id=publisher_id,
stream=stream,
Expand All @@ -256,11 +258,12 @@ async def _get_or_create_publisher(
publisher_id=publisher.id,
)

sequence = await client.query_publisher_sequence(
stream=stream,
reference=reference,
)
publisher.sequence.set(sequence + 1)
if reference is not None:
sequence = await client.query_publisher_sequence(
stream=stream,
reference=reference,
)
publisher.sequence.set(sequence + 1)

except StreamDoesNotExist as e:
await self._maybe_clean_up_during_lost_connection(stream)
Expand All @@ -275,21 +278,30 @@ async def _get_or_create_publisher(
client.add_handler(
schema.PublishConfirm,
partial(self._on_publish_confirm, publisher=publisher),
name=publisher.reference,
name=str(publisher.id),
)
client.add_handler(
schema.PublishError,
partial(self._on_publish_error, publisher=publisher),
name=publisher.reference,
name=str(publisher.id),
)
client.add_handler(
schema.MetadataUpdate,
partial(self._on_metadata_update),
name=publisher.reference,
name=str(publisher.id),
)

return publisher

def _get_next_available_id(self) -> int:
# given this list self._publishers we need to find the next available id
# we need to loop the list of publishers and find the first available id
for i in range(0, self._max_publishers_by_connection):
if all(p.id != i for p in self._publishers.values()):
return i

raise MaxPublishersPerConnectionReached("Max publishers per connection reached")

async def send_batch(
self,
stream: str,
Expand Down Expand Up @@ -378,16 +390,16 @@ async def _send_batch(
if not sync:
logger.debug("_send_batch: Not sync case")
if callback is not None:
if callback not in self._waiting_for_confirm[publisher.reference]:
self._waiting_for_confirm[publisher.reference][callback] = set()
if callback not in self._waiting_for_confirm[publisher.id]:
self._waiting_for_confirm[publisher.id][callback] = set()

self._waiting_for_confirm[publisher.reference][callback].update(publishing_ids)
self._waiting_for_confirm[publisher.id][callback].update(publishing_ids)

# this is just called in case of send_wait
else:
logger.debug("_send_batch: sync case")
future: asyncio.Future[None] = asyncio.Future()
self._waiting_for_confirm[publisher.reference][future] = publishing_ids.copy()
self._waiting_for_confirm[publisher.id][future] = publishing_ids.copy()
await asyncio.wait_for(future, timeout)

return list(publishing_ids)
Expand Down Expand Up @@ -473,6 +485,7 @@ async def _send_batch_async(
)
# publishing_ids.update([m.publishing_id for m in messages])
messages.clear()
publishing_id = 0
Comment thread
Gsantomaggio marked this conversation as resolved.
Outdated
for _ in range(item.entry.messages_count()):
publishing_id = publisher.sequence.next()

Expand Down Expand Up @@ -515,10 +528,10 @@ async def _send_batch_async(
publishing_ids.update([m.publishing_id for m in messages])

for callback in publishing_ids_callback:
if callback not in self._waiting_for_confirm[publisher.reference]:
self._waiting_for_confirm[publisher.reference][callback] = set()
if callback not in self._waiting_for_confirm[publisher.id]:
self._waiting_for_confirm[publisher.id][callback] = set()

self._waiting_for_confirm[publisher.reference][callback].update(publishing_ids_callback[callback])
self._waiting_for_confirm[publisher.id][callback].update(publishing_ids_callback[callback])

return list(publishing_ids)

Expand Down Expand Up @@ -634,7 +647,7 @@ async def _on_publish_confirm(self, frame: schema.PublishConfirm, publisher: _Pu
if frame.publisher_id != publisher.id:
return

waiting = self._waiting_for_confirm[publisher.reference]
waiting = self._waiting_for_confirm[publisher.id]
for confirmation in list(waiting):
logger.debug("_on_publish_confirm: looping over confirmations")
ids = waiting[confirmation]
Expand All @@ -657,7 +670,7 @@ async def _on_publish_error(self, frame: schema.PublishError, publisher: _Publis
if frame.publisher_id != publisher.id:
return

waiting = self._waiting_for_confirm[publisher.reference]
waiting = self._waiting_for_confirm[publisher.id]
for error in frame.errors:
exc = exceptions.ServerError.from_code(error.response_code)
for confirmation in list(waiting):
Expand Down Expand Up @@ -701,9 +714,9 @@ async def clean_up_publishers(self, stream: str):
if stream in self._publishers:
publisher = self._publishers[stream]
await publisher.client.delete_publisher(publisher.id)
publisher.client.remove_handler(schema.PublishConfirm, publisher.reference)
publisher.client.remove_handler(schema.PublishError, publisher.reference)
publisher.client.remove_handler(schema.MetadataUpdate, publisher.reference)
publisher.client.remove_handler(schema.PublishConfirm, str(publisher.id))
publisher.client.remove_handler(schema.PublishError, str(publisher.id))
publisher.client.remove_handler(schema.MetadataUpdate, str(publisher.id))
del self._publishers[stream]

async def delete_stream(self, stream: str, missing_ok: bool = False) -> None:
Expand Down
8 changes: 4 additions & 4 deletions tests/test_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,9 @@ async def test_publish_deduplication(stream: str, producer: Producer, consumer:
async def publish_with_ids(*ids):
for publishing_id in ids:
await producer.send_wait(
stream,
RawMessage(f"test_{publishing_id}".encode(), publishing_id),
stream=stream,
message=RawMessage(f"test_{publishing_id}".encode(), publishing_id=publishing_id),
publisher_name="MyProducerName",
)

await publish_with_ids(1, 2, 3)
Expand All @@ -259,8 +260,7 @@ async def test_publish_deduplication_async(stream: str, producer: Producer, cons
async def publish_with_ids(*ids):
for publishing_id in ids:
await producer.send(
stream,
RawMessage(f"test_{publishing_id}".encode(), publishing_id),
stream, RawMessage(f"test_{publishing_id}".encode(), publishing_id), "MyProducerName"
Comment thread
Gsantomaggio marked this conversation as resolved.
Outdated
)

await publish_with_ids(1, 2, 3)
Expand Down
26 changes: 26 additions & 0 deletions tests/test_consumer_validate_id.py → tests/test_validate_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,29 @@ def process_data(message_context: MessageContext):

await sub("my-subscriber-name")
await sub(None)


async def test_validate_publisher_id_to_stream(producer: Producer, pytestconfig) -> None:
now = int(time.time())
streams = [
"test_test_validate_publisher_id_to_stream_0_{}".format(now),
"test_test_validate_publisher_id_to_stream_1_{}".format(now),
"test_test_validate_publisher_id_to_stream_2_{}".format(now),
]

for stream in streams:
await producer.create_stream(stream)

for stream in streams:
for i in range(2):
await producer.send_wait(stream, AMQPMessage(body=bytes("hello: {}".format(i), "utf-8")))

await asyncio.sleep(1)
assert len(producer._publishers) == 3
for _publisher in producer._publishers.values():
assert _publisher.id in [0, 1, 2]

await producer.close()
assert len(producer._publishers) == 0
for stream in streams:
await producer.delete_stream(stream)