Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions src/Common/HttpResponseMessageExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
using System.Net;
using System.Net.Http;

namespace ModelContextProtocol;

/// <summary>
/// Extension methods for <see cref="HttpResponseMessage"/>.
/// </summary>
internal static class HttpResponseMessageExtensions
{
/// <summary>
/// Throws an <see cref="HttpRequestException"/> if the <see cref="HttpResponseMessage.IsSuccessStatusCode"/> property is <see langword="false"/>.
/// Unlike <see cref="HttpResponseMessage.EnsureSuccessStatusCode"/>, this method includes the response body in the exception message
/// to help diagnose issues when the server returns error details in the response body.
/// </summary>
/// <param name="response">The HTTP response message to check.</param>
/// <param name="cancellationToken">The token to monitor for cancellation requests.</param>
/// <returns>A task that represents the asynchronous operation.</returns>
/// <exception cref="HttpRequestException">The response status code does not indicate success.</exception>
public static async Task EnsureSuccessStatusCodeWithResponseBodyAsync(this HttpResponseMessage response, CancellationToken cancellationToken = default)
{
if (!response.IsSuccessStatusCode)
{
string? responseBody = null;
try
{
// Add a timeout to prevent hanging if the server sends an error response but doesn't end the request.
using var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
cts.CancelAfter(TimeSpan.FromSeconds(5));
responseBody = await response.Content.ReadAsStringAsync(cts.Token).ConfigureAwait(false);
}
catch (Exception ex) when (ex is not OperationCanceledException || !cancellationToken.IsCancellationRequested)
{
// Ignore errors reading the response body (e.g., stream closed, timeout) - we'll throw without it.
// Allow cancellation exceptions to propagate only if the original token was cancelled.
}

throw CreateHttpRequestException(response, responseBody);
}
}

/// <summary>
/// Creates an <see cref="HttpRequestException"/> for a non-success response, including the response body in the message.
/// </summary>
/// <param name="response">The HTTP response message.</param>
/// <param name="responseBody">The response body content, if available.</param>
/// <returns>An <see cref="HttpRequestException"/> with the response details.</returns>
public static HttpRequestException CreateHttpRequestException(HttpResponseMessage response, string? responseBody)
{
int statusCodeInt = (int)response.StatusCode;
string message = string.IsNullOrEmpty(responseBody)
? $"Response status code does not indicate success: {statusCodeInt} ({response.ReasonPhrase})."
: $"Response status code does not indicate success: {statusCodeInt} ({response.ReasonPhrase}). Response body: {responseBody}";

#if NET
return new HttpRequestException(message, inner: null, response.StatusCode);
#else
return new HttpRequestException(message);
#endif
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,7 @@ private async Task<string> ExchangeCodeForTokenAsync(
};

using var httpResponse = await _httpClient.SendAsync(request, cancellationToken).ConfigureAwait(false);
httpResponse.EnsureSuccessStatusCode();
await httpResponse.EnsureSuccessStatusCodeWithResponseBodyAsync(cancellationToken).ConfigureAwait(false);

var tokens = await HandleSuccessfulTokenResponseAsync(httpResponse, cancellationToken).ConfigureAwait(false);
LogOAuthAuthorizationCompleted();
Expand Down Expand Up @@ -544,7 +544,7 @@ private async Task<TokenContainer> HandleSuccessfulTokenResponseAsync(HttpRespon
using var httpResponse = await _httpClient.GetAsync(metadataUrl, cancellationToken).ConfigureAwait(false);
if (requireSuccess)
{
httpResponse.EnsureSuccessStatusCode();
await httpResponse.EnsureSuccessStatusCodeWithResponseBodyAsync(cancellationToken).ConfigureAwait(false);
}
else if (!httpResponse.IsSuccessStatusCode)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,19 @@ public override async Task SendMessageAsync(

if (!response.IsSuccessStatusCode)
{
// Read the response body once to include in both logging and exception
string responseBody = await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false);

if (_logger.IsEnabled(LogLevel.Trace))
{
LogRejectedPostSensitive(Name, messageId, await response.Content.ReadAsStringAsync(cancellationToken).ConfigureAwait(false));
LogRejectedPostSensitive(Name, messageId, responseBody);
}
else
{
LogRejectedPost(Name, messageId);
}

response.EnsureSuccessStatusCode();
throw HttpResponseMessageExtensions.CreateHttpRequestException(response, responseBody);
}
}

Expand Down Expand Up @@ -146,7 +149,7 @@ private async Task ReceiveMessagesAsync(CancellationToken cancellationToken)

using var response = await _httpClient.SendAsync(request, message: null, cancellationToken).ConfigureAwait(false);

response.EnsureSuccessStatusCode();
await response.EnsureSuccessStatusCodeWithResponseBodyAsync(cancellationToken).ConfigureAwait(false);

using var stream = await response.Content.ReadAsStreamAsync(cancellationToken).ConfigureAwait(false);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public override async Task SendMessageAsync(JsonRpcMessage message, Cancellation
{
// Immediately dispose the response. SendHttpRequestAsync only returns the response so the auto transport can look at it.
using var response = await SendHttpRequestAsync(message, cancellationToken).ConfigureAwait(false);
response.EnsureSuccessStatusCode();
await response.EnsureSuccessStatusCodeWithResponseBodyAsync(cancellationToken).ConfigureAwait(false);
}

// This is used by the auto transport so it can fall back and try SSE given a non-200 response without catching an exception.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
<Compile Include="..\Common\Throw.cs" Link="Throw.cs" />
<Compile Include="..\Common\Obsoletions.cs" Link="Obsoletions.cs" />
<Compile Include="..\Common\Experimentals.cs" Link="Experimentals.cs" />
<Compile Include="..\Common\HttpResponseMessageExtensions.cs" Link="HttpResponseMessageExtensions.cs" />
<Compile Include="..\Common\ServerSentEvents\**\*.cs" Link="ServerSentEvents\%(RecursiveDir)%(FileName)%(Extension)" />
</ItemGroup>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,32 @@ public async Task ConnectAsync_Throws_Exception_On_Failure()
Assert.Equal(1, retries);
}

[Fact]
public async Task ConnectAsync_Throws_HttpRequestException_With_ResponseBody_On_ErrorStatusCode()
{
using var mockHttpHandler = new MockHttpHandler();
using var httpClient = new HttpClient(mockHttpHandler);
await using var transport = new HttpClientTransport(_transportOptions, httpClient, LoggerFactory);

const string errorDetails = "Bad request: Invalid MCP protocol version";
mockHttpHandler.RequestHandler = (request) =>
{
return Task.FromResult(new HttpResponseMessage
{
StatusCode = HttpStatusCode.BadRequest,
ReasonPhrase = "Bad Request",
Content = new StringContent(errorDetails)
});
};

var httpException = await Assert.ThrowsAsync<HttpRequestException>(() => transport.ConnectAsync(TestContext.Current.CancellationToken));
Assert.Contains(errorDetails, httpException.Message);
Assert.Contains("400", httpException.Message);
#if NET
Assert.Equal(HttpStatusCode.BadRequest, httpException.StatusCode);
#endif
}

[Fact]
public async Task SendMessageAsync_Handles_Accepted_Response()
{
Expand Down Expand Up @@ -120,6 +146,53 @@ public async Task SendMessageAsync_Handles_Accepted_Response()
Assert.True(true);
}

[Fact]
public async Task SendMessageAsync_Throws_HttpRequestException_With_ResponseBody_On_ErrorStatusCode()
{
using var mockHttpHandler = new MockHttpHandler();
using var httpClient = new HttpClient(mockHttpHandler);
await using var transport = new HttpClientTransport(_transportOptions, httpClient, LoggerFactory);

var firstCall = true;
const string errorDetails = "Invalid JSON-RPC message format: missing 'id' field";

mockHttpHandler.RequestHandler = (request) =>
{
if (request.Method == HttpMethod.Post && request.RequestUri?.AbsoluteUri == "http://localhost:8080/sseendpoint")
{
return Task.FromResult(new HttpResponseMessage
{
StatusCode = HttpStatusCode.BadRequest,
ReasonPhrase = "Bad Request",
Content = new StringContent(errorDetails)
});
}
else
{
if (!firstCall)
throw new IOException("Abort");
else
firstCall = false;

return Task.FromResult(new HttpResponseMessage
{
StatusCode = HttpStatusCode.OK,
Content = new StringContent("event: endpoint\r\ndata: /sseendpoint\r\n\r\n")
});
}
};

await using var session = await transport.ConnectAsync(TestContext.Current.CancellationToken);
var httpException = await Assert.ThrowsAsync<HttpRequestException>(() =>
session.SendMessageAsync(new JsonRpcRequest { Method = RequestMethods.Initialize, Id = new RequestId(44) }, CancellationToken.None));

Assert.Contains(errorDetails, httpException.Message);
Assert.Contains("400", httpException.Message);
#if NET
Assert.Equal(HttpStatusCode.BadRequest, httpException.StatusCode);
#endif
}

[Fact]
public async Task ReceiveMessagesAsync_Handles_Messages()
{
Expand Down
Loading