Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,10 @@ await SendGetSseRequestWithRetriesAsync(
SseStreamState state,
CancellationToken cancellationToken)
{
int attempt = 0;
// When LastEventId is null, the first attempt is the initial GET SSE connection (not a reconnection),
// so we start at -1 to avoid counting it against MaxReconnectionAttempts.
// When LastEventId is already set, all attempts are true reconnections, so we start at 0.
int attempt = state.LastEventId is null ? -1 : 0;

// Delay before first attempt if we're reconnecting (have a Last-Event-ID)
bool shouldDelay = state.LastEventId is not null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using ModelContextProtocol.Protocol;
using ModelContextProtocol.Tests.Utils;
using System.Net;
using System.Text;

namespace ModelContextProtocol.Tests.Transport;

Expand Down Expand Up @@ -246,4 +247,83 @@ public async Task DisposeAsync_Should_Dispose_Resources()
var transportBase = Assert.IsAssignableFrom<TransportBase>(session);
Assert.False(transportBase.IsConnected);
}

[Fact]
public async Task StreamableHttp_InitialGetSseConnection_DoesNotCountAgainstMaxReconnectionAttempts()
{
// Arrange: The initial GET SSE connection (with no Last-Event-ID) is the initial connection,
// not a reconnection. It should not count against MaxReconnectionAttempts.
// With MaxReconnectionAttempts=2, we expect 1 initial + 2 reconnection = 3 total GET requests.
const int MaxReconnectionAttempts = 2;

var getRequestCount = 0;
var allGetRequestsDone = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);

var options = new HttpClientTransportOptions
{
Endpoint = new Uri("http://localhost:8080"),
TransportMode = HttpTransportMode.StreamableHttp,
MaxReconnectionAttempts = MaxReconnectionAttempts,
DefaultReconnectionInterval = TimeSpan.FromMilliseconds(1),
};

using var mockHttpHandler = new MockHttpHandler();
using var httpClient = new HttpClient(mockHttpHandler);
await using var transport = new HttpClientTransport(options, httpClient, LoggerFactory);

mockHttpHandler.RequestHandler = (request) =>
{
if (request.Method == HttpMethod.Post)
{
// Return a successful initialize response with a session-id header.
// This triggers ReceiveUnsolicitedMessagesAsync which starts the GET SSE stream.
var response = new HttpResponseMessage
{
StatusCode = HttpStatusCode.OK,
Content = new StringContent(
"""{"jsonrpc":"2.0","id":1,"result":{"protocolVersion":"2025-03-26","capabilities":{},"serverInfo":{"name":"TestServer","version":"1.0.0"}}}""",
Encoding.UTF8,
"application/json"),
};
response.Headers.Add("Mcp-Session-Id", "test-session");
return Task.FromResult(response);
}

if (request.Method == HttpMethod.Get)
{
// Return 500 for all GET SSE requests to force the retry loop to exhaust all attempts.
var count = Interlocked.Increment(ref getRequestCount);
if (count == 1 + MaxReconnectionAttempts)
{
allGetRequestsDone.TrySetResult(true);
}
return Task.FromResult(new HttpResponseMessage
{
StatusCode = HttpStatusCode.InternalServerError,
});
}

if (request.Method == HttpMethod.Delete)
{
return Task.FromResult(new HttpResponseMessage
{
StatusCode = HttpStatusCode.OK,
});
}

throw new InvalidOperationException($"Unexpected request: {request.Method}");
};

// Act - Connect and send the initialize request, which starts the background GET SSE task.
await using var session = await transport.ConnectAsync(TestContext.Current.CancellationToken);
await session.SendMessageAsync(
new JsonRpcRequest { Method = RequestMethods.Initialize, Id = new RequestId(1) },
TestContext.Current.CancellationToken);

// Wait for all expected GET requests to be made before disposing.
await allGetRequestsDone.Task.WaitAsync(TimeSpan.FromSeconds(10), TestContext.Current.CancellationToken);

// Assert - Total GET requests = 1 initial connection + MaxReconnectionAttempts reconnections.
Assert.Equal(1 + MaxReconnectionAttempts, getRequestCount);
}
}