|
1 | | -"""Tests for OIDC Proxy verify_id_token functionality.""" |
| 1 | +"""Tests for OIDC Proxy token management and propagation. |
2 | 2 |
|
3 | | -from unittest.mock import patch |
| 3 | +These tests cover the OIDCProxy's ability to issue, verify, and swap tokens |
| 4 | +between FastMCP and upstream identity providers. |
| 5 | +""" |
| 6 | + |
| 7 | +import time |
| 8 | +from typing import cast |
| 9 | +from unittest.mock import AsyncMock, MagicMock, patch |
4 | 10 |
|
5 | 11 | import pytest |
6 | 12 | from pydantic import AnyHttpUrl |
7 | 13 |
|
8 | | -from fastmcp.server.auth.oauth_proxy.models import UpstreamTokenSet |
| 14 | +from fastmcp.server.auth.auth import AccessToken |
| 15 | +from fastmcp.server.auth.oauth_proxy.models import JTIMapping, UpstreamTokenSet |
9 | 16 | from fastmcp.server.auth.oidc_proxy import OIDCConfiguration, OIDCProxy |
10 | 17 | from fastmcp.server.auth.providers.introspection import IntrospectionTokenVerifier |
11 | 18 | from fastmcp.server.auth.providers.jwt import JWTVerifier |
@@ -423,3 +430,189 @@ def test_scope_patch_applied_when_tokens_identical( |
423 | 430 | assert proxy._get_verification_token(token_set) == same_jwt |
424 | 431 | # The key point: even though the tokens are equal, the intent |
425 | 432 | # flag ensures load_access_token will patch scopes |
| 433 | + |
| 434 | + |
| 435 | +class TestUpstreamClaimsPropagation: |
| 436 | + """Tests for upstream claims propagation in load_access_token.""" |
| 437 | + |
| 438 | + @pytest.mark.asyncio |
| 439 | + async def test_load_access_token_preserves_upstream_claims( |
| 440 | + self, valid_oidc_configuration_dict |
| 441 | + ): |
| 442 | + """Test that upstream_claims in FastMCP JWT are merged into AccessToken.claims.""" |
| 443 | + with patch( |
| 444 | + "fastmcp.server.auth.oidc_proxy.OIDCConfiguration.get_oidc_configuration" |
| 445 | + ) as mock_get: |
| 446 | + oidc_config = OIDCConfiguration.model_validate( |
| 447 | + valid_oidc_configuration_dict |
| 448 | + ) |
| 449 | + mock_get.return_value = oidc_config |
| 450 | + |
| 451 | + proxy = OIDCProxy( |
| 452 | + config_url=TEST_CONFIG_URL, |
| 453 | + client_id=TEST_CLIENT_ID, |
| 454 | + client_secret=TEST_CLIENT_SECRET, |
| 455 | + base_url=TEST_BASE_URL, |
| 456 | + jwt_signing_key="test-secret", |
| 457 | + ) |
| 458 | + # Initialize JWT issuer |
| 459 | + proxy.set_mcp_path("/mcp") |
| 460 | + |
| 461 | + # 1. Issue a token with upstream_claims |
| 462 | + upstream_claims = {"sub": "idp-user-123", "email": "user@example.com"} |
| 463 | + fastmcp_jwt = proxy.jwt_issuer.issue_access_token( |
| 464 | + client_id=TEST_CLIENT_ID, |
| 465 | + scopes=["openid"], |
| 466 | + jti="test-jti", |
| 467 | + upstream_claims=upstream_claims, |
| 468 | + ) |
| 469 | + |
| 470 | + # 2. Mock storage and upstream verification |
| 471 | + # Mock the JTI mapping lookup |
| 472 | + proxy._jti_mapping_store = MagicMock() |
| 473 | + jti_mapping = JTIMapping( |
| 474 | + jti="test-jti", |
| 475 | + upstream_token_id="test-upstream-id", |
| 476 | + created_at=time.time(), |
| 477 | + ) |
| 478 | + proxy._jti_mapping_store.get = AsyncMock(return_value=jti_mapping) |
| 479 | + |
| 480 | + proxy._upstream_token_store = MagicMock() |
| 481 | + token_set = UpstreamTokenSet( |
| 482 | + upstream_token_id="test-upstream-id", |
| 483 | + access_token="idp-access-token", |
| 484 | + refresh_token=None, |
| 485 | + refresh_token_expires_at=None, |
| 486 | + expires_at=time.time() + 3600, |
| 487 | + token_type="Bearer", |
| 488 | + scope="openid", |
| 489 | + client_id=TEST_CLIENT_ID, |
| 490 | + created_at=time.time(), |
| 491 | + raw_token_data={"access_token": "idp-access-token"}, |
| 492 | + ) |
| 493 | + proxy._upstream_token_store.get = AsyncMock(return_value=token_set) |
| 494 | + |
| 495 | + # Mock the actual upstream token verification |
| 496 | + upstream_access_token = AccessToken( |
| 497 | + token="idp-access-token", |
| 498 | + client_id="idp-client-id", |
| 499 | + scopes=["openid"], |
| 500 | + expires_at=int(time.time() + 3600), |
| 501 | + claims={"provider_id": "999"}, |
| 502 | + ) |
| 503 | + proxy._token_validator.verify_token = AsyncMock( # ty: ignore[invalid-assignment] |
| 504 | + return_value=upstream_access_token |
| 505 | + ) |
| 506 | + |
| 507 | + # 3. Call load_access_token |
| 508 | + result = await proxy.load_access_token(fastmcp_jwt) |
| 509 | + |
| 510 | + # 4. Verify results |
| 511 | + assert result is not None |
| 512 | + if result is not None: |
| 513 | + result = cast(AccessToken, result) |
| 514 | + # Original upstream claims should be there |
| 515 | + assert result.claims["provider_id"] == "999" |
| 516 | + # Propagated upstream_claims should NOW be there (the fix) |
| 517 | + assert "upstream_claims" in result.claims |
| 518 | + assert result.claims["upstream_claims"] == upstream_claims |
| 519 | + |
| 520 | + @pytest.mark.asyncio |
| 521 | + async def test_load_access_token_does_not_mutate_cached_token( |
| 522 | + self, valid_oidc_configuration_dict |
| 523 | + ): |
| 524 | + """Test that load_access_token does not mutate the original AccessToken from verifier.""" |
| 525 | + with patch( |
| 526 | + "fastmcp.server.auth.oidc_proxy.OIDCConfiguration.get_oidc_configuration" |
| 527 | + ) as mock_get: |
| 528 | + oidc_config = OIDCConfiguration.model_validate( |
| 529 | + valid_oidc_configuration_dict |
| 530 | + ) |
| 531 | + mock_get.return_value = oidc_config |
| 532 | + |
| 533 | + proxy = OIDCProxy( |
| 534 | + config_url=TEST_CONFIG_URL, |
| 535 | + client_id=TEST_CLIENT_ID, |
| 536 | + client_secret=TEST_CLIENT_SECRET, |
| 537 | + base_url=TEST_BASE_URL, |
| 538 | + jwt_signing_key="test-secret", |
| 539 | + ) |
| 540 | + proxy.set_mcp_path("/mcp") |
| 541 | + |
| 542 | + # 1. Setup shared upstream token |
| 543 | + upstream_claims = {"user": "alice"} |
| 544 | + shared_claims = {"base": "claim"} |
| 545 | + # The original token returned by a verifier (potentially from cache) |
| 546 | + original_validated = AccessToken( |
| 547 | + token="shared-token", |
| 548 | + client_id="idp-client-id", |
| 549 | + scopes=["openid"], |
| 550 | + expires_at=int(time.time() + 3600), |
| 551 | + claims=shared_claims, |
| 552 | + ) |
| 553 | + |
| 554 | + # 2. Mock storage for first request |
| 555 | + proxy._jti_mapping_store = MagicMock() |
| 556 | + jti_mapping = JTIMapping( |
| 557 | + jti="jti-1", |
| 558 | + upstream_token_id="up-1", |
| 559 | + created_at=time.time(), |
| 560 | + ) |
| 561 | + proxy._jti_mapping_store.get = AsyncMock(return_value=jti_mapping) |
| 562 | + |
| 563 | + proxy._upstream_token_store = MagicMock() |
| 564 | + token_set = UpstreamTokenSet( |
| 565 | + upstream_token_id="up-1", |
| 566 | + access_token="shared-token", |
| 567 | + refresh_token=None, |
| 568 | + refresh_token_expires_at=None, |
| 569 | + expires_at=time.time() + 3600, |
| 570 | + token_type="Bearer", |
| 571 | + scope="openid", |
| 572 | + client_id=TEST_CLIENT_ID, |
| 573 | + created_at=time.time(), |
| 574 | + raw_token_data={"access_token": "shared-token"}, |
| 575 | + ) |
| 576 | + proxy._upstream_token_store.get = AsyncMock(return_value=token_set) |
| 577 | + |
| 578 | + # Verifier returns the SHARED instance |
| 579 | + proxy._token_validator.verify_token = AsyncMock( # ty: ignore[invalid-assignment] |
| 580 | + return_value=original_validated |
| 581 | + ) |
| 582 | + |
| 583 | + # 3. First request with upstream_claims |
| 584 | + fastmcp_jwt_1 = proxy.jwt_issuer.issue_access_token( |
| 585 | + client_id=TEST_CLIENT_ID, |
| 586 | + scopes=["openid"], |
| 587 | + jti="jti-1", |
| 588 | + upstream_claims=upstream_claims, |
| 589 | + ) |
| 590 | + result_1 = await proxy.load_access_token(fastmcp_jwt_1) |
| 591 | + assert result_1 is not None |
| 592 | + assert ( |
| 593 | + cast(AccessToken, result_1).claims["upstream_claims"] == upstream_claims |
| 594 | + ) |
| 595 | + |
| 596 | + # 4. CRITICAL CHECK: The original object must NOT have been mutated |
| 597 | + assert "upstream_claims" not in original_validated.claims |
| 598 | + |
| 599 | + # 5. Second request WITHOUT upstream_claims using same shared token |
| 600 | + jti_mapping_2 = JTIMapping( |
| 601 | + jti="jti-2", |
| 602 | + upstream_token_id="up-1", # Same upstream token ID |
| 603 | + created_at=time.time(), |
| 604 | + ) |
| 605 | + proxy._jti_mapping_store.get = AsyncMock(return_value=jti_mapping_2) |
| 606 | + |
| 607 | + fastmcp_jwt_2 = proxy.jwt_issuer.issue_access_token( |
| 608 | + client_id=TEST_CLIENT_ID, |
| 609 | + scopes=["openid"], |
| 610 | + jti="jti-2", |
| 611 | + # NO upstream_claims here |
| 612 | + ) |
| 613 | + result_2 = await proxy.load_access_token(fastmcp_jwt_2) |
| 614 | + |
| 615 | + assert result_2 is not None |
| 616 | + # If fix works, result_2.claims should NOT have "upstream_claims" leakage |
| 617 | + assert "upstream_claims" not in cast(AccessToken, result_2).claims |
| 618 | + assert cast(AccessToken, result_2).claims == shared_claims |
0 commit comments