-
Notifications
You must be signed in to change notification settings - Fork 1.9k
fix: propagate upstream_claims in load_access_token #3750
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
9479e45
a0cd28b
c7c0c1a
24d8f6b
385154b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1571,6 +1571,7 @@ async def load_access_token(self, token: str) -> AccessToken | None: # type: ig | |
| # 1. Verify FastMCP JWT signature and claims | ||
| payload = self.jwt_issuer.verify_token(token) | ||
| jti = payload["jti"] | ||
| upstream_claims = payload.get("upstream_claims") | ||
|
|
||
| # 2. Look up upstream token via JTI mapping | ||
| jti_mapping = await self._jti_mapping_store.get(key=jti) | ||
|
|
@@ -1693,6 +1694,17 @@ async def load_access_token(self, token: str) -> AccessToken | None: # type: ig | |
| } | ||
| ) | ||
|
|
||
| # Propagate upstream claims from the verified FastMCP JWT into the | ||
| # final AccessToken object. This allows subclasses to access custom | ||
| # identity data extracted during the initial authorization flow. | ||
| # We perform a model copy to avoid mutating a potentially cached | ||
| # reference shared across concurrent requests. | ||
| if validated and upstream_claims: | ||
| validated = validated.model_copy(deep=True) | ||
| if validated.claims is None: | ||
| validated.claims = {} | ||
| validated.claims["upstream_claims"] = upstream_claims | ||
|
Comment on lines
+1697
to
+1706
|
||
|
|
||
| logger.debug( | ||
| "Token swap successful for JTI=%s (upstream validated)", jti[:8] | ||
| ) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,11 +1,18 @@ | ||
| """Tests for OIDC Proxy verify_id_token functionality.""" | ||
| """Tests for OIDC Proxy token management and propagation. | ||
|
|
||
| from unittest.mock import patch | ||
| These tests cover the OIDCProxy's ability to issue, verify, and swap tokens | ||
| between FastMCP and upstream identity providers. | ||
| """ | ||
|
|
||
| import time | ||
| from typing import cast | ||
| from unittest.mock import AsyncMock, MagicMock, patch | ||
|
Comment on lines
+7
to
+9
|
||
|
|
||
| import pytest | ||
| from pydantic import AnyHttpUrl | ||
|
|
||
| from fastmcp.server.auth.oauth_proxy.models import UpstreamTokenSet | ||
| from fastmcp.server.auth.auth import AccessToken | ||
| from fastmcp.server.auth.oauth_proxy.models import JTIMapping, UpstreamTokenSet | ||
| from fastmcp.server.auth.oidc_proxy import OIDCConfiguration, OIDCProxy | ||
| from fastmcp.server.auth.providers.introspection import IntrospectionTokenVerifier | ||
| from fastmcp.server.auth.providers.jwt import JWTVerifier | ||
|
|
@@ -423,3 +430,189 @@ def test_scope_patch_applied_when_tokens_identical( | |
| assert proxy._get_verification_token(token_set) == same_jwt | ||
| # The key point: even though the tokens are equal, the intent | ||
| # flag ensures load_access_token will patch scopes | ||
|
|
||
|
|
||
| class TestUpstreamClaimsPropagation: | ||
| """Tests for upstream claims propagation in load_access_token.""" | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_load_access_token_preserves_upstream_claims( | ||
| self, valid_oidc_configuration_dict | ||
| ): | ||
| """Test that upstream_claims in FastMCP JWT are merged into AccessToken.claims.""" | ||
| with patch( | ||
| "fastmcp.server.auth.oidc_proxy.OIDCConfiguration.get_oidc_configuration" | ||
| ) as mock_get: | ||
| oidc_config = OIDCConfiguration.model_validate( | ||
| valid_oidc_configuration_dict | ||
| ) | ||
| mock_get.return_value = oidc_config | ||
|
|
||
| proxy = OIDCProxy( | ||
| config_url=TEST_CONFIG_URL, | ||
| client_id=TEST_CLIENT_ID, | ||
| client_secret=TEST_CLIENT_SECRET, | ||
| base_url=TEST_BASE_URL, | ||
| jwt_signing_key="test-secret", | ||
| ) | ||
| # Initialize JWT issuer | ||
| proxy.set_mcp_path("/mcp") | ||
|
|
||
| # 1. Issue a token with upstream_claims | ||
| upstream_claims = {"sub": "idp-user-123", "email": "[email protected]"} | ||
| fastmcp_jwt = proxy.jwt_issuer.issue_access_token( | ||
| client_id=TEST_CLIENT_ID, | ||
| scopes=["openid"], | ||
| jti="test-jti", | ||
| upstream_claims=upstream_claims, | ||
| ) | ||
|
|
||
| # 2. Mock storage and upstream verification | ||
| # Mock the JTI mapping lookup | ||
| proxy._jti_mapping_store = MagicMock() | ||
| jti_mapping = JTIMapping( | ||
| jti="test-jti", | ||
| upstream_token_id="test-upstream-id", | ||
| created_at=time.time(), | ||
| ) | ||
| proxy._jti_mapping_store.get = AsyncMock(return_value=jti_mapping) | ||
|
|
||
| proxy._upstream_token_store = MagicMock() | ||
| token_set = UpstreamTokenSet( | ||
| upstream_token_id="test-upstream-id", | ||
| access_token="idp-access-token", | ||
| refresh_token=None, | ||
| refresh_token_expires_at=None, | ||
| expires_at=time.time() + 3600, | ||
| token_type="Bearer", | ||
| scope="openid", | ||
| client_id=TEST_CLIENT_ID, | ||
| created_at=time.time(), | ||
| raw_token_data={"access_token": "idp-access-token"}, | ||
| ) | ||
| proxy._upstream_token_store.get = AsyncMock(return_value=token_set) | ||
|
|
||
| # Mock the actual upstream token verification | ||
| upstream_access_token = AccessToken( | ||
| token="idp-access-token", | ||
| client_id="idp-client-id", | ||
| scopes=["openid"], | ||
| expires_at=int(time.time() + 3600), | ||
| claims={"provider_id": "999"}, | ||
| ) | ||
| proxy._token_validator.verify_token = AsyncMock( # ty: ignore[invalid-assignment] | ||
| return_value=upstream_access_token | ||
| ) | ||
|
|
||
| # 3. Call load_access_token | ||
| result = await proxy.load_access_token(fastmcp_jwt) | ||
|
|
||
| # 4. Verify results | ||
| assert result is not None | ||
| if result is not None: | ||
| result = cast(AccessToken, result) | ||
| # Original upstream claims should be there | ||
| assert result.claims["provider_id"] == "999" | ||
| # Propagated upstream_claims should NOW be there (the fix) | ||
| assert "upstream_claims" in result.claims | ||
| assert result.claims["upstream_claims"] == upstream_claims | ||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_load_access_token_does_not_mutate_cached_token( | ||
| self, valid_oidc_configuration_dict | ||
| ): | ||
| """Test that load_access_token does not mutate the original AccessToken from verifier.""" | ||
| with patch( | ||
| "fastmcp.server.auth.oidc_proxy.OIDCConfiguration.get_oidc_configuration" | ||
| ) as mock_get: | ||
| oidc_config = OIDCConfiguration.model_validate( | ||
| valid_oidc_configuration_dict | ||
| ) | ||
| mock_get.return_value = oidc_config | ||
|
|
||
| proxy = OIDCProxy( | ||
| config_url=TEST_CONFIG_URL, | ||
| client_id=TEST_CLIENT_ID, | ||
| client_secret=TEST_CLIENT_SECRET, | ||
| base_url=TEST_BASE_URL, | ||
| jwt_signing_key="test-secret", | ||
| ) | ||
| proxy.set_mcp_path("/mcp") | ||
|
|
||
| # 1. Setup shared upstream token | ||
| upstream_claims = {"user": "alice"} | ||
| shared_claims = {"base": "claim"} | ||
| # The original token returned by a verifier (potentially from cache) | ||
| original_validated = AccessToken( | ||
| token="shared-token", | ||
| client_id="idp-client-id", | ||
| scopes=["openid"], | ||
| expires_at=int(time.time() + 3600), | ||
| claims=shared_claims, | ||
| ) | ||
|
|
||
| # 2. Mock storage for first request | ||
| proxy._jti_mapping_store = MagicMock() | ||
| jti_mapping = JTIMapping( | ||
| jti="jti-1", | ||
| upstream_token_id="up-1", | ||
| created_at=time.time(), | ||
| ) | ||
| proxy._jti_mapping_store.get = AsyncMock(return_value=jti_mapping) | ||
|
|
||
| proxy._upstream_token_store = MagicMock() | ||
| token_set = UpstreamTokenSet( | ||
| upstream_token_id="up-1", | ||
| access_token="shared-token", | ||
| refresh_token=None, | ||
| refresh_token_expires_at=None, | ||
| expires_at=time.time() + 3600, | ||
| token_type="Bearer", | ||
| scope="openid", | ||
| client_id=TEST_CLIENT_ID, | ||
| created_at=time.time(), | ||
| raw_token_data={"access_token": "shared-token"}, | ||
| ) | ||
| proxy._upstream_token_store.get = AsyncMock(return_value=token_set) | ||
|
|
||
| # Verifier returns the SHARED instance | ||
| proxy._token_validator.verify_token = AsyncMock( # ty: ignore[invalid-assignment] | ||
| return_value=original_validated | ||
| ) | ||
|
|
||
| # 3. First request with upstream_claims | ||
| fastmcp_jwt_1 = proxy.jwt_issuer.issue_access_token( | ||
| client_id=TEST_CLIENT_ID, | ||
| scopes=["openid"], | ||
| jti="jti-1", | ||
| upstream_claims=upstream_claims, | ||
| ) | ||
| result_1 = await proxy.load_access_token(fastmcp_jwt_1) | ||
| assert result_1 is not None | ||
| assert ( | ||
| cast(AccessToken, result_1).claims["upstream_claims"] == upstream_claims | ||
| ) | ||
|
|
||
| # 4. CRITICAL CHECK: The original object must NOT have been mutated | ||
| assert "upstream_claims" not in original_validated.claims | ||
|
|
||
| # 5. Second request WITHOUT upstream_claims using same shared token | ||
| jti_mapping_2 = JTIMapping( | ||
| jti="jti-2", | ||
| upstream_token_id="up-1", # Same upstream token ID | ||
| created_at=time.time(), | ||
| ) | ||
| proxy._jti_mapping_store.get = AsyncMock(return_value=jti_mapping_2) | ||
|
|
||
| fastmcp_jwt_2 = proxy.jwt_issuer.issue_access_token( | ||
| client_id=TEST_CLIENT_ID, | ||
| scopes=["openid"], | ||
| jti="jti-2", | ||
| # NO upstream_claims here | ||
| ) | ||
| result_2 = await proxy.load_access_token(fastmcp_jwt_2) | ||
|
|
||
| assert result_2 is not None | ||
| # If fix works, result_2.claims should NOT have "upstream_claims" leakage | ||
| assert "upstream_claims" not in cast(AccessToken, result_2).claims | ||
| assert cast(AccessToken, result_2).claims == shared_claims | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
load_access_tokenmutatesvalidated.claimsin place, but some verifiers (notablyIntrospectionTokenVerifier) return cachedAccessTokeninstances by reference (src/fastmcp/server/auth/providers/introspection.py:197-200,292). In that setup, this write can leak/overwrite claim state across requests sharing the same upstream token object; e.g., a request withupstream_claimscan persist data that is then returned for a later token whereupstream_claimsis absent (the branch is skipped, so stale data remains). Copying the token/claims before mutation avoids cross-request contamination.Useful? React with 👍 / 👎.