diff --git a/src/fastmcp/server/auth/oauth_proxy.py b/src/fastmcp/server/auth/oauth_proxy.py index 55a19d766..5d0dd217d 100644 --- a/src/fastmcp/server/auth/oauth_proxy.py +++ b/src/fastmcp/server/auth/oauth_proxy.py @@ -28,6 +28,10 @@ import httpx from authlib.common.security import generate_token from authlib.integrations.httpx_client import AsyncOAuth2Client +from mcp.server.auth.handlers.token import TokenErrorResponse, TokenSuccessResponse +from mcp.server.auth.handlers.token import TokenHandler as _SDKTokenHandler +from mcp.server.auth.json_response import PydanticJSONResponse +from mcp.server.auth.middleware.client_auth import ClientAuthenticator from mcp.server.auth.provider import ( AccessToken, AuthorizationCode, @@ -35,6 +39,7 @@ RefreshToken, TokenError, ) +from mcp.server.auth.routes import cors_middleware from mcp.server.auth.settings import ( ClientRegistrationOptions, RevocationOptions, @@ -42,7 +47,7 @@ from mcp.shared.auth import OAuthClientInformationFull, OAuthToken from pydantic import AnyHttpUrl, AnyUrl, SecretStr from starlette.requests import Request -from starlette.responses import JSONResponse, RedirectResponse +from starlette.responses import RedirectResponse from starlette.routing import Route import fastmcp @@ -122,6 +127,55 @@ def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl: HTTP_TIMEOUT_SECONDS: Final[int] = 30 +class TokenHandler(_SDKTokenHandler): + """TokenHandler that returns OAuth 2.1 compliant error responses. + + The MCP SDK always returns HTTP 400 for all client authentication issues. + However, OAuth 2.1 Section 5.3 and the MCP specification require that + invalid or expired tokens MUST receive a HTTP 401 response. + + This handler extends the base MCP SDK TokenHandler to transform client + authentication failures into OAuth 2.1 compliant responses: + - Changes 'unauthorized_client' to 'invalid_client' error code + - Returns HTTP 401 status code instead of 400 for client auth failures + + Per OAuth 2.1 Section 5.3: "The authorization server MAY return an HTTP 401 + (Unauthorized) status code to indicate which HTTP authentication schemes + are supported." + + Per MCP spec: "Invalid or expired tokens MUST receive a HTTP 401 response." + """ + + def response(self, obj: TokenSuccessResponse | TokenErrorResponse): + """Override response method to provide OAuth 2.1 compliant error handling.""" + # Check if this is a client authentication failure (not just unauthorized for grant type) + # unauthorized_client can mean two things: + # 1. Client authentication failed (client_id not found or wrong credentials) -> invalid_client 401 + # 2. Client not authorized for this grant type -> unauthorized_client 400 (correct per spec) + if ( + isinstance(obj, TokenErrorResponse) + and obj.error == "unauthorized_client" + and obj.error_description + and "Invalid client_id" in obj.error_description + ): + # Transform client auth failure to OAuth 2.1 compliant response + return PydanticJSONResponse( + content=TokenErrorResponse( + error="invalid_client", + error_description=obj.error_description, + error_uri=obj.error_uri, + ), + status_code=401, + headers={ + "Cache-Control": "no-store", + "Pragma": "no-cache", + }, + ) + + # Otherwise use default behavior from parent class + return super().response(obj) + + class OAuthProxy(OAuthProvider): """OAuth provider that presents a DCR-compliant interface while proxying to non-DCR IDPs. @@ -852,12 +906,17 @@ def get_routes( and "POST" in route.methods ): token_route_found = True - # Replace with our custom token handler + # Replace with our OAuth 2.1 compliant token handler + token_handler = TokenHandler( + provider=self, client_authenticator=ClientAuthenticator(self) + ) custom_routes.append( Route( path="/token", - endpoint=self._handle_token_request, - methods=["POST"], + endpoint=cors_middleware( + token_handler.handle, ["POST", "OPTIONS"] + ), + methods=["POST", "OPTIONS"], ) ) else: @@ -878,81 +937,6 @@ def get_routes( ) return custom_routes - # ------------------------------------------------------------------------- - # Custom Token Endpoint Handler - # ------------------------------------------------------------------------- - - async def _handle_token_request(self, request: Request) -> JSONResponse: - """Handle token requests with proper OAuth 2.1 error handling. - - This custom handler wraps the standard MCP SDK token handler but provides - OAuth 2.1 compliant error responses for client authentication failures: - - Returns HTTP 401 status code for client authentication failures - - Uses 'invalid_client' error code instead of 'unauthorized_client' - - Per OAuth 2.1 spec: "The authorization server MAY return an HTTP 401 - (Unauthorized) status code to indicate which HTTP authentication schemes - are supported. If the client attempted to authenticate via the Authorization - request header field, the authorization server MUST respond with an HTTP 401 - (Unauthorized) status code and include the WWW-Authenticate response header - field matching the authentication scheme used by the client." - """ - from mcp.server.auth.handlers.token import TokenHandler - from mcp.server.auth.middleware.client_auth import ClientAuthenticator - - # Create the standard token handler and client authenticator - token_handler = TokenHandler( - provider=self, client_authenticator=ClientAuthenticator(self) - ) - - # Handle the request normally - response = await token_handler.handle(request) - - # Check if the response is an error response for client authentication failure - if ( - hasattr(response, "body") - and hasattr(response, "status_code") - and response.status_code == 400 - ): - try: - import json - - # Parse the response body to check for client authentication errors - body_content = ( - response.body.decode("utf-8") - if hasattr(response.body, "decode") - else str(response.body) - ) - error_data = json.loads(body_content) - - # Check if this is an unauthorized_client error (which means invalid client_id) - if error_data.get( - "error" - ) == "unauthorized_client" and "Invalid client_id" in str( - error_data.get("error_description", "") - ): - logger.debug( - "Client authentication failed - client not found, returning OAuth 2.1 compliant error" - ) - - # Return the correct OAuth 2.1 response - return JSONResponse( - content={ - "error": "invalid_client", - "error_description": error_data.get("error_description"), - }, - status_code=401, - headers={ - "Cache-Control": "no-store", - "Pragma": "no-cache", - }, - ) - except (json.JSONDecodeError, AttributeError, KeyError): - # If we can't parse the response, return it as-is - pass - - return response - # ------------------------------------------------------------------------- # IdP Callback Forwarding # ------------------------------------------------------------------------- diff --git a/tests/server/auth/test_oauth_proxy.py b/tests/server/auth/test_oauth_proxy.py index 3bc82d7b9..2cd866938 100644 --- a/tests/server/auth/test_oauth_proxy.py +++ b/tests/server/auth/test_oauth_proxy.py @@ -1029,3 +1029,70 @@ async def test_token_endpoint_invalid_client_error(self, jwt_verifier): # Verify proper cache headers are set assert response.headers.get("Cache-Control") == "no-store" assert response.headers.get("Pragma") == "no-cache" + + +class TestTokenHandlerErrorTransformation: + """Tests for TokenHandler's OAuth 2.1 compliant error transformation.""" + + def test_transforms_client_auth_failure_to_invalid_client_401(self): + """Test that client authentication failures return invalid_client with 401.""" + from mcp.server.auth.handlers.token import TokenErrorResponse + + from fastmcp.server.auth.oauth_proxy import TokenHandler + + handler = TokenHandler(provider=Mock(), client_authenticator=Mock()) + + # Simulate error from ClientAuthenticator.authenticate() failure + error_response = TokenErrorResponse( + error="unauthorized_client", + error_description="Invalid client_id 'test-client-id'", + ) + + response = handler.response(error_response) + + # Should transform to OAuth 2.1 compliant response + assert response.status_code == 401 + assert b'"error":"invalid_client"' in response.body + assert ( + b'"error_description":"Invalid client_id \'test-client-id\'"' + in response.body + ) + + def test_does_not_transform_grant_type_unauthorized_to_invalid_client(self): + """Test that grant type authorization errors stay as unauthorized_client with 400.""" + from mcp.server.auth.handlers.token import TokenErrorResponse + + from fastmcp.server.auth.oauth_proxy import TokenHandler + + handler = TokenHandler(provider=Mock(), client_authenticator=Mock()) + + # Simulate error from grant_type not in client_info.grant_types + error_response = TokenErrorResponse( + error="unauthorized_client", + error_description="Client not authorized for this grant type", + ) + + response = handler.response(error_response) + + # Should NOT transform - keep as 400 unauthorized_client + assert response.status_code == 400 + assert b'"error":"unauthorized_client"' in response.body + + def test_does_not_transform_other_errors(self): + """Test that other error types pass through unchanged.""" + from mcp.server.auth.handlers.token import TokenErrorResponse + + from fastmcp.server.auth.oauth_proxy import TokenHandler + + handler = TokenHandler(provider=Mock(), client_authenticator=Mock()) + + error_response = TokenErrorResponse( + error="invalid_grant", + error_description="Authorization code has expired", + ) + + response = handler.response(error_response) + + # Should pass through unchanged + assert response.status_code == 400 + assert b'"error":"invalid_grant"' in response.body