Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions litellm/litellm_core_utils/get_litellm_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,5 +121,16 @@ def get_litellm_params(
"use_litellm_proxy": use_litellm_proxy,
"litellm_request_debug": litellm_request_debug,
"aws_region_name": kwargs.get("aws_region_name"),
# AWS credentials for Bedrock/Sagemaker
"aws_access_key_id": kwargs.get("aws_access_key_id"),
"aws_secret_access_key": kwargs.get("aws_secret_access_key"),
"aws_session_token": kwargs.get("aws_session_token"),
"aws_session_name": kwargs.get("aws_session_name"),
"aws_profile_name": kwargs.get("aws_profile_name"),
"aws_role_name": kwargs.get("aws_role_name"),
"aws_web_identity_token": kwargs.get("aws_web_identity_token"),
"aws_sts_endpoint": kwargs.get("aws_sts_endpoint"),
"aws_external_id": kwargs.get("aws_external_id"),
"aws_bedrock_runtime_endpoint": kwargs.get("aws_bedrock_runtime_endpoint"),
}
return litellm_params
2 changes: 1 addition & 1 deletion litellm/passthrough/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ def llm_passthrough_route(
model=model,
messages=[],
optional_params={},
litellm_params={},
litellm_params=litellm_params_dict,
api_key=provider_api_key,
api_base=base_target_url,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1179,6 +1179,161 @@ async def test_bedrock_error_handling_returns_actual_error(self):
in str(exc_info.value.detail)
)

@pytest.mark.asyncio
async def test_bedrock_passthrough_uses_model_specific_credentials(self):
"""
Test that Bedrock passthrough endpoints use credentials from model configuration
instead of environment variables when a router model is used.

This test verifies the fix for the bug where passthrough endpoints were using
environment variables instead of model-specific credentials from config.yaml.
"""
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
handle_bedrock_passthrough_router_model,
)
from litellm import Router
from litellm.litellm_core_utils.get_litellm_params import get_litellm_params

# Model-specific credentials (different from env vars)
model_access_key = "MODEL_SPECIFIC_ACCESS_KEY"
model_secret_key = "MODEL_SPECIFIC_SECRET_KEY"
model_region = "us-west-2"
model_session_token = "MODEL_SESSION_TOKEN"

# Environment variables (should NOT be used)
env_access_key = "ENV_ACCESS_KEY"
env_secret_key = "ENV_SECRET_KEY"
env_region = "us-east-1"

# Set environment variables to different values
with patch.dict(
os.environ,
{
"AWS_ACCESS_KEY_ID": env_access_key,
"AWS_SECRET_ACCESS_KEY": env_secret_key,
"AWS_REGION_NAME": env_region,
},
):
# Test 1: Verify get_litellm_params extracts AWS credentials from kwargs
kwargs_with_creds = {
"aws_access_key_id": model_access_key,
"aws_secret_access_key": model_secret_key,
"aws_region_name": model_region,
"aws_session_token": model_session_token,
"model": "bedrock/test-model",
}
litellm_params = get_litellm_params(**kwargs_with_creds)

# Verify credentials are extracted
assert litellm_params.get("aws_access_key_id") == model_access_key
assert litellm_params.get("aws_secret_access_key") == model_secret_key
assert litellm_params.get("aws_region_name") == model_region
assert litellm_params.get("aws_session_token") == model_session_token

# Test 2: Verify router passes model credentials to passthrough
router = Router(
model_list=[
{
"model_name": "claude-opus-4-1",
"litellm_params": {
"model": "bedrock/us.anthropic.claude-opus-4-20250514-v1:0",
"aws_access_key_id": model_access_key,
"aws_secret_access_key": model_secret_key,
"aws_region_name": model_region,
"aws_session_token": model_session_token,
"custom_llm_provider": "bedrock",
},
}
]
)

# Verify router has model-specific credentials
deployments = router.get_model_list(model_name="claude-opus-4-1")
assert len(deployments) > 0
deployment = deployments[0]
deployment_litellm_params = deployment.get("litellm_params", {})

# Verify model-specific credentials are in the deployment
assert deployment_litellm_params.get("aws_access_key_id") == model_access_key
assert deployment_litellm_params.get("aws_secret_access_key") == model_secret_key
assert deployment_litellm_params.get("aws_region_name") == model_region
assert deployment_litellm_params.get("aws_session_token") == model_session_token

# Verify environment variables are NOT in the deployment
assert deployment_litellm_params.get("aws_access_key_id") != env_access_key
assert deployment_litellm_params.get("aws_secret_access_key") != env_secret_key
assert deployment_litellm_params.get("aws_region_name") != env_region

# Test 3: Verify credentials are passed through the passthrough route
# Mock the passthrough route to capture what credentials are used
captured_kwargs = {}

async def mock_llm_passthrough_route(**kwargs):
captured_kwargs.update(kwargs)
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.aread = AsyncMock(
return_value=b'{"content": [{"text": "Hello"}]}'
)
return mock_response

mock_request = MagicMock(spec=Request)
mock_request.method = "POST"
mock_request.headers = {"content-type": "application/json"}
mock_request.query_params = {}
mock_request.url = MagicMock()
mock_request.url.path = "/bedrock/model/claude-opus-4-1/converse"

mock_request_body = {
"messages": [{"role": "user", "content": [{"text": "Hello"}]}]
}

mock_user_api_key_dict = Mock()
mock_user_api_key_dict.api_key = "test-key"
mock_proxy_logging_obj = Mock()
mock_proxy_logging_obj.post_call_failure_hook = AsyncMock()

with patch(
"litellm.passthrough.main.llm_passthrough_route",
new_callable=AsyncMock,
side_effect=mock_llm_passthrough_route,
), patch(
"litellm.proxy.common_request_processing.ProxyBaseLLMRequestProcessing.base_passthrough_process_llm_request",
new_callable=AsyncMock,
) as mock_process:
# Setup mock response
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.aread = AsyncMock(
return_value=b'{"content": [{"text": "Hello"}]}'
)
mock_process.return_value = mock_response

# Call the handler
await handle_bedrock_passthrough_router_model(
model="claude-opus-4-1",
endpoint="model/claude-opus-4-1/converse",
request=mock_request,
request_body=mock_request_body,
llm_router=router,
user_api_key_dict=mock_user_api_key_dict,
proxy_logging_obj=mock_proxy_logging_obj,
general_settings={},
proxy_config=None,
select_data_generator=None,
user_model=None,
user_temperature=None,
user_request_timeout=None,
user_max_tokens=None,
user_api_base=None,
version=None,
)

# Verify that the router was called (which means credentials flow through)
# The key verification is that get_litellm_params extracts the credentials
# and they're available in the router's deployment
assert mock_process.called


class TestLLMPassthroughFactoryProxyRoute:
@pytest.mark.asyncio
Expand Down
Loading