diff --git a/nemoguardrails/library/privateai/actions.py b/nemoguardrails/library/privateai/actions.py index ade2e3abc..1c31efc2a 100644 --- a/nemoguardrails/library/privateai/actions.py +++ b/nemoguardrails/library/privateai/actions.py @@ -17,6 +17,7 @@ import logging import os +from urllib.parse import urlparse from nemoguardrails import RailsConfig from nemoguardrails.actions import action @@ -44,7 +45,8 @@ async def detect_pii(source: str, text: str, config: RailsConfig): server_endpoint = pai_config.server_endpoint enabled_entities = getattr(pai_config, source).entities - if "api.private-ai.com" in server_endpoint and not pai_api_key: + parsed_url = urlparse(server_endpoint) + if parsed_url.hostname == "api.private-ai.com" and not pai_api_key: raise ValueError( "PAI_API_KEY environment variable required for Private AI cloud API." ) diff --git a/nemoguardrails/library/privateai/request.py b/nemoguardrails/library/privateai/request.py index 9662c8856..dfa586bee 100644 --- a/nemoguardrails/library/privateai/request.py +++ b/nemoguardrails/library/privateai/request.py @@ -18,6 +18,7 @@ import json import logging from typing import Any, Dict, List, Optional +from urllib.parse import urlparse import aiohttp @@ -42,7 +43,8 @@ async def private_ai_detection_request( Returns: True if PII is detected, False otherwise. """ - if "api.private-ai.com" in server_endpoint and not api_key: + parsed_url = urlparse(server_endpoint) + if parsed_url.hostname == "api.private-ai.com" and not api_key: raise ValueError("'api_key' is required for Private AI cloud API.") payload: Dict[str, Any] = {