Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ public async Task Client_CanResumePostResponseStream_AfterDisconnection()
[Fact]
public async Task Client_CanResumeUnsolicitedMessageStream_AfterDisconnection()
{
var timeout = TimeSpan.FromSeconds(10);
using var faultingStreamHandler = new FaultingStreamHandler()
{
InnerHandler = SocketsHttpHandler,
Expand All @@ -304,12 +305,12 @@ public async Task Client_CanResumeUnsolicitedMessageStream_AfterDisconnection()
await using var client = await ConnectClientAsync();

// Get the server instance
var server = await serverTcs.Task.WaitAsync(TestContext.Current.CancellationToken);
var server = await serverTcs.Task.WaitAsync(timeout, TestContext.Current.CancellationToken);

// Set up notification tracking with unique messages
var clientReceivedInitialNotificationTcs = new TaskCompletionSource();
var clientReceivedReplayedNotificationTcs = new TaskCompletionSource();
var clientReceivedReconnectNotificationTcs = new TaskCompletionSource();
var clientReceivedInitialNotificationTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
var clientReceivedReplayedNotificationTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
var clientReceivedReconnectNotificationTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);

const string CustomNotificationMethod = "test/custom_notification";
const string InitialMessage = "Initial notification";
Expand Down Expand Up @@ -343,11 +344,14 @@ public async Task Client_CanResumeUnsolicitedMessageStream_AfterDisconnection()
return default;
});

// Wait for the client's unsolicited message stream to be established before sending notifications
await faultingStreamHandler.WaitForUnsolicitedMessageStreamAsync(TestContext.Current.CancellationToken);

// Send a custom notification to the client on the unsolicited message stream
await server.SendNotificationAsync(CustomNotificationMethod, new JsonObject { ["message"] = InitialMessage }, cancellationToken: TestContext.Current.CancellationToken);

// Wait for client to receive the first notification
await clientReceivedInitialNotificationTcs.Task.WaitAsync(TestContext.Current.CancellationToken);
await clientReceivedInitialNotificationTcs.Task.WaitAsync(timeout, TestContext.Current.CancellationToken);

// Fault the unsolicited message stream (GET SSE)
var reconnectAttempt = await faultingStreamHandler.TriggerFaultAsync(TestContext.Current.CancellationToken);
Expand All @@ -359,13 +363,13 @@ public async Task Client_CanResumeUnsolicitedMessageStream_AfterDisconnection()
reconnectAttempt.Continue();

// Wait for client to receive the notification via replay
await clientReceivedReplayedNotificationTcs.Task.WaitAsync(TestContext.Current.CancellationToken);
await clientReceivedReplayedNotificationTcs.Task.WaitAsync(timeout, TestContext.Current.CancellationToken);

// Send a final notification while the client has reconnected - this should be handled by the transport
await server.SendNotificationAsync(CustomNotificationMethod, new JsonObject { ["message"] = ReconnectMessage }, cancellationToken: TestContext.Current.CancellationToken);

// Wait for the client to receive the final notification
await clientReceivedReconnectNotificationTcs.Task.WaitAsync(TestContext.Current.CancellationToken);
await clientReceivedReconnectNotificationTcs.Task.WaitAsync(timeout, TestContext.Current.CancellationToken);

// Assert each notification was received exactly once
Assert.Equal(1, initialNotificationReceivedCount);
Expand Down Expand Up @@ -531,7 +535,7 @@ public async Task PostResponse_EndsAndSseEventStreamWriterIsDisposed_WhenWriteEv
timeoutCts.CancelAfter(TimeSpan.FromSeconds(10));

// The call task should throw an OCE due to cancellation
await Assert.ThrowsAsync<OperationCanceledException>(() => callTask).WaitAsync(timeoutCts.Token);
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => callTask).WaitAsync(timeoutCts.Token);

// Wait for the writer to be disposed
await blockingStore.DisposedTask.WaitAsync(timeoutCts.Token);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ internal sealed class FaultingStreamHandler : DelegatingHandler
{
private FaultingStream? _lastStream;
private TaskCompletionSource? _reconnectTcs;
private TaskCompletionSource _unsolicitedMessageStreamReadyTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);

public Task WaitForUnsolicitedMessageStreamAsync(CancellationToken cancellationToken = default)
=> _unsolicitedMessageStreamReadyTcs.Task.WaitAsync(cancellationToken);

internal void SignalUnsolicitedMessageStreamReady() => _unsolicitedMessageStreamReadyTcs.TrySetResult();

public async Task<ReconnectAttempt> TriggerFaultAsync(CancellationToken cancellationToken)
{
Expand All @@ -24,6 +30,9 @@ public async Task<ReconnectAttempt> TriggerFaultAsync(CancellationToken cancella
throw new InvalidOperationException("Cannot trigger a fault while already waiting for reconnection.");
}

// Reset the TCS so we can wait for the reconnected unsolicited message stream
_unsolicitedMessageStreamReadyTcs = new(TaskCreationOptions.RunContinuationsAsynchronously);

_reconnectTcs = new();
await _lastStream.TriggerFaultAsync(cancellationToken);

Expand All @@ -46,6 +55,7 @@ protected override async Task<HttpResponseMessage> SendAsync(
_reconnectTcs = null;
}

var isGetRequest = request.Method == HttpMethod.Get;
var response = await base.SendAsync(request, cancellationToken);

// Only wrap SSE streams (text/event-stream)
Expand All @@ -63,6 +73,13 @@ protected override async Task<HttpResponseMessage> SendAsync(
}

response.Content = newContent;

// For GET requests (unsolicited message stream), set up the stream to signal
// when first data is read. This ensures the server's transport handler is ready.
if (isGetRequest)
{
_lastStream.SetReadyCallback(SignalUnsolicitedMessageStreamReady);
}
}

return response;
Expand All @@ -89,10 +106,14 @@ private sealed class FaultingStream(Stream innerStream) : Stream
{
private readonly CancellationTokenSource _cts = new();
private TaskCompletionSource? _faultTcs;
private Action? _readyCallback;
private bool _readySignaled;
private bool _disposed;

public bool IsDisposed => _disposed;

public void SetReadyCallback(Action callback) => _readyCallback = callback;

public async Task TriggerFaultAsync(CancellationToken cancellationToken)
{
if (_faultTcs is not null)
Expand Down Expand Up @@ -131,6 +152,12 @@ public override async ValueTask<int> ReadAsync(Memory<byte> buffer, Cancellation

_cts.Token.ThrowIfCancellationRequested();

if (bytesRead > 0 && !_readySignaled)
{
_readySignaled = true;
_readyCallback?.Invoke();
}

return bytesRead;
}
catch (OperationCanceledException) when (_cts.IsCancellationRequested)
Expand Down