Skip to content

Commit a0cd28b

Browse files
committed
fix: propagate upstream_claims in load_access_token #3723
1 parent d41bcb2 commit a0cd28b

2 files changed

Lines changed: 208 additions & 3 deletions

File tree

src/fastmcp/server/auth/oauth_proxy/proxy.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1571,6 +1571,7 @@ async def load_access_token(self, token: str) -> AccessToken | None: # type: ig
15711571
# 1. Verify FastMCP JWT signature and claims
15721572
payload = self.jwt_issuer.verify_token(token)
15731573
jti = payload["jti"]
1574+
upstream_claims = payload.get("upstream_claims")
15741575

15751576
# 2. Look up upstream token via JTI mapping
15761577
jti_mapping = await self._jti_mapping_store.get(key=jti)
@@ -1693,6 +1694,17 @@ async def load_access_token(self, token: str) -> AccessToken | None: # type: ig
16931694
}
16941695
)
16951696

1697+
# Propagate upstream claims from the verified FastMCP JWT into the
1698+
# final AccessToken object. This allows subclasses to access custom
1699+
# identity data extracted during the initial authorization flow.
1700+
# We perform a model copy to avoid mutating a potentially cached
1701+
# reference shared across concurrent requests.
1702+
if validated and upstream_claims:
1703+
validated = validated.model_copy(deep=True)
1704+
if validated.claims is None:
1705+
validated.claims = {}
1706+
validated.claims["upstream_claims"] = upstream_claims
1707+
16961708
logger.debug(
16971709
"Token swap successful for JTI=%s (upstream validated)", jti[:8]
16981710
)

tests/server/auth/test_oidc_proxy_token.py

Lines changed: 196 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
1-
"""Tests for OIDC Proxy verify_id_token functionality."""
1+
"""Tests for OIDC Proxy token management and propagation.
22
3-
from unittest.mock import patch
3+
These tests cover the OIDCProxy's ability to issue, verify, and swap tokens
4+
between FastMCP and upstream identity providers.
5+
"""
6+
7+
import time
8+
from typing import cast
9+
from unittest.mock import AsyncMock, MagicMock, patch
410

511
import pytest
612
from pydantic import AnyHttpUrl
713

8-
from fastmcp.server.auth.oauth_proxy.models import UpstreamTokenSet
14+
from fastmcp.server.auth.auth import AccessToken
15+
from fastmcp.server.auth.oauth_proxy.models import JTIMapping, UpstreamTokenSet
916
from fastmcp.server.auth.oidc_proxy import OIDCConfiguration, OIDCProxy
1017
from fastmcp.server.auth.providers.introspection import IntrospectionTokenVerifier
1118
from fastmcp.server.auth.providers.jwt import JWTVerifier
@@ -423,3 +430,189 @@ def test_scope_patch_applied_when_tokens_identical(
423430
assert proxy._get_verification_token(token_set) == same_jwt
424431
# The key point: even though the tokens are equal, the intent
425432
# flag ensures load_access_token will patch scopes
433+
434+
435+
class TestUpstreamClaimsPropagation:
436+
"""Tests for upstream claims propagation in load_access_token."""
437+
438+
@pytest.mark.asyncio
439+
async def test_load_access_token_preserves_upstream_claims(
440+
self, valid_oidc_configuration_dict
441+
):
442+
"""Test that upstream_claims in FastMCP JWT are merged into AccessToken.claims."""
443+
with patch(
444+
"fastmcp.server.auth.oidc_proxy.OIDCConfiguration.get_oidc_configuration"
445+
) as mock_get:
446+
oidc_config = OIDCConfiguration.model_validate(
447+
valid_oidc_configuration_dict
448+
)
449+
mock_get.return_value = oidc_config
450+
451+
proxy = OIDCProxy(
452+
config_url=TEST_CONFIG_URL,
453+
client_id=TEST_CLIENT_ID,
454+
client_secret=TEST_CLIENT_SECRET,
455+
base_url=TEST_BASE_URL,
456+
jwt_signing_key="test-secret",
457+
)
458+
# Initialize JWT issuer
459+
proxy.set_mcp_path("/mcp")
460+
461+
# 1. Issue a token with upstream_claims
462+
upstream_claims = {"sub": "idp-user-123", "email": "user@example.com"}
463+
fastmcp_jwt = proxy.jwt_issuer.issue_access_token(
464+
client_id=TEST_CLIENT_ID,
465+
scopes=["openid"],
466+
jti="test-jti",
467+
upstream_claims=upstream_claims,
468+
)
469+
470+
# 2. Mock storage and upstream verification
471+
# Mock the JTI mapping lookup
472+
proxy._jti_mapping_store = MagicMock()
473+
jti_mapping = JTIMapping(
474+
jti="test-jti",
475+
upstream_token_id="test-upstream-id",
476+
created_at=time.time(),
477+
)
478+
proxy._jti_mapping_store.get = AsyncMock(return_value=jti_mapping)
479+
480+
proxy._upstream_token_store = MagicMock()
481+
token_set = UpstreamTokenSet(
482+
upstream_token_id="test-upstream-id",
483+
access_token="idp-access-token",
484+
refresh_token=None,
485+
refresh_token_expires_at=None,
486+
expires_at=time.time() + 3600,
487+
token_type="Bearer",
488+
scope="openid",
489+
client_id=TEST_CLIENT_ID,
490+
created_at=time.time(),
491+
raw_token_data={"access_token": "idp-access-token"},
492+
)
493+
proxy._upstream_token_store.get = AsyncMock(return_value=token_set)
494+
495+
# Mock the actual upstream token verification
496+
upstream_access_token = AccessToken(
497+
token="idp-access-token",
498+
client_id="idp-client-id",
499+
scopes=["openid"],
500+
expires_at=int(time.time() + 3600),
501+
claims={"provider_id": "999"},
502+
)
503+
proxy._token_validator.verify_token = AsyncMock( # ty: ignore[invalid-assignment]
504+
return_value=upstream_access_token
505+
)
506+
507+
# 3. Call load_access_token
508+
result = await proxy.load_access_token(fastmcp_jwt)
509+
510+
# 4. Verify results
511+
assert result is not None
512+
if result is not None:
513+
result = cast(AccessToken, result)
514+
# Original upstream claims should be there
515+
assert result.claims["provider_id"] == "999"
516+
# Propagated upstream_claims should NOW be there (the fix)
517+
assert "upstream_claims" in result.claims
518+
assert result.claims["upstream_claims"] == upstream_claims
519+
520+
@pytest.mark.asyncio
521+
async def test_load_access_token_does_not_mutate_cached_token(
522+
self, valid_oidc_configuration_dict
523+
):
524+
"""Test that load_access_token does not mutate the original AccessToken from verifier."""
525+
with patch(
526+
"fastmcp.server.auth.oidc_proxy.OIDCConfiguration.get_oidc_configuration"
527+
) as mock_get:
528+
oidc_config = OIDCConfiguration.model_validate(
529+
valid_oidc_configuration_dict
530+
)
531+
mock_get.return_value = oidc_config
532+
533+
proxy = OIDCProxy(
534+
config_url=TEST_CONFIG_URL,
535+
client_id=TEST_CLIENT_ID,
536+
client_secret=TEST_CLIENT_SECRET,
537+
base_url=TEST_BASE_URL,
538+
jwt_signing_key="test-secret",
539+
)
540+
proxy.set_mcp_path("/mcp")
541+
542+
# 1. Setup shared upstream token
543+
upstream_claims = {"user": "alice"}
544+
shared_claims = {"base": "claim"}
545+
# The original token returned by a verifier (potentially from cache)
546+
original_validated = AccessToken(
547+
token="shared-token",
548+
client_id="idp-client-id",
549+
scopes=["openid"],
550+
expires_at=int(time.time() + 3600),
551+
claims=shared_claims,
552+
)
553+
554+
# 2. Mock storage for first request
555+
proxy._jti_mapping_store = MagicMock()
556+
jti_mapping = JTIMapping(
557+
jti="jti-1",
558+
upstream_token_id="up-1",
559+
created_at=time.time(),
560+
)
561+
proxy._jti_mapping_store.get = AsyncMock(return_value=jti_mapping)
562+
563+
proxy._upstream_token_store = MagicMock()
564+
token_set = UpstreamTokenSet(
565+
upstream_token_id="up-1",
566+
access_token="shared-token",
567+
refresh_token=None,
568+
refresh_token_expires_at=None,
569+
expires_at=time.time() + 3600,
570+
token_type="Bearer",
571+
scope="openid",
572+
client_id=TEST_CLIENT_ID,
573+
created_at=time.time(),
574+
raw_token_data={"access_token": "shared-token"},
575+
)
576+
proxy._upstream_token_store.get = AsyncMock(return_value=token_set)
577+
578+
# Verifier returns the SHARED instance
579+
proxy._token_validator.verify_token = AsyncMock( # ty: ignore[invalid-assignment]
580+
return_value=original_validated
581+
)
582+
583+
# 3. First request with upstream_claims
584+
fastmcp_jwt_1 = proxy.jwt_issuer.issue_access_token(
585+
client_id=TEST_CLIENT_ID,
586+
scopes=["openid"],
587+
jti="jti-1",
588+
upstream_claims=upstream_claims,
589+
)
590+
result_1 = await proxy.load_access_token(fastmcp_jwt_1)
591+
assert result_1 is not None
592+
assert (
593+
cast(AccessToken, result_1).claims["upstream_claims"] == upstream_claims
594+
)
595+
596+
# 4. CRITICAL CHECK: The original object must NOT have been mutated
597+
assert "upstream_claims" not in original_validated.claims
598+
599+
# 5. Second request WITHOUT upstream_claims using same shared token
600+
jti_mapping_2 = JTIMapping(
601+
jti="jti-2",
602+
upstream_token_id="up-1", # Same upstream token ID
603+
created_at=time.time(),
604+
)
605+
proxy._jti_mapping_store.get = AsyncMock(return_value=jti_mapping_2)
606+
607+
fastmcp_jwt_2 = proxy.jwt_issuer.issue_access_token(
608+
client_id=TEST_CLIENT_ID,
609+
scopes=["openid"],
610+
jti="jti-2",
611+
# NO upstream_claims here
612+
)
613+
result_2 = await proxy.load_access_token(fastmcp_jwt_2)
614+
615+
assert result_2 is not None
616+
# If fix works, result_2.claims should NOT have "upstream_claims" leakage
617+
assert "upstream_claims" not in cast(AccessToken, result_2).claims
618+
assert cast(AccessToken, result_2).claims == shared_claims

0 commit comments

Comments
 (0)