Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 1 addition & 39 deletions src/mcp/client/auth/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import httpx
from pydantic import BaseModel, Field, ValidationError

from mcp.client.auth.exceptions import OAuthFlowError, OAuthRegistrationError, OAuthTokenError
from mcp.client.auth.exceptions import OAuthFlowError, OAuthTokenError
from mcp.client.auth.utils import (
build_oauth_authorization_server_metadata_discovery_urls,
build_protected_resource_metadata_discovery_urls,
Expand Down Expand Up @@ -299,44 +299,6 @@ async def _handle_protected_resource_response(self, response: httpx.Response) ->
f"Protected Resource Metadata request failed: {response.status_code}"
) # pragma: no cover

async def _register_client(self) -> httpx.Request | None:
"""Build registration request or skip if already registered."""
if self.context.client_info:
return None

if self.context.oauth_metadata and self.context.oauth_metadata.registration_endpoint:
registration_url = str(self.context.oauth_metadata.registration_endpoint) # pragma: no cover
else:
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
registration_url = urljoin(auth_base_url, "/register")

registration_data = self.context.client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True)

# If token_endpoint_auth_method is None, auto-select based on server support
if self.context.client_metadata.token_endpoint_auth_method is None:
preference_order = ["client_secret_basic", "client_secret_post", "none"]

if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint_auth_methods_supported:
supported = self.context.oauth_metadata.token_endpoint_auth_methods_supported
for method in preference_order:
if method in supported:
registration_data["token_endpoint_auth_method"] = method
break
else:
# No compatible methods between client and server
raise OAuthRegistrationError(
f"No compatible authentication methods. "
f"Server supports: {supported}, "
f"Client supports: {preference_order}"
)
else:
# No server metadata available, use our default preference
registration_data["token_endpoint_auth_method"] = preference_order[0]

return httpx.Request(
"POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"}
)

async def _perform_authorization(self) -> httpx.Request:
"""Perform the authorization flow."""
auth_code, code_verifier = await self._perform_authorization_code_grant()
Expand Down
95 changes: 1 addition & 94 deletions tests/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""

import base64
import json
import time
from unittest import mock
from urllib.parse import unquote
Expand All @@ -13,7 +12,7 @@
from inline_snapshot import Is, snapshot
from pydantic import AnyHttpUrl, AnyUrl

from mcp.client.auth import OAuthClientProvider, OAuthRegistrationError, PKCEParameters
from mcp.client.auth import OAuthClientProvider, PKCEParameters
from mcp.client.auth.utils import (
build_oauth_authorization_server_metadata_discovery_urls,
build_protected_resource_metadata_discovery_urls,
Expand Down Expand Up @@ -581,98 +580,6 @@ async def test_omit_scope_when_no_prm_scopes_or_www_auth(
# Verify that scope is omitted
assert scopes is None

@pytest.mark.anyio
async def test_register_client_request(self, oauth_provider: OAuthClientProvider):
"""Test client registration request building."""
request = await oauth_provider._register_client()

assert request is not None
assert request.method == "POST"
assert str(request.url) == "https://api.example.com/register"
assert request.headers["Content-Type"] == "application/json"

@pytest.mark.anyio
async def test_register_client_skip_if_registered(self, oauth_provider: OAuthClientProvider):
"""Test client registration is skipped if already registered."""
# Set existing client info
client_info = OAuthClientInformationFull(
client_id="existing_client",
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
)
oauth_provider.context.client_info = client_info

# Should return None (skip registration)
request = await oauth_provider._register_client()
assert request is None

@pytest.mark.anyio
async def test_register_client_explicit_auth_method(self, mock_storage: MockTokenStorage):
"""Test that explicitly set token_endpoint_auth_method is used without auto-selection."""

async def redirect_handler(url: str) -> None:
pass # pragma: no cover

async def callback_handler() -> tuple[str, str | None]:
return "test_auth_code", "test_state" # pragma: no cover

# Create client metadata with explicit auth method
explicit_metadata = OAuthClientMetadata(
client_name="Test Client",
client_uri=AnyHttpUrl("https://example.com"),
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
scope="read write",
token_endpoint_auth_method="client_secret_basic",
)
provider = OAuthClientProvider(
server_url="https://api.example.com/v1/mcp",
client_metadata=explicit_metadata,
storage=mock_storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
)

request = await provider._register_client()
assert request is not None

body = json.loads(request.content)
# Should use the explicitly set method, not auto-select
assert body["token_endpoint_auth_method"] == "client_secret_basic"

@pytest.mark.anyio
async def test_register_client_none_auth_method_with_server_metadata(self, oauth_provider: OAuthClientProvider):
"""Test that token_endpoint_auth_method=None selects from server's supported methods."""
# Set server metadata with specific supported methods
oauth_provider.context.oauth_metadata = OAuthMetadata(
issuer=AnyHttpUrl("https://auth.example.com"),
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"),
token_endpoint=AnyHttpUrl("https://auth.example.com/token"),
token_endpoint_auth_methods_supported=["client_secret_post"],
)
# Ensure client_metadata has None for token_endpoint_auth_method

request = await oauth_provider._register_client()
assert request is not None

body = json.loads(request.content)
assert body["token_endpoint_auth_method"] == "client_secret_post"

@pytest.mark.anyio
async def test_register_client_none_auth_method_no_compatible(self, oauth_provider: OAuthClientProvider):
"""Test that registration raises error when no compatible auth methods."""
# Set server metadata with unsupported methods only
oauth_provider.context.oauth_metadata = OAuthMetadata(
issuer=AnyHttpUrl("https://auth.example.com"),
authorization_endpoint=AnyHttpUrl("https://auth.example.com/authorize"),
token_endpoint=AnyHttpUrl("https://auth.example.com/token"),
token_endpoint_auth_methods_supported=["private_key_jwt", "client_secret_jwt"],
)

with pytest.raises(OAuthRegistrationError) as exc_info:
await oauth_provider._register_client()

assert "No compatible authentication methods" in str(exc_info.value)
assert "private_key_jwt" in str(exc_info.value)

@pytest.mark.anyio
async def test_token_exchange_request_authorization_code(self, oauth_provider: OAuthClientProvider):
"""Test token exchange request building."""
Expand Down