diff --git a/src/fastmcp/server/auth/oauth_proxy/proxy.py b/src/fastmcp/server/auth/oauth_proxy/proxy.py index bcad01941..3210f7d6e 100644 --- a/src/fastmcp/server/auth/oauth_proxy/proxy.py +++ b/src/fastmcp/server/auth/oauth_proxy/proxy.py @@ -1571,6 +1571,7 @@ async def load_access_token(self, token: str) -> AccessToken | None: # type: ig # 1. Verify FastMCP JWT signature and claims payload = self.jwt_issuer.verify_token(token) jti = payload["jti"] + upstream_claims = payload.get("upstream_claims") # 2. Look up upstream token via JTI mapping jti_mapping = await self._jti_mapping_store.get(key=jti) @@ -1693,6 +1694,17 @@ async def load_access_token(self, token: str) -> AccessToken | None: # type: ig } ) + # Propagate upstream claims from the verified FastMCP JWT into the + # final AccessToken object. This allows subclasses to access custom + # identity data extracted during the initial authorization flow. + # We perform a model copy to avoid mutating a potentially cached + # reference shared across concurrent requests. + if validated and upstream_claims: + validated = validated.model_copy(deep=True) + if validated.claims is None: + validated.claims = {} + validated.claims["upstream_claims"] = upstream_claims + logger.debug( "Token swap successful for JTI=%s (upstream validated)", jti[:8] ) diff --git a/tests/server/auth/test_oidc_proxy_token.py b/tests/server/auth/test_oidc_proxy_token.py index c261fb02e..dff6b51ae 100644 --- a/tests/server/auth/test_oidc_proxy_token.py +++ b/tests/server/auth/test_oidc_proxy_token.py @@ -1,11 +1,18 @@ -"""Tests for OIDC Proxy verify_id_token functionality.""" +"""Tests for OIDC Proxy token management and propagation. -from unittest.mock import patch +These tests cover the OIDCProxy's ability to issue, verify, and swap tokens +between FastMCP and upstream identity providers. +""" + +import time +from typing import cast +from unittest.mock import AsyncMock, MagicMock, patch import pytest from pydantic import AnyHttpUrl -from fastmcp.server.auth.oauth_proxy.models import UpstreamTokenSet +from fastmcp.server.auth.auth import AccessToken +from fastmcp.server.auth.oauth_proxy.models import JTIMapping, UpstreamTokenSet from fastmcp.server.auth.oidc_proxy import OIDCConfiguration, OIDCProxy from fastmcp.server.auth.providers.introspection import IntrospectionTokenVerifier from fastmcp.server.auth.providers.jwt import JWTVerifier @@ -423,3 +430,189 @@ def test_scope_patch_applied_when_tokens_identical( assert proxy._get_verification_token(token_set) == same_jwt # The key point: even though the tokens are equal, the intent # flag ensures load_access_token will patch scopes + + +class TestUpstreamClaimsPropagation: + """Tests for upstream claims propagation in load_access_token.""" + + @pytest.mark.asyncio + async def test_load_access_token_preserves_upstream_claims( + self, valid_oidc_configuration_dict + ): + """Test that upstream_claims in FastMCP JWT are merged into AccessToken.claims.""" + with patch( + "fastmcp.server.auth.oidc_proxy.OIDCConfiguration.get_oidc_configuration" + ) as mock_get: + oidc_config = OIDCConfiguration.model_validate( + valid_oidc_configuration_dict + ) + mock_get.return_value = oidc_config + + proxy = OIDCProxy( + config_url=TEST_CONFIG_URL, + client_id=TEST_CLIENT_ID, + client_secret=TEST_CLIENT_SECRET, + base_url=TEST_BASE_URL, + jwt_signing_key="test-secret", + ) + # Initialize JWT issuer + proxy.set_mcp_path("/mcp") + + # 1. Issue a token with upstream_claims + upstream_claims = {"sub": "idp-user-123", "email": "user@example.com"} + fastmcp_jwt = proxy.jwt_issuer.issue_access_token( + client_id=TEST_CLIENT_ID, + scopes=["openid"], + jti="test-jti", + upstream_claims=upstream_claims, + ) + + # 2. Mock storage and upstream verification + # Mock the JTI mapping lookup + proxy._jti_mapping_store = MagicMock() + jti_mapping = JTIMapping( + jti="test-jti", + upstream_token_id="test-upstream-id", + created_at=time.time(), + ) + proxy._jti_mapping_store.get = AsyncMock(return_value=jti_mapping) + + proxy._upstream_token_store = MagicMock() + token_set = UpstreamTokenSet( + upstream_token_id="test-upstream-id", + access_token="idp-access-token", + refresh_token=None, + refresh_token_expires_at=None, + expires_at=time.time() + 3600, + token_type="Bearer", + scope="openid", + client_id=TEST_CLIENT_ID, + created_at=time.time(), + raw_token_data={"access_token": "idp-access-token"}, + ) + proxy._upstream_token_store.get = AsyncMock(return_value=token_set) + + # Mock the actual upstream token verification + upstream_access_token = AccessToken( + token="idp-access-token", + client_id="idp-client-id", + scopes=["openid"], + expires_at=int(time.time() + 3600), + claims={"provider_id": "999"}, + ) + proxy._token_validator.verify_token = AsyncMock( # ty: ignore[invalid-assignment] + return_value=upstream_access_token + ) + + # 3. Call load_access_token + result = await proxy.load_access_token(fastmcp_jwt) + + # 4. Verify results + assert result is not None + if result is not None: + result = cast(AccessToken, result) + # Original upstream claims should be there + assert result.claims["provider_id"] == "999" + # Propagated upstream_claims should NOW be there (the fix) + assert "upstream_claims" in result.claims + assert result.claims["upstream_claims"] == upstream_claims + + @pytest.mark.asyncio + async def test_load_access_token_does_not_mutate_cached_token( + self, valid_oidc_configuration_dict + ): + """Test that load_access_token does not mutate the original AccessToken from verifier.""" + with patch( + "fastmcp.server.auth.oidc_proxy.OIDCConfiguration.get_oidc_configuration" + ) as mock_get: + oidc_config = OIDCConfiguration.model_validate( + valid_oidc_configuration_dict + ) + mock_get.return_value = oidc_config + + proxy = OIDCProxy( + config_url=TEST_CONFIG_URL, + client_id=TEST_CLIENT_ID, + client_secret=TEST_CLIENT_SECRET, + base_url=TEST_BASE_URL, + jwt_signing_key="test-secret", + ) + proxy.set_mcp_path("/mcp") + + # 1. Setup shared upstream token + upstream_claims = {"user": "alice"} + shared_claims = {"base": "claim"} + # The original token returned by a verifier (potentially from cache) + original_validated = AccessToken( + token="shared-token", + client_id="idp-client-id", + scopes=["openid"], + expires_at=int(time.time() + 3600), + claims=shared_claims, + ) + + # 2. Mock storage for first request + proxy._jti_mapping_store = MagicMock() + jti_mapping = JTIMapping( + jti="jti-1", + upstream_token_id="up-1", + created_at=time.time(), + ) + proxy._jti_mapping_store.get = AsyncMock(return_value=jti_mapping) + + proxy._upstream_token_store = MagicMock() + token_set = UpstreamTokenSet( + upstream_token_id="up-1", + access_token="shared-token", + refresh_token=None, + refresh_token_expires_at=None, + expires_at=time.time() + 3600, + token_type="Bearer", + scope="openid", + client_id=TEST_CLIENT_ID, + created_at=time.time(), + raw_token_data={"access_token": "shared-token"}, + ) + proxy._upstream_token_store.get = AsyncMock(return_value=token_set) + + # Verifier returns the SHARED instance + proxy._token_validator.verify_token = AsyncMock( # ty: ignore[invalid-assignment] + return_value=original_validated + ) + + # 3. First request with upstream_claims + fastmcp_jwt_1 = proxy.jwt_issuer.issue_access_token( + client_id=TEST_CLIENT_ID, + scopes=["openid"], + jti="jti-1", + upstream_claims=upstream_claims, + ) + result_1 = await proxy.load_access_token(fastmcp_jwt_1) + assert result_1 is not None + assert ( + cast(AccessToken, result_1).claims["upstream_claims"] == upstream_claims + ) + + # 4. CRITICAL CHECK: The original object must NOT have been mutated + assert "upstream_claims" not in original_validated.claims + + # 5. Second request WITHOUT upstream_claims using same shared token + jti_mapping_2 = JTIMapping( + jti="jti-2", + upstream_token_id="up-1", # Same upstream token ID + created_at=time.time(), + ) + proxy._jti_mapping_store.get = AsyncMock(return_value=jti_mapping_2) + + fastmcp_jwt_2 = proxy.jwt_issuer.issue_access_token( + client_id=TEST_CLIENT_ID, + scopes=["openid"], + jti="jti-2", + # NO upstream_claims here + ) + result_2 = await proxy.load_access_token(fastmcp_jwt_2) + + assert result_2 is not None + # If fix works, result_2.claims should NOT have "upstream_claims" leakage + assert "upstream_claims" not in cast(AccessToken, result_2).claims + assert cast(AccessToken, result_2).claims == shared_claims