fix: propagate upstream_claims in load_access_token#3750
fix: propagate upstream_claims in load_access_token#3750kvdhanush06 wants to merge 5 commits intoPrefectHQ:mainfrom
Conversation
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: f8f89e7b72
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if validated and upstream_claims: | ||
| if not hasattr(validated, "claims") or validated.claims is None: | ||
| validated.claims = {} | ||
| validated.claims["upstream_claims"] = upstream_claims |
There was a problem hiding this comment.
Copy claims before injecting upstream_claims
load_access_token mutates validated.claims in place, but some verifiers (notably IntrospectionTokenVerifier) return cached AccessToken instances by reference (src/fastmcp/server/auth/providers/introspection.py:197-200,292). In that setup, this write can leak/overwrite claim state across requests sharing the same upstream token object; e.g., a request with upstream_claims can persist data that is then returned for a later token where upstream_claims is absent (the branch is skipped, so stale data remains). Copying the token/claims before mutation avoids cross-request contamination.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Pull request overview
This PR fixes FastMCP OAuth proxy token swapping so that custom upstream_claims embedded (and signed) into the FastMCP JWT are preserved and surfaced on the AccessToken returned by load_access_token, addressing #3723.
Changes:
- Preserve
upstream_claimsfrom the verified FastMCP JWT payload duringload_access_token. - Merge propagated
upstream_claimsinto the final returnedAccessToken.claims. - Add a regression test covering upstream-claims propagation through the token swap.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
src/fastmcp/server/auth/oauth_proxy/proxy.py |
Extracts upstream_claims from verified FastMCP JWT and injects them into the returned AccessToken after upstream validation. |
tests/server/auth/test_oidc_proxy_token.py |
Adds an async regression test ensuring upstream_claims survive the swap and appear on AccessToken.claims. |
| # Propagate upstream claims from the verified FastMCP JWT into the | ||
| # final AccessToken object. This allows subclasses to access custom | ||
| # identity data extracted during the initial authorization flow. | ||
| if validated and upstream_claims: | ||
| if not hasattr(validated, "claims") or validated.claims is None: | ||
| validated.claims = {} | ||
| validated.claims["upstream_claims"] = upstream_claims |
There was a problem hiding this comment.
load_access_token mutates validated.claims in-place when injecting upstream_claims. If the configured _token_validator caches and reuses AccessToken instances (e.g., IntrospectionTokenVerifier caches results), this can leak/retain upstream_claims across requests/tokens. Prefer returning a copied AccessToken with merged claims (e.g., via model_copy/new dict) to avoid mutating shared/cached objects and to ensure the claims dict isn’t shared (shallow copies can still share nested dicts).
| import time | ||
| from typing import cast | ||
| from unittest.mock import AsyncMock, MagicMock, patch |
There was a problem hiding this comment.
The module-level docstring at the top of this test file was removed. Most test modules in this repo start with a brief module docstring; restoring it helps keep documentation/style consistent across the test suite.
| jti_mapping = JTIMapping.model_construct( | ||
| jti="test-jti", | ||
| upstream_token_id="test-upstream-id", | ||
| created_at=time.time(), | ||
| ) |
There was a problem hiding this comment.
Using JTIMapping.model_construct(...) bypasses Pydantic validation, which can mask schema/type issues and make the test less representative. Since all required fields are provided here, prefer constructing JTIMapping(...) normally so the test will fail if the model contract changes.
| token_set = UpstreamTokenSet.model_construct( | ||
| upstream_token_id="test-upstream-id", | ||
| access_token="idp-access-token", | ||
| refresh_token=None, | ||
| refresh_token_expires_at=None, | ||
| expires_at=time.time() + 3600, | ||
| token_type="Bearer", | ||
| scope="openid", | ||
| client_id=TEST_CLIENT_ID, | ||
| created_at=time.time(), | ||
| raw_token_data={"access_token": "idp-access-token"}, | ||
| ) |
There was a problem hiding this comment.
Similarly, UpstreamTokenSet.model_construct(...) bypasses validation and can hide issues in the test setup. Since this test provides all required fields with valid types, prefer UpstreamTokenSet(...) to ensure validation stays exercised and future model changes are caught.
| upstream_access_token = AccessToken.model_construct( | ||
| token="idp-access-token", | ||
| scopes=["openid"], | ||
| expires_at=int(time.time() + 3600), | ||
| claims={"provider_id": "999"}, | ||
| ) | ||
| proxy._token_validator.verify_token = AsyncMock( # ty: ignore[invalid-assignment] |
There was a problem hiding this comment.
AccessToken.model_construct(...) creates an AccessToken without fields that real verifiers typically set (e.g., client_id) and bypasses validation entirely, so the test can pass with an invalid object shape. Prefer constructing a valid AccessToken(...) (including client_id) so the regression test exercises realistic behavior and will fail if the AccessToken contract changes.
f8f89e7 to
2b46bde
Compare
Test Failure AnalysisSummary: The static analysis ( Root Cause: In
Suggested Solution: Remove the two # Line 549 — change to:
proxy._jti_mapping_store = MagicMock()
# Line 557 — change to:
proxy._upstream_token_store = MagicMock()Note: other Detailed Analysis
The
Related Files
Posted by marvin, the triage bot |
2b46bde to
9479e45
Compare
jlowin
left a comment
There was a problem hiding this comment.
Thanks for the fix — the behavior change itself is correct and well-reasoned. The upstream_claims should survive load_access_token, and the model_copy(deep=True) to avoid mutating a cached verifier result is the right call.
However, the tests need to be simplified. The existing TestTransparentUpstreamRefresh class in tests/server/auth/oauth_proxy/test_tokens.py demonstrates the pattern we use for testing load_access_token: a real OAuthProxy with real in-memory stores, populated directly, no mocked stores. Your tests mock every internal store with MagicMock, which is more brittle and harder to read.
Since the propagation logic lives in OAuthProxy.load_access_token (not OIDCProxy), the tests should live in test_tokens.py alongside the other load_access_token tests, not in test_oidc_proxy_token.py. This also eliminates the OIDC configuration boilerplate.
Following the existing pattern, you'd add a setup helper and two concise tests:
async def _setup_session_with_claims(
self, proxy, *, upstream_claims=None
):
"""Set up a proxy JWT pointing at a valid upstream token, with optional upstream_claims."""
upstream_token_id = "upstream-tok-id"
access_jti = "test-claims-jti"
upstream_token_set = UpstreamTokenSet(
upstream_token_id=upstream_token_id,
access_token="valid-upstream-access",
refresh_token=None,
refresh_token_expires_at=None,
expires_at=time.time() + 3600,
token_type="Bearer",
scope="read",
client_id="test-client",
created_at=time.time(),
)
await proxy._upstream_token_store.put(
key=upstream_token_id, value=upstream_token_set, ttl=3600,
)
await proxy._jti_mapping_store.put(
key=access_jti,
value=JTIMapping(
jti=access_jti, upstream_token_id=upstream_token_id, created_at=time.time(),
),
ttl=3600,
)
return proxy.jwt_issuer.issue_access_token(
client_id="test-client", scopes=["read"], jti=access_jti,
expires_in=3600, upstream_claims=upstream_claims,
)Then the tests themselves become straightforward:
async def test_upstream_claims_propagated(self, proxy):
jwt = await self._setup_session_with_claims(
proxy, upstream_claims={"sub": "user-123"}
)
result = await proxy.load_access_token(jwt)
assert result is not None
assert result.claims["upstream_claims"] == {"sub": "user-123"}
async def test_upstream_claims_not_mutated_on_cached_token(self, proxy, mock_verifier):
jwt = await self._setup_session_with_claims(
proxy, upstream_claims={"sub": "user-123"}
)
result = await proxy.load_access_token(jwt)
assert result is not None
assert result.claims["upstream_claims"] == {"sub": "user-123"}
# Original verifier result must not be mutated
for call in mock_verifier.verify_token.call_args_list:
returned = await mock_verifier.verify_token(call.args[0])
if returned:
assert "upstream_claims" not in returned.claimsA few other small things:
- Drop
@pytest.mark.asyncio—asyncio_mode = "auto"is configured globally - The
cast(AccessToken, result)afterassert result is not Noneis redundant — the assert already narrows the type
9479e45 to
a0cd28b
Compare
…kvdhanush06/fastmcp into fix/upstream-claims-propagation
…kvdhanush06/fastmcp into fix/upstream-claims-propagation
Description
Closes #3723
FastMCP's
OAuthProxydiscarded custom claims embedded in the JWT under theupstream_claimskey during theload_access_tokenprocess. While these claims were correctly signed into the JWT at issuance (via_extract_upstream_claims), they were lost during the token swap where the FastMCP JWT is exchanged for the upstream provider's token.Root Cause
The
load_access_tokenmethod verified the FastMCP JWT to retrieve thejtifor mapping lookup but failed to preserve the verified payload. It then validated the upstream token and returned a freshAccessTokenobject that only contained the provider's claims, losing the contextually relevantupstream_claimsthat the server had previously verified and signed.Fix
Updated
load_access_tokeninsrc/fastmcp/server/auth/oauth_proxy/proxy.pyto:upstream_claimsfrom the verified FastMCP JWT payload.AccessToken.claimsdictionary.This ensures that any custom data intended to survive the token swap (e.g. cross-provider user IDs or internal metadata) is accessible to the downstream MCP server.
Verification
Added a regression test suite
TestUpstreamClaimsPropagationintests/server/auth/test_oidc_proxy_token.pycovering:upstream_claimsfrom the FastMCP JWT.AccessTokenafter the upstream token swap.Contribution type
Checklist
uv run prek run --all-filesand all checks pass