Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 63 additions & 79 deletions src/fastmcp/server/auth/oauth_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,26 @@
import httpx
from authlib.common.security import generate_token
from authlib.integrations.httpx_client import AsyncOAuth2Client
from mcp.server.auth.handlers.token import TokenErrorResponse, TokenSuccessResponse
from mcp.server.auth.handlers.token import TokenHandler as _SDKTokenHandler
from mcp.server.auth.json_response import PydanticJSONResponse
from mcp.server.auth.middleware.client_auth import ClientAuthenticator
from mcp.server.auth.provider import (
AccessToken,
AuthorizationCode,
AuthorizationParams,
RefreshToken,
TokenError,
)
from mcp.server.auth.routes import cors_middleware
from mcp.server.auth.settings import (
ClientRegistrationOptions,
RevocationOptions,
)
from mcp.shared.auth import OAuthClientInformationFull, OAuthToken
from pydantic import AnyHttpUrl, AnyUrl, SecretStr
from starlette.requests import Request
from starlette.responses import JSONResponse, RedirectResponse
from starlette.responses import RedirectResponse
from starlette.routing import Route

import fastmcp
Expand Down Expand Up @@ -122,6 +127,55 @@ def validate_redirect_uri(self, redirect_uri: AnyUrl | None) -> AnyUrl:
HTTP_TIMEOUT_SECONDS: Final[int] = 30


class TokenHandler(_SDKTokenHandler):
"""TokenHandler that returns OAuth 2.1 compliant error responses.

The MCP SDK always returns HTTP 400 for all client authentication issues.
However, OAuth 2.1 Section 5.3 and the MCP specification require that
invalid or expired tokens MUST receive a HTTP 401 response.

This handler extends the base MCP SDK TokenHandler to transform client
authentication failures into OAuth 2.1 compliant responses:
- Changes 'unauthorized_client' to 'invalid_client' error code
- Returns HTTP 401 status code instead of 400 for client auth failures

Per OAuth 2.1 Section 5.3: "The authorization server MAY return an HTTP 401
(Unauthorized) status code to indicate which HTTP authentication schemes
are supported."

Per MCP spec: "Invalid or expired tokens MUST receive a HTTP 401 response."
"""

def response(self, obj: TokenSuccessResponse | TokenErrorResponse):
"""Override response method to provide OAuth 2.1 compliant error handling."""
# Check if this is a client authentication failure (not just unauthorized for grant type)
# unauthorized_client can mean two things:
# 1. Client authentication failed (client_id not found or wrong credentials) -> invalid_client 401
# 2. Client not authorized for this grant type -> unauthorized_client 400 (correct per spec)
if (
isinstance(obj, TokenErrorResponse)
and obj.error == "unauthorized_client"
and obj.error_description
and "Invalid client_id" in obj.error_description
):
# Transform client auth failure to OAuth 2.1 compliant response
return PydanticJSONResponse(
content=TokenErrorResponse(
error="invalid_client",
error_description=obj.error_description,
error_uri=obj.error_uri,
),
status_code=401,
headers={
"Cache-Control": "no-store",
"Pragma": "no-cache",
},
)

# Otherwise use default behavior from parent class
return super().response(obj)


class OAuthProxy(OAuthProvider):
"""OAuth provider that presents a DCR-compliant interface while proxying to non-DCR IDPs.

Expand Down Expand Up @@ -852,12 +906,17 @@ def get_routes(
and "POST" in route.methods
):
token_route_found = True
# Replace with our custom token handler
# Replace with our OAuth 2.1 compliant token handler
token_handler = TokenHandler(
provider=self, client_authenticator=ClientAuthenticator(self)
)
custom_routes.append(
Route(
path="/token",
endpoint=self._handle_token_request,
methods=["POST"],
endpoint=cors_middleware(
token_handler.handle, ["POST", "OPTIONS"]
),
methods=["POST", "OPTIONS"],
)
)
else:
Expand All @@ -878,81 +937,6 @@ def get_routes(
)
return custom_routes

# -------------------------------------------------------------------------
# Custom Token Endpoint Handler
# -------------------------------------------------------------------------

async def _handle_token_request(self, request: Request) -> JSONResponse:
"""Handle token requests with proper OAuth 2.1 error handling.

This custom handler wraps the standard MCP SDK token handler but provides
OAuth 2.1 compliant error responses for client authentication failures:
- Returns HTTP 401 status code for client authentication failures
- Uses 'invalid_client' error code instead of 'unauthorized_client'

Per OAuth 2.1 spec: "The authorization server MAY return an HTTP 401
(Unauthorized) status code to indicate which HTTP authentication schemes
are supported. If the client attempted to authenticate via the Authorization
request header field, the authorization server MUST respond with an HTTP 401
(Unauthorized) status code and include the WWW-Authenticate response header
field matching the authentication scheme used by the client."
"""
from mcp.server.auth.handlers.token import TokenHandler
from mcp.server.auth.middleware.client_auth import ClientAuthenticator

# Create the standard token handler and client authenticator
token_handler = TokenHandler(
provider=self, client_authenticator=ClientAuthenticator(self)
)

# Handle the request normally
response = await token_handler.handle(request)

# Check if the response is an error response for client authentication failure
if (
hasattr(response, "body")
and hasattr(response, "status_code")
and response.status_code == 400
):
try:
import json

# Parse the response body to check for client authentication errors
body_content = (
response.body.decode("utf-8")
if hasattr(response.body, "decode")
else str(response.body)
)
error_data = json.loads(body_content)

# Check if this is an unauthorized_client error (which means invalid client_id)
if error_data.get(
"error"
) == "unauthorized_client" and "Invalid client_id" in str(
error_data.get("error_description", "")
):
logger.debug(
"Client authentication failed - client not found, returning OAuth 2.1 compliant error"
)

# Return the correct OAuth 2.1 response
return JSONResponse(
content={
"error": "invalid_client",
"error_description": error_data.get("error_description"),
},
status_code=401,
headers={
"Cache-Control": "no-store",
"Pragma": "no-cache",
},
)
except (json.JSONDecodeError, AttributeError, KeyError):
# If we can't parse the response, return it as-is
pass

return response

# -------------------------------------------------------------------------
# IdP Callback Forwarding
# -------------------------------------------------------------------------
Expand Down
67 changes: 67 additions & 0 deletions tests/server/auth/test_oauth_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,3 +1029,70 @@ async def test_token_endpoint_invalid_client_error(self, jwt_verifier):
# Verify proper cache headers are set
assert response.headers.get("Cache-Control") == "no-store"
assert response.headers.get("Pragma") == "no-cache"


class TestTokenHandlerErrorTransformation:
"""Tests for TokenHandler's OAuth 2.1 compliant error transformation."""

def test_transforms_client_auth_failure_to_invalid_client_401(self):
"""Test that client authentication failures return invalid_client with 401."""
from mcp.server.auth.handlers.token import TokenErrorResponse

from fastmcp.server.auth.oauth_proxy import TokenHandler

handler = TokenHandler(provider=Mock(), client_authenticator=Mock())

# Simulate error from ClientAuthenticator.authenticate() failure
error_response = TokenErrorResponse(
error="unauthorized_client",
error_description="Invalid client_id 'test-client-id'",
)

response = handler.response(error_response)

# Should transform to OAuth 2.1 compliant response
assert response.status_code == 401
assert b'"error":"invalid_client"' in response.body
assert (
b'"error_description":"Invalid client_id \'test-client-id\'"'
in response.body
)

def test_does_not_transform_grant_type_unauthorized_to_invalid_client(self):
"""Test that grant type authorization errors stay as unauthorized_client with 400."""
from mcp.server.auth.handlers.token import TokenErrorResponse

from fastmcp.server.auth.oauth_proxy import TokenHandler

handler = TokenHandler(provider=Mock(), client_authenticator=Mock())

# Simulate error from grant_type not in client_info.grant_types
error_response = TokenErrorResponse(
error="unauthorized_client",
error_description="Client not authorized for this grant type",
)

response = handler.response(error_response)

# Should NOT transform - keep as 400 unauthorized_client
assert response.status_code == 400
assert b'"error":"unauthorized_client"' in response.body

def test_does_not_transform_other_errors(self):
"""Test that other error types pass through unchanged."""
from mcp.server.auth.handlers.token import TokenErrorResponse

from fastmcp.server.auth.oauth_proxy import TokenHandler

handler = TokenHandler(provider=Mock(), client_authenticator=Mock())

error_response = TokenErrorResponse(
error="invalid_grant",
error_description="Authorization code has expired",
)

response = handler.response(error_response)

# Should pass through unchanged
assert response.status_code == 400
assert b'"error":"invalid_grant"' in response.body
Loading