From 622ff356ee36ee7b5e910a2c8a1100e89590dc51 Mon Sep 17 00:00:00 2001 From: WenjiaoYue Date: Mon, 16 Jun 2025 10:03:24 +0800 Subject: [PATCH 1/7] add Upstreaming E-RAG LLM Input/Output Guardrails --- .../opea_guardrails_microservice.py | 98 +- .../utils/llm_guard_input_guardrail.py | 146 ++ .../utils/llm_guard_input_scanners.py | 952 +++++++++++++ .../utils/llm_guard_output_guardrail.py | 98 ++ .../utils/llm_guard_output_scanners.py | 1214 +++++++++++++++++ 5 files changed, 2484 insertions(+), 24 deletions(-) create mode 100644 comps/guardrails/src/guardrails/utils/llm_guard_input_guardrail.py create mode 100644 comps/guardrails/src/guardrails/utils/llm_guard_input_scanners.py create mode 100644 comps/guardrails/src/guardrails/utils/llm_guard_output_guardrail.py create mode 100644 comps/guardrails/src/guardrails/utils/llm_guard_output_scanners.py diff --git a/comps/guardrails/src/guardrails/opea_guardrails_microservice.py b/comps/guardrails/src/guardrails/opea_guardrails_microservice.py index 90ec1c5bc9..72d7600031 100644 --- a/comps/guardrails/src/guardrails/opea_guardrails_microservice.py +++ b/comps/guardrails/src/guardrails/opea_guardrails_microservice.py @@ -3,14 +3,23 @@ import os import time +import asyncio from typing import Union +from dotenv import dotenv_values +from fastapi import HTTPException -from integrations.llamaguard import OpeaGuardrailsLlamaGuard -from integrations.wildguard import OpeaGuardrailsWildGuard +from utils.llm_guard_input_guardrail import ( + OPEALLMGuardInputGuardrail +) +from utils.llm_guard_output_guardrail import ( + OPEALLMGuardOutputGuardrail +) from comps import ( CustomLogger, GeneratedDoc, + LLMParamsDoc, + SearchedDoc, OpeaComponentLoader, ServiceType, TextDoc, @@ -20,9 +29,16 @@ statistics_dict, ) +from comps.cores.proto.api_protocol import ChatCompletionRequest, DocSumChatCompletionRequest + logger = CustomLogger("opea_guardrails_microservice") logflag = os.getenv("LOGFLAG", False) +usvc_config = { + **dotenv_values(".env"), + **os.environ +} + guardrails_component_name = os.getenv("GUARDRAILS_COMPONENT_NAME", "OPEA_LLAMA_GUARD") # Initialize OpeaComponentLoader loader = OpeaComponentLoader( @@ -31,6 +47,8 @@ description=f"OPEA Guardrails Component: {guardrails_component_name}", ) +input_guardrail = OPEALLMGuardInputGuardrail(usvc_config) +output_guardrail = OPEALLMGuardOutputGuardrail(usvc_config) @register_microservice( name="opea_service@guardrails", @@ -38,33 +56,65 @@ endpoint="/v1/guardrails", host="0.0.0.0", port=9090, - input_datatype=Union[GeneratedDoc, TextDoc], - output_datatype=TextDoc, + input_datatype=Union[LLMParamsDoc, GeneratedDoc, ChatCompletionRequest, SearchedDoc, ChatCompletionRequest, DocSumChatCompletionRequest], + output_datatype=Union[LLMParamsDoc, GeneratedDoc, ChatCompletionRequest, SearchedDoc, ChatCompletionRequest, DocSumChatCompletionRequest], ) @register_statistics(names=["opea_service@guardrails"]) -async def safety_guard(input: Union[GeneratedDoc, TextDoc]) -> TextDoc: - start = time.time() - - # Log the input if logging is enabled +async def safety_guard(input: Union[LLMParamsDoc, GeneratedDoc, ChatCompletionRequest, SearchedDoc, ChatCompletionRequest, DocSumChatCompletionRequest]) -> Union[LLMParamsDoc, GeneratedDoc, ChatCompletionRequest, SearchedDoc, ChatCompletionRequest, DocSumChatCompletionRequest]: + start_time = time.time() + if logflag: - logger.info(f"Input received: {input}") - + logger.info(f"Received input: {input}") + try: - # Use the loader to invoke the component - guardrails_response = await loader.invoke(input) - - # Log the result if logging is enabled - if logflag: - logger.info(f"Output received: {guardrails_response}") - - # Record statistics - statistics_dict["opea_service@guardrails"].append_latency(time.time() - start, None) - return guardrails_response - + if isinstance(input, LLMParamsDoc): + processed = input_guardrail.scan_llm_input(input) + + statistics_dict["opea_service@guardrails"].append_latency( + time.time() - start_time, + f"input_guard:{type(input).__name__}" + ) + + if logflag: + logger.info(f"Input guard passed: {processed}") + return processed + + elif isinstance(input, GeneratedDoc): + processed = output_guardrail.scan_llm_output(input) + + if os.getenv("APPLY_CONTENT_GUARD", "true").lower() == "true": + text_doc = TextDoc(text=processed.text) + content_guard_result = await loader.invoke(text_doc) + processed.text = content_guard_result.text + + statistics_dict["opea_service@guardrails"].append_latency( + time.time() - start_time, + f"output_guard:{type(input).__name__}" + ) + + if logflag: + logger.info(f"Output guard passed: {processed}") + return processed + + except HTTPException as e: + if e.status_code == 466: + logger.warning(f"Security rejection: {e.detail}") + statistics_dict["opea_service@guardrails"].append_latency( + time.time() - start_time, + f"rejection:{e.status_code}" + ) + raise e + except Exception as e: - logger.error(f"Error during guardrails invocation: {e}") - raise - + logger.error(f"Unexpected error: {str(e)}") + statistics_dict["opea_service@guardrails"].append_latency( + time.time() - start_time, + f"error:{type(e).__name__}" + ) + raise HTTPException( + status_code=500, + detail=f"Internal server error: {str(e)}" + ) if __name__ == "__main__": opea_microservices["opea_service@guardrails"].start() diff --git a/comps/guardrails/src/guardrails/utils/llm_guard_input_guardrail.py b/comps/guardrails/src/guardrails/utils/llm_guard_input_guardrail.py new file mode 100644 index 0000000000..7e72c358fa --- /dev/null +++ b/comps/guardrails/src/guardrails/utils/llm_guard_input_guardrail.py @@ -0,0 +1,146 @@ +# Copyright (C) 2024-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from llm_guard import scan_prompt +from fastapi import HTTPException + +from utils.llm_guard_input_scanners import InputScannersConfig +from comps import get_opea_logger, LLMParamsDoc + +logger = get_opea_logger("opea_llm_guard_input_guardrail_microservice") + + +class OPEALLMGuardInputGuardrail: + """ + OPEALLMGuardInputGuardrail is responsible for scanning and sanitizing LLM input prompts + using various input scanners provided by LLM Guard. + + This class initializes the input scanners based on the provided configuration and + scans the input prompts to ensure they meet the required guardrail criteria. + + Attributes: + _scanners (list): A list of enabled scanners. + + Methods: + __init__(usv_config: dict): + Initializes the OPEALLMGuardInputGuardrail with the provided configuration. + + scan_llm_input(input_doc: LLMParamsDoc) -> tuple[str, dict[str, bool], dict[str, float]]: + Scans the prompt from an LLMParamsDoc object and returns the sanitized prompt, + validation results, and scores. + """ + + def __init__(self, usv_config: dict): + """ + Initializes the OPEALLMGuardInputGuardrail with the provided configuration. + + Args: + usv_config (dict): The configuration dictionary for initializing the input scanners. + + Raises: + Exception: If an unexpected error occurs during the initialization of scanners. + """ + try: + self._scanners_config = InputScannersConfig(usv_config) + self._scanners = self._scanners_config.create_enabled_input_scanners() + except ValueError as e: + logger.exception(f"Value Error occured while initializing LLM Guard Input Guardrail scanners: {e}") + raise + except Exception as e: + logger.exception( + f"An unexpected error occured during initializing \ + LLM Guard Input Guardrail scanners: {e}" + ) + raise + + def _get_anonymize_vault(self): + anon = [item for item in self._scanners if type(item).__name__ == "Anonymize"] + if len(anon) > 0: + return anon[0]._vault.get() + return None + + def _recreate_anonymize_scanner_if_exists(self): + anon = [item for item in self._scanners if type(item).__name__ == "Anonymize"] + if len(anon) > 0: + logger.info(f"Anonymize scanner found: {len(anon)}") + self._scanners.remove(anon[0]) + self._scanners.append(self._scanners_config._create_anonymize_scanner()) + + def _analyze_scan_outputs(self, prompt, results_valid, results_score): + filtered_results_valid_no_redacted = {} + scanners_with_redact = ["BanCompetitors", "BanSubstrings", "OPEABanSubstrings", "Regex", "OPEARegexScanner"] + + for key, value in results_valid.items(): + if_redacted = False + redacted_scanner = [item for item in self._scanners if type(item).__name__ in scanners_with_redact and type(item).__name__ == key] + + if len(redacted_scanner) > 0: + if_redacted = redacted_scanner[0]._redact + + if key != 'Anonymize' and not if_redacted: + filtered_results_valid_no_redacted[key] = value + + if False in filtered_results_valid_no_redacted.values(): + msg = f"Prompt {prompt} is not valid, scores: {results_score}" + logger.error(f"{msg}") + usr_msg = "I'm sorry, I cannot assist you with your prompt." + raise HTTPException(status_code=466, detail=f"{usr_msg}") + + def scan_llm_input(self, input_doc: LLMParamsDoc) -> LLMParamsDoc: + """ + Scan the prompt from an LLMParamsDoc object. + + Args: + input_doc (LLMParamsDoc): The input document containing the prompt to be scanned. + + Returns: + tuple[str, dict[str, bool], dict[str, float]]: A tuple containing the sanitized prompt, + a dictionary of validation results, and a dictionary of scores. + + Raises: + HTTPException: If the prompt is not valid based on the scanner results. + """ + fresh_scanners = False + if input_doc.input_guardrail_params is not None: + if self._scanners_config.changed(input_doc.input_guardrail_params.dict()): + self._scanners = self._scanners_config.create_enabled_input_scanners() + fresh_scanners = True + else: + logger.warning("Input guardrail params not found in input document.") + if self._scanners: + if not fresh_scanners: + logger.info("Recreating anonymize scanner if exists to clear the Vault.") + self._recreate_anonymize_scanner_if_exists() + system_prompt = input_doc.messages.system + user_prompt = input_doc.messages.user + + # We want to block only user question with a TokenLimit Scanner + scanners_without_token_limit = [item for item in self._scanners if type(item).__name__ != "TokenLimit"] + if len(self._scanners) != scanners_without_token_limit: + sanitized_system_prompt, system_results_valid, system_results_score = scan_prompt(scanners_without_token_limit, system_prompt) + else: + sanitized_system_prompt, system_results_valid, system_results_score = scan_prompt(self._scanners, system_prompt) + + if "### Question:" in user_prompt: + # Default template is used + prefix = "### Question: " + suffix = " \n ### Answer:" + user_prompt_to_scan = user_prompt.split(prefix)[1].split(suffix)[0] + sanitized_user_prompt, user_results_valid, user_results_score = scan_prompt(self._scanners, user_prompt_to_scan) + sanitized_user_prompt = prefix + sanitized_user_prompt + suffix + else: + sanitized_user_prompt, user_results_valid, user_results_score = scan_prompt(self._scanners, user_prompt) + + self._analyze_scan_outputs(system_prompt, system_results_valid, system_results_score) + self._analyze_scan_outputs(user_prompt, user_results_valid, user_results_score) + + input_doc.messages.system = sanitized_system_prompt + input_doc.messages.user = sanitized_user_prompt + if input_doc.output_guardrail_params is not None and 'Anonymize' in user_results_valid: + input_doc.output_guardrail_params.anonymize_vault = self._get_anonymize_vault() + elif input_doc.output_guardrail_params is None and 'Anonymize' in user_results_valid: + logger.warning("No output guardrails params, could not append the vault for Anonymize scanner.") + return input_doc + else: + logger.info("No input scanners enabled. Skipping scanning.") + return input_doc diff --git a/comps/guardrails/src/guardrails/utils/llm_guard_input_scanners.py b/comps/guardrails/src/guardrails/utils/llm_guard_input_scanners.py new file mode 100644 index 0000000000..60092710ea --- /dev/null +++ b/comps/guardrails/src/guardrails/utils/llm_guard_input_scanners.py @@ -0,0 +1,952 @@ +# ruff: noqa: F401 +# Copyright (C) 2024-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +from llm_guard.vault import Vault +from llm_guard.input_scanners import ( + Anonymize, + BanCode, + BanCompetitors, + BanTopics, + Code, + Gibberish, + InvisibleText, + Language, + PromptInjection, + Secrets, + Sentiment, + TokenLimit, + Toxicity + ) + +# import models definition +from llm_guard.input_scanners.ban_code import ( + MODEL_SM as BANCODE_MODEL_SM, + MODEL_TINY as BANCODE_MODEL_TINY +) + +from llm_guard.input_scanners.ban_competitors import ( + MODEL_V1 as BANCOMPETITORS_MODEL_V1 +) + +from llm_guard.input_scanners.ban_topics import ( + MODEL_DEBERTA_LARGE_V2 as BANTOPICS_MODEL_DEBERTA_LARGE_V2, + MODEL_DEBERTA_BASE_V2 as BANTOPICS_MODEL_DEBERTA_BASE_V2, + MODEL_BGE_M3_V2 as BANTOPICS_MODEL_BGE_M3_V2, + MODEL_ROBERTA_LARGE_C_V2 as BANTOPICS_MODEL_ROBERTA_LARGE_C_V2, + MODEL_ROBERTA_BASE_C_V2 as BANTOPICS_MODEL_ROBERTA_BASE_C_V2 +) + +from llm_guard.input_scanners.code import ( + DEFAULT_MODEL as CODE_DEFAULT_MODEL +) + +from llm_guard.input_scanners.gibberish import ( + DEFAULT_MODEL as GIBBERISH_DEFAULT_MODEL, +) + +from llm_guard.input_scanners.language import ( + DEFAULT_MODEL as LANGUAGE_DEFAULT_MODEL, +) + +from llm_guard.input_scanners.prompt_injection import ( + V1_MODEL as PROMPTINJECTION_V1_MODEL, + V2_MODEL as PROMPTINJECTION_V2_MODEL, + V2_SMALL_MODEL as PROMPTINJECTION_V2_SMALL_MODEL, +) + +from llm_guard.input_scanners.toxicity import ( + DEFAULT_MODEL as TOXICITY_DEFAULT_MODEL +) + +ENABLED_SCANNERS = [ + 'anonymize', + 'ban_code', + 'ban_competitors', + 'ban_substrings', + 'ban_topics', + 'code', + 'gibberish', + 'invisible_text', + 'language', + 'prompt_injection', + 'regex', + 'secrets', + 'sentiment', + 'token_limit', + 'toxicity' +] + +from comps.guardrails.utils.scanners import OPEABanSubstrings, OPEARegexScanner +from comps import get_opea_logger, sanitize_env +logger = get_opea_logger("opea_llm_guard_input_guardrail_microservice") + + +class InputScannersConfig: + + def __init__(self, config_dict): + self._input_scanners_config = { + **self._get_anonymize_config_from_env(config_dict), + **self._get_ban_code_config_from_env(config_dict), + **self._get_ban_competitors_config_from_env(config_dict), + **self._get_ban_substrings_config_from_env(config_dict), + **self._get_ban_topics_config_from_env(config_dict), + **self._get_code_config_from_env(config_dict), + **self._get_gibberish_config_from_env(config_dict), + **self._get_invisible_text_config_from_env(config_dict), + **self._get_language_config_from_env(config_dict), + **self._get_prompt_injection_config_from_env(config_dict), + **self._get_regex_config_from_env(config_dict), + **self._get_secrets_config_from_env(config_dict), + **self._get_sentiment_config_from_env(config_dict), + **self._get_token_limit_config_from_env(config_dict), + **self._get_toxicity_config_from_env(config_dict) + } + +#### METHODS FOR VALIDATING CONFIGS + + def _validate_value(self, value): + """ + Validate and convert the input value. + + Args: + value (str): The value to be validated and converted. + + Returns: + bool | int | str: The validated and converted value. + """ + if value is None: + return None + elif value.isdigit(): + return float(value) + elif value.lower() == "true": + return True + elif value.lower() == "false": + return False + return value + + def _get_anonymize_config_from_env(self, config_dict): + """ + Get the Anonymize scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The anonymize scanner configuration. + """ + return { + "anonymize": { + k.replace("ANONYMIZE_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("ANONYMIZE_") + } + } + + def _get_ban_code_config_from_env(self, config_dict): + """ + Get the BanCode scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The BanCode scanner configuration. + """ + return { + "ban_code": { + k.replace("BAN_CODE_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("BAN_CODE_") + } + } + + def _get_ban_competitors_config_from_env(self, config_dict): + """ + Get the BanCompetitors scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The BanCompetitors scanner configuration. + """ + return { + "ban_competitors": { + k.replace("BAN_COMPETITORS_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("BAN_COMPETITORS_") + } + } + + def _get_ban_substrings_config_from_env(self, config_dict): + """ + Get the BanSubstrings scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The BanSubstrings scanner configuration. + """ + return { + "ban_substrings": { + k.replace("BAN_SUBSTRINGS_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("BAN_SUBSTRINGS_") + } + } + + def _get_ban_topics_config_from_env(self, config_dict): + """ + Get the BanTopics scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The BanTopics scanner configuration. + """ + return { + "ban_topics": { + k.replace("BAN_TOPICS_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("BAN_TOPICS_") + } + } + + def _get_code_config_from_env(self, config_dict): + """ + Get the Code scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The Code scanner configuration. + """ + return { + "code": { + k.replace("CODE_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("CODE_") + } + } + + def _get_gibberish_config_from_env(self, config_dict): + """ + Get the Gibberish scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The Gibberish scanner configuration. + """ + return { + "gibberish": { + k.replace("GIBBERISH_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("GIBBERISH_") + } + } + def _get_invisible_text_config_from_env(self, config_dict): + """ + Get the InvisibleText scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The InvisibleText scanner configuration. + """ + return { + "invisible_text": { + k.replace("INVISIBLE_TEXT_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("INVISIBLE_TEXT_") + } + } + + def _get_language_config_from_env(self, config_dict): + """ + Get the Language scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The Language scanner configuration. + """ + return { + "language": { + k.replace("LANGUAGE_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("LANGUAGE_") + } + } + + def _get_prompt_injection_config_from_env(self, config_dict): + """ + Get the PromptInjection scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The PromptInjection scanner configuration. + """ + return { + "prompt_injection": { + k.replace("PROMPT_INJECTION_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("PROMPT_INJECTION_") + } + } + + def _get_regex_config_from_env(self, config_dict): + """ + Get the Regex scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The Regex scanner configuration. + """ + return { + "regex": { + k.replace("REGEX_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("REGEX_") + } + } + + def _get_secrets_config_from_env(self, config_dict): + """ + Get the Secrets scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The Secrets scanner configuration. + """ + return { + "secrets": { + k.replace("SECRETS_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("SECRETS_") + } + } + + def _get_sentiment_config_from_env(self, config_dict): + """ + Get the Secrets scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The Sentiment scanner configuration. + """ + return { + "sentiment": { + k.replace("SENTIMENT_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("SENTIMENT_") + } + } + + def _get_token_limit_config_from_env(self, config_dict): + """ + Get the TokenLimit scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The TokenLimit scanner configuration. + """ + return { + "token_limit": { + k.replace("TOKEN_LIMIT_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("TOKEN_LIMIT_") + } + } + + def _get_toxicity_config_from_env(self, config_dict): + """ + Get the Toxicity scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The Toxicity scanner configuration. + """ + return { + "toxicity": { + k.replace("TOXICITY_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("TOXICITY_") + } + } + +#### METHODS FOR CREATING SCANNERS + + def _create_anonymize_scanner(self, scanner_config=None): + if scanner_config is None: + logger.warning("_create_anonymize_scanner was invoked without scanner_config. Recreating with saved config to clear the Vault.") + if hasattr(self, "_anonymize_params") and self._anonymize_params is not None: + scanner_config = self._anonymize_params + else: + raise ValueError("_create_anonymize_scanner was invoked without scanner_config but no self._anonymize_params were saved. Such action is not allowed.") + vault = Vault() + anonymize_params = {'vault': vault, 'use_onnx': scanner_config.get('use_onnx', False)} + hidden_names = scanner_config.get('hidden_names', None) + allowed_names = scanner_config.get('allowed_names', None) + entity_types = scanner_config.get('entity_types', None) + preamble = scanner_config.get('preamble', None) + regex_patterns = scanner_config.get('regex_patterns', None) + use_faker = scanner_config.get('use_faker', None) + recognizer_conf = scanner_config.get('recognizer_conf', None) + threshold = scanner_config.get('threshold', None) + language = scanner_config.get('language', None) + + if isinstance(hidden_names, str): + hidden_names = sanitize_env(hidden_names) + + if isinstance(allowed_names, str): + allowed_names = sanitize_env(allowed_names) + + if isinstance(entity_types, str): + entity_types = sanitize_env(entity_types) + + if isinstance(regex_patterns, str): + regex_patterns = sanitize_env(regex_patterns) + + if hidden_names is not None: + if isinstance(hidden_names, str): + artifacts = set([',', '', '.']) + anonymize_params['hidden_names'] = list(set(hidden_names.split(',')) - artifacts) + elif isinstance(hidden_names, list): + anonymize_params['hidden_names'] = hidden_names + else: + logger.error("Provided type is not valid for Anonymize scanner") + raise ValueError("Provided type is not valid for Anonymize scanner") + if allowed_names is not None: + if isinstance(allowed_names, str): + artifacts = set([',', '', '.']) + anonymize_params['allowed_names'] = list(set(allowed_names.split(',')) - artifacts) + elif isinstance(hidden_names, list): + anonymize_params['allowed_names'] = allowed_names + else: + logger.error("Provided type is not valid for Anonymize scanner") + raise ValueError("Provided type is not valid for Anonymize scanner") + if entity_types is not None: + if isinstance(entity_types, str): + artifacts = set([',', '', '.']) + anonymize_params['entity_types'] = list(set(entity_types.split(',')) - artifacts) + elif isinstance(hidden_names, list): + anonymize_params['entity_types'] = entity_types + else: + logger.error("Provided type is not valid for Anonymize scanner") + raise ValueError("Provided type is not valid for Anonymize scanner") + if preamble is not None: + anonymize_params['preamble'] = preamble + if regex_patterns is not None: + if isinstance(regex_patterns, str): + artifacts = set([',', '', '.']) + anonymize_params['regex_patterns'] = list(set(regex_patterns.split(',')) - artifacts) + elif isinstance(hidden_names, list): + anonymize_params['regex_patterns'] = regex_patterns + else: + logger.error("Provided type is not valid for Anonymize scanner") + raise ValueError("Provided type is not valid for Anonymize scanner") + if use_faker is not None: + anonymize_params['use_faker'] = use_faker + if recognizer_conf is not None: + anonymize_params['recognizer_conf'] = recognizer_conf + if threshold is not None: + anonymize_params['threshold'] = threshold + if language is not None: + anonymize_params['language'] = language + logger.info(f"Creating Anonymize scanner with params: {anonymize_params}") + self._anonymize_params = {key: value for key, value in anonymize_params.items() if key != 'vault'} + return Anonymize(**anonymize_params) + + def _create_ban_code_scanner(self, scanner_config): + enabled_models = {'MODEL_SM': BANCODE_MODEL_SM, 'MODEL_TINY': BANCODE_MODEL_TINY} + bancode_params = {'use_onnx': scanner_config.get('use_onnx', False)} # by default we dont't want to use onnx + + model_name = scanner_config.get('model', None) + threshold = scanner_config.get('threshold', None) + + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for BanCode scanner: {model_name}") + bancode_params['model'] = enabled_models[model_name] # Model class from LLM Guard + else: + err_msg = f"Model name is not valid for BanCode scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + if threshold is not None: + bancode_params['threshold'] = threshold # float + logger.info(f"Creating BanCode scanner with params: {bancode_params}") + return BanCode(**bancode_params) + + def _create_ban_competitors_scanner(self, scanner_config): + enabled_models = {'MODEL_V1': BANCOMPETITORS_MODEL_V1} + ban_competitors_params = {'use_onnx': scanner_config.get('use_onnx', False)} # by default we don't want to use onnx + + competitors = scanner_config.get('competitors', None) + threshold = scanner_config.get('threshold', None) + redact = scanner_config.get('redact', None) + model_name = scanner_config.get('model', None) + + if isinstance(competitors, str): + competitors = sanitize_env(competitors) + + if competitors: + if isinstance(competitors, str): + artifacts = set([',', '', '.']) + ban_competitors_params['competitors'] = list(set(competitors.split(',')) - artifacts) # list + elif isinstance(competitors, list): + ban_competitors_params['competitors'] = competitors + else: + logger.error("Provided type is not valid for BanCompetitors scanner") + raise ValueError("Provided type is not valid for BanCompetitors scanner") + else: + logger.error("Competitors list is required for BanCompetitors scanner. Please provide a list of competitors.") + raise TypeError("Competitors list is required for BanCompetitors scanner. Please provide a list of competitors.") + if threshold is not None: + ban_competitors_params['threshold'] = threshold # float + if redact is not None: + ban_competitors_params['redact'] = redact + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for BanCompetitors scanner: {model_name}") + ban_competitors_params['model'] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for BanCompetitors scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + logger.info(f"Creating BanCompetitors scanner with params: {ban_competitors_params}") + return BanCompetitors(**ban_competitors_params) + + def _create_ban_substrings_scanner(self, scanner_config): + available_match_types = ['str', 'word'] + ban_substrings_params = {} + + substrings = scanner_config.get('substrings', None) + match_type = scanner_config.get('match_type', None) + case_sensitive = scanner_config.get('case_sensitive', None) + redact = scanner_config.get('redact', None) + contains_all = scanner_config.get('contains_all', None) + + if isinstance(substrings, str): + substrings = sanitize_env(substrings) + + if substrings: + if isinstance(substrings, str): + artifacts = set([',', '', '.']) + ban_substrings_params['substrings'] = list(set(substrings.split(',')) - artifacts)# list + elif substrings and isinstance(substrings, list): + ban_substrings_params['substrings'] = substrings + else: + logger.error("Provided type is not valid for BanSubstrings scanner") + raise ValueError("Provided type is not valid for BanSubstrings scanner") + else: + logger.error("Substrings list is required for BanSubstrings scanner") + raise TypeError("Substrings list is required for BanSubstrings scanner") + if match_type is not None and match_type in available_match_types: + ban_substrings_params['match_type'] = match_type # MatchType + if case_sensitive is not None: + ban_substrings_params['case_sensitive'] = case_sensitive # bool + if redact is not None: + ban_substrings_params['redact'] = redact # bool + if contains_all is not None: + ban_substrings_params['contains_all'] = contains_all # bool + logger.info(f"Creating BanSubstrings scanner with params: {ban_substrings_params}") + return OPEABanSubstrings(**ban_substrings_params) + + def _create_ban_topics_scanner(self, scanner_config): + enabled_models = { + 'MODEL_DEBERTA_LARGE_V2': BANTOPICS_MODEL_DEBERTA_LARGE_V2, + 'MODEL_DEBERTA_BASE_V2': BANTOPICS_MODEL_DEBERTA_BASE_V2, + 'MODEL_BGE_M3_V2': BANTOPICS_MODEL_BGE_M3_V2, + 'MODEL_ROBERTA_LARGE_C_V2': BANTOPICS_MODEL_ROBERTA_LARGE_C_V2, + 'MODEL_ROBERTA_BASE_C_V2': BANTOPICS_MODEL_ROBERTA_BASE_C_V2 + } + ban_topics_params = {'use_onnx': scanner_config.get('use_onnx', False)} + + topics = scanner_config.get('topics', None) + threshold = scanner_config.get('threshold', None) + model_name = scanner_config.get('model', None) + + if isinstance(topics, str): + topics = sanitize_env(topics) + + if topics: + if isinstance(topics, str): + artifacts = set([',', '', '.']) + ban_topics_params['topics'] = list(set(topics.split(',')) - artifacts) + elif isinstance(topics, list): + ban_topics_params['topics'] = topics + else: + logger.error("Provided type is not valid for BanTopics scanner") + raise ValueError("Provided type is not valid for BanTopics scanner") + else: + logger.error("Topics list is required for BanTopics scanner") + raise TypeError("Topics list is required for BanTopics scanner") + if threshold is not None: + ban_topics_params['threshold'] = threshold + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for BanTopics scanner: {model_name}") + ban_topics_params['model'] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for BanTopics scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + logger.info(f"Creating BanTopics scanner with params: {ban_topics_params}") + return BanTopics(**ban_topics_params) + + def _create_code_scanner(self, scanner_config): + enabled_models = {'DEFAULT_MODEL': CODE_DEFAULT_MODEL} + code_params = {'use_onnx': scanner_config.get('use_onnx', False)} + + languages = scanner_config.get('languages', None) + model_name = scanner_config.get('model', None) + is_blocked = scanner_config.get('is_blocked', None) + threshold = scanner_config.get('threshold', None) + + if isinstance(languages, str): + languages = sanitize_env(languages) + + if languages: + if isinstance(languages, str): + artifacts = set([',', '', '.']) + code_params['languages'] = list(set(languages.split(',')) - artifacts) + elif isinstance(languages, list): + code_params['languages'] = languages + else: + logger.error("Provided type is not valid for Code scanner") + raise ValueError("Provided type is not valid for Code scanner") + else: + logger.error("Languages list is required for Code scanner") + raise TypeError("Languages list is required for Code scanner") + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for Code scanner: {model_name}") + code_params['model'] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for Code scanner. Please provide a valid model name. Provided model: {model_name}" + logger.error(err_msg) + raise ValueError(err_msg) + if is_blocked is not None: + code_params['is_blocked'] = is_blocked + if threshold is not None: + code_params['threshold'] = threshold + logger.info(f"Creating Code scanner with params: {code_params}") + return Code(**code_params) + + def _create_gibberish_scanner(self, scanner_config): + enabled_models = {'DEFAULT_MODEL': GIBBERISH_DEFAULT_MODEL} + enabled_match_types = ['sentence', 'full'] + gibberish_params = {'use_onnx': scanner_config.get('use_onnx', False)} + + model_name = scanner_config.get('model', None) + threshold = scanner_config.get('threshold', None) + match_type = scanner_config.get('match_type', None) + + if match_type == "sentence": + import nltk + nltk.download('punkt_tab') + + if threshold is not None: + gibberish_params['threshold'] = threshold + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for Gibberish scanner: {model_name}") + gibberish_params['model'] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for Gibberish scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + if match_type is not None and match_type in enabled_match_types: + gibberish_params['match_type'] = match_type + + logger.info(f"Creating Gibberish scanner with params: {gibberish_params}") + return Gibberish(**gibberish_params) + + def _create_invisible_text_scanner(self): + return InvisibleText() + + def _create_language_scanner(self, scanner_config): + enabled_models = {'DEFAULT_MODEL': LANGUAGE_DEFAULT_MODEL} + enabled_match_types = ['sentence', 'full'] + language_params = {'use_onnx': scanner_config.get('use_onnx', False)} + + valid_languages = scanner_config.get('valid_languages', None) + model_name = scanner_config.get('model', None) + threshold = scanner_config.get('threshold', None) + match_type = scanner_config.get('match_type', None) + + if isinstance(valid_languages, str): + valid_languages = sanitize_env(valid_languages) + + if valid_languages: + if isinstance(valid_languages, str): + artifacts = set([',', '', '.']) + language_params['valid_languages'] = list(set(valid_languages.split(',')) - artifacts) + elif isinstance(valid_languages, list): + language_params['valid_languages'] = valid_languages + else: + logger.error("Provided type is not valid for Language scanner") + raise ValueError("Provided type is not valid for Language scanner") + else: + logger.error("Valid languages list is required for Language scanner") + raise TypeError("Valid languages list is required for Language scanner") + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for Language scanner: {model_name}") + language_params['model'] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for Language scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + if threshold is not None: + language_params['threshold'] = threshold + if match_type is not None and match_type in enabled_match_types: + language_params['match_type'] = match_type + logger.info(f"Creating Language scanner with params: {language_params}") + return Language(**language_params) + + def _create_prompt_injection_scanner(self, scanner_config): + enabled_models = { + 'V1_MODEL': PROMPTINJECTION_V1_MODEL, + 'V2_MODEL': PROMPTINJECTION_V2_MODEL, + 'V2_SMALL_MODEL': PROMPTINJECTION_V2_SMALL_MODEL + } + enabled_match_types = ['sentence', 'full', "truncate_token_head_tail", "truncate_head_tail", "chunks"] + prompt_injection_params = {'use_onnx': scanner_config.get('use_onnx', False)} + + model_name = scanner_config.get('model', None) + threshold = scanner_config.get('threshold', None) + match_type = scanner_config.get('match_type', None) + + if match_type == "sentence": + import nltk + nltk.download('punkt_tab') + + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for PromptInjection scanner: {model_name}") + prompt_injection_params['model'] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for PromptInjection scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + if threshold is not None: + prompt_injection_params['threshold'] = threshold + if match_type is not None and match_type in enabled_match_types: + prompt_injection_params['match_type'] = match_type + logger.info(f"Creating PromptInjection scanner with params: {prompt_injection_params}") + return PromptInjection(**prompt_injection_params) + + def _create_regex_scanner(self, scanner_config): + enabled_match_types = ['search', 'fullmatch'] + regex_params = {} + + patterns = scanner_config.get('patterns', None) + is_blocked = scanner_config.get('is_blocked', None) + match_type = scanner_config.get('match_type', None) + redact = scanner_config.get('redact', None) + + if isinstance(patterns, str): + patterns = sanitize_env(patterns) + + if patterns: + if isinstance(patterns, str): + artifacts = set([',', '', '.']) + regex_params['patterns'] = list(set(patterns.split(',')) - artifacts) + elif isinstance(patterns, list): + regex_params['patterns'] = patterns + else: + logger.error("Provided type is not valid for Regex scanner") + raise ValueError("Provided type is not valid for Regex scanner") + else: + logger.error("Patterns list is required for Regex scanner") + raise TypeError("Patterns list is required for Regex scanner") + if is_blocked is not None: + regex_params['is_blocked'] = is_blocked + if match_type is not None and match_type in enabled_match_types: + regex_params['match_type'] = match_type + if redact is not None: + regex_params['redact'] = redact + + logger.info(f"Creating Regex scanner with params: {regex_params}") + return OPEARegexScanner(**regex_params) + + def _create_secrets_scanner(self, scanner_config): + enabled_redact_types = ['partial', 'all', 'hash'] + secrets_params = {} + + redact = scanner_config.get('redact', None) + + if redact is not None and redact in enabled_redact_types: + secrets_params['redact'] = redact + + logger.info(f"Creating Secrets scanner with params: {secrets_params}") + return Secrets(**secrets_params) + + def _create_sentiment_scanner(self, scanner_config): + enabled_lexicons = ["vader_lexicon"] + sentiment_params = {} + + threshold = scanner_config.get('threshold', None) + lexicon = scanner_config.get('lexicon', None) + + if threshold is not None: + sentiment_params['threshold'] = threshold + if lexicon is not None and lexicon in enabled_lexicons: + sentiment_params['lexicon'] = lexicon + + logger.info(f"Creating Sentiment scanner with params: {sentiment_params}") + return Sentiment(**sentiment_params) + + def _create_token_limit_scanner(self, scanner_config): + enabled_encodings = ['cl100k_base'] # TODO: test more encoding from tiktoken + token_limit_params = {} + + limit = int(scanner_config.get('limit', None)) + encoding_name = scanner_config.get('encoding', None) + + if limit is not None: + token_limit_params['limit'] = limit + if encoding_name is not None and encoding_name in enabled_encodings: + token_limit_params['encoding_name'] = encoding_name + + logger.info(f"Creating TokenLimit scanner with params: {token_limit_params}") + return TokenLimit(**token_limit_params) + + def _create_toxicity_scanner(self, scanner_config): + enabled_models = {'DEFAULT_MODEL': TOXICITY_DEFAULT_MODEL} + enabled_match_types = ['sentence', 'full'] + toxicity_params = {'use_onnx': scanner_config.get('use_onnx', False)} + + model_name = scanner_config.get('model', None) + threshold = scanner_config.get('threshold', None) + match_type = scanner_config.get('match_type', None) + + if match_type == "sentence": + import nltk + nltk.download('punkt_tab') + + + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for Toxicity scanner: {model_name}") + toxicity_params['model'] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for Toxicity scanner. Please provide a valid model name. Provided model: {model_name}" + logger.error(err_msg) + raise ValueError(err_msg) + if threshold is not None: + toxicity_params['threshold'] = threshold + if match_type is not None and match_type in enabled_match_types: + toxicity_params['match_type'] = match_type + + logger.info(f"Creating Toxicity scanner with params: {toxicity_params}") + return Toxicity(**toxicity_params) + + def _create_input_scanner(self, scanner_name, scanner_config): + if scanner_name not in ENABLED_SCANNERS: + logger.error(f"Scanner {scanner_name} is not supported. Enabled scanners are: {ENABLED_SCANNERS}") + raise ValueError(f"Scanner {scanner_name} is not supported") + if scanner_name == 'anonymize': + return self._create_anonymize_scanner(scanner_config) + elif scanner_name == 'ban_code': + return self._create_ban_code_scanner(scanner_config) + elif scanner_name == 'ban_competitors': + return self._create_ban_competitors_scanner(scanner_config) + elif scanner_name == 'ban_substrings': + return self._create_ban_substrings_scanner(scanner_config) + elif scanner_name == 'ban_topics': + return self._create_ban_topics_scanner(scanner_config) + elif scanner_name == 'code': + return self._create_code_scanner(scanner_config) + elif scanner_name == 'gibberish': + return self._create_gibberish_scanner(scanner_config) + elif scanner_name == 'invisible_text': + return self._create_invisible_text_scanner() + elif scanner_name == 'language': + return self._create_language_scanner(scanner_config) + elif scanner_name == 'prompt_injection': + return self._create_prompt_injection_scanner(scanner_config) + elif scanner_name == 'regex': + return self._create_regex_scanner(scanner_config) + elif scanner_name == 'secrets': + return self._create_secrets_scanner(scanner_config) + elif scanner_name == 'sentiment': + return self._create_sentiment_scanner(scanner_config) + elif scanner_name == 'token_limit': + return self._create_token_limit_scanner(scanner_config) + elif scanner_name == 'toxicity': + return self._create_toxicity_scanner(scanner_config) + return None + + def create_enabled_input_scanners(self): + """ + Create and return a list of enabled scanners based on the global configuration. + + Returns: + list: A list of enabled scanner instances. + """ + enabled_scanners_names_and_configs = {k: v for k, v in self._input_scanners_config.items() if v.get("enabled")} + enabled_scanners_objects = [] + + err_msgs = {} # list for all erronous scanners + only_validation_errors = True + for scanner_name, scanner_config in enabled_scanners_names_and_configs.items(): + try: + logger.info(f"Attempting to create scanner: {scanner_name}") + scanner_object = self._create_input_scanner(scanner_name, scanner_config) + enabled_scanners_objects.append(scanner_object) + except ValueError as e: + err_msg = f"A ValueError occured during creating input scanner {scanner_name}: {e}" + logger.error(err_msg) + err_msgs[scanner_name] = err_msg + self._input_scanners_config[scanner_name]["enabled"] = False + continue + except TypeError as e: + err_msg = f"A TypeError occured during creating input scanner {scanner_name}: {e}" + logger.error(err_msg) + err_msgs[scanner_name] = err_msg + self._input_scanners_config[scanner_name]["enabled"] = False + continue + except Exception as e: + err_msg = f"An unexpected error occured during creating input scanner {scanner_name}: {e}" + logger.error(err_msg) + err_msgs[scanner_name] = err_msg + only_validation_errors = False + self._input_scanners_config[scanner_name]["enabled"] = False + continue + + if err_msgs: + if only_validation_errors: + raise ValueError(f"Some scanners failed to be created due to validation errors. The details: {err_msgs}") + else: + raise Exception(f"Some scanners failed to be created due to validation or unexpected errors. The details: {err_msgs}") + + return [s for s in enabled_scanners_objects if s is not None] + + def changed(self, new_scanners_config): + """ + Check if the scanners configuration has changed. + + Args: + new_scanners_config (dict): The current scanners configuration. + + Returns: + bool: True if the configuration has changed, False otherwise. + """ + del new_scanners_config['id'] + newly_enabled_scanners = {k: {in_k: in_v for in_k, in_v in v.items() if in_k != 'id'} for k, v in new_scanners_config.items() if v.get("enabled")} + previously_enabled_scanners = {k: v for k, v in self._input_scanners_config.items() if v.get("enabled")} + if newly_enabled_scanners == previously_enabled_scanners: # if the enables scanners are the same we do nothing + logger.info("No changes in list for enabled scanners. Checking configuration changes...") + return False + else: + logger.warning("Sanners configuration has been changed, re-creating scanners") + self._input_scanners_config.clear() + stripped_new_scanners_config = {k: {in_k: in_v for in_k, in_v in v.items() if in_k != 'id'} for k, v in new_scanners_config.items()} + self._input_scanners_config.update(stripped_new_scanners_config) + return True diff --git a/comps/guardrails/src/guardrails/utils/llm_guard_output_guardrail.py b/comps/guardrails/src/guardrails/utils/llm_guard_output_guardrail.py new file mode 100644 index 0000000000..5176002a29 --- /dev/null +++ b/comps/guardrails/src/guardrails/utils/llm_guard_output_guardrail.py @@ -0,0 +1,98 @@ +# Copyright (C) 2024-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +from llm_guard import scan_output +from fastapi import HTTPException + +from utils.llm_guard_output_scanners import OutputScannersConfig +from comps import get_opea_logger, GeneratedDoc + +logger = get_opea_logger("opea_llm_guard_output_guardrail_microservice") + +class OPEALLMGuardOutputGuardrail: + """ + OPEALLMGuardOutputGuardrail is responsible for scanning and sanitizing LLM output responses + using various output scanners provided by LLM Guard. + + This class initializes the output scanners based on the provided configuration and + scans the output responses to ensure they meet the required guardrail criteria. + + Attributes: + _scanners (list): A list of enabled scanners. + + Methods: + __init__(usv_config: list): + Initializes the OPEALLMGuardOutputGuardrail with the provided configuration. + + scan_llm_output(output_doc: object) -> str: + Scans the output from an LLM output document and returns the sanitized output. + """ + + def __init__(self, usv_config: list): + """ + Initializes the OPEALLMGuardOutputGuardrail with the provided configuration. + + Args: + usv_config (list): The configuration list for initializing the output scanners. + + Raises: + Exception: If an unexpected error occurs during the initialization of scanners. + """ + try: + self._scanners_config = OutputScannersConfig(usv_config) + self._scanners = self._scanners_config.create_enabled_output_scanners() + except Exception as e: + logger.exception( + f"An unexpected error occured during initializing \ + LLM Guard Output Guardrail scanners: {e}" + ) + raise + + + def scan_llm_output(self, output_doc: GeneratedDoc) -> str: + """ + Scans the output from an LLM output document. + + Args: + output_doc (object): The output document containing the response to be scanned. + + Returns: + str: The sanitized output. + + Raises: + HTTPException: If the output is not valid based on the scanner results. + Exception: If an unexpected error occurs during scanning. + """ + try: + if output_doc.output_guardrail_params is not None: + self._scanners_config.vault = output_doc.output_guardrail_params.anonymize_vault + if self._scanners_config.changed(output_doc.output_guardrail_params.dict()): + self._scanners = self._scanners_config.create_enabled_output_scanners() + else: + logger.warning("Output guardrail params not found in input document.") + if self._scanners: + sanitized_output, results_valid, results_score = scan_output( + self._scanners, output_doc.prompt, output_doc.text + ) + if False in results_valid.values(): + msg = f"LLM Output {output_doc.text} is not valid, scores: {results_score}" + logger.error(msg) + usr_msg = "I'm sorry, but the model output is not valid according to the policies." + redact_or_truncated = [c.get('redact', False) or c.get('truncate', False) for _, c in self._scanners_config._output_scanners_config.items()] # to see if sanitized output available + if any(redact_or_truncated): + usr_msg = f"We sanitized the answer due to the guardrails policies: {sanitized_output}" + raise HTTPException(status_code=466, detail=usr_msg) + return sanitized_output + else: + logger.warning("No output scanners enabled. Skipping scanning.") + return output_doc.text + except HTTPException as e: + raise e + except ValueError as e: + error_msg = f"Validation Error occured while initializing LLM Guard Output Guardrail scanners: {e}" + logger.exception(error_msg) + raise HTTPException(status_code=400, detail=error_msg) + except Exception as e: + error_msg = f"An unexpected error occured during scanning prompt with LLM Guard Output Guardrail: {e}" + logger.exception(error_msg) + raise HTTPException(status_code=500, detail=error_msg) diff --git a/comps/guardrails/src/guardrails/utils/llm_guard_output_scanners.py b/comps/guardrails/src/guardrails/utils/llm_guard_output_scanners.py new file mode 100644 index 0000000000..514db7e964 --- /dev/null +++ b/comps/guardrails/src/guardrails/utils/llm_guard_output_scanners.py @@ -0,0 +1,1214 @@ +# ruff: noqa: F401 +# Copyright (C) 2024-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +from llm_guard.vault import Vault +from llm_guard.output_scanners import ( + BanCode, + BanCompetitors, + BanTopics, + Bias, + Code, + Deanonymize, + JSON, + Language, + LanguageSame, + MaliciousURLs, + NoRefusal, + NoRefusalLight, + ReadingTime, + FactualConsistency, + Gibberish, + Relevance, + Sensitive, + Sentiment, + Toxicity, + URLReachability +) + +# import models definition +from llm_guard.input_scanners.ban_code import ( #input, becasue the same scanner to input and output + MODEL_SM as BANCODE_MODEL_SM, + MODEL_TINY as BANCODE_MODEL_TINY +) + +from llm_guard.input_scanners.ban_competitors import ( #input, becasue the same scanner to input and output + MODEL_V1 as BANCOMPETITORS_MODEL_V1 +) + +from llm_guard.input_scanners.ban_topics import ( #input, becasue the same scanner to input and output + MODEL_DEBERTA_LARGE_V2 as BANTOPICS_MODEL_DEBERTA_LARGE_V2, + MODEL_DEBERTA_BASE_V2 as BANTOPICS_MODEL_DEBERTA_BASE_V2, + MODEL_BGE_M3_V2 as BANTOPICS_MODEL_BGE_M3_V2, + MODEL_ROBERTA_LARGE_C_V2 as BANTOPICS_MODEL_ROBERTA_LARGE_C_V2, + MODEL_ROBERTA_BASE_C_V2 as BANTOPICS_MODEL_ROBERTA_BASE_C_V2 +) + +from llm_guard.output_scanners.bias import ( + DEFAULT_MODEL as BIAS_DEFAULT_MODEL +) + +from llm_guard.input_scanners.code import ( + DEFAULT_MODEL as CODE_DEFAULT_MODEL +) + +from llm_guard.input_scanners.gibberish import ( + DEFAULT_MODEL as GIBBERISH_DEFAULT_MODEL +) + +from llm_guard.input_scanners.language import ( + DEFAULT_MODEL as LANGUAGE_DEFAULT_MODEL, +) + +from llm_guard.output_scanners.malicious_urls import ( + DEFAULT_MODEL as MALICIOUS_URLS_DEFAULT_MODEL +) + +from llm_guard.output_scanners.no_refusal import ( + DEFAULT_MODEL as NO_REFUSAL_DEFAULT_MODEL +) + +from llm_guard.output_scanners.relevance import ( + MODEL_EN_BGE_BASE as RELEVANCE_MODEL_EN_BGE_BASE, + MODEL_EN_BGE_LARGE as RELEVANCE_MODEL_EN_BGE_LARGE, + MODEL_EN_BGE_SMALL as RELEVANCE_MODEL_EN_BGE_SMALL +) + +from llm_guard.input_scanners.toxicity import ( + DEFAULT_MODEL as TOXICITY_DEFAULT_MODEL +) + +ENABLED_SCANNERS = [ + 'ban_code', + 'ban_competitors', + 'ban_substrings', + 'ban_topics', + 'bias', + 'code', + 'deanonymize', + 'json_scanner', + 'language', + 'language_same', + 'malicious_urls', + 'no_refusal', + 'no_refusal_light', + 'reading_time', + 'factual_consistency', + 'gibberish', + 'regex', + 'relevance', + 'sensitive', + 'sentiment', + 'toxicity', + 'url_reachability' +] + +from comps.guardrails.utils.scanners import OPEABanSubstrings, OPEARegexScanner +from comps import get_opea_logger, sanitize_env +logger = get_opea_logger("opea_llm_guard_output_guardrail_microservice") + +class OutputScannersConfig: + def __init__(self, config_dict): + self._output_scanners_config = { + **self._get_ban_code_config_from_env(config_dict), + **self._get_ban_competitors_config_from_env(config_dict), + **self._get_ban_substrings_config_from_env(config_dict), + **self._get_ban_topics_config_from_env(config_dict), + **self._get_bias_config_from_env(config_dict), + **self._get_code_config_from_env(config_dict), + **self._get_deanonymize_config_from_env(config_dict), + **self._get_json_scanner_config_from_env(config_dict), + **self._get_language_config_from_env(config_dict), + **self._get_language_same_config_from_env(config_dict), + **self._get_malicious_urls_config_from_env(config_dict), + **self._get_no_refusal_config_from_env(config_dict), + **self._get_no_refusal_light_config_from_env(config_dict), + **self._get_reading_time_config_from_env(config_dict), + **self._get_factual_consistency_config_from_env(config_dict), + **self._get_gibberish_config_from_env(config_dict), + **self._get_regex_config_from_env(config_dict), + **self._get_relevance_config_from_env(config_dict), + **self._get_sensitive_config_from_env(config_dict), + **self._get_sentiment_config_from_env(config_dict), + **self._get_toxicity_config_from_env(config_dict), + **self._get_url_reachability_config_from_env(config_dict) + } + self.vault = None + +#### METHODS FOR VALIDATING CONFIGS + + def _validate_value(self, value): + """ + Validate and convert the input value. + + Args: + value (str): The value to be validated and converted. + + Returns: + bool | int | str: The validated and converted value. + """ + if value is None: + return None + elif value.isdigit(): + return float(value) + elif value.lower() == "true": + return True + elif value.lower() == "false": + return False + return value + + def _get_ban_code_config_from_env(self, config_dict): + """ + Get the BanCode scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The BanCode scanner configuration. + """ + return { + "ban_code": { + k.replace("BAN_CODE_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("BAN_CODE_") + } + } + + def _get_ban_competitors_config_from_env(self, config_dict): + """ + Get the BanCompetitors scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The BanCompetitors scanner configuration. + """ + return { + "ban_competitors": { + k.replace("BAN_COMPETITORS_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("BAN_COMPETITORS_") + } + } + + def _get_ban_substrings_config_from_env(self, config_dict): + """ + Get the BanSubstrings scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The BanSubstrings scanner configuration. + """ + return { + "ban_substrings": { + k.replace("BAN_SUBSTRINGS_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("BAN_SUBSTRINGS_") + } + } + + def _get_ban_topics_config_from_env(self, config_dict): + """ + Get the BanTopics scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The BanTopics scanner configuration. + """ + return { + "ban_topics": { + k.replace("BAN_TOPICS_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("BAN_TOPICS_") + } + } + + def _get_bias_config_from_env(self, config_dict): + """ + Get the Bias scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The Bias scanner configuration. + """ + return { + "bias": { + k.replace("BIAS_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("BIAS_") + } + } + + def _get_code_config_from_env(self, config_dict): + """ + Get the Code scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The Code scanner configuration. + """ + return { + "code": { + k.replace("CODE_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("CODE_") + } + } + + def _get_deanonymize_config_from_env(self, config_dict): + """ + Get the Deanonymize scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The deanonymize scanner configuration. + """ + return { + "deanonymize": { + k.replace("DEANONYMIZE_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("DEANONYMIZE_") + } + } + + def _get_json_scanner_config_from_env(self, config_dict): + """ + Get the JSON scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The JSON scanner configuration. + """ + return { + "json_scanner": { + k.replace("JSON_SCANNER_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("JSON_SCANNER_") + } + } + + def _get_language_config_from_env(self, config_dict): + """ + Get the Language scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The Language scanner configuration. + """ + return { + "language": { + k.replace("LANGUAGE_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("LANGUAGE_") + } + } + + def _get_language_same_config_from_env(self, config_dict): + """ + Get the LanguageSame scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The LanguageSame scanner configuration. + """ + return { + "language_same": { + k.replace("LANGUAGE_SAME_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("LANGUAGE_SAME_") + } + } + + def _get_malicious_urls_config_from_env(self, config_dict): + """ + Get the MaliciousURLs scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The MaliciousURLs scanner configuration. + """ + return { + "malicious_urls": { + k.replace("MALICIOUS_URLS_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("MALICIOUS_URLS_") + } + } + + def _get_no_refusal_config_from_env(self, config_dict): + """ + Get the NoRefusal scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The NoRefusal scanner configuration. + """ + return { + "no_refusal": { + k.replace("NO_REFUSAL_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("NO_REFUSAL_") + } + } + + def _get_no_refusal_light_config_from_env(self, config_dict): + """ + Get the NoRefusalLight scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The NoRefusalLight scanner configuration. + """ + return { + "no_refusal_light": { + k.replace("NO_REFUSAL_LIGHT_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("NO_REFUSAL_LIGHT_") + } + } + + def _get_reading_time_config_from_env(self, config_dict): + """ + Get the ReadingTime scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The ReadingTime scanner configuration. + """ + return { + "reading_time": { + k.replace("READING_TIME_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("READING_TIME_") + } + } + + def _get_factual_consistency_config_from_env(self, config_dict): + """ + Get the FactualConsitency scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The FactualConsitency scanner configuration. + """ + return { + "factual_consistency": { + k.replace("FACTUAL_CONSISTENCY_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("FACTUAL_CONSISTENCY_") + } + } + + def _get_gibberish_config_from_env(self, config_dict): + """ + Get the Gibberish scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The Gibberish scanner configuration. + """ + return { + "gibberish": { + k.replace("GIBBERISH_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("GIBBERISH_") + } + } + + def _get_regex_config_from_env(self, config_dict): + """ + Get the Regex scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The Regex scanner configuration. + """ + return { + "regex": { + k.replace("REGEX_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("REGEX_") + } + } + + def _get_relevance_config_from_env(self, config_dict): + """ + Get the Relevance scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The Relevance scanner configuration. + """ + return { + "relevance": { + k.replace("RELEVANCE_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("RELEVANCE_") + } + } + + def _get_sensitive_config_from_env(self, config_dict): + """ + Get the Sensitive scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The Sensitive scanner configuration. + """ + return { + "sensitive": { + k.replace("SENSITIVE_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("SENSITIVE_") + } + } + + def _get_sentiment_config_from_env(self, config_dict): + """ + Get the Sentiment scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The Sentiment scanner configuration. + """ + return { + "sentiment": { + k.replace("SENTIMENT_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("SENTIMENT_") + } + } + + def _get_toxicity_config_from_env(self, config_dict): + """ + Get the Toxicity scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The Toxicity scanner configuration. + """ + return { + "toxicity": { + k.replace("TOXICITY_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("TOXICITY_") + } + } + + def _get_url_reachability_config_from_env(self, config_dict): + """ + Get the URLReachability scanner configuration from the environment. + + Args: + config_dict (dict): The configuration dictionary. + + Returns: + dict: The URLReachability scanner configuration. + """ + return { + "url_reachability": { + k.replace("URL_REACHABILITY_", "").lower(): self._validate_value(v) + for k, v in config_dict.items() if k.startswith("URL_REACHABILITY_") + } + } + +#### METHODS FOR CREATING SCANNERS + + def _create_ban_code_scanner(self, scanner_config): + enabled_models = {'MODEL_SM': BANCODE_MODEL_SM, 'MODEL_TINY': BANCODE_MODEL_TINY} + bancode_params = {'use_onnx': scanner_config.get('use_onnx', False)} # by default we don't want to use onnx + + model_name = scanner_config.get('model', None) + threshold = scanner_config.get('threshold', None) + + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for BanCode scanner: {model_name}") + bancode_params['model'] = enabled_models[model_name] # Model class from LLM Guard + else: + err_msg = f"Model name is not valid for BanCode scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + if threshold is not None: + bancode_params['threshold'] = threshold + logger.info(f"Creating BanCode scanner with params: {bancode_params}") + return BanCode(**bancode_params) + + def _create_ban_competitors_scanner(self, scanner_config): + enabled_models = {'MODEL_V1': BANCOMPETITORS_MODEL_V1} + ban_competitors_params = {'use_onnx': scanner_config.get('use_onnx', False)} # by default we want don't to use onnx + + competitors = scanner_config.get('competitors', None) + threshold = scanner_config.get('threshold', None) + redact = scanner_config.get('redact', None) + model_name = scanner_config.get('model', None) + + if isinstance(competitors, str): + competitors = sanitize_env(competitors) + + if competitors: + if isinstance(competitors, str): + artifacts = set([',', '', '.']) + ban_competitors_params['competitors'] = list(set(competitors.split(',')) - artifacts) + elif isinstance(competitors, list): + ban_competitors_params['competitors'] = competitors + else: + logger.error("Provided type is not valid for BanCompetitors scanner") + raise ValueError("Provided type is not valid for BanCompetitors scanner") + else: + logger.error("Competitors list is required for BanCompetitors scanner. Please provide a list of competitors.") + raise TypeError("Competitors list is required for BanCompetitors scanner. Please provide a list of competitors.") + if threshold is not None: + ban_competitors_params['threshold'] = threshold + if redact is not None: + ban_competitors_params['redact'] = redact + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for BanCompetitors scanner: {model_name}") + ban_competitors_params['model'] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for BanCompetitors scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + logger.info(f"Creating BanCompetitors scanner with params: {ban_competitors_params}") + return BanCompetitors(**ban_competitors_params) + + def _create_ban_substrings_scanner(self, scanner_config): + available_match_types = ['str', 'word'] + ban_substrings_params = {} + + substrings = scanner_config.get('substrings', None) + match_type = scanner_config.get('match_type', None) + case_sensitive = scanner_config.get('case_sensitive', None) + redact = scanner_config.get('redact', None) + contains_all = scanner_config.get('contains_all', None) + + if isinstance(substrings, str): + substrings = sanitize_env(substrings) + + if substrings: + if isinstance(substrings, str): + artifacts = set([',', '', '.']) + ban_substrings_params['substrings'] = list(set(substrings.split(',')) - artifacts) + elif substrings and isinstance(substrings, list): + ban_substrings_params['substrings'] = substrings + else: + logger.error("Provided type is not valid for BanSubstrings scanner") + raise ValueError("Provided type is not valid for BanSubstrings scanner") + else: + logger.error("Substrings list is required for BanSubstrings scanner") + raise TypeError("Substrings list is required for BanSubstrings scanner") + if match_type is not None and match_type in available_match_types: + ban_substrings_params['match_type'] = match_type + if case_sensitive is not None: + ban_substrings_params['case_sensitive'] = case_sensitive + if redact is not None: + ban_substrings_params['redact'] = redact + if contains_all is not None: + ban_substrings_params['contains_all'] = contains_all + logger.info(f"Creating BanSubstrings scanner with params: {ban_substrings_params}") + return OPEABanSubstrings(**ban_substrings_params) + + def _create_ban_topics_scanner(self, scanner_config): + enabled_models = { + 'MODEL_DEBERTA_LARGE_V2': BANTOPICS_MODEL_DEBERTA_LARGE_V2, + 'MODEL_DEBERTA_BASE_V2': BANTOPICS_MODEL_DEBERTA_BASE_V2, + 'MODEL_BGE_M3_V2': BANTOPICS_MODEL_BGE_M3_V2, + 'MODEL_ROBERTA_LARGE_C_V2': BANTOPICS_MODEL_ROBERTA_LARGE_C_V2, + 'MODEL_ROBERTA_BASE_C_V2': BANTOPICS_MODEL_ROBERTA_BASE_C_V2 + } + ban_topics_params = {'use_onnx': scanner_config.get('use_onnx', False)} + + topics = scanner_config.get('topics', None) + threshold = scanner_config.get('threshold', None) + model_name = scanner_config.get('model', None) + + if isinstance(topics, str): + topics = sanitize_env(topics) + + if topics: + if isinstance(topics, str): + artifacts = set([',', '', '.']) + ban_topics_params['topics'] = list(set(topics.split(',')) - artifacts) + elif isinstance(topics, list): + ban_topics_params['topics'] = topics + else: + logger.error("Provided type is not valid for BanTopics scanner") + raise ValueError("Provided type is not valid for BanTopics scanner") + else: + logger.error("Topics list is required for BanTopics scanner") + raise TypeError("Topics list is required for BanTopics scanner") + if threshold is not None: + ban_topics_params['threshold'] = threshold + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for BanTopics scanner: {model_name}") + ban_topics_params['model'] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for BanTopics scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + logger.info(f"Creating BanTopics scanner with params: {ban_topics_params}") + return BanTopics(**ban_topics_params) + + def _create_bias_scanner(self, scanner_config): + available_match_types = ['str', 'word'] + enabled_models = {'DEFAULT_MODEL': BIAS_DEFAULT_MODEL} + bias_params = {'use_onnx': scanner_config.get('use_onnx', False)} + + threshold = scanner_config.get('threshold', None) + match_type = scanner_config.get('match_type', None) + model_name = scanner_config.get('model', None) + + if threshold is not None: + bias_params['threshold'] = threshold + if match_type is not None and match_type in available_match_types: + bias_params['match_type'] = match_type + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for Bias scanner: {model_name}") + bias_params['model'] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for Bias scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + + logger.info(f"Creating Bias scanner with params: {bias_params}") + return Bias(**bias_params) + + def _create_code_scanner(self, scanner_config): + enabled_models = {'DEFAULT_MODEL': CODE_DEFAULT_MODEL} + code_params = {'use_onnx': scanner_config.get('use_onnx', False)} + + languages = scanner_config.get('languages', None) + model_name = scanner_config.get('model', None) + is_blocked = scanner_config.get('is_blocked', None) + threshold = scanner_config.get('threshold', None) + + if isinstance(languages, str): + languages = sanitize_env(languages) + + if languages: + if isinstance(languages, str): + artifacts = set([',', '', '.']) + code_params['languages'] = list(set(languages.split(',')) - artifacts) + elif isinstance(languages, list): + code_params['languages'] = languages + else: + logger.error("Provided type is not valid for Code scanner") + raise ValueError("Provided type is not valid for Code scanner") + else: + logger.error("Languages list is required for Code scanner") + raise TypeError("Languages list is required for Code scanner") + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for Code scanner: {model_name}") + code_params['model'] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for Code scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + if is_blocked is not None: + code_params['is_blocked'] = is_blocked + if threshold is not None: + code_params['threshold'] = threshold + logger.info(f"Creating Code scanner with params: {code_params}") + return Code(**code_params) + + def _create_deanonymize_scanner(self, scanner_config, vault): + if not vault: + raise Exception("Vault is required for Deanonymize scanner") + deanonymize_params = {'vault': vault} + + matching_strategy = scanner_config.get('matching_strategy', None) + if matching_strategy is not None: + deanonymize_params['matching_strategy'] = matching_strategy + + logger.info(f"Creating Deanonymize scanner with params: {deanonymize_params}") + return Deanonymize(**deanonymize_params) + + def _create_json_scanner(self, scanner_config): + json_scanner_params = {} + + required_elements = scanner_config.get('required_elements', None) + repair = scanner_config.get('repair', None) + + if required_elements is not None: + json_scanner_params['required_elements'] = required_elements + if repair is not None: + json_scanner_params['repair'] = repair + + logger.info(f"Creating JSON scanner with params: {json_scanner_params}") + return JSON(**json_scanner_params) + + def _create_language_scanner(self, scanner_config): + enabled_models = {'DEFAULT_MODEL': LANGUAGE_DEFAULT_MODEL} + enabled_match_types = ['sentence', 'full'] + language_params = {'use_onnx': scanner_config.get('use_onnx', False)} + + valid_languages = scanner_config.get('valid_languages', None) + model_name = scanner_config.get('model', None) + threshold = scanner_config.get('threshold', None) + match_type = scanner_config.get('match_type', None) + + if isinstance(valid_languages, str): + valid_languages = sanitize_env(valid_languages) + + if valid_languages: + if isinstance(valid_languages, str): + artifacts = set([',', '', '.']) + language_params['valid_languages'] = list(set(valid_languages.split(',')) - artifacts) + elif isinstance(valid_languages, list): + language_params['valid_languages'] = valid_languages + else: + logger.error("Provided type is not valid for Language scanner") + raise ValueError("Provided type is not valid for Language scanner") + else: + logger.error("Valid languages list is required for Language scanner") + raise TypeError("Valid languages list is required for Language scanner") + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for Language scanner: {model_name}") + language_params['model'] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for Language scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + if threshold is not None: + language_params['threshold'] = threshold + if match_type is not None and match_type in enabled_match_types: + language_params['match_type'] = match_type + logger.info(f"Creating Language scanner with params: {language_params}") + return Language(**language_params) + + def _create_language_same_scanner(self, scanner_config): + enabled_models = {'DEFAULT_MODEL': LANGUAGE_DEFAULT_MODEL} + language_same_params = {'use_onnx': scanner_config.get('use_onnx', False)} + + model_name = scanner_config.get('model', None) + threshold = scanner_config.get('threshold', None) + + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for LanguageSame scanner: {model_name}") + language_same_params['model'] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for LanguageSame scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + if threshold is not None: + language_same_params['threshold'] = threshold + + logger.info(f"Creating LanguageSame scanner with params: {language_same_params}") + return LanguageSame(**language_same_params) + + def _create_malicious_urls_scanner(self, scanner_config): + enabled_models = {'DEFAULT_MODEL': MALICIOUS_URLS_DEFAULT_MODEL} + malicious_urls_params = {'use_onnx': scanner_config.get('use_onnx', False)} + + threshold = scanner_config.get('threshold', None) + model_name = scanner_config.get('model', None) + + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for MaliciousURLs scanner: {model_name}") + malicious_urls_params['model'] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for MaliciousURLs scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + if threshold is not None: + malicious_urls_params['threshold'] = threshold + + logger.info(f"Creating MaliciousURLs scanner with params: {malicious_urls_params}") + return MaliciousURLs(**malicious_urls_params) + + def _create_no_refusal_scanner(self, scanner_config): + enabled_models = {'DEFAULT_MODEL': NO_REFUSAL_DEFAULT_MODEL} + enabled_match_types = ['sentence', 'full'] + no_refusal_params = {'use_onnx': scanner_config.get('use_onnx', False)} + + threshold = scanner_config.get('threshold', None) + model_name = scanner_config.get('model', None) + match_type = scanner_config.get('match_type', None) + + if threshold is not None: + no_refusal_params['threshold'] = threshold + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for NoRefusal scanner: {model_name}") + no_refusal_params['model'] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for NoRefusal scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + if match_type is not None and match_type in enabled_match_types: + no_refusal_params['match_type'] = match_type + + logger.info(f"Creating NoRefusal scanner with params: {no_refusal_params}") + return NoRefusal(**no_refusal_params) + + def _create_no_refusal_light_scanner(self): + logger.info("Creating NoRefusalLight scanner.") + return NoRefusalLight() + + def _create_reading_time_scanner(self, scanner_config): + reading_time_params = {} + + max_time = scanner_config.get('max_time', None) + truncate = scanner_config.get('truncate', None) + + if max_time is not None: + reading_time_params['max_time'] = float(max_time) + else: + logger.error("Max time is required for ReadingTime scanner") + raise TypeError("Max time is required for ReadingTime scanner") + if truncate is not None: + reading_time_params['truncate'] = truncate + + logger.info(f"Creating ReadingTime scanner with params: {reading_time_params}") + return ReadingTime(**reading_time_params) + + def _create_factual_consistency_scanner(self, scanner_config): + enabled_models = {"DEFAULT_MODEL": BANTOPICS_MODEL_DEBERTA_BASE_V2} # BanTopics model is used as deault in FactualConsistency + factual_consistency_params = {'use_onnx': scanner_config.get('use_onnx', False)} + + model_name = scanner_config.get('model_name', None) + minimum_score = scanner_config.get('minimum_score', None) + + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for NoRefusal scanner: {model_name}") + factual_consistency_params['model'] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for NoRefusal scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + if minimum_score is not None: + factual_consistency_params['minimum_score'] = minimum_score + + logger.info(f"Creating FactualConsistency scanner with params: {factual_consistency_params}") + return FactualConsistency(**factual_consistency_params) + + def _create_gibberish_scanner(self, scanner_config): + enabled_models = {'DEFAULT_MODEL': GIBBERISH_DEFAULT_MODEL} + enabled_match_types = ['sentence', 'full'] + gibberish_params = {'use_onnx': scanner_config.get('use_onnx', False)} + + model_name = scanner_config.get('model', None) + threshold = scanner_config.get('threshold', None) + match_type = scanner_config.get('match_type', None) + + if match_type == "sentence": + import nltk + nltk.download('punkt_tab') + + if threshold is not None: + gibberish_params['threshold'] = threshold + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for Gibberish scanner: {model_name}") + gibberish_params['model'] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for Gibberish scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + if match_type is not None and match_type in enabled_match_types: + gibberish_params['match_type'] = match_type + + logger.info(f"Creating Gibberish scanner with params: {gibberish_params}") + return Gibberish(**gibberish_params) + + def _create_regex_scanner(self, scanner_config): + enabled_match_types = ['search', 'fullmatch'] + regex_params = {} + + patterns = scanner_config.get('patterns', None) + is_blocked = scanner_config.get('is_blocked', None) + match_type = scanner_config.get('match_type', None) + redact = scanner_config.get('redact', None) + + if isinstance(patterns, str): + patterns = sanitize_env(patterns) + + if patterns: + if isinstance(patterns, str): + artifacts = set([',', '', '.']) + regex_params['patterns'] = list(set(patterns.split(',')) - artifacts) + elif isinstance(patterns, list): + regex_params['patterns'] = patterns + else: + logger.error("Provided type is not valid for Regex scanner") + raise ValueError("Provided type is not valid for Regex scanner") + else: + logger.error("Patterns list is required for Regex scanner") + raise TypeError("Patterns list is required for Regex scanner") + if is_blocked is not None: + regex_params['is_blocked'] = is_blocked + if match_type is not None and match_type in enabled_match_types: + regex_params['match_type'] = match_type + if redact is not None: + regex_params['redact'] = redact + + logger.info(f"Creating Regex scanner with params: {regex_params}") + return OPEARegexScanner(**regex_params) + + def _create_relevance_scanner(self, scanner_config): + enabled_models = {'MODEL_EN_BGE_BASE': RELEVANCE_MODEL_EN_BGE_BASE, + 'MODEL_EN_BGE_LARGE': RELEVANCE_MODEL_EN_BGE_LARGE, + 'MODEL_EN_BGE_SMALL': RELEVANCE_MODEL_EN_BGE_SMALL} + relevance_params = {'use_onnx': scanner_config.get('use_onnx', False)} # TODO: onnx off, because of bug on LLM Guard side + + model_name = scanner_config.get('model', None) + threshold = scanner_config.get('threshold', None) + + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for Gibberish scanner: {model_name}") + relevance_params['model'] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for Relevance scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + if threshold is not None: + relevance_params['threshold'] = threshold + + logger.info(f"Creating Relevance scanner with params: {relevance_params}") + return Relevance(**relevance_params) + + def _create_sensitive_scanner(self, scanner_config): + sensitive_params = {'use_onnx': scanner_config.get('use_onnx', False)} + + entity_types = scanner_config.get('entity_types', None) + regex_patterns = scanner_config.get('regex_patterns', None) + redact = scanner_config.get('redact', None) + recognizer_conf = scanner_config.get('recognizer_conf', None) + threshold = scanner_config.get('threshold', None) + + if entity_types is not None: + if isinstance(entity_types, str): + entity_types = sanitize_env(entity_types) + + if entity_types: + if isinstance(entity_types, str): + artifacts = set([',', '', '.']) + sensitive_params['entity_types'] = list(set(entity_types.split(',')) - artifacts) + elif isinstance(entity_types, list): + sensitive_params['entity_types'] = entity_types + else: + logger.error("Provided type is not valid for Sensitive scanner") + raise ValueError("Provided type is not valid for Sensitive scanner") + + if regex_patterns is not None: + sensitive_params['regex_patterns'] = regex_patterns + if redact is not None: + sensitive_params['redact'] = redact + if recognizer_conf is not None: + sensitive_params['recognizer_conf'] = recognizer_conf + if threshold is not None: + sensitive_params['threshold'] = threshold + + logger.info(f"Creating Sensitive scanner with params: {sensitive_params}") + return Sensitive(**sensitive_params) + + def _create_sentiment_scanner(self, scanner_config): + enabled_lexicons = ["vader_lexicon"] + sentiment_params = {} + + threshold = scanner_config.get('threshold', None) + lexicon = scanner_config.get('lexicon', None) + + if threshold is not None: + sentiment_params['threshold'] = threshold + if lexicon is not None and lexicon in enabled_lexicons: + sentiment_params['lexicon'] = lexicon + + logger.info(f"Creating Sentiment scanner with params: {sentiment_params}") + return Sentiment(**sentiment_params) + + def _create_toxicity_scanner(self, scanner_config): + enabled_models = {'DEFAULT_MODEL': TOXICITY_DEFAULT_MODEL} + enabled_match_types = ['sentence', 'full'] + toxicity_params = {'use_onnx': scanner_config.get('use_onnx', False)} + + model_name = scanner_config.get('model', None) + threshold = scanner_config.get('threshold', None) + match_type = scanner_config.get('match_type', None) + + if match_type == "sentence": + import nltk + nltk.download('punkt_tab') + + + if model_name is not None: + if model_name in enabled_models: + logger.info(f"Using selected model for Toxicity scanner: {model_name}") + toxicity_params['model'] = enabled_models[model_name] + else: + err_msg = f"Model name is not valid for Toxicity scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" + logger.error(err_msg) + raise ValueError(err_msg) + if threshold is not None: + toxicity_params['threshold'] = threshold + if match_type is not None and match_type in enabled_match_types: + toxicity_params['match_type'] = match_type + + logger.info(f"Creating Toxicity scanner with params: {toxicity_params}") + return Toxicity(**toxicity_params) + + def _create_url_reachability_scanner(self, scanner_config): + url_reachability_params = {} + + success_status_codes = scanner_config.get('success_status_codes', None) + timeout = scanner_config.get('timeout', None) + + if success_status_codes is not None: + if isinstance(success_status_codes, str): + artifacts = set([',', '', '.']) + url_reachability_params['success_status_codes'] = list(set(success_status_codes.split(',')) - artifacts) + elif isinstance(success_status_codes, list): + url_reachability_params['success_status_codes'] = success_status_codes + else: + logger.error("Provided type is not valid for Language scanner") + raise ValueError("Provided type is not valid for Language scanner") + if timeout is not None: + url_reachability_params['timeout'] = timeout + + logger.info(f"Creating URLReachability scanner with params: {url_reachability_params}") + return URLReachability(**url_reachability_params) + + def _create_output_scanner(self, scanner_name, scanner_config, vault=None): + if scanner_name not in ENABLED_SCANNERS: + logger.error(f"Scanner {scanner_name} is not supported") + raise Exception(f"Scanner {scanner_name} is not supported. Enabled scanners are: {ENABLED_SCANNERS}") + if scanner_name == "ban_code": + return self._create_ban_code_scanner(scanner_config) + elif scanner_name == "ban_competitors": + return self._create_ban_competitors_scanner(scanner_config) + elif scanner_name == "ban_substrings": + return self._create_ban_substrings_scanner(scanner_config) + elif scanner_name == "ban_topics": + return self._create_ban_topics_scanner(scanner_config) + elif scanner_name == "bias": + return self._create_bias_scanner(scanner_config) + elif scanner_name == "code": + return self._create_code_scanner(scanner_config) + elif scanner_name == "deanonymize": + return self._create_deanonymize_scanner(scanner_config, vault) + elif scanner_name == "json_scanner": + return self._create_json_scanner(scanner_config) + elif scanner_name == "language": + return self._create_language_scanner(scanner_config) + elif scanner_name == "language_same": + return self._create_language_same_scanner(scanner_config) + elif scanner_name == "malicious_urls": + return self._create_malicious_urls_scanner(scanner_config) + elif scanner_name == "no_refusal": + return self._create_no_refusal_scanner(scanner_config) + elif scanner_name == "no_refusal_light": + return self._create_no_refusal_light_scanner() + elif scanner_name == "reading_time": + return self._create_reading_time_scanner(scanner_config) + elif scanner_name == "factual_consistency": + return self._create_factual_consistency_scanner(scanner_config) + elif scanner_name == "gibberish": + return self._create_gibberish_scanner(scanner_config) + elif scanner_name == "regex": + return self._create_regex_scanner(scanner_config) + elif scanner_name == "relevance": + return self._create_relevance_scanner(scanner_config) + elif scanner_name == "sensitive": + return self._create_sensitive_scanner(scanner_config) + elif scanner_name == "sentiment": + return self._create_sentiment_scanner(scanner_config) + elif scanner_name == "toxicity": + return self._create_toxicity_scanner(scanner_config) + elif scanner_name == "url_reachability": + return self._create_url_reachability_scanner(scanner_config) + return None + + def create_enabled_output_scanners(self): + """ + Create and return a list of enabled scanners based on the global configuration. + + Returns: + list: A list of enabled scanner instances. + """ + enabled_scanners_names_and_configs = {k: v for k, v in self._output_scanners_config.items() if isinstance(v, dict) and v.get("enabled")} + enabled_scanners_objects = [] + + err_msgs = {} # list for all erronous scanners + only_validation_errors = True + for scanner_name, scanner_config in enabled_scanners_names_and_configs.items(): + try: + logger.info(f"Attempting to create scanner: {scanner_name}") + scanner_object = self._create_output_scanner(scanner_name, scanner_config, vault=self.vault) + enabled_scanners_objects.append(scanner_object) + except ValueError as e: + err_msg = f"A ValueError occured during creating output scanner {scanner_name}: {e}" + logger.error(err_msg) + err_msgs[scanner_name] = err_msg + self._output_scanners_config[scanner_name]["enabled"] = False + continue + except TypeError as e: + err_msg = f"A TypeError occured during creating output scanner {scanner_name}: {e}" + logger.error(err_msg) + err_msgs[scanner_name] = err_msg + self._output_scanners_config[scanner_name]["enabled"] = False + continue + except Exception as e: + err_msg = f"An unexpected error occured during creating output scanner {scanner_name}: {e}" + logger.error(err_msg) + err_msgs[scanner_name] = err_msg + self._output_scanners_config[scanner_name]["enabled"] = False + only_validation_errors = False + continue + + if err_msgs: + if only_validation_errors: + raise ValueError(f"Some scanners failed to be created due to validation errors. The details: {err_msgs}") + else: + raise Exception(f"Some scanners failed to be created due to validation or unexpected errors. The details: {err_msgs}") + + return [s for s in enabled_scanners_objects if s is not None] + + def changed(self, new_scanners_config): + """ + Check if the scanners configuration has changed. + + Args: + new_scanners_config (dict): The current scanners configuration. + + Returns: + bool: True if the configuration has changed, False otherwise. + """ + del new_scanners_config['id'] + newly_enabled_scanners = {k: {in_k: in_v for in_k, in_v in v.items() if in_k != 'id'} for k, v in new_scanners_config.items() if isinstance(v, dict) and v.get("enabled")} + previously_enabled_scanners = {k: v for k, v in self._output_scanners_config.items() if isinstance(v, dict) and v.get("enabled")} + if newly_enabled_scanners == previously_enabled_scanners: # if the enabled scanners are the same we do nothing + logger.info("No changes in list for enabled scanners. Checking configuration changes...") + return False + else: + logger.warning("Sanners configuration has been changed, re-creating scanners") + self._output_scanners_config.clear() + stripped_new_scanners_config = {k: {in_k: in_v for in_k, in_v in v.items() if in_k != 'id'} for k, v in new_scanners_config.items() if isinstance(v, dict)} + self._output_scanners_config.update(stripped_new_scanners_config) + return True From 4bb337992d1e36525f6aa41149606996c8d3f814 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 16 Jun 2025 02:06:10 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../opea_guardrails_microservice.py | 92 +-- .../utils/llm_guard_input_guardrail.py | 43 +- .../utils/llm_guard_input_scanners.py | 586 ++++++++------- .../utils/llm_guard_output_guardrail.py | 30 +- .../utils/llm_guard_output_scanners.py | 697 +++++++++--------- 5 files changed, 732 insertions(+), 716 deletions(-) diff --git a/comps/guardrails/src/guardrails/opea_guardrails_microservice.py b/comps/guardrails/src/guardrails/opea_guardrails_microservice.py index 72d7600031..35e6161e71 100644 --- a/comps/guardrails/src/guardrails/opea_guardrails_microservice.py +++ b/comps/guardrails/src/guardrails/opea_guardrails_microservice.py @@ -1,26 +1,22 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import asyncio import os import time -import asyncio from typing import Union + from dotenv import dotenv_values from fastapi import HTTPException - -from utils.llm_guard_input_guardrail import ( - OPEALLMGuardInputGuardrail -) -from utils.llm_guard_output_guardrail import ( - OPEALLMGuardOutputGuardrail -) +from utils.llm_guard_input_guardrail import OPEALLMGuardInputGuardrail +from utils.llm_guard_output_guardrail import OPEALLMGuardOutputGuardrail from comps import ( CustomLogger, GeneratedDoc, LLMParamsDoc, - SearchedDoc, OpeaComponentLoader, + SearchedDoc, ServiceType, TextDoc, opea_microservices, @@ -28,16 +24,12 @@ register_statistics, statistics_dict, ) - from comps.cores.proto.api_protocol import ChatCompletionRequest, DocSumChatCompletionRequest logger = CustomLogger("opea_guardrails_microservice") logflag = os.getenv("LOGFLAG", False) -usvc_config = { - **dotenv_values(".env"), - **os.environ -} +usvc_config = {**dotenv_values(".env"), **os.environ} guardrails_component_name = os.getenv("GUARDRAILS_COMPONENT_NAME", "OPEA_LLAMA_GUARD") # Initialize OpeaComponentLoader @@ -50,71 +42,89 @@ input_guardrail = OPEALLMGuardInputGuardrail(usvc_config) output_guardrail = OPEALLMGuardOutputGuardrail(usvc_config) + @register_microservice( name="opea_service@guardrails", service_type=ServiceType.GUARDRAIL, endpoint="/v1/guardrails", host="0.0.0.0", port=9090, - input_datatype=Union[LLMParamsDoc, GeneratedDoc, ChatCompletionRequest, SearchedDoc, ChatCompletionRequest, DocSumChatCompletionRequest], - output_datatype=Union[LLMParamsDoc, GeneratedDoc, ChatCompletionRequest, SearchedDoc, ChatCompletionRequest, DocSumChatCompletionRequest], + input_datatype=Union[ + LLMParamsDoc, + GeneratedDoc, + ChatCompletionRequest, + SearchedDoc, + ChatCompletionRequest, + DocSumChatCompletionRequest, + ], + output_datatype=Union[ + LLMParamsDoc, + GeneratedDoc, + ChatCompletionRequest, + SearchedDoc, + ChatCompletionRequest, + DocSumChatCompletionRequest, + ], ) @register_statistics(names=["opea_service@guardrails"]) -async def safety_guard(input: Union[LLMParamsDoc, GeneratedDoc, ChatCompletionRequest, SearchedDoc, ChatCompletionRequest, DocSumChatCompletionRequest]) -> Union[LLMParamsDoc, GeneratedDoc, ChatCompletionRequest, SearchedDoc, ChatCompletionRequest, DocSumChatCompletionRequest]: +async def safety_guard( + input: Union[ + LLMParamsDoc, + GeneratedDoc, + ChatCompletionRequest, + SearchedDoc, + ChatCompletionRequest, + DocSumChatCompletionRequest, + ], +) -> Union[ + LLMParamsDoc, GeneratedDoc, ChatCompletionRequest, SearchedDoc, ChatCompletionRequest, DocSumChatCompletionRequest +]: start_time = time.time() - + if logflag: logger.info(f"Received input: {input}") - + try: if isinstance(input, LLMParamsDoc): processed = input_guardrail.scan_llm_input(input) - + statistics_dict["opea_service@guardrails"].append_latency( - time.time() - start_time, - f"input_guard:{type(input).__name__}" + time.time() - start_time, f"input_guard:{type(input).__name__}" ) - + if logflag: logger.info(f"Input guard passed: {processed}") return processed - + elif isinstance(input, GeneratedDoc): processed = output_guardrail.scan_llm_output(input) - + if os.getenv("APPLY_CONTENT_GUARD", "true").lower() == "true": text_doc = TextDoc(text=processed.text) content_guard_result = await loader.invoke(text_doc) processed.text = content_guard_result.text - + statistics_dict["opea_service@guardrails"].append_latency( - time.time() - start_time, - f"output_guard:{type(input).__name__}" + time.time() - start_time, f"output_guard:{type(input).__name__}" ) - + if logflag: logger.info(f"Output guard passed: {processed}") return processed - + except HTTPException as e: if e.status_code == 466: logger.warning(f"Security rejection: {e.detail}") statistics_dict["opea_service@guardrails"].append_latency( - time.time() - start_time, - f"rejection:{e.status_code}" + time.time() - start_time, f"rejection:{e.status_code}" ) raise e - + except Exception as e: logger.error(f"Unexpected error: {str(e)}") - statistics_dict["opea_service@guardrails"].append_latency( - time.time() - start_time, - f"error:{type(e).__name__}" - ) - raise HTTPException( - status_code=500, - detail=f"Internal server error: {str(e)}" - ) + statistics_dict["opea_service@guardrails"].append_latency(time.time() - start_time, f"error:{type(e).__name__}") + raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") + if __name__ == "__main__": opea_microservices["opea_service@guardrails"].start() diff --git a/comps/guardrails/src/guardrails/utils/llm_guard_input_guardrail.py b/comps/guardrails/src/guardrails/utils/llm_guard_input_guardrail.py index 7e72c358fa..1d1b724a25 100644 --- a/comps/guardrails/src/guardrails/utils/llm_guard_input_guardrail.py +++ b/comps/guardrails/src/guardrails/utils/llm_guard_input_guardrail.py @@ -1,18 +1,17 @@ # Copyright (C) 2024-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from llm_guard import scan_prompt from fastapi import HTTPException - +from llm_guard import scan_prompt from utils.llm_guard_input_scanners import InputScannersConfig -from comps import get_opea_logger, LLMParamsDoc + +from comps import LLMParamsDoc, get_opea_logger logger = get_opea_logger("opea_llm_guard_input_guardrail_microservice") class OPEALLMGuardInputGuardrail: - """ - OPEALLMGuardInputGuardrail is responsible for scanning and sanitizing LLM input prompts + """OPEALLMGuardInputGuardrail is responsible for scanning and sanitizing LLM input prompts using various input scanners provided by LLM Guard. This class initializes the input scanners based on the provided configuration and @@ -31,8 +30,7 @@ class OPEALLMGuardInputGuardrail: """ def __init__(self, usv_config: dict): - """ - Initializes the OPEALLMGuardInputGuardrail with the provided configuration. + """Initializes the OPEALLMGuardInputGuardrail with the provided configuration. Args: usv_config (dict): The configuration dictionary for initializing the input scanners. @@ -44,11 +42,11 @@ def __init__(self, usv_config: dict): self._scanners_config = InputScannersConfig(usv_config) self._scanners = self._scanners_config.create_enabled_input_scanners() except ValueError as e: - logger.exception(f"Value Error occured while initializing LLM Guard Input Guardrail scanners: {e}") + logger.exception(f"Value Error occurred while initializing LLM Guard Input Guardrail scanners: {e}") raise except Exception as e: logger.exception( - f"An unexpected error occured during initializing \ + f"An unexpected error occurred during initializing \ LLM Guard Input Guardrail scanners: {e}" ) raise @@ -72,12 +70,16 @@ def _analyze_scan_outputs(self, prompt, results_valid, results_score): for key, value in results_valid.items(): if_redacted = False - redacted_scanner = [item for item in self._scanners if type(item).__name__ in scanners_with_redact and type(item).__name__ == key] + redacted_scanner = [ + item + for item in self._scanners + if type(item).__name__ in scanners_with_redact and type(item).__name__ == key + ] if len(redacted_scanner) > 0: if_redacted = redacted_scanner[0]._redact - if key != 'Anonymize' and not if_redacted: + if key != "Anonymize" and not if_redacted: filtered_results_valid_no_redacted[key] = value if False in filtered_results_valid_no_redacted.values(): @@ -87,8 +89,7 @@ def _analyze_scan_outputs(self, prompt, results_valid, results_score): raise HTTPException(status_code=466, detail=f"{usr_msg}") def scan_llm_input(self, input_doc: LLMParamsDoc) -> LLMParamsDoc: - """ - Scan the prompt from an LLMParamsDoc object. + """Scan the prompt from an LLMParamsDoc object. Args: input_doc (LLMParamsDoc): The input document containing the prompt to be scanned. @@ -117,16 +118,22 @@ def scan_llm_input(self, input_doc: LLMParamsDoc) -> LLMParamsDoc: # We want to block only user question with a TokenLimit Scanner scanners_without_token_limit = [item for item in self._scanners if type(item).__name__ != "TokenLimit"] if len(self._scanners) != scanners_without_token_limit: - sanitized_system_prompt, system_results_valid, system_results_score = scan_prompt(scanners_without_token_limit, system_prompt) + sanitized_system_prompt, system_results_valid, system_results_score = scan_prompt( + scanners_without_token_limit, system_prompt + ) else: - sanitized_system_prompt, system_results_valid, system_results_score = scan_prompt(self._scanners, system_prompt) + sanitized_system_prompt, system_results_valid, system_results_score = scan_prompt( + self._scanners, system_prompt + ) if "### Question:" in user_prompt: # Default template is used prefix = "### Question: " suffix = " \n ### Answer:" user_prompt_to_scan = user_prompt.split(prefix)[1].split(suffix)[0] - sanitized_user_prompt, user_results_valid, user_results_score = scan_prompt(self._scanners, user_prompt_to_scan) + sanitized_user_prompt, user_results_valid, user_results_score = scan_prompt( + self._scanners, user_prompt_to_scan + ) sanitized_user_prompt = prefix + sanitized_user_prompt + suffix else: sanitized_user_prompt, user_results_valid, user_results_score = scan_prompt(self._scanners, user_prompt) @@ -136,9 +143,9 @@ def scan_llm_input(self, input_doc: LLMParamsDoc) -> LLMParamsDoc: input_doc.messages.system = sanitized_system_prompt input_doc.messages.user = sanitized_user_prompt - if input_doc.output_guardrail_params is not None and 'Anonymize' in user_results_valid: + if input_doc.output_guardrail_params is not None and "Anonymize" in user_results_valid: input_doc.output_guardrail_params.anonymize_vault = self._get_anonymize_vault() - elif input_doc.output_guardrail_params is None and 'Anonymize' in user_results_valid: + elif input_doc.output_guardrail_params is None and "Anonymize" in user_results_valid: logger.warning("No output guardrails params, could not append the vault for Anonymize scanner.") return input_doc else: diff --git a/comps/guardrails/src/guardrails/utils/llm_guard_input_scanners.py b/comps/guardrails/src/guardrails/utils/llm_guard_input_scanners.py index 60092710ea..3fc9e2f0b2 100644 --- a/comps/guardrails/src/guardrails/utils/llm_guard_input_scanners.py +++ b/comps/guardrails/src/guardrails/utils/llm_guard_input_scanners.py @@ -1,7 +1,6 @@ # ruff: noqa: F401 # Copyright (C) 2024-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from llm_guard.vault import Vault from llm_guard.input_scanners import ( Anonymize, BanCode, @@ -15,69 +14,48 @@ Secrets, Sentiment, TokenLimit, - Toxicity - ) - -# import models definition -from llm_guard.input_scanners.ban_code import ( - MODEL_SM as BANCODE_MODEL_SM, - MODEL_TINY as BANCODE_MODEL_TINY -) - -from llm_guard.input_scanners.ban_competitors import ( - MODEL_V1 as BANCOMPETITORS_MODEL_V1 -) - -from llm_guard.input_scanners.ban_topics import ( - MODEL_DEBERTA_LARGE_V2 as BANTOPICS_MODEL_DEBERTA_LARGE_V2, - MODEL_DEBERTA_BASE_V2 as BANTOPICS_MODEL_DEBERTA_BASE_V2, - MODEL_BGE_M3_V2 as BANTOPICS_MODEL_BGE_M3_V2, - MODEL_ROBERTA_LARGE_C_V2 as BANTOPICS_MODEL_ROBERTA_LARGE_C_V2, - MODEL_ROBERTA_BASE_C_V2 as BANTOPICS_MODEL_ROBERTA_BASE_C_V2 -) - -from llm_guard.input_scanners.code import ( - DEFAULT_MODEL as CODE_DEFAULT_MODEL + Toxicity, ) -from llm_guard.input_scanners.gibberish import ( - DEFAULT_MODEL as GIBBERISH_DEFAULT_MODEL, -) - -from llm_guard.input_scanners.language import ( - DEFAULT_MODEL as LANGUAGE_DEFAULT_MODEL, -) - -from llm_guard.input_scanners.prompt_injection import ( - V1_MODEL as PROMPTINJECTION_V1_MODEL, - V2_MODEL as PROMPTINJECTION_V2_MODEL, - V2_SMALL_MODEL as PROMPTINJECTION_V2_SMALL_MODEL, -) - -from llm_guard.input_scanners.toxicity import ( - DEFAULT_MODEL as TOXICITY_DEFAULT_MODEL -) +# import models definition +from llm_guard.input_scanners.ban_code import MODEL_SM as BANCODE_MODEL_SM +from llm_guard.input_scanners.ban_code import MODEL_TINY as BANCODE_MODEL_TINY +from llm_guard.input_scanners.ban_competitors import MODEL_V1 as BANCOMPETITORS_MODEL_V1 +from llm_guard.input_scanners.ban_topics import MODEL_BGE_M3_V2 as BANTOPICS_MODEL_BGE_M3_V2 +from llm_guard.input_scanners.ban_topics import MODEL_DEBERTA_BASE_V2 as BANTOPICS_MODEL_DEBERTA_BASE_V2 +from llm_guard.input_scanners.ban_topics import MODEL_DEBERTA_LARGE_V2 as BANTOPICS_MODEL_DEBERTA_LARGE_V2 +from llm_guard.input_scanners.ban_topics import MODEL_ROBERTA_BASE_C_V2 as BANTOPICS_MODEL_ROBERTA_BASE_C_V2 +from llm_guard.input_scanners.ban_topics import MODEL_ROBERTA_LARGE_C_V2 as BANTOPICS_MODEL_ROBERTA_LARGE_C_V2 +from llm_guard.input_scanners.code import DEFAULT_MODEL as CODE_DEFAULT_MODEL +from llm_guard.input_scanners.gibberish import DEFAULT_MODEL as GIBBERISH_DEFAULT_MODEL +from llm_guard.input_scanners.language import DEFAULT_MODEL as LANGUAGE_DEFAULT_MODEL +from llm_guard.input_scanners.prompt_injection import V1_MODEL as PROMPTINJECTION_V1_MODEL +from llm_guard.input_scanners.prompt_injection import V2_MODEL as PROMPTINJECTION_V2_MODEL +from llm_guard.input_scanners.prompt_injection import V2_SMALL_MODEL as PROMPTINJECTION_V2_SMALL_MODEL +from llm_guard.input_scanners.toxicity import DEFAULT_MODEL as TOXICITY_DEFAULT_MODEL +from llm_guard.vault import Vault ENABLED_SCANNERS = [ - 'anonymize', - 'ban_code', - 'ban_competitors', - 'ban_substrings', - 'ban_topics', - 'code', - 'gibberish', - 'invisible_text', - 'language', - 'prompt_injection', - 'regex', - 'secrets', - 'sentiment', - 'token_limit', - 'toxicity' + "anonymize", + "ban_code", + "ban_competitors", + "ban_substrings", + "ban_topics", + "code", + "gibberish", + "invisible_text", + "language", + "prompt_injection", + "regex", + "secrets", + "sentiment", + "token_limit", + "toxicity", ] -from comps.guardrails.utils.scanners import OPEABanSubstrings, OPEARegexScanner from comps import get_opea_logger, sanitize_env +from comps.guardrails.utils.scanners import OPEABanSubstrings, OPEARegexScanner + logger = get_opea_logger("opea_llm_guard_input_guardrail_microservice") @@ -99,14 +77,13 @@ def __init__(self, config_dict): **self._get_secrets_config_from_env(config_dict), **self._get_sentiment_config_from_env(config_dict), **self._get_token_limit_config_from_env(config_dict), - **self._get_toxicity_config_from_env(config_dict) + **self._get_toxicity_config_from_env(config_dict), } -#### METHODS FOR VALIDATING CONFIGS + #### METHODS FOR VALIDATING CONFIGS def _validate_value(self, value): - """ - Validate and convert the input value. + """Validate and convert the input value. Args: value (str): The value to be validated and converted. @@ -125,8 +102,7 @@ def _validate_value(self, value): return value def _get_anonymize_config_from_env(self, config_dict): - """ - Get the Anonymize scanner configuration from the environment. + """Get the Anonymize scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -137,13 +113,13 @@ def _get_anonymize_config_from_env(self, config_dict): return { "anonymize": { k.replace("ANONYMIZE_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("ANONYMIZE_") + for k, v in config_dict.items() + if k.startswith("ANONYMIZE_") } } def _get_ban_code_config_from_env(self, config_dict): - """ - Get the BanCode scanner configuration from the environment. + """Get the BanCode scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -154,13 +130,13 @@ def _get_ban_code_config_from_env(self, config_dict): return { "ban_code": { k.replace("BAN_CODE_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("BAN_CODE_") + for k, v in config_dict.items() + if k.startswith("BAN_CODE_") } } def _get_ban_competitors_config_from_env(self, config_dict): - """ - Get the BanCompetitors scanner configuration from the environment. + """Get the BanCompetitors scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -171,13 +147,13 @@ def _get_ban_competitors_config_from_env(self, config_dict): return { "ban_competitors": { k.replace("BAN_COMPETITORS_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("BAN_COMPETITORS_") + for k, v in config_dict.items() + if k.startswith("BAN_COMPETITORS_") } } def _get_ban_substrings_config_from_env(self, config_dict): - """ - Get the BanSubstrings scanner configuration from the environment. + """Get the BanSubstrings scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -188,13 +164,13 @@ def _get_ban_substrings_config_from_env(self, config_dict): return { "ban_substrings": { k.replace("BAN_SUBSTRINGS_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("BAN_SUBSTRINGS_") + for k, v in config_dict.items() + if k.startswith("BAN_SUBSTRINGS_") } } def _get_ban_topics_config_from_env(self, config_dict): - """ - Get the BanTopics scanner configuration from the environment. + """Get the BanTopics scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -205,13 +181,13 @@ def _get_ban_topics_config_from_env(self, config_dict): return { "ban_topics": { k.replace("BAN_TOPICS_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("BAN_TOPICS_") + for k, v in config_dict.items() + if k.startswith("BAN_TOPICS_") } } def _get_code_config_from_env(self, config_dict): - """ - Get the Code scanner configuration from the environment. + """Get the Code scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -222,13 +198,13 @@ def _get_code_config_from_env(self, config_dict): return { "code": { k.replace("CODE_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("CODE_") + for k, v in config_dict.items() + if k.startswith("CODE_") } } def _get_gibberish_config_from_env(self, config_dict): - """ - Get the Gibberish scanner configuration from the environment. + """Get the Gibberish scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -239,12 +215,13 @@ def _get_gibberish_config_from_env(self, config_dict): return { "gibberish": { k.replace("GIBBERISH_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("GIBBERISH_") + for k, v in config_dict.items() + if k.startswith("GIBBERISH_") } } + def _get_invisible_text_config_from_env(self, config_dict): - """ - Get the InvisibleText scanner configuration from the environment. + """Get the InvisibleText scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -255,13 +232,13 @@ def _get_invisible_text_config_from_env(self, config_dict): return { "invisible_text": { k.replace("INVISIBLE_TEXT_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("INVISIBLE_TEXT_") + for k, v in config_dict.items() + if k.startswith("INVISIBLE_TEXT_") } } def _get_language_config_from_env(self, config_dict): - """ - Get the Language scanner configuration from the environment. + """Get the Language scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -272,13 +249,13 @@ def _get_language_config_from_env(self, config_dict): return { "language": { k.replace("LANGUAGE_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("LANGUAGE_") + for k, v in config_dict.items() + if k.startswith("LANGUAGE_") } } def _get_prompt_injection_config_from_env(self, config_dict): - """ - Get the PromptInjection scanner configuration from the environment. + """Get the PromptInjection scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -289,13 +266,13 @@ def _get_prompt_injection_config_from_env(self, config_dict): return { "prompt_injection": { k.replace("PROMPT_INJECTION_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("PROMPT_INJECTION_") + for k, v in config_dict.items() + if k.startswith("PROMPT_INJECTION_") } } def _get_regex_config_from_env(self, config_dict): - """ - Get the Regex scanner configuration from the environment. + """Get the Regex scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -306,13 +283,13 @@ def _get_regex_config_from_env(self, config_dict): return { "regex": { k.replace("REGEX_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("REGEX_") + for k, v in config_dict.items() + if k.startswith("REGEX_") } } def _get_secrets_config_from_env(self, config_dict): - """ - Get the Secrets scanner configuration from the environment. + """Get the Secrets scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -323,13 +300,13 @@ def _get_secrets_config_from_env(self, config_dict): return { "secrets": { k.replace("SECRETS_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("SECRETS_") + for k, v in config_dict.items() + if k.startswith("SECRETS_") } } def _get_sentiment_config_from_env(self, config_dict): - """ - Get the Secrets scanner configuration from the environment. + """Get the Secrets scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -340,13 +317,13 @@ def _get_sentiment_config_from_env(self, config_dict): return { "sentiment": { k.replace("SENTIMENT_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("SENTIMENT_") + for k, v in config_dict.items() + if k.startswith("SENTIMENT_") } } def _get_token_limit_config_from_env(self, config_dict): - """ - Get the TokenLimit scanner configuration from the environment. + """Get the TokenLimit scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -357,13 +334,13 @@ def _get_token_limit_config_from_env(self, config_dict): return { "token_limit": { k.replace("TOKEN_LIMIT_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("TOKEN_LIMIT_") + for k, v in config_dict.items() + if k.startswith("TOKEN_LIMIT_") } } def _get_toxicity_config_from_env(self, config_dict): - """ - Get the Toxicity scanner configuration from the environment. + """Get the Toxicity scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -374,30 +351,35 @@ def _get_toxicity_config_from_env(self, config_dict): return { "toxicity": { k.replace("TOXICITY_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("TOXICITY_") - } + for k, v in config_dict.items() + if k.startswith("TOXICITY_") + } } -#### METHODS FOR CREATING SCANNERS + #### METHODS FOR CREATING SCANNERS def _create_anonymize_scanner(self, scanner_config=None): if scanner_config is None: - logger.warning("_create_anonymize_scanner was invoked without scanner_config. Recreating with saved config to clear the Vault.") + logger.warning( + "_create_anonymize_scanner was invoked without scanner_config. Recreating with saved config to clear the Vault." + ) if hasattr(self, "_anonymize_params") and self._anonymize_params is not None: scanner_config = self._anonymize_params else: - raise ValueError("_create_anonymize_scanner was invoked without scanner_config but no self._anonymize_params were saved. Such action is not allowed.") + raise ValueError( + "_create_anonymize_scanner was invoked without scanner_config but no self._anonymize_params were saved. Such action is not allowed." + ) vault = Vault() - anonymize_params = {'vault': vault, 'use_onnx': scanner_config.get('use_onnx', False)} - hidden_names = scanner_config.get('hidden_names', None) - allowed_names = scanner_config.get('allowed_names', None) - entity_types = scanner_config.get('entity_types', None) - preamble = scanner_config.get('preamble', None) - regex_patterns = scanner_config.get('regex_patterns', None) - use_faker = scanner_config.get('use_faker', None) - recognizer_conf = scanner_config.get('recognizer_conf', None) - threshold = scanner_config.get('threshold', None) - language = scanner_config.get('language', None) + anonymize_params = {"vault": vault, "use_onnx": scanner_config.get("use_onnx", False)} + hidden_names = scanner_config.get("hidden_names", None) + allowed_names = scanner_config.get("allowed_names", None) + entity_types = scanner_config.get("entity_types", None) + preamble = scanner_config.get("preamble", None) + regex_patterns = scanner_config.get("regex_patterns", None) + use_faker = scanner_config.get("use_faker", None) + recognizer_conf = scanner_config.get("recognizer_conf", None) + threshold = scanner_config.get("threshold", None) + language = scanner_config.get("language", None) if isinstance(hidden_names, str): hidden_names = sanitize_env(hidden_names) @@ -413,106 +395,112 @@ def _create_anonymize_scanner(self, scanner_config=None): if hidden_names is not None: if isinstance(hidden_names, str): - artifacts = set([',', '', '.']) - anonymize_params['hidden_names'] = list(set(hidden_names.split(',')) - artifacts) + artifacts = set([",", "", "."]) + anonymize_params["hidden_names"] = list(set(hidden_names.split(",")) - artifacts) elif isinstance(hidden_names, list): - anonymize_params['hidden_names'] = hidden_names + anonymize_params["hidden_names"] = hidden_names else: logger.error("Provided type is not valid for Anonymize scanner") raise ValueError("Provided type is not valid for Anonymize scanner") if allowed_names is not None: if isinstance(allowed_names, str): - artifacts = set([',', '', '.']) - anonymize_params['allowed_names'] = list(set(allowed_names.split(',')) - artifacts) + artifacts = set([",", "", "."]) + anonymize_params["allowed_names"] = list(set(allowed_names.split(",")) - artifacts) elif isinstance(hidden_names, list): - anonymize_params['allowed_names'] = allowed_names + anonymize_params["allowed_names"] = allowed_names else: logger.error("Provided type is not valid for Anonymize scanner") raise ValueError("Provided type is not valid for Anonymize scanner") if entity_types is not None: if isinstance(entity_types, str): - artifacts = set([',', '', '.']) - anonymize_params['entity_types'] = list(set(entity_types.split(',')) - artifacts) + artifacts = set([",", "", "."]) + anonymize_params["entity_types"] = list(set(entity_types.split(",")) - artifacts) elif isinstance(hidden_names, list): - anonymize_params['entity_types'] = entity_types + anonymize_params["entity_types"] = entity_types else: logger.error("Provided type is not valid for Anonymize scanner") raise ValueError("Provided type is not valid for Anonymize scanner") if preamble is not None: - anonymize_params['preamble'] = preamble + anonymize_params["preamble"] = preamble if regex_patterns is not None: if isinstance(regex_patterns, str): - artifacts = set([',', '', '.']) - anonymize_params['regex_patterns'] = list(set(regex_patterns.split(',')) - artifacts) + artifacts = set([",", "", "."]) + anonymize_params["regex_patterns"] = list(set(regex_patterns.split(",")) - artifacts) elif isinstance(hidden_names, list): - anonymize_params['regex_patterns'] = regex_patterns + anonymize_params["regex_patterns"] = regex_patterns else: logger.error("Provided type is not valid for Anonymize scanner") raise ValueError("Provided type is not valid for Anonymize scanner") if use_faker is not None: - anonymize_params['use_faker'] = use_faker + anonymize_params["use_faker"] = use_faker if recognizer_conf is not None: - anonymize_params['recognizer_conf'] = recognizer_conf + anonymize_params["recognizer_conf"] = recognizer_conf if threshold is not None: - anonymize_params['threshold'] = threshold + anonymize_params["threshold"] = threshold if language is not None: - anonymize_params['language'] = language + anonymize_params["language"] = language logger.info(f"Creating Anonymize scanner with params: {anonymize_params}") - self._anonymize_params = {key: value for key, value in anonymize_params.items() if key != 'vault'} + self._anonymize_params = {key: value for key, value in anonymize_params.items() if key != "vault"} return Anonymize(**anonymize_params) def _create_ban_code_scanner(self, scanner_config): - enabled_models = {'MODEL_SM': BANCODE_MODEL_SM, 'MODEL_TINY': BANCODE_MODEL_TINY} - bancode_params = {'use_onnx': scanner_config.get('use_onnx', False)} # by default we dont't want to use onnx + enabled_models = {"MODEL_SM": BANCODE_MODEL_SM, "MODEL_TINY": BANCODE_MODEL_TINY} + bancode_params = {"use_onnx": scanner_config.get("use_onnx", False)} # by default we don't want to use onnx - model_name = scanner_config.get('model', None) - threshold = scanner_config.get('threshold', None) + model_name = scanner_config.get("model", None) + threshold = scanner_config.get("threshold", None) if model_name is not None: if model_name in enabled_models: logger.info(f"Using selected model for BanCode scanner: {model_name}") - bancode_params['model'] = enabled_models[model_name] # Model class from LLM Guard + bancode_params["model"] = enabled_models[model_name] # Model class from LLM Guard else: err_msg = f"Model name is not valid for BanCode scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" logger.error(err_msg) raise ValueError(err_msg) if threshold is not None: - bancode_params['threshold'] = threshold # float + bancode_params["threshold"] = threshold # float logger.info(f"Creating BanCode scanner with params: {bancode_params}") return BanCode(**bancode_params) def _create_ban_competitors_scanner(self, scanner_config): - enabled_models = {'MODEL_V1': BANCOMPETITORS_MODEL_V1} - ban_competitors_params = {'use_onnx': scanner_config.get('use_onnx', False)} # by default we don't want to use onnx + enabled_models = {"MODEL_V1": BANCOMPETITORS_MODEL_V1} + ban_competitors_params = { + "use_onnx": scanner_config.get("use_onnx", False) + } # by default we don't want to use onnx - competitors = scanner_config.get('competitors', None) - threshold = scanner_config.get('threshold', None) - redact = scanner_config.get('redact', None) - model_name = scanner_config.get('model', None) + competitors = scanner_config.get("competitors", None) + threshold = scanner_config.get("threshold", None) + redact = scanner_config.get("redact", None) + model_name = scanner_config.get("model", None) if isinstance(competitors, str): competitors = sanitize_env(competitors) if competitors: if isinstance(competitors, str): - artifacts = set([',', '', '.']) - ban_competitors_params['competitors'] = list(set(competitors.split(',')) - artifacts) # list + artifacts = set([",", "", "."]) + ban_competitors_params["competitors"] = list(set(competitors.split(",")) - artifacts) # list elif isinstance(competitors, list): - ban_competitors_params['competitors'] = competitors + ban_competitors_params["competitors"] = competitors else: logger.error("Provided type is not valid for BanCompetitors scanner") raise ValueError("Provided type is not valid for BanCompetitors scanner") else: - logger.error("Competitors list is required for BanCompetitors scanner. Please provide a list of competitors.") - raise TypeError("Competitors list is required for BanCompetitors scanner. Please provide a list of competitors.") + logger.error( + "Competitors list is required for BanCompetitors scanner. Please provide a list of competitors." + ) + raise TypeError( + "Competitors list is required for BanCompetitors scanner. Please provide a list of competitors." + ) if threshold is not None: - ban_competitors_params['threshold'] = threshold # float + ban_competitors_params["threshold"] = threshold # float if redact is not None: - ban_competitors_params['redact'] = redact + ban_competitors_params["redact"] = redact if model_name is not None: if model_name in enabled_models: logger.info(f"Using selected model for BanCompetitors scanner: {model_name}") - ban_competitors_params['model'] = enabled_models[model_name] + ban_competitors_params["model"] = enabled_models[model_name] else: err_msg = f"Model name is not valid for BanCompetitors scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" logger.error(err_msg) @@ -521,24 +509,24 @@ def _create_ban_competitors_scanner(self, scanner_config): return BanCompetitors(**ban_competitors_params) def _create_ban_substrings_scanner(self, scanner_config): - available_match_types = ['str', 'word'] + available_match_types = ["str", "word"] ban_substrings_params = {} - substrings = scanner_config.get('substrings', None) - match_type = scanner_config.get('match_type', None) - case_sensitive = scanner_config.get('case_sensitive', None) - redact = scanner_config.get('redact', None) - contains_all = scanner_config.get('contains_all', None) + substrings = scanner_config.get("substrings", None) + match_type = scanner_config.get("match_type", None) + case_sensitive = scanner_config.get("case_sensitive", None) + redact = scanner_config.get("redact", None) + contains_all = scanner_config.get("contains_all", None) if isinstance(substrings, str): substrings = sanitize_env(substrings) if substrings: if isinstance(substrings, str): - artifacts = set([',', '', '.']) - ban_substrings_params['substrings'] = list(set(substrings.split(',')) - artifacts)# list + artifacts = set([",", "", "."]) + ban_substrings_params["substrings"] = list(set(substrings.split(",")) - artifacts) # list elif substrings and isinstance(substrings, list): - ban_substrings_params['substrings'] = substrings + ban_substrings_params["substrings"] = substrings else: logger.error("Provided type is not valid for BanSubstrings scanner") raise ValueError("Provided type is not valid for BanSubstrings scanner") @@ -546,39 +534,39 @@ def _create_ban_substrings_scanner(self, scanner_config): logger.error("Substrings list is required for BanSubstrings scanner") raise TypeError("Substrings list is required for BanSubstrings scanner") if match_type is not None and match_type in available_match_types: - ban_substrings_params['match_type'] = match_type # MatchType + ban_substrings_params["match_type"] = match_type # MatchType if case_sensitive is not None: - ban_substrings_params['case_sensitive'] = case_sensitive # bool + ban_substrings_params["case_sensitive"] = case_sensitive # bool if redact is not None: - ban_substrings_params['redact'] = redact # bool + ban_substrings_params["redact"] = redact # bool if contains_all is not None: - ban_substrings_params['contains_all'] = contains_all # bool + ban_substrings_params["contains_all"] = contains_all # bool logger.info(f"Creating BanSubstrings scanner with params: {ban_substrings_params}") return OPEABanSubstrings(**ban_substrings_params) def _create_ban_topics_scanner(self, scanner_config): enabled_models = { - 'MODEL_DEBERTA_LARGE_V2': BANTOPICS_MODEL_DEBERTA_LARGE_V2, - 'MODEL_DEBERTA_BASE_V2': BANTOPICS_MODEL_DEBERTA_BASE_V2, - 'MODEL_BGE_M3_V2': BANTOPICS_MODEL_BGE_M3_V2, - 'MODEL_ROBERTA_LARGE_C_V2': BANTOPICS_MODEL_ROBERTA_LARGE_C_V2, - 'MODEL_ROBERTA_BASE_C_V2': BANTOPICS_MODEL_ROBERTA_BASE_C_V2 + "MODEL_DEBERTA_LARGE_V2": BANTOPICS_MODEL_DEBERTA_LARGE_V2, + "MODEL_DEBERTA_BASE_V2": BANTOPICS_MODEL_DEBERTA_BASE_V2, + "MODEL_BGE_M3_V2": BANTOPICS_MODEL_BGE_M3_V2, + "MODEL_ROBERTA_LARGE_C_V2": BANTOPICS_MODEL_ROBERTA_LARGE_C_V2, + "MODEL_ROBERTA_BASE_C_V2": BANTOPICS_MODEL_ROBERTA_BASE_C_V2, } - ban_topics_params = {'use_onnx': scanner_config.get('use_onnx', False)} + ban_topics_params = {"use_onnx": scanner_config.get("use_onnx", False)} - topics = scanner_config.get('topics', None) - threshold = scanner_config.get('threshold', None) - model_name = scanner_config.get('model', None) + topics = scanner_config.get("topics", None) + threshold = scanner_config.get("threshold", None) + model_name = scanner_config.get("model", None) if isinstance(topics, str): topics = sanitize_env(topics) if topics: if isinstance(topics, str): - artifacts = set([',', '', '.']) - ban_topics_params['topics'] = list(set(topics.split(',')) - artifacts) + artifacts = set([",", "", "."]) + ban_topics_params["topics"] = list(set(topics.split(",")) - artifacts) elif isinstance(topics, list): - ban_topics_params['topics'] = topics + ban_topics_params["topics"] = topics else: logger.error("Provided type is not valid for BanTopics scanner") raise ValueError("Provided type is not valid for BanTopics scanner") @@ -586,11 +574,11 @@ def _create_ban_topics_scanner(self, scanner_config): logger.error("Topics list is required for BanTopics scanner") raise TypeError("Topics list is required for BanTopics scanner") if threshold is not None: - ban_topics_params['threshold'] = threshold + ban_topics_params["threshold"] = threshold if model_name is not None: if model_name in enabled_models: logger.info(f"Using selected model for BanTopics scanner: {model_name}") - ban_topics_params['model'] = enabled_models[model_name] + ban_topics_params["model"] = enabled_models[model_name] else: err_msg = f"Model name is not valid for BanTopics scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" logger.error(err_msg) @@ -599,23 +587,23 @@ def _create_ban_topics_scanner(self, scanner_config): return BanTopics(**ban_topics_params) def _create_code_scanner(self, scanner_config): - enabled_models = {'DEFAULT_MODEL': CODE_DEFAULT_MODEL} - code_params = {'use_onnx': scanner_config.get('use_onnx', False)} + enabled_models = {"DEFAULT_MODEL": CODE_DEFAULT_MODEL} + code_params = {"use_onnx": scanner_config.get("use_onnx", False)} - languages = scanner_config.get('languages', None) - model_name = scanner_config.get('model', None) - is_blocked = scanner_config.get('is_blocked', None) - threshold = scanner_config.get('threshold', None) + languages = scanner_config.get("languages", None) + model_name = scanner_config.get("model", None) + is_blocked = scanner_config.get("is_blocked", None) + threshold = scanner_config.get("threshold", None) if isinstance(languages, str): languages = sanitize_env(languages) if languages: if isinstance(languages, str): - artifacts = set([',', '', '.']) - code_params['languages'] = list(set(languages.split(',')) - artifacts) + artifacts = set([",", "", "."]) + code_params["languages"] = list(set(languages.split(",")) - artifacts) elif isinstance(languages, list): - code_params['languages'] = languages + code_params["languages"] = languages else: logger.error("Provided type is not valid for Code scanner") raise ValueError("Provided type is not valid for Code scanner") @@ -625,43 +613,44 @@ def _create_code_scanner(self, scanner_config): if model_name is not None: if model_name in enabled_models: logger.info(f"Using selected model for Code scanner: {model_name}") - code_params['model'] = enabled_models[model_name] + code_params["model"] = enabled_models[model_name] else: err_msg = f"Model name is not valid for Code scanner. Please provide a valid model name. Provided model: {model_name}" logger.error(err_msg) raise ValueError(err_msg) if is_blocked is not None: - code_params['is_blocked'] = is_blocked + code_params["is_blocked"] = is_blocked if threshold is not None: - code_params['threshold'] = threshold + code_params["threshold"] = threshold logger.info(f"Creating Code scanner with params: {code_params}") return Code(**code_params) def _create_gibberish_scanner(self, scanner_config): - enabled_models = {'DEFAULT_MODEL': GIBBERISH_DEFAULT_MODEL} - enabled_match_types = ['sentence', 'full'] - gibberish_params = {'use_onnx': scanner_config.get('use_onnx', False)} + enabled_models = {"DEFAULT_MODEL": GIBBERISH_DEFAULT_MODEL} + enabled_match_types = ["sentence", "full"] + gibberish_params = {"use_onnx": scanner_config.get("use_onnx", False)} - model_name = scanner_config.get('model', None) - threshold = scanner_config.get('threshold', None) - match_type = scanner_config.get('match_type', None) + model_name = scanner_config.get("model", None) + threshold = scanner_config.get("threshold", None) + match_type = scanner_config.get("match_type", None) if match_type == "sentence": import nltk - nltk.download('punkt_tab') + + nltk.download("punkt_tab") if threshold is not None: - gibberish_params['threshold'] = threshold + gibberish_params["threshold"] = threshold if model_name is not None: if model_name in enabled_models: logger.info(f"Using selected model for Gibberish scanner: {model_name}") - gibberish_params['model'] = enabled_models[model_name] + gibberish_params["model"] = enabled_models[model_name] else: err_msg = f"Model name is not valid for Gibberish scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" logger.error(err_msg) raise ValueError(err_msg) if match_type is not None and match_type in enabled_match_types: - gibberish_params['match_type'] = match_type + gibberish_params["match_type"] = match_type logger.info(f"Creating Gibberish scanner with params: {gibberish_params}") return Gibberish(**gibberish_params) @@ -670,24 +659,24 @@ def _create_invisible_text_scanner(self): return InvisibleText() def _create_language_scanner(self, scanner_config): - enabled_models = {'DEFAULT_MODEL': LANGUAGE_DEFAULT_MODEL} - enabled_match_types = ['sentence', 'full'] - language_params = {'use_onnx': scanner_config.get('use_onnx', False)} + enabled_models = {"DEFAULT_MODEL": LANGUAGE_DEFAULT_MODEL} + enabled_match_types = ["sentence", "full"] + language_params = {"use_onnx": scanner_config.get("use_onnx", False)} - valid_languages = scanner_config.get('valid_languages', None) - model_name = scanner_config.get('model', None) - threshold = scanner_config.get('threshold', None) - match_type = scanner_config.get('match_type', None) + valid_languages = scanner_config.get("valid_languages", None) + model_name = scanner_config.get("model", None) + threshold = scanner_config.get("threshold", None) + match_type = scanner_config.get("match_type", None) if isinstance(valid_languages, str): valid_languages = sanitize_env(valid_languages) if valid_languages: if isinstance(valid_languages, str): - artifacts = set([',', '', '.']) - language_params['valid_languages'] = list(set(valid_languages.split(',')) - artifacts) + artifacts = set([",", "", "."]) + language_params["valid_languages"] = list(set(valid_languages.split(",")) - artifacts) elif isinstance(valid_languages, list): - language_params['valid_languages'] = valid_languages + language_params["valid_languages"] = valid_languages else: logger.error("Provided type is not valid for Language scanner") raise ValueError("Provided type is not valid for Language scanner") @@ -697,68 +686,69 @@ def _create_language_scanner(self, scanner_config): if model_name is not None: if model_name in enabled_models: logger.info(f"Using selected model for Language scanner: {model_name}") - language_params['model'] = enabled_models[model_name] + language_params["model"] = enabled_models[model_name] else: err_msg = f"Model name is not valid for Language scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" logger.error(err_msg) raise ValueError(err_msg) if threshold is not None: - language_params['threshold'] = threshold + language_params["threshold"] = threshold if match_type is not None and match_type in enabled_match_types: - language_params['match_type'] = match_type + language_params["match_type"] = match_type logger.info(f"Creating Language scanner with params: {language_params}") return Language(**language_params) def _create_prompt_injection_scanner(self, scanner_config): enabled_models = { - 'V1_MODEL': PROMPTINJECTION_V1_MODEL, - 'V2_MODEL': PROMPTINJECTION_V2_MODEL, - 'V2_SMALL_MODEL': PROMPTINJECTION_V2_SMALL_MODEL + "V1_MODEL": PROMPTINJECTION_V1_MODEL, + "V2_MODEL": PROMPTINJECTION_V2_MODEL, + "V2_SMALL_MODEL": PROMPTINJECTION_V2_SMALL_MODEL, } - enabled_match_types = ['sentence', 'full', "truncate_token_head_tail", "truncate_head_tail", "chunks"] - prompt_injection_params = {'use_onnx': scanner_config.get('use_onnx', False)} + enabled_match_types = ["sentence", "full", "truncate_token_head_tail", "truncate_head_tail", "chunks"] + prompt_injection_params = {"use_onnx": scanner_config.get("use_onnx", False)} - model_name = scanner_config.get('model', None) - threshold = scanner_config.get('threshold', None) - match_type = scanner_config.get('match_type', None) + model_name = scanner_config.get("model", None) + threshold = scanner_config.get("threshold", None) + match_type = scanner_config.get("match_type", None) if match_type == "sentence": import nltk - nltk.download('punkt_tab') + + nltk.download("punkt_tab") if model_name is not None: if model_name in enabled_models: logger.info(f"Using selected model for PromptInjection scanner: {model_name}") - prompt_injection_params['model'] = enabled_models[model_name] + prompt_injection_params["model"] = enabled_models[model_name] else: err_msg = f"Model name is not valid for PromptInjection scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" logger.error(err_msg) raise ValueError(err_msg) if threshold is not None: - prompt_injection_params['threshold'] = threshold + prompt_injection_params["threshold"] = threshold if match_type is not None and match_type in enabled_match_types: - prompt_injection_params['match_type'] = match_type + prompt_injection_params["match_type"] = match_type logger.info(f"Creating PromptInjection scanner with params: {prompt_injection_params}") return PromptInjection(**prompt_injection_params) def _create_regex_scanner(self, scanner_config): - enabled_match_types = ['search', 'fullmatch'] + enabled_match_types = ["search", "fullmatch"] regex_params = {} - patterns = scanner_config.get('patterns', None) - is_blocked = scanner_config.get('is_blocked', None) - match_type = scanner_config.get('match_type', None) - redact = scanner_config.get('redact', None) + patterns = scanner_config.get("patterns", None) + is_blocked = scanner_config.get("is_blocked", None) + match_type = scanner_config.get("match_type", None) + redact = scanner_config.get("redact", None) if isinstance(patterns, str): patterns = sanitize_env(patterns) if patterns: if isinstance(patterns, str): - artifacts = set([',', '', '.']) - regex_params['patterns'] = list(set(patterns.split(',')) - artifacts) + artifacts = set([",", "", "."]) + regex_params["patterns"] = list(set(patterns.split(",")) - artifacts) elif isinstance(patterns, list): - regex_params['patterns'] = patterns + regex_params["patterns"] = patterns else: logger.error("Provided type is not valid for Regex scanner") raise ValueError("Provided type is not valid for Regex scanner") @@ -766,23 +756,23 @@ def _create_regex_scanner(self, scanner_config): logger.error("Patterns list is required for Regex scanner") raise TypeError("Patterns list is required for Regex scanner") if is_blocked is not None: - regex_params['is_blocked'] = is_blocked + regex_params["is_blocked"] = is_blocked if match_type is not None and match_type in enabled_match_types: - regex_params['match_type'] = match_type + regex_params["match_type"] = match_type if redact is not None: - regex_params['redact'] = redact + regex_params["redact"] = redact logger.info(f"Creating Regex scanner with params: {regex_params}") return OPEARegexScanner(**regex_params) def _create_secrets_scanner(self, scanner_config): - enabled_redact_types = ['partial', 'all', 'hash'] + enabled_redact_types = ["partial", "all", "hash"] secrets_params = {} - redact = scanner_config.get('redact', None) + redact = scanner_config.get("redact", None) if redact is not None and redact in enabled_redact_types: - secrets_params['redact'] = redact + secrets_params["redact"] = redact logger.info(f"Creating Secrets scanner with params: {secrets_params}") return Secrets(**secrets_params) @@ -791,58 +781,58 @@ def _create_sentiment_scanner(self, scanner_config): enabled_lexicons = ["vader_lexicon"] sentiment_params = {} - threshold = scanner_config.get('threshold', None) - lexicon = scanner_config.get('lexicon', None) + threshold = scanner_config.get("threshold", None) + lexicon = scanner_config.get("lexicon", None) if threshold is not None: - sentiment_params['threshold'] = threshold + sentiment_params["threshold"] = threshold if lexicon is not None and lexicon in enabled_lexicons: - sentiment_params['lexicon'] = lexicon + sentiment_params["lexicon"] = lexicon logger.info(f"Creating Sentiment scanner with params: {sentiment_params}") return Sentiment(**sentiment_params) def _create_token_limit_scanner(self, scanner_config): - enabled_encodings = ['cl100k_base'] # TODO: test more encoding from tiktoken + enabled_encodings = ["cl100k_base"] # TODO: test more encoding from tiktoken token_limit_params = {} - limit = int(scanner_config.get('limit', None)) - encoding_name = scanner_config.get('encoding', None) + limit = int(scanner_config.get("limit", None)) + encoding_name = scanner_config.get("encoding", None) if limit is not None: - token_limit_params['limit'] = limit + token_limit_params["limit"] = limit if encoding_name is not None and encoding_name in enabled_encodings: - token_limit_params['encoding_name'] = encoding_name + token_limit_params["encoding_name"] = encoding_name logger.info(f"Creating TokenLimit scanner with params: {token_limit_params}") return TokenLimit(**token_limit_params) def _create_toxicity_scanner(self, scanner_config): - enabled_models = {'DEFAULT_MODEL': TOXICITY_DEFAULT_MODEL} - enabled_match_types = ['sentence', 'full'] - toxicity_params = {'use_onnx': scanner_config.get('use_onnx', False)} + enabled_models = {"DEFAULT_MODEL": TOXICITY_DEFAULT_MODEL} + enabled_match_types = ["sentence", "full"] + toxicity_params = {"use_onnx": scanner_config.get("use_onnx", False)} - model_name = scanner_config.get('model', None) - threshold = scanner_config.get('threshold', None) - match_type = scanner_config.get('match_type', None) + model_name = scanner_config.get("model", None) + threshold = scanner_config.get("threshold", None) + match_type = scanner_config.get("match_type", None) if match_type == "sentence": import nltk - nltk.download('punkt_tab') + nltk.download("punkt_tab") if model_name is not None: if model_name in enabled_models: logger.info(f"Using selected model for Toxicity scanner: {model_name}") - toxicity_params['model'] = enabled_models[model_name] + toxicity_params["model"] = enabled_models[model_name] else: err_msg = f"Model name is not valid for Toxicity scanner. Please provide a valid model name. Provided model: {model_name}" logger.error(err_msg) raise ValueError(err_msg) if threshold is not None: - toxicity_params['threshold'] = threshold + toxicity_params["threshold"] = threshold if match_type is not None and match_type in enabled_match_types: - toxicity_params['match_type'] = match_type + toxicity_params["match_type"] = match_type logger.info(f"Creating Toxicity scanner with params: {toxicity_params}") return Toxicity(**toxicity_params) @@ -851,41 +841,40 @@ def _create_input_scanner(self, scanner_name, scanner_config): if scanner_name not in ENABLED_SCANNERS: logger.error(f"Scanner {scanner_name} is not supported. Enabled scanners are: {ENABLED_SCANNERS}") raise ValueError(f"Scanner {scanner_name} is not supported") - if scanner_name == 'anonymize': + if scanner_name == "anonymize": return self._create_anonymize_scanner(scanner_config) - elif scanner_name == 'ban_code': + elif scanner_name == "ban_code": return self._create_ban_code_scanner(scanner_config) - elif scanner_name == 'ban_competitors': + elif scanner_name == "ban_competitors": return self._create_ban_competitors_scanner(scanner_config) - elif scanner_name == 'ban_substrings': + elif scanner_name == "ban_substrings": return self._create_ban_substrings_scanner(scanner_config) - elif scanner_name == 'ban_topics': + elif scanner_name == "ban_topics": return self._create_ban_topics_scanner(scanner_config) - elif scanner_name == 'code': + elif scanner_name == "code": return self._create_code_scanner(scanner_config) - elif scanner_name == 'gibberish': + elif scanner_name == "gibberish": return self._create_gibberish_scanner(scanner_config) - elif scanner_name == 'invisible_text': + elif scanner_name == "invisible_text": return self._create_invisible_text_scanner() - elif scanner_name == 'language': + elif scanner_name == "language": return self._create_language_scanner(scanner_config) - elif scanner_name == 'prompt_injection': + elif scanner_name == "prompt_injection": return self._create_prompt_injection_scanner(scanner_config) - elif scanner_name == 'regex': + elif scanner_name == "regex": return self._create_regex_scanner(scanner_config) - elif scanner_name == 'secrets': + elif scanner_name == "secrets": return self._create_secrets_scanner(scanner_config) - elif scanner_name == 'sentiment': + elif scanner_name == "sentiment": return self._create_sentiment_scanner(scanner_config) - elif scanner_name == 'token_limit': + elif scanner_name == "token_limit": return self._create_token_limit_scanner(scanner_config) - elif scanner_name == 'toxicity': + elif scanner_name == "toxicity": return self._create_toxicity_scanner(scanner_config) return None def create_enabled_input_scanners(self): - """ - Create and return a list of enabled scanners based on the global configuration. + """Create and return a list of enabled scanners based on the global configuration. Returns: list: A list of enabled scanner instances. @@ -893,7 +882,7 @@ def create_enabled_input_scanners(self): enabled_scanners_names_and_configs = {k: v for k, v in self._input_scanners_config.items() if v.get("enabled")} enabled_scanners_objects = [] - err_msgs = {} # list for all erronous scanners + err_msgs = {} # list for all erroneous scanners only_validation_errors = True for scanner_name, scanner_config in enabled_scanners_names_and_configs.items(): try: @@ -901,19 +890,19 @@ def create_enabled_input_scanners(self): scanner_object = self._create_input_scanner(scanner_name, scanner_config) enabled_scanners_objects.append(scanner_object) except ValueError as e: - err_msg = f"A ValueError occured during creating input scanner {scanner_name}: {e}" + err_msg = f"A ValueError occurred during creating input scanner {scanner_name}: {e}" logger.error(err_msg) err_msgs[scanner_name] = err_msg self._input_scanners_config[scanner_name]["enabled"] = False continue except TypeError as e: - err_msg = f"A TypeError occured during creating input scanner {scanner_name}: {e}" + err_msg = f"A TypeError occurred during creating input scanner {scanner_name}: {e}" logger.error(err_msg) err_msgs[scanner_name] = err_msg self._input_scanners_config[scanner_name]["enabled"] = False continue except Exception as e: - err_msg = f"An unexpected error occured during creating input scanner {scanner_name}: {e}" + err_msg = f"An unexpected error occurred during creating input scanner {scanner_name}: {e}" logger.error(err_msg) err_msgs[scanner_name] = err_msg only_validation_errors = False @@ -922,15 +911,18 @@ def create_enabled_input_scanners(self): if err_msgs: if only_validation_errors: - raise ValueError(f"Some scanners failed to be created due to validation errors. The details: {err_msgs}") + raise ValueError( + f"Some scanners failed to be created due to validation errors. The details: {err_msgs}" + ) else: - raise Exception(f"Some scanners failed to be created due to validation or unexpected errors. The details: {err_msgs}") + raise Exception( + f"Some scanners failed to be created due to validation or unexpected errors. The details: {err_msgs}" + ) return [s for s in enabled_scanners_objects if s is not None] def changed(self, new_scanners_config): - """ - Check if the scanners configuration has changed. + """Check if the scanners configuration has changed. Args: new_scanners_config (dict): The current scanners configuration. @@ -938,15 +930,21 @@ def changed(self, new_scanners_config): Returns: bool: True if the configuration has changed, False otherwise. """ - del new_scanners_config['id'] - newly_enabled_scanners = {k: {in_k: in_v for in_k, in_v in v.items() if in_k != 'id'} for k, v in new_scanners_config.items() if v.get("enabled")} + del new_scanners_config["id"] + newly_enabled_scanners = { + k: {in_k: in_v for in_k, in_v in v.items() if in_k != "id"} + for k, v in new_scanners_config.items() + if v.get("enabled") + } previously_enabled_scanners = {k: v for k, v in self._input_scanners_config.items() if v.get("enabled")} - if newly_enabled_scanners == previously_enabled_scanners: # if the enables scanners are the same we do nothing + if newly_enabled_scanners == previously_enabled_scanners: # if the enables scanners are the same we do nothing logger.info("No changes in list for enabled scanners. Checking configuration changes...") return False else: logger.warning("Sanners configuration has been changed, re-creating scanners") self._input_scanners_config.clear() - stripped_new_scanners_config = {k: {in_k: in_v for in_k, in_v in v.items() if in_k != 'id'} for k, v in new_scanners_config.items()} + stripped_new_scanners_config = { + k: {in_k: in_v for in_k, in_v in v.items() if in_k != "id"} for k, v in new_scanners_config.items() + } self._input_scanners_config.update(stripped_new_scanners_config) return True diff --git a/comps/guardrails/src/guardrails/utils/llm_guard_output_guardrail.py b/comps/guardrails/src/guardrails/utils/llm_guard_output_guardrail.py index 5176002a29..08b79c51f3 100644 --- a/comps/guardrails/src/guardrails/utils/llm_guard_output_guardrail.py +++ b/comps/guardrails/src/guardrails/utils/llm_guard_output_guardrail.py @@ -1,17 +1,17 @@ # Copyright (C) 2024-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from llm_guard import scan_output from fastapi import HTTPException - +from llm_guard import scan_output from utils.llm_guard_output_scanners import OutputScannersConfig -from comps import get_opea_logger, GeneratedDoc + +from comps import GeneratedDoc, get_opea_logger logger = get_opea_logger("opea_llm_guard_output_guardrail_microservice") + class OPEALLMGuardOutputGuardrail: - """ - OPEALLMGuardOutputGuardrail is responsible for scanning and sanitizing LLM output responses + """OPEALLMGuardOutputGuardrail is responsible for scanning and sanitizing LLM output responses using various output scanners provided by LLM Guard. This class initializes the output scanners based on the provided configuration and @@ -29,8 +29,7 @@ class OPEALLMGuardOutputGuardrail: """ def __init__(self, usv_config: list): - """ - Initializes the OPEALLMGuardOutputGuardrail with the provided configuration. + """Initializes the OPEALLMGuardOutputGuardrail with the provided configuration. Args: usv_config (list): The configuration list for initializing the output scanners. @@ -43,15 +42,13 @@ def __init__(self, usv_config: list): self._scanners = self._scanners_config.create_enabled_output_scanners() except Exception as e: logger.exception( - f"An unexpected error occured during initializing \ + f"An unexpected error occurred during initializing \ LLM Guard Output Guardrail scanners: {e}" ) raise - def scan_llm_output(self, output_doc: GeneratedDoc) -> str: - """ - Scans the output from an LLM output document. + """Scans the output from an LLM output document. Args: output_doc (object): The output document containing the response to be scanned. @@ -73,12 +70,15 @@ def scan_llm_output(self, output_doc: GeneratedDoc) -> str: if self._scanners: sanitized_output, results_valid, results_score = scan_output( self._scanners, output_doc.prompt, output_doc.text - ) + ) if False in results_valid.values(): msg = f"LLM Output {output_doc.text} is not valid, scores: {results_score}" logger.error(msg) usr_msg = "I'm sorry, but the model output is not valid according to the policies." - redact_or_truncated = [c.get('redact', False) or c.get('truncate', False) for _, c in self._scanners_config._output_scanners_config.items()] # to see if sanitized output available + redact_or_truncated = [ + c.get("redact", False) or c.get("truncate", False) + for _, c in self._scanners_config._output_scanners_config.items() + ] # to see if sanitized output available if any(redact_or_truncated): usr_msg = f"We sanitized the answer due to the guardrails policies: {sanitized_output}" raise HTTPException(status_code=466, detail=usr_msg) @@ -89,10 +89,10 @@ def scan_llm_output(self, output_doc: GeneratedDoc) -> str: except HTTPException as e: raise e except ValueError as e: - error_msg = f"Validation Error occured while initializing LLM Guard Output Guardrail scanners: {e}" + error_msg = f"Validation Error occurred while initializing LLM Guard Output Guardrail scanners: {e}" logger.exception(error_msg) raise HTTPException(status_code=400, detail=error_msg) except Exception as e: - error_msg = f"An unexpected error occured during scanning prompt with LLM Guard Output Guardrail: {e}" + error_msg = f"An unexpected error occurred during scanning prompt with LLM Guard Output Guardrail: {e}" logger.exception(error_msg) raise HTTPException(status_code=500, detail=error_msg) diff --git a/comps/guardrails/src/guardrails/utils/llm_guard_output_scanners.py b/comps/guardrails/src/guardrails/utils/llm_guard_output_scanners.py index 514db7e964..c2de5089c9 100644 --- a/comps/guardrails/src/guardrails/utils/llm_guard_output_scanners.py +++ b/comps/guardrails/src/guardrails/utils/llm_guard_output_scanners.py @@ -1,111 +1,86 @@ # ruff: noqa: F401 # Copyright (C) 2024-2025 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from llm_guard.vault import Vault +# import models definition +from llm_guard.input_scanners.ban_code import ( + MODEL_SM as BANCODE_MODEL_SM, # input, because the same scanner to input and output +) +from llm_guard.input_scanners.ban_code import MODEL_TINY as BANCODE_MODEL_TINY +from llm_guard.input_scanners.ban_competitors import ( + MODEL_V1 as BANCOMPETITORS_MODEL_V1, # input, because the same scanner to input and output +) +from llm_guard.input_scanners.ban_topics import MODEL_BGE_M3_V2 as BANTOPICS_MODEL_BGE_M3_V2 +from llm_guard.input_scanners.ban_topics import MODEL_DEBERTA_BASE_V2 as BANTOPICS_MODEL_DEBERTA_BASE_V2 +from llm_guard.input_scanners.ban_topics import ( + MODEL_DEBERTA_LARGE_V2 as BANTOPICS_MODEL_DEBERTA_LARGE_V2, # input, because the same scanner to input and output +) +from llm_guard.input_scanners.ban_topics import MODEL_ROBERTA_BASE_C_V2 as BANTOPICS_MODEL_ROBERTA_BASE_C_V2 +from llm_guard.input_scanners.ban_topics import MODEL_ROBERTA_LARGE_C_V2 as BANTOPICS_MODEL_ROBERTA_LARGE_C_V2 +from llm_guard.input_scanners.code import DEFAULT_MODEL as CODE_DEFAULT_MODEL +from llm_guard.input_scanners.gibberish import DEFAULT_MODEL as GIBBERISH_DEFAULT_MODEL +from llm_guard.input_scanners.language import DEFAULT_MODEL as LANGUAGE_DEFAULT_MODEL +from llm_guard.input_scanners.toxicity import DEFAULT_MODEL as TOXICITY_DEFAULT_MODEL from llm_guard.output_scanners import ( + JSON, BanCode, BanCompetitors, BanTopics, Bias, Code, Deanonymize, - JSON, + FactualConsistency, + Gibberish, Language, LanguageSame, MaliciousURLs, NoRefusal, NoRefusalLight, ReadingTime, - FactualConsistency, - Gibberish, Relevance, Sensitive, Sentiment, Toxicity, - URLReachability -) - -# import models definition -from llm_guard.input_scanners.ban_code import ( #input, becasue the same scanner to input and output - MODEL_SM as BANCODE_MODEL_SM, - MODEL_TINY as BANCODE_MODEL_TINY -) - -from llm_guard.input_scanners.ban_competitors import ( #input, becasue the same scanner to input and output - MODEL_V1 as BANCOMPETITORS_MODEL_V1 -) - -from llm_guard.input_scanners.ban_topics import ( #input, becasue the same scanner to input and output - MODEL_DEBERTA_LARGE_V2 as BANTOPICS_MODEL_DEBERTA_LARGE_V2, - MODEL_DEBERTA_BASE_V2 as BANTOPICS_MODEL_DEBERTA_BASE_V2, - MODEL_BGE_M3_V2 as BANTOPICS_MODEL_BGE_M3_V2, - MODEL_ROBERTA_LARGE_C_V2 as BANTOPICS_MODEL_ROBERTA_LARGE_C_V2, - MODEL_ROBERTA_BASE_C_V2 as BANTOPICS_MODEL_ROBERTA_BASE_C_V2 -) - -from llm_guard.output_scanners.bias import ( - DEFAULT_MODEL as BIAS_DEFAULT_MODEL -) - -from llm_guard.input_scanners.code import ( - DEFAULT_MODEL as CODE_DEFAULT_MODEL -) - -from llm_guard.input_scanners.gibberish import ( - DEFAULT_MODEL as GIBBERISH_DEFAULT_MODEL -) - -from llm_guard.input_scanners.language import ( - DEFAULT_MODEL as LANGUAGE_DEFAULT_MODEL, -) - -from llm_guard.output_scanners.malicious_urls import ( - DEFAULT_MODEL as MALICIOUS_URLS_DEFAULT_MODEL -) - -from llm_guard.output_scanners.no_refusal import ( - DEFAULT_MODEL as NO_REFUSAL_DEFAULT_MODEL -) - -from llm_guard.output_scanners.relevance import ( - MODEL_EN_BGE_BASE as RELEVANCE_MODEL_EN_BGE_BASE, - MODEL_EN_BGE_LARGE as RELEVANCE_MODEL_EN_BGE_LARGE, - MODEL_EN_BGE_SMALL as RELEVANCE_MODEL_EN_BGE_SMALL -) - -from llm_guard.input_scanners.toxicity import ( - DEFAULT_MODEL as TOXICITY_DEFAULT_MODEL + URLReachability, ) +from llm_guard.output_scanners.bias import DEFAULT_MODEL as BIAS_DEFAULT_MODEL +from llm_guard.output_scanners.malicious_urls import DEFAULT_MODEL as MALICIOUS_URLS_DEFAULT_MODEL +from llm_guard.output_scanners.no_refusal import DEFAULT_MODEL as NO_REFUSAL_DEFAULT_MODEL +from llm_guard.output_scanners.relevance import MODEL_EN_BGE_BASE as RELEVANCE_MODEL_EN_BGE_BASE +from llm_guard.output_scanners.relevance import MODEL_EN_BGE_LARGE as RELEVANCE_MODEL_EN_BGE_LARGE +from llm_guard.output_scanners.relevance import MODEL_EN_BGE_SMALL as RELEVANCE_MODEL_EN_BGE_SMALL +from llm_guard.vault import Vault ENABLED_SCANNERS = [ - 'ban_code', - 'ban_competitors', - 'ban_substrings', - 'ban_topics', - 'bias', - 'code', - 'deanonymize', - 'json_scanner', - 'language', - 'language_same', - 'malicious_urls', - 'no_refusal', - 'no_refusal_light', - 'reading_time', - 'factual_consistency', - 'gibberish', - 'regex', - 'relevance', - 'sensitive', - 'sentiment', - 'toxicity', - 'url_reachability' + "ban_code", + "ban_competitors", + "ban_substrings", + "ban_topics", + "bias", + "code", + "deanonymize", + "json_scanner", + "language", + "language_same", + "malicious_urls", + "no_refusal", + "no_refusal_light", + "reading_time", + "factual_consistency", + "gibberish", + "regex", + "relevance", + "sensitive", + "sentiment", + "toxicity", + "url_reachability", ] -from comps.guardrails.utils.scanners import OPEABanSubstrings, OPEARegexScanner from comps import get_opea_logger, sanitize_env +from comps.guardrails.utils.scanners import OPEABanSubstrings, OPEARegexScanner + logger = get_opea_logger("opea_llm_guard_output_guardrail_microservice") + class OutputScannersConfig: def __init__(self, config_dict): self._output_scanners_config = { @@ -130,15 +105,14 @@ def __init__(self, config_dict): **self._get_sensitive_config_from_env(config_dict), **self._get_sentiment_config_from_env(config_dict), **self._get_toxicity_config_from_env(config_dict), - **self._get_url_reachability_config_from_env(config_dict) + **self._get_url_reachability_config_from_env(config_dict), } self.vault = None -#### METHODS FOR VALIDATING CONFIGS + #### METHODS FOR VALIDATING CONFIGS def _validate_value(self, value): - """ - Validate and convert the input value. + """Validate and convert the input value. Args: value (str): The value to be validated and converted. @@ -157,8 +131,7 @@ def _validate_value(self, value): return value def _get_ban_code_config_from_env(self, config_dict): - """ - Get the BanCode scanner configuration from the environment. + """Get the BanCode scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -169,13 +142,13 @@ def _get_ban_code_config_from_env(self, config_dict): return { "ban_code": { k.replace("BAN_CODE_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("BAN_CODE_") + for k, v in config_dict.items() + if k.startswith("BAN_CODE_") } } def _get_ban_competitors_config_from_env(self, config_dict): - """ - Get the BanCompetitors scanner configuration from the environment. + """Get the BanCompetitors scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -186,13 +159,13 @@ def _get_ban_competitors_config_from_env(self, config_dict): return { "ban_competitors": { k.replace("BAN_COMPETITORS_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("BAN_COMPETITORS_") + for k, v in config_dict.items() + if k.startswith("BAN_COMPETITORS_") } } def _get_ban_substrings_config_from_env(self, config_dict): - """ - Get the BanSubstrings scanner configuration from the environment. + """Get the BanSubstrings scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -203,13 +176,13 @@ def _get_ban_substrings_config_from_env(self, config_dict): return { "ban_substrings": { k.replace("BAN_SUBSTRINGS_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("BAN_SUBSTRINGS_") + for k, v in config_dict.items() + if k.startswith("BAN_SUBSTRINGS_") } } def _get_ban_topics_config_from_env(self, config_dict): - """ - Get the BanTopics scanner configuration from the environment. + """Get the BanTopics scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -220,13 +193,13 @@ def _get_ban_topics_config_from_env(self, config_dict): return { "ban_topics": { k.replace("BAN_TOPICS_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("BAN_TOPICS_") + for k, v in config_dict.items() + if k.startswith("BAN_TOPICS_") } } def _get_bias_config_from_env(self, config_dict): - """ - Get the Bias scanner configuration from the environment. + """Get the Bias scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -237,13 +210,13 @@ def _get_bias_config_from_env(self, config_dict): return { "bias": { k.replace("BIAS_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("BIAS_") + for k, v in config_dict.items() + if k.startswith("BIAS_") } } def _get_code_config_from_env(self, config_dict): - """ - Get the Code scanner configuration from the environment. + """Get the Code scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -254,13 +227,13 @@ def _get_code_config_from_env(self, config_dict): return { "code": { k.replace("CODE_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("CODE_") + for k, v in config_dict.items() + if k.startswith("CODE_") } } def _get_deanonymize_config_from_env(self, config_dict): - """ - Get the Deanonymize scanner configuration from the environment. + """Get the Deanonymize scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -271,13 +244,13 @@ def _get_deanonymize_config_from_env(self, config_dict): return { "deanonymize": { k.replace("DEANONYMIZE_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("DEANONYMIZE_") + for k, v in config_dict.items() + if k.startswith("DEANONYMIZE_") } } def _get_json_scanner_config_from_env(self, config_dict): - """ - Get the JSON scanner configuration from the environment. + """Get the JSON scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -288,13 +261,13 @@ def _get_json_scanner_config_from_env(self, config_dict): return { "json_scanner": { k.replace("JSON_SCANNER_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("JSON_SCANNER_") + for k, v in config_dict.items() + if k.startswith("JSON_SCANNER_") } } def _get_language_config_from_env(self, config_dict): - """ - Get the Language scanner configuration from the environment. + """Get the Language scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -305,13 +278,13 @@ def _get_language_config_from_env(self, config_dict): return { "language": { k.replace("LANGUAGE_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("LANGUAGE_") + for k, v in config_dict.items() + if k.startswith("LANGUAGE_") } } def _get_language_same_config_from_env(self, config_dict): - """ - Get the LanguageSame scanner configuration from the environment. + """Get the LanguageSame scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -322,13 +295,13 @@ def _get_language_same_config_from_env(self, config_dict): return { "language_same": { k.replace("LANGUAGE_SAME_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("LANGUAGE_SAME_") + for k, v in config_dict.items() + if k.startswith("LANGUAGE_SAME_") } } def _get_malicious_urls_config_from_env(self, config_dict): - """ - Get the MaliciousURLs scanner configuration from the environment. + """Get the MaliciousURLs scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -339,13 +312,13 @@ def _get_malicious_urls_config_from_env(self, config_dict): return { "malicious_urls": { k.replace("MALICIOUS_URLS_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("MALICIOUS_URLS_") + for k, v in config_dict.items() + if k.startswith("MALICIOUS_URLS_") } } def _get_no_refusal_config_from_env(self, config_dict): - """ - Get the NoRefusal scanner configuration from the environment. + """Get the NoRefusal scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -356,13 +329,13 @@ def _get_no_refusal_config_from_env(self, config_dict): return { "no_refusal": { k.replace("NO_REFUSAL_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("NO_REFUSAL_") + for k, v in config_dict.items() + if k.startswith("NO_REFUSAL_") } } def _get_no_refusal_light_config_from_env(self, config_dict): - """ - Get the NoRefusalLight scanner configuration from the environment. + """Get the NoRefusalLight scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -373,13 +346,13 @@ def _get_no_refusal_light_config_from_env(self, config_dict): return { "no_refusal_light": { k.replace("NO_REFUSAL_LIGHT_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("NO_REFUSAL_LIGHT_") + for k, v in config_dict.items() + if k.startswith("NO_REFUSAL_LIGHT_") } } def _get_reading_time_config_from_env(self, config_dict): - """ - Get the ReadingTime scanner configuration from the environment. + """Get the ReadingTime scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -390,13 +363,13 @@ def _get_reading_time_config_from_env(self, config_dict): return { "reading_time": { k.replace("READING_TIME_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("READING_TIME_") + for k, v in config_dict.items() + if k.startswith("READING_TIME_") } } def _get_factual_consistency_config_from_env(self, config_dict): - """ - Get the FactualConsitency scanner configuration from the environment. + """Get the FactualConsitency scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -407,13 +380,13 @@ def _get_factual_consistency_config_from_env(self, config_dict): return { "factual_consistency": { k.replace("FACTUAL_CONSISTENCY_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("FACTUAL_CONSISTENCY_") + for k, v in config_dict.items() + if k.startswith("FACTUAL_CONSISTENCY_") } } def _get_gibberish_config_from_env(self, config_dict): - """ - Get the Gibberish scanner configuration from the environment. + """Get the Gibberish scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -424,13 +397,13 @@ def _get_gibberish_config_from_env(self, config_dict): return { "gibberish": { k.replace("GIBBERISH_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("GIBBERISH_") + for k, v in config_dict.items() + if k.startswith("GIBBERISH_") } } def _get_regex_config_from_env(self, config_dict): - """ - Get the Regex scanner configuration from the environment. + """Get the Regex scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -441,13 +414,13 @@ def _get_regex_config_from_env(self, config_dict): return { "regex": { k.replace("REGEX_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("REGEX_") + for k, v in config_dict.items() + if k.startswith("REGEX_") } } def _get_relevance_config_from_env(self, config_dict): - """ - Get the Relevance scanner configuration from the environment. + """Get the Relevance scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -458,13 +431,13 @@ def _get_relevance_config_from_env(self, config_dict): return { "relevance": { k.replace("RELEVANCE_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("RELEVANCE_") + for k, v in config_dict.items() + if k.startswith("RELEVANCE_") } } def _get_sensitive_config_from_env(self, config_dict): - """ - Get the Sensitive scanner configuration from the environment. + """Get the Sensitive scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -475,13 +448,13 @@ def _get_sensitive_config_from_env(self, config_dict): return { "sensitive": { k.replace("SENSITIVE_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("SENSITIVE_") + for k, v in config_dict.items() + if k.startswith("SENSITIVE_") } } def _get_sentiment_config_from_env(self, config_dict): - """ - Get the Sentiment scanner configuration from the environment. + """Get the Sentiment scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -492,13 +465,13 @@ def _get_sentiment_config_from_env(self, config_dict): return { "sentiment": { k.replace("SENTIMENT_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("SENTIMENT_") + for k, v in config_dict.items() + if k.startswith("SENTIMENT_") } } def _get_toxicity_config_from_env(self, config_dict): - """ - Get the Toxicity scanner configuration from the environment. + """Get the Toxicity scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -509,13 +482,13 @@ def _get_toxicity_config_from_env(self, config_dict): return { "toxicity": { k.replace("TOXICITY_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("TOXICITY_") + for k, v in config_dict.items() + if k.startswith("TOXICITY_") } } def _get_url_reachability_config_from_env(self, config_dict): - """ - Get the URLReachability scanner configuration from the environment. + """Get the URLReachability scanner configuration from the environment. Args: config_dict (dict): The configuration dictionary. @@ -526,64 +499,71 @@ def _get_url_reachability_config_from_env(self, config_dict): return { "url_reachability": { k.replace("URL_REACHABILITY_", "").lower(): self._validate_value(v) - for k, v in config_dict.items() if k.startswith("URL_REACHABILITY_") + for k, v in config_dict.items() + if k.startswith("URL_REACHABILITY_") } } -#### METHODS FOR CREATING SCANNERS + #### METHODS FOR CREATING SCANNERS def _create_ban_code_scanner(self, scanner_config): - enabled_models = {'MODEL_SM': BANCODE_MODEL_SM, 'MODEL_TINY': BANCODE_MODEL_TINY} - bancode_params = {'use_onnx': scanner_config.get('use_onnx', False)} # by default we don't want to use onnx + enabled_models = {"MODEL_SM": BANCODE_MODEL_SM, "MODEL_TINY": BANCODE_MODEL_TINY} + bancode_params = {"use_onnx": scanner_config.get("use_onnx", False)} # by default we don't want to use onnx - model_name = scanner_config.get('model', None) - threshold = scanner_config.get('threshold', None) + model_name = scanner_config.get("model", None) + threshold = scanner_config.get("threshold", None) if model_name is not None: if model_name in enabled_models: logger.info(f"Using selected model for BanCode scanner: {model_name}") - bancode_params['model'] = enabled_models[model_name] # Model class from LLM Guard + bancode_params["model"] = enabled_models[model_name] # Model class from LLM Guard else: err_msg = f"Model name is not valid for BanCode scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" logger.error(err_msg) raise ValueError(err_msg) if threshold is not None: - bancode_params['threshold'] = threshold + bancode_params["threshold"] = threshold logger.info(f"Creating BanCode scanner with params: {bancode_params}") return BanCode(**bancode_params) def _create_ban_competitors_scanner(self, scanner_config): - enabled_models = {'MODEL_V1': BANCOMPETITORS_MODEL_V1} - ban_competitors_params = {'use_onnx': scanner_config.get('use_onnx', False)} # by default we want don't to use onnx + enabled_models = {"MODEL_V1": BANCOMPETITORS_MODEL_V1} + ban_competitors_params = { + "use_onnx": scanner_config.get("use_onnx", False) + } # by default we want don't to use onnx - competitors = scanner_config.get('competitors', None) - threshold = scanner_config.get('threshold', None) - redact = scanner_config.get('redact', None) - model_name = scanner_config.get('model', None) + competitors = scanner_config.get("competitors", None) + threshold = scanner_config.get("threshold", None) + redact = scanner_config.get("redact", None) + model_name = scanner_config.get("model", None) if isinstance(competitors, str): competitors = sanitize_env(competitors) if competitors: if isinstance(competitors, str): - artifacts = set([',', '', '.']) - ban_competitors_params['competitors'] = list(set(competitors.split(',')) - artifacts) + artifacts = set([",", "", "."]) + ban_competitors_params["competitors"] = list(set(competitors.split(",")) - artifacts) elif isinstance(competitors, list): - ban_competitors_params['competitors'] = competitors + ban_competitors_params["competitors"] = competitors else: logger.error("Provided type is not valid for BanCompetitors scanner") raise ValueError("Provided type is not valid for BanCompetitors scanner") else: - logger.error("Competitors list is required for BanCompetitors scanner. Please provide a list of competitors.") - raise TypeError("Competitors list is required for BanCompetitors scanner. Please provide a list of competitors.") + logger.error( + "Competitors list is required for BanCompetitors scanner. Please provide a list of competitors." + ) + raise TypeError( + "Competitors list is required for BanCompetitors scanner. Please provide a list of competitors." + ) if threshold is not None: - ban_competitors_params['threshold'] = threshold + ban_competitors_params["threshold"] = threshold if redact is not None: - ban_competitors_params['redact'] = redact + ban_competitors_params["redact"] = redact if model_name is not None: if model_name in enabled_models: logger.info(f"Using selected model for BanCompetitors scanner: {model_name}") - ban_competitors_params['model'] = enabled_models[model_name] + ban_competitors_params["model"] = enabled_models[model_name] else: err_msg = f"Model name is not valid for BanCompetitors scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" logger.error(err_msg) @@ -592,24 +572,24 @@ def _create_ban_competitors_scanner(self, scanner_config): return BanCompetitors(**ban_competitors_params) def _create_ban_substrings_scanner(self, scanner_config): - available_match_types = ['str', 'word'] + available_match_types = ["str", "word"] ban_substrings_params = {} - substrings = scanner_config.get('substrings', None) - match_type = scanner_config.get('match_type', None) - case_sensitive = scanner_config.get('case_sensitive', None) - redact = scanner_config.get('redact', None) - contains_all = scanner_config.get('contains_all', None) + substrings = scanner_config.get("substrings", None) + match_type = scanner_config.get("match_type", None) + case_sensitive = scanner_config.get("case_sensitive", None) + redact = scanner_config.get("redact", None) + contains_all = scanner_config.get("contains_all", None) if isinstance(substrings, str): substrings = sanitize_env(substrings) if substrings: if isinstance(substrings, str): - artifacts = set([',', '', '.']) - ban_substrings_params['substrings'] = list(set(substrings.split(',')) - artifacts) + artifacts = set([",", "", "."]) + ban_substrings_params["substrings"] = list(set(substrings.split(",")) - artifacts) elif substrings and isinstance(substrings, list): - ban_substrings_params['substrings'] = substrings + ban_substrings_params["substrings"] = substrings else: logger.error("Provided type is not valid for BanSubstrings scanner") raise ValueError("Provided type is not valid for BanSubstrings scanner") @@ -617,39 +597,39 @@ def _create_ban_substrings_scanner(self, scanner_config): logger.error("Substrings list is required for BanSubstrings scanner") raise TypeError("Substrings list is required for BanSubstrings scanner") if match_type is not None and match_type in available_match_types: - ban_substrings_params['match_type'] = match_type + ban_substrings_params["match_type"] = match_type if case_sensitive is not None: - ban_substrings_params['case_sensitive'] = case_sensitive + ban_substrings_params["case_sensitive"] = case_sensitive if redact is not None: - ban_substrings_params['redact'] = redact + ban_substrings_params["redact"] = redact if contains_all is not None: - ban_substrings_params['contains_all'] = contains_all + ban_substrings_params["contains_all"] = contains_all logger.info(f"Creating BanSubstrings scanner with params: {ban_substrings_params}") return OPEABanSubstrings(**ban_substrings_params) def _create_ban_topics_scanner(self, scanner_config): enabled_models = { - 'MODEL_DEBERTA_LARGE_V2': BANTOPICS_MODEL_DEBERTA_LARGE_V2, - 'MODEL_DEBERTA_BASE_V2': BANTOPICS_MODEL_DEBERTA_BASE_V2, - 'MODEL_BGE_M3_V2': BANTOPICS_MODEL_BGE_M3_V2, - 'MODEL_ROBERTA_LARGE_C_V2': BANTOPICS_MODEL_ROBERTA_LARGE_C_V2, - 'MODEL_ROBERTA_BASE_C_V2': BANTOPICS_MODEL_ROBERTA_BASE_C_V2 + "MODEL_DEBERTA_LARGE_V2": BANTOPICS_MODEL_DEBERTA_LARGE_V2, + "MODEL_DEBERTA_BASE_V2": BANTOPICS_MODEL_DEBERTA_BASE_V2, + "MODEL_BGE_M3_V2": BANTOPICS_MODEL_BGE_M3_V2, + "MODEL_ROBERTA_LARGE_C_V2": BANTOPICS_MODEL_ROBERTA_LARGE_C_V2, + "MODEL_ROBERTA_BASE_C_V2": BANTOPICS_MODEL_ROBERTA_BASE_C_V2, } - ban_topics_params = {'use_onnx': scanner_config.get('use_onnx', False)} + ban_topics_params = {"use_onnx": scanner_config.get("use_onnx", False)} - topics = scanner_config.get('topics', None) - threshold = scanner_config.get('threshold', None) - model_name = scanner_config.get('model', None) + topics = scanner_config.get("topics", None) + threshold = scanner_config.get("threshold", None) + model_name = scanner_config.get("model", None) if isinstance(topics, str): topics = sanitize_env(topics) if topics: if isinstance(topics, str): - artifacts = set([',', '', '.']) - ban_topics_params['topics'] = list(set(topics.split(',')) - artifacts) + artifacts = set([",", "", "."]) + ban_topics_params["topics"] = list(set(topics.split(",")) - artifacts) elif isinstance(topics, list): - ban_topics_params['topics'] = topics + ban_topics_params["topics"] = topics else: logger.error("Provided type is not valid for BanTopics scanner") raise ValueError("Provided type is not valid for BanTopics scanner") @@ -657,11 +637,11 @@ def _create_ban_topics_scanner(self, scanner_config): logger.error("Topics list is required for BanTopics scanner") raise TypeError("Topics list is required for BanTopics scanner") if threshold is not None: - ban_topics_params['threshold'] = threshold + ban_topics_params["threshold"] = threshold if model_name is not None: if model_name in enabled_models: logger.info(f"Using selected model for BanTopics scanner: {model_name}") - ban_topics_params['model'] = enabled_models[model_name] + ban_topics_params["model"] = enabled_models[model_name] else: err_msg = f"Model name is not valid for BanTopics scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" logger.error(err_msg) @@ -670,22 +650,22 @@ def _create_ban_topics_scanner(self, scanner_config): return BanTopics(**ban_topics_params) def _create_bias_scanner(self, scanner_config): - available_match_types = ['str', 'word'] - enabled_models = {'DEFAULT_MODEL': BIAS_DEFAULT_MODEL} - bias_params = {'use_onnx': scanner_config.get('use_onnx', False)} + available_match_types = ["str", "word"] + enabled_models = {"DEFAULT_MODEL": BIAS_DEFAULT_MODEL} + bias_params = {"use_onnx": scanner_config.get("use_onnx", False)} - threshold = scanner_config.get('threshold', None) - match_type = scanner_config.get('match_type', None) - model_name = scanner_config.get('model', None) + threshold = scanner_config.get("threshold", None) + match_type = scanner_config.get("match_type", None) + model_name = scanner_config.get("model", None) if threshold is not None: - bias_params['threshold'] = threshold + bias_params["threshold"] = threshold if match_type is not None and match_type in available_match_types: - bias_params['match_type'] = match_type + bias_params["match_type"] = match_type if model_name is not None: if model_name in enabled_models: logger.info(f"Using selected model for Bias scanner: {model_name}") - bias_params['model'] = enabled_models[model_name] + bias_params["model"] = enabled_models[model_name] else: err_msg = f"Model name is not valid for Bias scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" logger.error(err_msg) @@ -695,23 +675,23 @@ def _create_bias_scanner(self, scanner_config): return Bias(**bias_params) def _create_code_scanner(self, scanner_config): - enabled_models = {'DEFAULT_MODEL': CODE_DEFAULT_MODEL} - code_params = {'use_onnx': scanner_config.get('use_onnx', False)} + enabled_models = {"DEFAULT_MODEL": CODE_DEFAULT_MODEL} + code_params = {"use_onnx": scanner_config.get("use_onnx", False)} - languages = scanner_config.get('languages', None) - model_name = scanner_config.get('model', None) - is_blocked = scanner_config.get('is_blocked', None) - threshold = scanner_config.get('threshold', None) + languages = scanner_config.get("languages", None) + model_name = scanner_config.get("model", None) + is_blocked = scanner_config.get("is_blocked", None) + threshold = scanner_config.get("threshold", None) if isinstance(languages, str): languages = sanitize_env(languages) if languages: if isinstance(languages, str): - artifacts = set([',', '', '.']) - code_params['languages'] = list(set(languages.split(',')) - artifacts) + artifacts = set([",", "", "."]) + code_params["languages"] = list(set(languages.split(",")) - artifacts) elif isinstance(languages, list): - code_params['languages'] = languages + code_params["languages"] = languages else: logger.error("Provided type is not valid for Code scanner") raise ValueError("Provided type is not valid for Code scanner") @@ -721,26 +701,26 @@ def _create_code_scanner(self, scanner_config): if model_name is not None: if model_name in enabled_models: logger.info(f"Using selected model for Code scanner: {model_name}") - code_params['model'] = enabled_models[model_name] + code_params["model"] = enabled_models[model_name] else: err_msg = f"Model name is not valid for Code scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" logger.error(err_msg) raise ValueError(err_msg) if is_blocked is not None: - code_params['is_blocked'] = is_blocked + code_params["is_blocked"] = is_blocked if threshold is not None: - code_params['threshold'] = threshold + code_params["threshold"] = threshold logger.info(f"Creating Code scanner with params: {code_params}") return Code(**code_params) def _create_deanonymize_scanner(self, scanner_config, vault): if not vault: raise Exception("Vault is required for Deanonymize scanner") - deanonymize_params = {'vault': vault} + deanonymize_params = {"vault": vault} - matching_strategy = scanner_config.get('matching_strategy', None) + matching_strategy = scanner_config.get("matching_strategy", None) if matching_strategy is not None: - deanonymize_params['matching_strategy'] = matching_strategy + deanonymize_params["matching_strategy"] = matching_strategy logger.info(f"Creating Deanonymize scanner with params: {deanonymize_params}") return Deanonymize(**deanonymize_params) @@ -748,36 +728,36 @@ def _create_deanonymize_scanner(self, scanner_config, vault): def _create_json_scanner(self, scanner_config): json_scanner_params = {} - required_elements = scanner_config.get('required_elements', None) - repair = scanner_config.get('repair', None) + required_elements = scanner_config.get("required_elements", None) + repair = scanner_config.get("repair", None) if required_elements is not None: - json_scanner_params['required_elements'] = required_elements + json_scanner_params["required_elements"] = required_elements if repair is not None: - json_scanner_params['repair'] = repair + json_scanner_params["repair"] = repair logger.info(f"Creating JSON scanner with params: {json_scanner_params}") return JSON(**json_scanner_params) def _create_language_scanner(self, scanner_config): - enabled_models = {'DEFAULT_MODEL': LANGUAGE_DEFAULT_MODEL} - enabled_match_types = ['sentence', 'full'] - language_params = {'use_onnx': scanner_config.get('use_onnx', False)} + enabled_models = {"DEFAULT_MODEL": LANGUAGE_DEFAULT_MODEL} + enabled_match_types = ["sentence", "full"] + language_params = {"use_onnx": scanner_config.get("use_onnx", False)} - valid_languages = scanner_config.get('valid_languages', None) - model_name = scanner_config.get('model', None) - threshold = scanner_config.get('threshold', None) - match_type = scanner_config.get('match_type', None) + valid_languages = scanner_config.get("valid_languages", None) + model_name = scanner_config.get("model", None) + threshold = scanner_config.get("threshold", None) + match_type = scanner_config.get("match_type", None) if isinstance(valid_languages, str): valid_languages = sanitize_env(valid_languages) if valid_languages: if isinstance(valid_languages, str): - artifacts = set([',', '', '.']) - language_params['valid_languages'] = list(set(valid_languages.split(',')) - artifacts) + artifacts = set([",", "", "."]) + language_params["valid_languages"] = list(set(valid_languages.split(",")) - artifacts) elif isinstance(valid_languages, list): - language_params['valid_languages'] = valid_languages + language_params["valid_languages"] = valid_languages else: logger.error("Provided type is not valid for Language scanner") raise ValueError("Provided type is not valid for Language scanner") @@ -787,81 +767,81 @@ def _create_language_scanner(self, scanner_config): if model_name is not None: if model_name in enabled_models: logger.info(f"Using selected model for Language scanner: {model_name}") - language_params['model'] = enabled_models[model_name] + language_params["model"] = enabled_models[model_name] else: err_msg = f"Model name is not valid for Language scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" logger.error(err_msg) raise ValueError(err_msg) if threshold is not None: - language_params['threshold'] = threshold + language_params["threshold"] = threshold if match_type is not None and match_type in enabled_match_types: - language_params['match_type'] = match_type + language_params["match_type"] = match_type logger.info(f"Creating Language scanner with params: {language_params}") return Language(**language_params) def _create_language_same_scanner(self, scanner_config): - enabled_models = {'DEFAULT_MODEL': LANGUAGE_DEFAULT_MODEL} - language_same_params = {'use_onnx': scanner_config.get('use_onnx', False)} + enabled_models = {"DEFAULT_MODEL": LANGUAGE_DEFAULT_MODEL} + language_same_params = {"use_onnx": scanner_config.get("use_onnx", False)} - model_name = scanner_config.get('model', None) - threshold = scanner_config.get('threshold', None) + model_name = scanner_config.get("model", None) + threshold = scanner_config.get("threshold", None) if model_name is not None: if model_name in enabled_models: logger.info(f"Using selected model for LanguageSame scanner: {model_name}") - language_same_params['model'] = enabled_models[model_name] + language_same_params["model"] = enabled_models[model_name] else: err_msg = f"Model name is not valid for LanguageSame scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" logger.error(err_msg) raise ValueError(err_msg) if threshold is not None: - language_same_params['threshold'] = threshold + language_same_params["threshold"] = threshold logger.info(f"Creating LanguageSame scanner with params: {language_same_params}") return LanguageSame(**language_same_params) def _create_malicious_urls_scanner(self, scanner_config): - enabled_models = {'DEFAULT_MODEL': MALICIOUS_URLS_DEFAULT_MODEL} - malicious_urls_params = {'use_onnx': scanner_config.get('use_onnx', False)} + enabled_models = {"DEFAULT_MODEL": MALICIOUS_URLS_DEFAULT_MODEL} + malicious_urls_params = {"use_onnx": scanner_config.get("use_onnx", False)} - threshold = scanner_config.get('threshold', None) - model_name = scanner_config.get('model', None) + threshold = scanner_config.get("threshold", None) + model_name = scanner_config.get("model", None) if model_name is not None: if model_name in enabled_models: logger.info(f"Using selected model for MaliciousURLs scanner: {model_name}") - malicious_urls_params['model'] = enabled_models[model_name] + malicious_urls_params["model"] = enabled_models[model_name] else: err_msg = f"Model name is not valid for MaliciousURLs scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" logger.error(err_msg) raise ValueError(err_msg) if threshold is not None: - malicious_urls_params['threshold'] = threshold + malicious_urls_params["threshold"] = threshold logger.info(f"Creating MaliciousURLs scanner with params: {malicious_urls_params}") return MaliciousURLs(**malicious_urls_params) def _create_no_refusal_scanner(self, scanner_config): - enabled_models = {'DEFAULT_MODEL': NO_REFUSAL_DEFAULT_MODEL} - enabled_match_types = ['sentence', 'full'] - no_refusal_params = {'use_onnx': scanner_config.get('use_onnx', False)} + enabled_models = {"DEFAULT_MODEL": NO_REFUSAL_DEFAULT_MODEL} + enabled_match_types = ["sentence", "full"] + no_refusal_params = {"use_onnx": scanner_config.get("use_onnx", False)} - threshold = scanner_config.get('threshold', None) - model_name = scanner_config.get('model', None) - match_type = scanner_config.get('match_type', None) + threshold = scanner_config.get("threshold", None) + model_name = scanner_config.get("model", None) + match_type = scanner_config.get("match_type", None) if threshold is not None: - no_refusal_params['threshold'] = threshold + no_refusal_params["threshold"] = threshold if model_name is not None: if model_name in enabled_models: logger.info(f"Using selected model for NoRefusal scanner: {model_name}") - no_refusal_params['model'] = enabled_models[model_name] + no_refusal_params["model"] = enabled_models[model_name] else: err_msg = f"Model name is not valid for NoRefusal scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" logger.error(err_msg) raise ValueError(err_msg) if match_type is not None and match_type in enabled_match_types: - no_refusal_params['match_type'] = match_type + no_refusal_params["match_type"] = match_type logger.info(f"Creating NoRefusal scanner with params: {no_refusal_params}") return NoRefusal(**no_refusal_params) @@ -873,88 +853,91 @@ def _create_no_refusal_light_scanner(self): def _create_reading_time_scanner(self, scanner_config): reading_time_params = {} - max_time = scanner_config.get('max_time', None) - truncate = scanner_config.get('truncate', None) + max_time = scanner_config.get("max_time", None) + truncate = scanner_config.get("truncate", None) if max_time is not None: - reading_time_params['max_time'] = float(max_time) + reading_time_params["max_time"] = float(max_time) else: logger.error("Max time is required for ReadingTime scanner") raise TypeError("Max time is required for ReadingTime scanner") if truncate is not None: - reading_time_params['truncate'] = truncate + reading_time_params["truncate"] = truncate logger.info(f"Creating ReadingTime scanner with params: {reading_time_params}") return ReadingTime(**reading_time_params) def _create_factual_consistency_scanner(self, scanner_config): - enabled_models = {"DEFAULT_MODEL": BANTOPICS_MODEL_DEBERTA_BASE_V2} # BanTopics model is used as deault in FactualConsistency - factual_consistency_params = {'use_onnx': scanner_config.get('use_onnx', False)} + enabled_models = { + "DEFAULT_MODEL": BANTOPICS_MODEL_DEBERTA_BASE_V2 + } # BanTopics model is used as default in FactualConsistency + factual_consistency_params = {"use_onnx": scanner_config.get("use_onnx", False)} - model_name = scanner_config.get('model_name', None) - minimum_score = scanner_config.get('minimum_score', None) + model_name = scanner_config.get("model_name", None) + minimum_score = scanner_config.get("minimum_score", None) if model_name is not None: if model_name in enabled_models: logger.info(f"Using selected model for NoRefusal scanner: {model_name}") - factual_consistency_params['model'] = enabled_models[model_name] + factual_consistency_params["model"] = enabled_models[model_name] else: err_msg = f"Model name is not valid for NoRefusal scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" logger.error(err_msg) raise ValueError(err_msg) if minimum_score is not None: - factual_consistency_params['minimum_score'] = minimum_score + factual_consistency_params["minimum_score"] = minimum_score logger.info(f"Creating FactualConsistency scanner with params: {factual_consistency_params}") return FactualConsistency(**factual_consistency_params) def _create_gibberish_scanner(self, scanner_config): - enabled_models = {'DEFAULT_MODEL': GIBBERISH_DEFAULT_MODEL} - enabled_match_types = ['sentence', 'full'] - gibberish_params = {'use_onnx': scanner_config.get('use_onnx', False)} + enabled_models = {"DEFAULT_MODEL": GIBBERISH_DEFAULT_MODEL} + enabled_match_types = ["sentence", "full"] + gibberish_params = {"use_onnx": scanner_config.get("use_onnx", False)} - model_name = scanner_config.get('model', None) - threshold = scanner_config.get('threshold', None) - match_type = scanner_config.get('match_type', None) + model_name = scanner_config.get("model", None) + threshold = scanner_config.get("threshold", None) + match_type = scanner_config.get("match_type", None) if match_type == "sentence": import nltk - nltk.download('punkt_tab') + + nltk.download("punkt_tab") if threshold is not None: - gibberish_params['threshold'] = threshold + gibberish_params["threshold"] = threshold if model_name is not None: if model_name in enabled_models: logger.info(f"Using selected model for Gibberish scanner: {model_name}") - gibberish_params['model'] = enabled_models[model_name] + gibberish_params["model"] = enabled_models[model_name] else: err_msg = f"Model name is not valid for Gibberish scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" logger.error(err_msg) raise ValueError(err_msg) if match_type is not None and match_type in enabled_match_types: - gibberish_params['match_type'] = match_type + gibberish_params["match_type"] = match_type logger.info(f"Creating Gibberish scanner with params: {gibberish_params}") return Gibberish(**gibberish_params) def _create_regex_scanner(self, scanner_config): - enabled_match_types = ['search', 'fullmatch'] + enabled_match_types = ["search", "fullmatch"] regex_params = {} - patterns = scanner_config.get('patterns', None) - is_blocked = scanner_config.get('is_blocked', None) - match_type = scanner_config.get('match_type', None) - redact = scanner_config.get('redact', None) + patterns = scanner_config.get("patterns", None) + is_blocked = scanner_config.get("is_blocked", None) + match_type = scanner_config.get("match_type", None) + redact = scanner_config.get("redact", None) if isinstance(patterns, str): patterns = sanitize_env(patterns) if patterns: if isinstance(patterns, str): - artifacts = set([',', '', '.']) - regex_params['patterns'] = list(set(patterns.split(',')) - artifacts) + artifacts = set([",", "", "."]) + regex_params["patterns"] = list(set(patterns.split(",")) - artifacts) elif isinstance(patterns, list): - regex_params['patterns'] = patterns + regex_params["patterns"] = patterns else: logger.error("Provided type is not valid for Regex scanner") raise ValueError("Provided type is not valid for Regex scanner") @@ -962,46 +945,50 @@ def _create_regex_scanner(self, scanner_config): logger.error("Patterns list is required for Regex scanner") raise TypeError("Patterns list is required for Regex scanner") if is_blocked is not None: - regex_params['is_blocked'] = is_blocked + regex_params["is_blocked"] = is_blocked if match_type is not None and match_type in enabled_match_types: - regex_params['match_type'] = match_type + regex_params["match_type"] = match_type if redact is not None: - regex_params['redact'] = redact + regex_params["redact"] = redact logger.info(f"Creating Regex scanner with params: {regex_params}") return OPEARegexScanner(**regex_params) def _create_relevance_scanner(self, scanner_config): - enabled_models = {'MODEL_EN_BGE_BASE': RELEVANCE_MODEL_EN_BGE_BASE, - 'MODEL_EN_BGE_LARGE': RELEVANCE_MODEL_EN_BGE_LARGE, - 'MODEL_EN_BGE_SMALL': RELEVANCE_MODEL_EN_BGE_SMALL} - relevance_params = {'use_onnx': scanner_config.get('use_onnx', False)} # TODO: onnx off, because of bug on LLM Guard side + enabled_models = { + "MODEL_EN_BGE_BASE": RELEVANCE_MODEL_EN_BGE_BASE, + "MODEL_EN_BGE_LARGE": RELEVANCE_MODEL_EN_BGE_LARGE, + "MODEL_EN_BGE_SMALL": RELEVANCE_MODEL_EN_BGE_SMALL, + } + relevance_params = { + "use_onnx": scanner_config.get("use_onnx", False) + } # TODO: onnx off, because of bug on LLM Guard side - model_name = scanner_config.get('model', None) - threshold = scanner_config.get('threshold', None) + model_name = scanner_config.get("model", None) + threshold = scanner_config.get("threshold", None) if model_name is not None: if model_name in enabled_models: logger.info(f"Using selected model for Gibberish scanner: {model_name}") - relevance_params['model'] = enabled_models[model_name] + relevance_params["model"] = enabled_models[model_name] else: err_msg = f"Model name is not valid for Relevance scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" logger.error(err_msg) raise ValueError(err_msg) if threshold is not None: - relevance_params['threshold'] = threshold + relevance_params["threshold"] = threshold logger.info(f"Creating Relevance scanner with params: {relevance_params}") return Relevance(**relevance_params) def _create_sensitive_scanner(self, scanner_config): - sensitive_params = {'use_onnx': scanner_config.get('use_onnx', False)} + sensitive_params = {"use_onnx": scanner_config.get("use_onnx", False)} - entity_types = scanner_config.get('entity_types', None) - regex_patterns = scanner_config.get('regex_patterns', None) - redact = scanner_config.get('redact', None) - recognizer_conf = scanner_config.get('recognizer_conf', None) - threshold = scanner_config.get('threshold', None) + entity_types = scanner_config.get("entity_types", None) + regex_patterns = scanner_config.get("regex_patterns", None) + redact = scanner_config.get("redact", None) + recognizer_conf = scanner_config.get("recognizer_conf", None) + threshold = scanner_config.get("threshold", None) if entity_types is not None: if isinstance(entity_types, str): @@ -1009,22 +996,22 @@ def _create_sensitive_scanner(self, scanner_config): if entity_types: if isinstance(entity_types, str): - artifacts = set([',', '', '.']) - sensitive_params['entity_types'] = list(set(entity_types.split(',')) - artifacts) + artifacts = set([",", "", "."]) + sensitive_params["entity_types"] = list(set(entity_types.split(",")) - artifacts) elif isinstance(entity_types, list): - sensitive_params['entity_types'] = entity_types + sensitive_params["entity_types"] = entity_types else: logger.error("Provided type is not valid for Sensitive scanner") raise ValueError("Provided type is not valid for Sensitive scanner") if regex_patterns is not None: - sensitive_params['regex_patterns'] = regex_patterns + sensitive_params["regex_patterns"] = regex_patterns if redact is not None: - sensitive_params['redact'] = redact + sensitive_params["redact"] = redact if recognizer_conf is not None: - sensitive_params['recognizer_conf'] = recognizer_conf + sensitive_params["recognizer_conf"] = recognizer_conf if threshold is not None: - sensitive_params['threshold'] = threshold + sensitive_params["threshold"] = threshold logger.info(f"Creating Sensitive scanner with params: {sensitive_params}") return Sensitive(**sensitive_params) @@ -1033,43 +1020,43 @@ def _create_sentiment_scanner(self, scanner_config): enabled_lexicons = ["vader_lexicon"] sentiment_params = {} - threshold = scanner_config.get('threshold', None) - lexicon = scanner_config.get('lexicon', None) + threshold = scanner_config.get("threshold", None) + lexicon = scanner_config.get("lexicon", None) if threshold is not None: - sentiment_params['threshold'] = threshold + sentiment_params["threshold"] = threshold if lexicon is not None and lexicon in enabled_lexicons: - sentiment_params['lexicon'] = lexicon + sentiment_params["lexicon"] = lexicon logger.info(f"Creating Sentiment scanner with params: {sentiment_params}") return Sentiment(**sentiment_params) def _create_toxicity_scanner(self, scanner_config): - enabled_models = {'DEFAULT_MODEL': TOXICITY_DEFAULT_MODEL} - enabled_match_types = ['sentence', 'full'] - toxicity_params = {'use_onnx': scanner_config.get('use_onnx', False)} + enabled_models = {"DEFAULT_MODEL": TOXICITY_DEFAULT_MODEL} + enabled_match_types = ["sentence", "full"] + toxicity_params = {"use_onnx": scanner_config.get("use_onnx", False)} - model_name = scanner_config.get('model', None) - threshold = scanner_config.get('threshold', None) - match_type = scanner_config.get('match_type', None) + model_name = scanner_config.get("model", None) + threshold = scanner_config.get("threshold", None) + match_type = scanner_config.get("match_type", None) if match_type == "sentence": import nltk - nltk.download('punkt_tab') + nltk.download("punkt_tab") if model_name is not None: if model_name in enabled_models: logger.info(f"Using selected model for Toxicity scanner: {model_name}") - toxicity_params['model'] = enabled_models[model_name] + toxicity_params["model"] = enabled_models[model_name] else: err_msg = f"Model name is not valid for Toxicity scanner. Please provide a valid model name. Provided model: {model_name}. Enabled models: {list(enabled_models.keys())}" logger.error(err_msg) raise ValueError(err_msg) if threshold is not None: - toxicity_params['threshold'] = threshold + toxicity_params["threshold"] = threshold if match_type is not None and match_type in enabled_match_types: - toxicity_params['match_type'] = match_type + toxicity_params["match_type"] = match_type logger.info(f"Creating Toxicity scanner with params: {toxicity_params}") return Toxicity(**toxicity_params) @@ -1077,20 +1064,20 @@ def _create_toxicity_scanner(self, scanner_config): def _create_url_reachability_scanner(self, scanner_config): url_reachability_params = {} - success_status_codes = scanner_config.get('success_status_codes', None) - timeout = scanner_config.get('timeout', None) + success_status_codes = scanner_config.get("success_status_codes", None) + timeout = scanner_config.get("timeout", None) if success_status_codes is not None: if isinstance(success_status_codes, str): - artifacts = set([',', '', '.']) - url_reachability_params['success_status_codes'] = list(set(success_status_codes.split(',')) - artifacts) + artifacts = set([",", "", "."]) + url_reachability_params["success_status_codes"] = list(set(success_status_codes.split(",")) - artifacts) elif isinstance(success_status_codes, list): - url_reachability_params['success_status_codes'] = success_status_codes + url_reachability_params["success_status_codes"] = success_status_codes else: logger.error("Provided type is not valid for Language scanner") raise ValueError("Provided type is not valid for Language scanner") if timeout is not None: - url_reachability_params['timeout'] = timeout + url_reachability_params["timeout"] = timeout logger.info(f"Creating URLReachability scanner with params: {url_reachability_params}") return URLReachability(**url_reachability_params) @@ -1146,16 +1133,17 @@ def _create_output_scanner(self, scanner_name, scanner_config, vault=None): return None def create_enabled_output_scanners(self): - """ - Create and return a list of enabled scanners based on the global configuration. + """Create and return a list of enabled scanners based on the global configuration. Returns: list: A list of enabled scanner instances. """ - enabled_scanners_names_and_configs = {k: v for k, v in self._output_scanners_config.items() if isinstance(v, dict) and v.get("enabled")} + enabled_scanners_names_and_configs = { + k: v for k, v in self._output_scanners_config.items() if isinstance(v, dict) and v.get("enabled") + } enabled_scanners_objects = [] - err_msgs = {} # list for all erronous scanners + err_msgs = {} # list for all erroneous scanners only_validation_errors = True for scanner_name, scanner_config in enabled_scanners_names_and_configs.items(): try: @@ -1163,19 +1151,19 @@ def create_enabled_output_scanners(self): scanner_object = self._create_output_scanner(scanner_name, scanner_config, vault=self.vault) enabled_scanners_objects.append(scanner_object) except ValueError as e: - err_msg = f"A ValueError occured during creating output scanner {scanner_name}: {e}" + err_msg = f"A ValueError occurred during creating output scanner {scanner_name}: {e}" logger.error(err_msg) err_msgs[scanner_name] = err_msg self._output_scanners_config[scanner_name]["enabled"] = False continue except TypeError as e: - err_msg = f"A TypeError occured during creating output scanner {scanner_name}: {e}" + err_msg = f"A TypeError occurred during creating output scanner {scanner_name}: {e}" logger.error(err_msg) err_msgs[scanner_name] = err_msg self._output_scanners_config[scanner_name]["enabled"] = False continue except Exception as e: - err_msg = f"An unexpected error occured during creating output scanner {scanner_name}: {e}" + err_msg = f"An unexpected error occurred during creating output scanner {scanner_name}: {e}" logger.error(err_msg) err_msgs[scanner_name] = err_msg self._output_scanners_config[scanner_name]["enabled"] = False @@ -1184,15 +1172,18 @@ def create_enabled_output_scanners(self): if err_msgs: if only_validation_errors: - raise ValueError(f"Some scanners failed to be created due to validation errors. The details: {err_msgs}") + raise ValueError( + f"Some scanners failed to be created due to validation errors. The details: {err_msgs}" + ) else: - raise Exception(f"Some scanners failed to be created due to validation or unexpected errors. The details: {err_msgs}") + raise Exception( + f"Some scanners failed to be created due to validation or unexpected errors. The details: {err_msgs}" + ) return [s for s in enabled_scanners_objects if s is not None] def changed(self, new_scanners_config): - """ - Check if the scanners configuration has changed. + """Check if the scanners configuration has changed. Args: new_scanners_config (dict): The current scanners configuration. @@ -1200,15 +1191,25 @@ def changed(self, new_scanners_config): Returns: bool: True if the configuration has changed, False otherwise. """ - del new_scanners_config['id'] - newly_enabled_scanners = {k: {in_k: in_v for in_k, in_v in v.items() if in_k != 'id'} for k, v in new_scanners_config.items() if isinstance(v, dict) and v.get("enabled")} - previously_enabled_scanners = {k: v for k, v in self._output_scanners_config.items() if isinstance(v, dict) and v.get("enabled")} - if newly_enabled_scanners == previously_enabled_scanners: # if the enabled scanners are the same we do nothing + del new_scanners_config["id"] + newly_enabled_scanners = { + k: {in_k: in_v for in_k, in_v in v.items() if in_k != "id"} + for k, v in new_scanners_config.items() + if isinstance(v, dict) and v.get("enabled") + } + previously_enabled_scanners = { + k: v for k, v in self._output_scanners_config.items() if isinstance(v, dict) and v.get("enabled") + } + if newly_enabled_scanners == previously_enabled_scanners: # if the enabled scanners are the same we do nothing logger.info("No changes in list for enabled scanners. Checking configuration changes...") return False else: logger.warning("Sanners configuration has been changed, re-creating scanners") self._output_scanners_config.clear() - stripped_new_scanners_config = {k: {in_k: in_v for in_k, in_v in v.items() if in_k != 'id'} for k, v in new_scanners_config.items() if isinstance(v, dict)} + stripped_new_scanners_config = { + k: {in_k: in_v for in_k, in_v in v.items() if in_k != "id"} + for k, v in new_scanners_config.items() + if isinstance(v, dict) + } self._output_scanners_config.update(stripped_new_scanners_config) return True From 80098351b2464c64e5bdec369dc2d2e5e76fdc49 Mon Sep 17 00:00:00 2001 From: WenjiaoYue Date: Mon, 16 Jun 2025 11:31:14 +0800 Subject: [PATCH 3/7] add input/output conf --- .../opea_guardrails_microservice.py | 141 ++++++++--------- .../src/guardrails/requirements.txt | 1 + .../src/guardrails/utils/.input_env | 113 ++++++++++++++ .../src/guardrails/utils/.output_env | 144 ++++++++++++++++++ 4 files changed, 329 insertions(+), 70 deletions(-) create mode 100644 comps/guardrails/src/guardrails/utils/.input_env create mode 100644 comps/guardrails/src/guardrails/utils/.output_env diff --git a/comps/guardrails/src/guardrails/opea_guardrails_microservice.py b/comps/guardrails/src/guardrails/opea_guardrails_microservice.py index 35e6161e71..b1ac7c3fbc 100644 --- a/comps/guardrails/src/guardrails/opea_guardrails_microservice.py +++ b/comps/guardrails/src/guardrails/opea_guardrails_microservice.py @@ -1,22 +1,28 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -import asyncio import os import time +import asyncio from typing import Union - from dotenv import dotenv_values from fastapi import HTTPException -from utils.llm_guard_input_guardrail import OPEALLMGuardInputGuardrail -from utils.llm_guard_output_guardrail import OPEALLMGuardOutputGuardrail +from fastapi.responses import StreamingResponse +from pydantic import ValidationError + +from utils.llm_guard_input_guardrail import ( + OPEALLMGuardInputGuardrail +) +from utils.llm_guard_output_guardrail import ( + OPEALLMGuardOutputGuardrail +) from comps import ( CustomLogger, GeneratedDoc, LLMParamsDoc, - OpeaComponentLoader, SearchedDoc, + OpeaComponentLoader, ServiceType, TextDoc, opea_microservices, @@ -24,12 +30,21 @@ register_statistics, statistics_dict, ) + from comps.cores.proto.api_protocol import ChatCompletionRequest, DocSumChatCompletionRequest logger = CustomLogger("opea_guardrails_microservice") logflag = os.getenv("LOGFLAG", False) -usvc_config = {**dotenv_values(".env"), **os.environ} +input_usvc_config = { + **dotenv_values("utils/.input_env"), + **os.environ +} + +output_usvc_config = { + **dotenv_values("utils/.output_env"), + **os.environ +} guardrails_component_name = os.getenv("GUARDRAILS_COMPONENT_NAME", "OPEA_LLAMA_GUARD") # Initialize OpeaComponentLoader @@ -39,9 +54,8 @@ description=f"OPEA Guardrails Component: {guardrails_component_name}", ) -input_guardrail = OPEALLMGuardInputGuardrail(usvc_config) -output_guardrail = OPEALLMGuardOutputGuardrail(usvc_config) - +input_guardrail = OPEALLMGuardInputGuardrail(input_usvc_config) +output_guardrail = OPEALLMGuardOutputGuardrail(output_usvc_config) @register_microservice( name="opea_service@guardrails", @@ -49,82 +63,69 @@ endpoint="/v1/guardrails", host="0.0.0.0", port=9090, - input_datatype=Union[ - LLMParamsDoc, - GeneratedDoc, - ChatCompletionRequest, - SearchedDoc, - ChatCompletionRequest, - DocSumChatCompletionRequest, - ], - output_datatype=Union[ - LLMParamsDoc, - GeneratedDoc, - ChatCompletionRequest, - SearchedDoc, - ChatCompletionRequest, - DocSumChatCompletionRequest, - ], + input_datatype=Union[LLMParamsDoc, GeneratedDoc, TextDoc], + output_datatype=Union[TextDoc, GeneratedDoc, StreamingResponse], ) @register_statistics(names=["opea_service@guardrails"]) -async def safety_guard( - input: Union[ - LLMParamsDoc, - GeneratedDoc, - ChatCompletionRequest, - SearchedDoc, - ChatCompletionRequest, - DocSumChatCompletionRequest, - ], -) -> Union[ - LLMParamsDoc, GeneratedDoc, ChatCompletionRequest, SearchedDoc, ChatCompletionRequest, DocSumChatCompletionRequest -]: +async def safety_guard(input: Union[LLMParamsDoc, GeneratedDoc, TextDoc]) -> Union[TextDoc, GeneratedDoc, StreamingResponse]: start_time = time.time() - + if logflag: logger.info(f"Received input: {input}") - + try: if isinstance(input, LLMParamsDoc): processed = input_guardrail.scan_llm_input(input) - + statistics_dict["opea_service@guardrails"].append_latency( - time.time() - start_time, f"input_guard:{type(input).__name__}" + time.time() - start_time, + f"input_guard:{type(input).__name__}" ) - + if logflag: logger.info(f"Input guard passed: {processed}") return processed - - elif isinstance(input, GeneratedDoc): - processed = output_guardrail.scan_llm_output(input) - - if os.getenv("APPLY_CONTENT_GUARD", "true").lower() == "true": - text_doc = TextDoc(text=processed.text) - content_guard_result = await loader.invoke(text_doc) - processed.text = content_guard_result.text - - statistics_dict["opea_service@guardrails"].append_latency( - time.time() - start_time, f"output_guard:{type(input).__name__}" - ) - - if logflag: - logger.info(f"Output guard passed: {processed}") - return processed - - except HTTPException as e: - if e.status_code == 466: - logger.warning(f"Security rejection: {e.detail}") - statistics_dict["opea_service@guardrails"].append_latency( - time.time() - start_time, f"rejection:{e.status_code}" - ) - raise e + + # Use the loader to invoke the component + guardrails_response = await loader.invoke(processed) + + if isinstance(guardrails_response, GeneratedDoc): + try: + data = await guardrails_response.json() + doc = GeneratedDoc(**data) + except ValidationError as e: + err_msg = f"ValidationError creating GeneratedDoc: {e.errors()}" + logger.error(err_msg) + raise HTTPException(status_code=422, detail=err_msg) from e + except Exception as e: + logger.error(f"Problem with creating GenerateDoc: {e}") + raise HTTPException(status_code=500, detail=f"{e}") from e + + scanned_output = output_guardrail.scan_llm_output(doc) + + if doc.streaming is False: + return GeneratedDoc(text=scanned_output, prompt=doc.prompt, streaming=False) + else: + generator = scanned_output.split() + async def stream_generator(): + chat_response = "" + try: + for text in generator: + chat_response += text + chunk_repr = repr(' ' + text) # Guard takes over LLM streaming + logger.debug("[guard - chat_stream] chunk:{chunk_repr}") + yield f"data: {chunk_repr}\n\n" + await asyncio.sleep(0.02) # Delay of 0.02 second between chunks + logger.debug("[guard - chat_stream] stream response: {chat_response}") + yield "data: [DONE]\n\n" + except Exception as e: + logger.error(f"Error streaming from Guard: {e}") + yield "data: [ERROR]\n\n" + return StreamingResponse(stream_generator(), media_type="text/event-stream") except Exception as e: - logger.error(f"Unexpected error: {str(e)}") - statistics_dict["opea_service@guardrails"].append_latency(time.time() - start_time, f"error:{type(e).__name__}") - raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") - + logger.error(f"Error during guardrails invocation: {e}") + raise if __name__ == "__main__": opea_microservices["opea_service@guardrails"].start() diff --git a/comps/guardrails/src/guardrails/requirements.txt b/comps/guardrails/src/guardrails/requirements.txt index e299b4ab9f..a57696ca60 100644 --- a/comps/guardrails/src/guardrails/requirements.txt +++ b/comps/guardrails/src/guardrails/requirements.txt @@ -12,3 +12,4 @@ prometheus-fastapi-instrumentator sentencepiece shortuuid uvicorn +llm_guard \ No newline at end of file diff --git a/comps/guardrails/src/guardrails/utils/.input_env b/comps/guardrails/src/guardrails/utils/.input_env new file mode 100644 index 0000000000..e15a6d87ce --- /dev/null +++ b/comps/guardrails/src/guardrails/utils/.input_env @@ -0,0 +1,113 @@ +# Copyright (C) 2024-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +## LLM Guard Input Guardrail Microservice Settings +## Singular input scanners settings +## Anonymize scanner settings +ANONYMIZE_ENABLED=false +ANONYMIZE_USE_ONNX=false +ANONYMIZE_HIDDEN_NAMES +ANONYMIZE_ALLOWED_NAMES +ANONYMIZE_ENTITY_TYPES +ANONYMIZE_PREAMBLE +ANONYMIZE_REGEX_PATTERMS +ANONYMIZE_USE_FAKER +ANONYMIZE_RECOGNIZER_CONF +ANONYMIZE_THRESHOLD +ANONYMIZE_LANGUAGE + + +## BanCode scanner settings +BAN_CODE_ENABLED=false +BAN_CODE_USE_ONNX=false +BAN_CODE_MODEL +BAN_CODE_THRESHOLD + +## BanCompetitors scanner settings +BAN_COMPETITORS_ENABLED=false +BAN_COMPETITORS_USE_ONNX=false +BAN_COMPETITORS_COMPETITORS="Competitor1,Competitor2,Competitor3" +BAN_COMPETITORS_THRESHOLD +BAN_COMPETITORS_REDACT +BAN_COMPETITORS_MODEL + +## BanSubstrings scanner settings +BAN_SUBSTRINGS_ENABLED=false +BAN_SUBSTRINGS_SUBSTRINGS="backdoor,malware,virus" +BAN_SUBSTRINGS_MATCH_TYPE +BAN_SUBSTRINGS_CASE_SENSITIVE +BAN_SUBSTRINGS_REDACT +BAN_SUBSTRINGS_CONTAINS_ALL + +## BanTopics scanner settings +BAN_TOPICS_ENABLED=false +BAN_TOPICS_USE_ONNX=false +BAN_TOPICS_TOPICS="violence,attack,war" +BAN_TOPICS_THRESHOLD +BAN_TOPICS_MODEL + +## Code scanner settings +CODE_ENABLED=false +CODE_USE_ONNX=false +CODE_LANGUAGES="Java,Python" +CODE_MODEL +CODE_IS_BLOCKED +CODE_THRESHOLD + +## Gibberish scanner settings +GIBBERISH_ENABLED=false +GIBBERISH_USE_ONNX=false +GIBBERISH_MODEL +GIBBERISH_THRESHOLD +GIBBERISH_MATCH_TYPE + +## Invisible Text scanner settings +INVISIBLETEXT_ENABLED=false + +## Language scanner settings +LANGUAGE_ENABLED=false +LANGUAGE_USE_ONNX=false +LANGUAGE_VALID_LANGUAGES="en,es" +LANGUAGE_MODEL +LANGUAGE_THRESHOLD +LANGUAGE_MATCH_TYPE + +## Prompt Injection scanner settings +PROMPT_INJECTION_ENABLED=false +PROMPT_INJECTION_USE_ONNX=false +PROMPT_INJECTION_MODEL +PROMPT_INJECTION_THRESHOLD +PROMPT_INJECTION_MATCH_TYPE + +## Regex scanner settings +REGEX_ENABLED=false +REGEX_PATTERNS="Bearer [A-Za-z0-9-._~+/]+" +REGEX_IS_BLOCKED +REGEX_MATCH_TYPE +REGEX_REDACT + +## Secrets scanner settings +SCERETS_ENABLED=false +SECRETS_REDACT_MODE + +## Sentiment scanner settings +SENTIMENT_ENABLED=false +SENTIMENT_THERSHOLD +SENTIMENT_LEXICON + +## TokenLimit scanner settings +TOKEN_LIMIT_ENABLED=false +TOKEN_LIMIT_LIMIT +TOKEN_LIMIT_ENCODING_NAME +TOKEN_LIMIT_MODEL_NAME + +## Toxicity scanner settings +TOXICITY_ENABLED=false +TOXICITY_USE_ONNX=false +TOXICITY_MODEL +TOXICITY_THRESHOLD +TOXICITY_MATCH_TYPE + +## Uncomment to change the microservice part +# LLM_GUARD_INPUT_SCANNER_USVC_PORT=8050 +# OPEA_LOGGER_LEVEL="INFO" diff --git a/comps/guardrails/src/guardrails/utils/.output_env b/comps/guardrails/src/guardrails/utils/.output_env new file mode 100644 index 0000000000..92fb49ad6a --- /dev/null +++ b/comps/guardrails/src/guardrails/utils/.output_env @@ -0,0 +1,144 @@ +# Copyright (C) 2024-2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +## LLM Guard Output Guardrail Microservice Settings +## Singular output scanners settings +## BanCode scanner settings +BAN_CODE_ENABLED=false +BAN_CODE_USE_ONNX=false +BAN_CODE_MODEL +BAN_CODE_THRESHOLD + +## BanCompetitors scanner settings +BAN_COMPETITORS_ENABLED=false +BAN_COMPETITORS_USE_ONNX=false +BAN_COMPETITORS_COMPETITORS="Competitor1,Competitor2,Competitor3" +BAN_COMPETITORS_THRESHOLD +BAN_COMPETITORS_REDACT +BAN_COMPETITORS_MODEL + +## BanSubstrings scanner settings +BAN_SUBSTRINGS_ENABLED=false +BAN_SUBSTRINGS_SUBSTRINGS="backdoor,malware,virus" +BAN_SUBSTRINGS_MATCH_TYPE +BAN_SUBSTRINGS_CASE_SENSITIVE +BAN_SUBSTRINGS_REDACT=true +BAN_SUBSTRINGS_CONTAINS_ALL + +## BanTopics scanner settings +BAN_TOPICS_ENABLED=false +BAN_TOPICS_USE_ONNX=false +BAN_TOPICS_TOPICS="violence,attack,war" +BAN_TOPICS_THRESHOLD +BAN_TOPICS_MODEL + +## Bias scanner settings +BIAS_ENABLED=false +BIAS_USE_ONNX=false +BIAS_MODEL +BIAS_THRESHOLD +BIAS_MATCH_TYPE + +## Codes scanner settings +CODE_ENABLED=false +CODE_USE_ONNX=false +CODE_LANGUAGES="Java,Python" +CODE_MODEL +CODE_IS_BLOCKED +CODE_THRESHOLD + +## Deanonymize scanner settings +DEANONYMIZE_ENABLED=false +DEANONYMIZE_MATCHING_STRATEGY + +## JSON scanner settings +JSON_SCANNER_ENABLED=false +JSON_SCANNER_REQUIRED_ELEMENTS +JSON_SCANNER_REPAIR + +## Language scanner settings +LANGUAGE_ENABLED=false +LANGUAGE_USE_ONNX=false +LANGUAGE_VALID_LANGUAGES="en,es" +LANGUAGE_MODEL +LANGUAGE_THRESHOLD +LANGUAGE_MATCH_TYPE + +## LanguageSame scanner settings +LANGUAGE_SAME_ENABLED=false +LANGUAGE_SAME_USE_ONNX=false +LANGUAGE_SAME_MODEL +LANGUAGE_SAME_THRESHOLD + +## MaliciousURLs scanner settings +MALICIOUS_URLS_ENABLED=false +MALICIOUS_URLS_USE_ONNX=false +MALICIOUS_URLS_MODEL +MALICIOUS_URLS_THRESHOLD + +## NoRefusal scanner settings +NO_REFUSAL_ENABLED=false +NO_REFUSAL_USE_ONNX=false +NO_REFUSAL_MODEL +NO_REFUSAL_THRESHOLD +NO_REFUSAL_MATCH_TYPE + +## NoRefusalLight scanner settings +NO_REFUSAL_LIGHT_ENABLED=false + +## ReadingTime scanner settings +READING_TIME_ENABLED=false +READING_TIME_MAX_TIME=0.5 +READING_TIME_TRUNCATE + +## FactualConsistency scanner settings +FACTUAL_CONSISTENCY_ENABLED=false +FACTUAL_CONSISTENCY_USE_ONNX=false +FACTUAL_CONSISTENCY_MODEL +FACTUAL_CONSISTENCY_MINIMUM_SCORE + +## Gibberish scanner settings +GIBBERISH_ENABLED=false +GIBBERISH_USE_ONNX=false +GIBBERISH_MODEL +GIBBERISH_THRESHOLD +GIBBERISH_MATCH_TYPE + +## Regex scanner settings +REGEX_ENABLED=false +REGEX_PATTERNS="Bearer [A-Za-z0-9-._~+/]+" +REGEX_IS_BLOCKED +REGEX_MATCH_TYPE +REGEX_REDACT + +## Relevance scanner settings +RELEVANCE_ENABLED=false +RELEVANCE_USE_ONNX=false +RELEVANCE_MODEL +RELEVANCE_THRESHOLD + +## Snsitive scanner settings +SENSITIVE_ENABLED=false +SENSITIVE_USE_ONNX=false +SENSITIVE_ENTITY_TYPES +SENSITIVE_REGEX_PATTERNS +SENSITIVE_REDACT +SENSITIVE_RECOGNIZER_CONF +SENSITIVE_THRESHOLD + +## Sentiment scanner settings +SENTIMENT_ENABLED=false +SENTIMENT_THERSHOLD +SENTIMENT_LEXICON + +## Toxicity scanner settings +TOXICITY_ENABLED=false +TOXICITY_USE_ONNX=false +TOXICITY_MODEL +TOXICITY_THRESHOLD +TOXICITY_MATCH_TYPE + +## URLReachability +URL_REACHABILITY_ENABLED=false +URL_REACHABILITY_SUCCESS_STATUS_CODES +URL_REACHABILITY_TIMEOUT \ No newline at end of file From 22ec3c317cd89e92d137339245bfeb16e6790ea6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 16 Jun 2025 03:31:49 +0000 Subject: [PATCH 4/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../opea_guardrails_microservice.py | 48 ++++++++----------- .../src/guardrails/requirements.txt | 2 +- 2 files changed, 22 insertions(+), 28 deletions(-) diff --git a/comps/guardrails/src/guardrails/opea_guardrails_microservice.py b/comps/guardrails/src/guardrails/opea_guardrails_microservice.py index b1ac7c3fbc..51c289a780 100644 --- a/comps/guardrails/src/guardrails/opea_guardrails_microservice.py +++ b/comps/guardrails/src/guardrails/opea_guardrails_microservice.py @@ -1,28 +1,24 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import asyncio import os import time -import asyncio from typing import Union + from dotenv import dotenv_values from fastapi import HTTPException from fastapi.responses import StreamingResponse from pydantic import ValidationError - -from utils.llm_guard_input_guardrail import ( - OPEALLMGuardInputGuardrail -) -from utils.llm_guard_output_guardrail import ( - OPEALLMGuardOutputGuardrail -) +from utils.llm_guard_input_guardrail import OPEALLMGuardInputGuardrail +from utils.llm_guard_output_guardrail import OPEALLMGuardOutputGuardrail from comps import ( CustomLogger, GeneratedDoc, LLMParamsDoc, - SearchedDoc, OpeaComponentLoader, + SearchedDoc, ServiceType, TextDoc, opea_microservices, @@ -30,21 +26,14 @@ register_statistics, statistics_dict, ) - from comps.cores.proto.api_protocol import ChatCompletionRequest, DocSumChatCompletionRequest logger = CustomLogger("opea_guardrails_microservice") logflag = os.getenv("LOGFLAG", False) -input_usvc_config = { - **dotenv_values("utils/.input_env"), - **os.environ -} +input_usvc_config = {**dotenv_values("utils/.input_env"), **os.environ} -output_usvc_config = { - **dotenv_values("utils/.output_env"), - **os.environ -} +output_usvc_config = {**dotenv_values("utils/.output_env"), **os.environ} guardrails_component_name = os.getenv("GUARDRAILS_COMPONENT_NAME", "OPEA_LLAMA_GUARD") # Initialize OpeaComponentLoader @@ -57,6 +46,7 @@ input_guardrail = OPEALLMGuardInputGuardrail(input_usvc_config) output_guardrail = OPEALLMGuardOutputGuardrail(output_usvc_config) + @register_microservice( name="opea_service@guardrails", service_type=ServiceType.GUARDRAIL, @@ -67,25 +57,26 @@ output_datatype=Union[TextDoc, GeneratedDoc, StreamingResponse], ) @register_statistics(names=["opea_service@guardrails"]) -async def safety_guard(input: Union[LLMParamsDoc, GeneratedDoc, TextDoc]) -> Union[TextDoc, GeneratedDoc, StreamingResponse]: +async def safety_guard( + input: Union[LLMParamsDoc, GeneratedDoc, TextDoc], +) -> Union[TextDoc, GeneratedDoc, StreamingResponse]: start_time = time.time() - + if logflag: logger.info(f"Received input: {input}") - + try: if isinstance(input, LLMParamsDoc): processed = input_guardrail.scan_llm_input(input) - + statistics_dict["opea_service@guardrails"].append_latency( - time.time() - start_time, - f"input_guard:{type(input).__name__}" + time.time() - start_time, f"input_guard:{type(input).__name__}" ) - + if logflag: logger.info(f"Input guard passed: {processed}") return processed - + # Use the loader to invoke the component guardrails_response = await loader.invoke(processed) @@ -107,12 +98,13 @@ async def safety_guard(input: Union[LLMParamsDoc, GeneratedDoc, TextDoc]) -> Uni return GeneratedDoc(text=scanned_output, prompt=doc.prompt, streaming=False) else: generator = scanned_output.split() + async def stream_generator(): chat_response = "" try: for text in generator: chat_response += text - chunk_repr = repr(' ' + text) # Guard takes over LLM streaming + chunk_repr = repr(" " + text) # Guard takes over LLM streaming logger.debug("[guard - chat_stream] chunk:{chunk_repr}") yield f"data: {chunk_repr}\n\n" await asyncio.sleep(0.02) # Delay of 0.02 second between chunks @@ -121,12 +113,14 @@ async def stream_generator(): except Exception as e: logger.error(f"Error streaming from Guard: {e}") yield "data: [ERROR]\n\n" + return StreamingResponse(stream_generator(), media_type="text/event-stream") except Exception as e: logger.error(f"Error during guardrails invocation: {e}") raise + if __name__ == "__main__": opea_microservices["opea_service@guardrails"].start() logger.info("OPEA guardrails Microservice is up and running successfully...") diff --git a/comps/guardrails/src/guardrails/requirements.txt b/comps/guardrails/src/guardrails/requirements.txt index a57696ca60..43fdc95858 100644 --- a/comps/guardrails/src/guardrails/requirements.txt +++ b/comps/guardrails/src/guardrails/requirements.txt @@ -5,6 +5,7 @@ huggingface-hub<=0.24.0 langchain-community langchain-huggingface langchain-openai +llm_guard opentelemetry-api opentelemetry-exporter-otlp opentelemetry-sdk @@ -12,4 +13,3 @@ prometheus-fastapi-instrumentator sentencepiece shortuuid uvicorn -llm_guard \ No newline at end of file From 99bbc1e0b5069741927399d90f240445437e58bb Mon Sep 17 00:00:00 2001 From: WenjiaoYue Date: Wed, 18 Jun 2025 16:30:49 +0800 Subject: [PATCH 5/7] Add related scripts --- comps/cores/mega/utils.py | 15 ++ comps/cores/proto/docarray.py | 229 +++++++++++++++++- .../opea_guardrails_microservice.py | 75 +++--- .../src/guardrails/requirements.txt | 3 +- .../utils/llm_guard_input_guardrail.py | 161 ++++-------- .../utils/llm_guard_input_scanners.py | 10 +- .../utils/llm_guard_output_guardrail.py | 4 +- .../utils/llm_guard_output_scanners.py | 7 +- .../src/guardrails/utils/scanners.py | 76 ++++++ ...ails_guardrails_llamaguard_on_intel_hpu.sh | 52 +++- ...rails_guardrails_wildguard_on_intel_hpu.sh | 51 ++++ 11 files changed, 501 insertions(+), 182 deletions(-) create mode 100644 comps/guardrails/src/guardrails/utils/scanners.py diff --git a/comps/cores/mega/utils.py b/comps/cores/mega/utils.py index 4d8dc5b3eb..afeefa1156 100644 --- a/comps/cores/mega/utils.py +++ b/comps/cores/mega/utils.py @@ -181,3 +181,18 @@ def handle_message(messages): return prompt, images else: return prompt + +def sanitize_env(value: Optional[str]) -> Optional[str]: + """Remove quotes from a configuration value if present. + Args: + value (str): The configuration value to sanitize. + Returns: + str: The sanitized configuration value. + """ + if value is None: + return None + if value.startswith('"') and value.endswith('"'): + value = value[1:-1] + elif value.startswith('\'') and value.endswith('\''): + value = value[1:-1] + return value \ No newline at end of file diff --git a/comps/cores/proto/docarray.py b/comps/cores/proto/docarray.py index 2e45ece5e7..267131c68e 100644 --- a/comps/cores/proto/docarray.py +++ b/comps/cores/proto/docarray.py @@ -1,7 +1,7 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, Tuple import numpy as np from docarray import BaseDoc, DocList @@ -144,7 +144,6 @@ class Config: class SearchedMultimodalDoc(SearchedDoc): metadata: List[Dict[str, Any]] - class LVMSearchedMultimodalDoc(SearchedMultimodalDoc): max_new_tokens: conint(ge=0, le=1024) = 512 top_k: int = 10 @@ -162,16 +161,222 @@ class LVMSearchedMultimodalDoc(SearchedMultimodalDoc): ), ) - -class GeneratedDoc(BaseDoc): - text: str - prompt: str - - class RerankedDoc(BaseDoc): reranked_docs: DocList[TextDoc] initial_query: str +class AnonymizeModel(BaseDoc): + enabled: bool = False + use_onnx: bool = False + hidden_names: Optional[List[str]] = None + allowed_names: Optional[List[str]] = None + entity_types: Optional[List[str]] = None + preamble: Optional[str] = None + regex_patterns: Optional[List[str]] = None + use_faker: Optional[bool] = None + recognizer_conf: Optional[str] = None + threshold: Optional[float] = None + language: Optional[str] = None + +class BanCodeModel(BaseDoc): + enabled: bool = False + use_onnx: bool = False + model: Optional[str] = None + threshold: Optional[float] = None + +class BanCompetitorsModel(BaseDoc): + enabled: bool = False + use_onnx: bool = False + competitors: List[str] = ["Competitor1", "Competitor2", "Competitor3"] + model: Optional[str] = None + threshold: Optional[float] = None + redact: Optional[bool] = None + +class BanSubstringsModel(BaseDoc): + enabled: bool = False + substrings: List[str] = ["backdoor", "malware", "virus"] + match_type: Optional[str] = None + case_sensitive: bool = False + redact: Optional[bool] = None + contains_all: Optional[bool] = None + +class BanTopicsModel(BaseDoc): + enabled: bool = False + use_onnx: bool = False + topics: List[str] = ["violence","attack","war"] + threshold: Optional[float] = None + model: Optional[str] = None + +class CodeModel(BaseDoc): + enabled: bool = False + use_onnx: bool = False + languages: List[str] = ["Java", "Python"] + model: Optional[str] = None + is_blocked: Optional[bool] = None + threshold: Optional[float] = None + +class GibberishModel(BaseDoc): + enabled: bool = False + use_onnx: bool = False + model: Optional[str] = None + threshold: Optional[float] = None + match_type: Optional[str] = None + +class InvisibleText(BaseDoc): + enabled: bool = False + +class LanguageModel(BaseDoc): + enabled: bool = False + use_onnx: bool = False + valid_languages: List[str] = ["en", "es"] + model: Optional[str] = None + threshold: Optional[float] = None + match_type: Optional[str] = None + +class PromptInjectionModel(BaseDoc): + enabled: bool = False + use_onnx: bool = False + model: Optional[str] = None + threshold: Optional[float] = None + match_type: Optional[str] = None + +class RegexModel(BaseDoc): + enabled: bool = False + patterns: List[str] = ["Bearer [A-Za-z0-9-._~+/]+"] + is_blocked: Optional[bool] = None + match_type: Optional[str] = None + redact: Optional[bool] = None + +class SecretsModel(BaseDoc): + enabled: bool = False + redact_mode: Optional[str] = None + +class SentimentModel(BaseDoc): + enabled: bool = False + threshold: Optional[float] = None + lexicon: Optional[str] = None + +class TokenLimitModel(BaseDoc): + enabled: bool = False + limit: Optional[int] = None + encoding_name: Optional[str] = None + model_name: Optional[str] = None + +class ToxicityModel(BaseDoc): + enabled: bool = False + use_onnx: bool = False + model: Optional[str] = None + threshold: Optional[float] = None + match_type: Optional[str] = None +class BiasModel(BaseDoc): + enabled: bool = False + use_onnx: bool = False + model: Optional[str] = None + threshold: Optional[float] = None + match_type: Optional[str] = None + +class DeanonymizeModel(BaseDoc): + enabled: bool = False + matching_strategy: Optional[str] = None + +class JSONModel(BaseDoc): + enabled: bool = False + required_elements: Optional[int] = None + repair: Optional[bool] = None + +class LanguageSameModel(BaseDoc): + enabled: bool = False + use_onnx: bool = False + model: Optional[str] = None + threshold: Optional[float] = None + +class MaliciousURLsModel(BaseDoc): + enabled: bool = False + use_onnx: bool = False + model: Optional[str] = None + threshold: Optional[float] = None + +class NoRefusalModel(BaseDoc): + enabled: bool = False + use_onnx: bool = False + model: Optional[str] = None + threshold: Optional[float] = None + match_type: Optional[str] = None + +class NoRefusalLightModel(BaseDoc): + enabled: bool = False + +class ReadingTimeModel(BaseDoc): + enabled: bool = False + max_time: float = 0.5 + truncate: Optional[bool] = None + +class FactualConsistencyModel(BaseDoc): + enabled: bool = False + use_onnx: bool = False + model: Optional[str] = None + minimum_score: Optional[float] = None + +class RelevanceModel(BaseDoc): + enabled: bool = False + use_onnx: bool = False + model: Optional[str] = None + threshold: Optional[float] = None + +class SensitiveModel(BaseDoc): + enabled: bool = False + use_onnx: bool = False + entity_types: Optional[List[str]] = None + regex_patterns: Optional[List[str]] = None + redact: Optional[bool] = None + recognizer_conf: Optional[str] = None + threshold: Optional[float] = None + +class URLReachabilityModel(BaseDoc): + enabled: bool = False + success_status_codes: Optional[List[int]] = None + timeout: Optional[int] = None +class LLMGuardInputGuardrailParams(BaseDoc): + anonymize: AnonymizeModel = AnonymizeModel() + ban_code: BanCodeModel = BanCodeModel() + ban_competitors: BanCompetitorsModel = BanCompetitorsModel() + ban_substrings: BanSubstringsModel = BanSubstringsModel() + ban_topics: BanTopicsModel = BanTopicsModel() + code: CodeModel = CodeModel() + gibberish: GibberishModel = GibberishModel() + invisible_text: InvisibleText = InvisibleText() + language: LanguageModel = LanguageModel() + prompt_injection: PromptInjectionModel = PromptInjectionModel() + regex: RegexModel = RegexModel() + secrets: SecretsModel = SecretsModel() + sentiment: SentimentModel = SentimentModel() + token_limit: TokenLimitModel = TokenLimitModel() + toxicity: ToxicityModel = ToxicityModel() + +class LLMGuardOutputGuardrailParams(BaseDoc): + ban_code: BanCodeModel = BanCodeModel() + ban_competitors: BanCompetitorsModel = BanCompetitorsModel() + ban_substrings: BanSubstringsModel = BanSubstringsModel() + ban_topics: BanTopicsModel = BanTopicsModel() + bias: BiasModel = BiasModel() + code: CodeModel = CodeModel() + deanonymize: DeanonymizeModel = DeanonymizeModel() + json_scanner: JSONModel = JSONModel() + language: LanguageModel = LanguageModel() + language_same: LanguageSameModel = LanguageSameModel() + malicious_urls: MaliciousURLsModel = MaliciousURLsModel() + no_refusal: NoRefusalModel = NoRefusalModel() + no_refusal_light: NoRefusalLightModel = NoRefusalLightModel() + reading_time: ReadingTimeModel = ReadingTimeModel() + factual_consistency: FactualConsistencyModel = FactualConsistencyModel() + gibberish: GibberishModel = GibberishModel() + regex: RegexModel = RegexModel() + relevance: RelevanceModel = RelevanceModel() + sensitive: SensitiveModel = SensitiveModel() + sentiment: SentimentModel = SentimentModel() + toxicity: ToxicityModel = ToxicityModel() + url_reachability: URLReachabilityModel = URLReachabilityModel() + anonymize_vault: Optional[List[Tuple]] = None # the only parameter not available in fingerprint. Used to tramsmit vault class LLMParamsDoc(BaseDoc): model: Optional[str] = None # for openai and ollama @@ -187,6 +392,8 @@ class LLMParamsDoc(BaseDoc): repetition_penalty: float = 1.03 stream: bool = True language: str = "auto" # can be "en", "zh" + input_guardrail_params: Optional[LLMGuardInputGuardrailParams] = None + output_guardrail_params: Optional[LLMGuardOutputGuardrailParams] = None chat_template: Optional[str] = Field( default=None, @@ -212,7 +419,11 @@ class LLMParamsDoc(BaseDoc): def chat_template_must_contain_variables(cls, v): return v - +class GeneratedDoc(BaseDoc): + text: str + prompt: str + output_guardrail_params: Optional[LLMGuardOutputGuardrailParams] = None + class LLMParams(BaseDoc): model: Optional[str] = None max_tokens: int = 1024 diff --git a/comps/guardrails/src/guardrails/opea_guardrails_microservice.py b/comps/guardrails/src/guardrails/opea_guardrails_microservice.py index 51c289a780..dbc3fa7bfa 100644 --- a/comps/guardrails/src/guardrails/opea_guardrails_microservice.py +++ b/comps/guardrails/src/guardrails/opea_guardrails_microservice.py @@ -10,8 +10,16 @@ from fastapi import HTTPException from fastapi.responses import StreamingResponse from pydantic import ValidationError -from utils.llm_guard_input_guardrail import OPEALLMGuardInputGuardrail -from utils.llm_guard_output_guardrail import OPEALLMGuardOutputGuardrail + +from utils.llm_guard_input_guardrail import ( + OPEALLMGuardInputGuardrail +) +from utils.llm_guard_output_guardrail import ( + OPEALLMGuardOutputGuardrail +) + +from integrations.llamaguard import OpeaGuardrailsLlamaGuard +from integrations.wildguard import OpeaGuardrailsWildGuard from comps import ( CustomLogger, @@ -26,7 +34,6 @@ register_statistics, statistics_dict, ) -from comps.cores.proto.api_protocol import ChatCompletionRequest, DocSumChatCompletionRequest logger = CustomLogger("opea_guardrails_microservice") logflag = os.getenv("LOGFLAG", False) @@ -54,67 +61,39 @@ host="0.0.0.0", port=9090, input_datatype=Union[LLMParamsDoc, GeneratedDoc, TextDoc], - output_datatype=Union[TextDoc, GeneratedDoc, StreamingResponse], + output_datatype=Union[TextDoc, GeneratedDoc], ) @register_statistics(names=["opea_service@guardrails"]) -async def safety_guard( - input: Union[LLMParamsDoc, GeneratedDoc, TextDoc], -) -> Union[TextDoc, GeneratedDoc, StreamingResponse]: - start_time = time.time() - +async def safety_guard(input: Union[LLMParamsDoc, GeneratedDoc, TextDoc]) -> Union[TextDoc, GeneratedDoc]: + start = time.time() + if logflag: logger.info(f"Received input: {input}") try: if isinstance(input, LLMParamsDoc): processed = input_guardrail.scan_llm_input(input) - - statistics_dict["opea_service@guardrails"].append_latency( - time.time() - start_time, f"input_guard:{type(input).__name__}" - ) - if logflag: logger.info(f"Input guard passed: {processed}") - return processed - # Use the loader to invoke the component - guardrails_response = await loader.invoke(processed) - - if isinstance(guardrails_response, GeneratedDoc): + elif isinstance(input, GeneratedDoc): try: - data = await guardrails_response.json() - doc = GeneratedDoc(**data) - except ValidationError as e: - err_msg = f"ValidationError creating GeneratedDoc: {e.errors()}" - logger.error(err_msg) - raise HTTPException(status_code=422, detail=err_msg) from e + doc = input except Exception as e: - logger.error(f"Problem with creating GenerateDoc: {e}") + logger.error(f"Problem using input as GeneratedDoc: {e}") raise HTTPException(status_code=500, detail=f"{e}") from e - scanned_output = output_guardrail.scan_llm_output(doc) - if doc.streaming is False: - return GeneratedDoc(text=scanned_output, prompt=doc.prompt, streaming=False) - else: - generator = scanned_output.split() - - async def stream_generator(): - chat_response = "" - try: - for text in generator: - chat_response += text - chunk_repr = repr(" " + text) # Guard takes over LLM streaming - logger.debug("[guard - chat_stream] chunk:{chunk_repr}") - yield f"data: {chunk_repr}\n\n" - await asyncio.sleep(0.02) # Delay of 0.02 second between chunks - logger.debug("[guard - chat_stream] stream response: {chat_response}") - yield "data: [DONE]\n\n" - except Exception as e: - logger.error(f"Error streaming from Guard: {e}") - yield "data: [ERROR]\n\n" - - return StreamingResponse(stream_generator(), media_type="text/event-stream") + processed = scanned_output + else: + processed = input + + # Use the loader to invoke the component + guardrails_response = await loader.invoke(processed) + + # Record statistics + statistics_dict["opea_service@guardrails"].append_latency(time.time() - start, None) + return guardrails_response except Exception as e: logger.error(f"Error during guardrails invocation: {e}") diff --git a/comps/guardrails/src/guardrails/requirements.txt b/comps/guardrails/src/guardrails/requirements.txt index 43fdc95858..39f7af32b3 100644 --- a/comps/guardrails/src/guardrails/requirements.txt +++ b/comps/guardrails/src/guardrails/requirements.txt @@ -5,7 +5,6 @@ huggingface-hub<=0.24.0 langchain-community langchain-huggingface langchain-openai -llm_guard opentelemetry-api opentelemetry-exporter-otlp opentelemetry-sdk @@ -13,3 +12,5 @@ prometheus-fastapi-instrumentator sentencepiece shortuuid uvicorn +llm_guard +presidio_anonymizer \ No newline at end of file diff --git a/comps/guardrails/src/guardrails/utils/llm_guard_input_guardrail.py b/comps/guardrails/src/guardrails/utils/llm_guard_input_guardrail.py index 1d1b724a25..a32078abc4 100644 --- a/comps/guardrails/src/guardrails/utils/llm_guard_input_guardrail.py +++ b/comps/guardrails/src/guardrails/utils/llm_guard_input_guardrail.py @@ -5,149 +5,82 @@ from llm_guard import scan_prompt from utils.llm_guard_input_scanners import InputScannersConfig -from comps import LLMParamsDoc, get_opea_logger - -logger = get_opea_logger("opea_llm_guard_input_guardrail_microservice") +from comps import LLMParamsDoc, CustomLogger +logger = CustomLogger("opea_llm_guard_input_guardrail_microservice") class OPEALLMGuardInputGuardrail: """OPEALLMGuardInputGuardrail is responsible for scanning and sanitizing LLM input prompts using various input scanners provided by LLM Guard. - - This class initializes the input scanners based on the provided configuration and - scans the input prompts to ensure they meet the required guardrail criteria. - - Attributes: - _scanners (list): A list of enabled scanners. - - Methods: - __init__(usv_config: dict): - Initializes the OPEALLMGuardInputGuardrail with the provided configuration. - - scan_llm_input(input_doc: LLMParamsDoc) -> tuple[str, dict[str, bool], dict[str, float]]: - Scans the prompt from an LLMParamsDoc object and returns the sanitized prompt, - validation results, and scores. """ def __init__(self, usv_config: dict): - """Initializes the OPEALLMGuardInputGuardrail with the provided configuration. - - Args: - usv_config (dict): The configuration dictionary for initializing the input scanners. - - Raises: - Exception: If an unexpected error occurs during the initialization of scanners. - """ try: self._scanners_config = InputScannersConfig(usv_config) self._scanners = self._scanners_config.create_enabled_input_scanners() except ValueError as e: - logger.exception(f"Value Error occurred while initializing LLM Guard Input Guardrail scanners: {e}") + logger.exception(f"Value Error during scanner initialization: {e}") raise except Exception as e: - logger.exception( - f"An unexpected error occurred during initializing \ - LLM Guard Input Guardrail scanners: {e}" - ) + logger.exception(f"Unexpected error during scanner initialization: {e}") raise def _get_anonymize_vault(self): - anon = [item for item in self._scanners if type(item).__name__ == "Anonymize"] - if len(anon) > 0: - return anon[0]._vault.get() + for item in self._scanners: + if type(item).__name__ == "Anonymize": + return item._vault.get() return None def _recreate_anonymize_scanner_if_exists(self): - anon = [item for item in self._scanners if type(item).__name__ == "Anonymize"] - if len(anon) > 0: - logger.info(f"Anonymize scanner found: {len(anon)}") - self._scanners.remove(anon[0]) - self._scanners.append(self._scanners_config._create_anonymize_scanner()) + for item in self._scanners: + if type(item).__name__ == "Anonymize": + logger.info("Recreating Anonymize scanner to clear Vault.") + self._scanners.remove(item) + self._scanners.append(self._scanners_config._create_anonymize_scanner()) + break def _analyze_scan_outputs(self, prompt, results_valid, results_score): - filtered_results_valid_no_redacted = {} - scanners_with_redact = ["BanCompetitors", "BanSubstrings", "OPEABanSubstrings", "Regex", "OPEARegexScanner"] - - for key, value in results_valid.items(): - if_redacted = False - redacted_scanner = [ - item - for item in self._scanners - if type(item).__name__ in scanners_with_redact and type(item).__name__ == key - ] - - if len(redacted_scanner) > 0: - if_redacted = redacted_scanner[0]._redact - - if key != "Anonymize" and not if_redacted: - filtered_results_valid_no_redacted[key] = value + filtered_results = { + key: value + for key, value in results_valid.items() + if key != "Anonymize" and not ( + type(scanner := next((s for s in self._scanners if type(s).__name__ == key), None)).__name__ in { + "BanCompetitors", "BanSubstrings", "OPEABanSubstrings", "Regex", "OPEARegexScanner" + } and getattr(scanner, "_redact", False) + ) + } - if False in filtered_results_valid_no_redacted.values(): - msg = f"Prompt {prompt} is not valid, scores: {results_score}" - logger.error(f"{msg}") - usr_msg = "I'm sorry, I cannot assist you with your prompt." - raise HTTPException(status_code=466, detail=f"{usr_msg}") + if False in filtered_results.values(): + msg = f"Prompt '{prompt}' is not valid, scores: {results_score}" + logger.error(msg) + raise HTTPException(status_code=466, detail="I'm sorry, I cannot assist you with your prompt.") def scan_llm_input(self, input_doc: LLMParamsDoc) -> LLMParamsDoc: - """Scan the prompt from an LLMParamsDoc object. - - Args: - input_doc (LLMParamsDoc): The input document containing the prompt to be scanned. - - Returns: - tuple[str, dict[str, bool], dict[str, float]]: A tuple containing the sanitized prompt, - a dictionary of validation results, and a dictionary of scores. - - Raises: - HTTPException: If the prompt is not valid based on the scanner results. - """ fresh_scanners = False + if input_doc.input_guardrail_params is not None: if self._scanners_config.changed(input_doc.input_guardrail_params.dict()): self._scanners = self._scanners_config.create_enabled_input_scanners() fresh_scanners = True else: - logger.warning("Input guardrail params not found in input document.") - if self._scanners: - if not fresh_scanners: - logger.info("Recreating anonymize scanner if exists to clear the Vault.") - self._recreate_anonymize_scanner_if_exists() - system_prompt = input_doc.messages.system - user_prompt = input_doc.messages.user - - # We want to block only user question with a TokenLimit Scanner - scanners_without_token_limit = [item for item in self._scanners if type(item).__name__ != "TokenLimit"] - if len(self._scanners) != scanners_without_token_limit: - sanitized_system_prompt, system_results_valid, system_results_score = scan_prompt( - scanners_without_token_limit, system_prompt - ) - else: - sanitized_system_prompt, system_results_valid, system_results_score = scan_prompt( - self._scanners, system_prompt - ) - - if "### Question:" in user_prompt: - # Default template is used - prefix = "### Question: " - suffix = " \n ### Answer:" - user_prompt_to_scan = user_prompt.split(prefix)[1].split(suffix)[0] - sanitized_user_prompt, user_results_valid, user_results_score = scan_prompt( - self._scanners, user_prompt_to_scan - ) - sanitized_user_prompt = prefix + sanitized_user_prompt + suffix - else: - sanitized_user_prompt, user_results_valid, user_results_score = scan_prompt(self._scanners, user_prompt) - - self._analyze_scan_outputs(system_prompt, system_results_valid, system_results_score) - self._analyze_scan_outputs(user_prompt, user_results_valid, user_results_score) - - input_doc.messages.system = sanitized_system_prompt - input_doc.messages.user = sanitized_user_prompt - if input_doc.output_guardrail_params is not None and "Anonymize" in user_results_valid: - input_doc.output_guardrail_params.anonymize_vault = self._get_anonymize_vault() - elif input_doc.output_guardrail_params is None and "Anonymize" in user_results_valid: - logger.warning("No output guardrails params, could not append the vault for Anonymize scanner.") - return input_doc - else: - logger.info("No input scanners enabled. Skipping scanning.") + logger.warning("Input guardrail params not found.") + + if not self._scanners: + logger.info("No scanners enabled. Skipping input scan.") return input_doc + + if not fresh_scanners: + self._recreate_anonymize_scanner_if_exists() + + user_prompt = input_doc.query + sanitized_user_prompt, results_valid, results_score = scan_prompt(self._scanners, user_prompt) + self._analyze_scan_outputs(user_prompt, results_valid, results_score) + + input_doc.query = sanitized_user_prompt + + if input_doc.output_guardrail_params is not None and "Anonymize" in results_valid: + input_doc.output_guardrail_params.anonymize_vault = self._get_anonymize_vault() + elif input_doc.output_guardrail_params is None and "Anonymize" in results_valid: + logger.warning("Anonymize scanner result exists, but output_guardrail_params is missing.") + + return input_doc diff --git a/comps/guardrails/src/guardrails/utils/llm_guard_input_scanners.py b/comps/guardrails/src/guardrails/utils/llm_guard_input_scanners.py index 3fc9e2f0b2..3083c3fa48 100644 --- a/comps/guardrails/src/guardrails/utils/llm_guard_input_scanners.py +++ b/comps/guardrails/src/guardrails/utils/llm_guard_input_scanners.py @@ -53,10 +53,12 @@ "toxicity", ] -from comps import get_opea_logger, sanitize_env -from comps.guardrails.utils.scanners import OPEABanSubstrings, OPEARegexScanner +from comps import CustomLogger +from comps.cores.mega.utils import sanitize_env -logger = get_opea_logger("opea_llm_guard_input_guardrail_microservice") +from comps.guardrails.src.guardrails.utils.scanners import OPEABanSubstrings, OPEARegexScanner + +logger = CustomLogger("opea_llm_guard_input_guardrail_microservice") class InputScannersConfig: @@ -947,4 +949,4 @@ def changed(self, new_scanners_config): k: {in_k: in_v for in_k, in_v in v.items() if in_k != "id"} for k, v in new_scanners_config.items() } self._input_scanners_config.update(stripped_new_scanners_config) - return True + return True \ No newline at end of file diff --git a/comps/guardrails/src/guardrails/utils/llm_guard_output_guardrail.py b/comps/guardrails/src/guardrails/utils/llm_guard_output_guardrail.py index 08b79c51f3..4d1abc8368 100644 --- a/comps/guardrails/src/guardrails/utils/llm_guard_output_guardrail.py +++ b/comps/guardrails/src/guardrails/utils/llm_guard_output_guardrail.py @@ -5,9 +5,9 @@ from llm_guard import scan_output from utils.llm_guard_output_scanners import OutputScannersConfig -from comps import GeneratedDoc, get_opea_logger +from comps import GeneratedDoc, CustomLogger -logger = get_opea_logger("opea_llm_guard_output_guardrail_microservice") +logger = CustomLogger("opea_llm_guard_output_guardrail_microservice") class OPEALLMGuardOutputGuardrail: diff --git a/comps/guardrails/src/guardrails/utils/llm_guard_output_scanners.py b/comps/guardrails/src/guardrails/utils/llm_guard_output_scanners.py index c2de5089c9..af5dcea20a 100644 --- a/comps/guardrails/src/guardrails/utils/llm_guard_output_scanners.py +++ b/comps/guardrails/src/guardrails/utils/llm_guard_output_scanners.py @@ -75,10 +75,11 @@ "url_reachability", ] -from comps import get_opea_logger, sanitize_env -from comps.guardrails.utils.scanners import OPEABanSubstrings, OPEARegexScanner +from comps import CustomLogger +from comps.cores.mega.utils import sanitize_env +from comps.guardrails.src.guardrails.utils.scanners import OPEABanSubstrings, OPEARegexScanner -logger = get_opea_logger("opea_llm_guard_output_guardrail_microservice") +logger = CustomLogger("opea_llm_guard_output_guardrail_microservice") class OutputScannersConfig: diff --git a/comps/guardrails/src/guardrails/utils/scanners.py b/comps/guardrails/src/guardrails/utils/scanners.py new file mode 100644 index 0000000000..ca52f1cb5f --- /dev/null +++ b/comps/guardrails/src/guardrails/utils/scanners.py @@ -0,0 +1,76 @@ +import re + +from collections.abc import Iterable +from llm_guard.input_scanners import BanSubstrings, Regex +from llm_guard.input_scanners.regex import MatchType +from presidio_anonymizer.core.text_replace_builder import TextReplaceBuilder + +from comps import CustomLogger + +logger = CustomLogger("opea_llm_guard_utils_scanners") + +# The bug is reported here: https://github.com/protectai/llm-guard/issues/210 +class OPEABanSubstrings(BanSubstrings): + def _redact_text(self, text: str, substrings: list[str]) -> str: + redacted_text = text + flags = 0 + if not self._case_sensitive: + flags = re.IGNORECASE + for s in substrings: + regex_redacted = re.compile(re.escape(s), flags) + redacted_text = regex_redacted.sub("[REDACTED]", redacted_text) + return redacted_text + + def scan(self, prompt: str, output: str = None) -> tuple[str, bool, float]: + if output is not None: + return super().scan(output) + return super().scan(prompt) + +# LLM Guard's Regex Scanner doesn't replace all occurrences of found patterns. +# The bug is reported here: https://github.com/protectai/llm-guard/issues/229 +class OPEARegexScanner(Regex): + def scan(self, prompt: str, output: str = None) -> tuple[str, bool, float]: + text_to_scan = "" + if output is not None: + text_to_scan = output + else: + text_to_scan = prompt + + text_replace_builder = TextReplaceBuilder(original_text=text_to_scan) + for pattern in self._patterns: + if self._match_type == MatchType.SEARCH: + matches = re.finditer(pattern, text_to_scan) + else: + matches = self._match_type.match(pattern, text_to_scan) + + if matches is None: + continue + elif isinstance(matches, Iterable): + matches = list(matches) + if len(matches) == 0: + continue + else: + matches = [matches] + + if self._is_blocked: + logger.warning(f"Pattern was detected in the text: {pattern}") + + if self._redact: + for match in reversed(matches): + text_replace_builder.replace_text_get_insertion_index( + "[REDACTED]", + match.start(), + match.end(), + ) + + return text_replace_builder.output_text, False, 1.0 + + logger.debug(f"Pattern matched the text: {pattern}") + return text_replace_builder.output_text, True, 0.0 + + if self._is_blocked: + logger.debug("None of the patterns were found in the text") + return text_replace_builder.output_text, True, 0.0 + + logger.warning("None of the patterns matched the text") + return text_replace_builder.output_text, False, 1.0 \ No newline at end of file diff --git a/tests/guardrails/test_guardrails_guardrails_llamaguard_on_intel_hpu.sh b/tests/guardrails/test_guardrails_guardrails_llamaguard_on_intel_hpu.sh index da77898a13..aefa91c89a 100644 --- a/tests/guardrails/test_guardrails_guardrails_llamaguard_on_intel_hpu.sh +++ b/tests/guardrails/test_guardrails_guardrails_llamaguard_on_intel_hpu.sh @@ -52,7 +52,7 @@ function validate_microservice() { fi echo "test 2 - safe" result=$(http_proxy= curl http://localhost:11303/v1/guardrails -X POST -d '{"text":"How do you buy a car in the US?","parameters":{"max_new_tokens":32}}' -H 'Content-Type: application/json') - if [[ $result == *"car"* ]]; then + if [[ $result == *"car"* ]]; then echo "Result correct." else echo "Result wrong." @@ -60,7 +60,57 @@ function validate_microservice() { docker logs llamaguard-guardrails-server exit 1 fi + echo "test 3 - input guard" + result=$(http_proxy= curl http://localhost:11303/v1/guardrails -X POST -d '{ + "query": "What are virus and backdoor?", + "max_new_tokens": 17, + "top_k": 10, + "top_p": 0.95, + "typical_p": 0.95, + "temperature": 0.01, + "repetition_penalty": 1.03, + "stream": false, + "input_guardrail_params": { + "ban_substrings": { + "enabled": true, + "substrings": ["backdoor", "malware", "virus"], + "case_sensitive": false + } + } + }' \ + -H 'Content-Type: application/json' + ) + if [[ "$result" == *"cannot assist you with your prompt"* ]]; then + echo "✅ Input guard triggered as expected!" + else + echo "❌ Input guard failed to trigger!" + fi + + echo "test 4 - output guard" + result=$(http_proxy= curl http://localhost:11303/v1/guardrails -X POST -d '{"prompt":"What is Deep Learning?", + "text": "To create a virus and backdoor you need to...", + "output_guardrail_params": + {"ban_substrings": + {"enabled":true, + "substrings":["backdoor","malware","virus"], + "match_type":null, + "case_sensitive":false, + "redact":null, + "contains_all":null} + } + }' \ + -H 'Content-Type: application/json' + ) + + echo "Output guard result:" + echo "$result" + if [[ "$result" != *"virus"* ]]; then + echo "✅ Output guard triggered successfully (virus removed or replaced)!" + else + echo "❌ Output guard failed to trigger (virus still in output)!" + exit 1 + fi } function stop_docker() { diff --git a/tests/guardrails/test_guardrails_guardrails_wildguard_on_intel_hpu.sh b/tests/guardrails/test_guardrails_guardrails_wildguard_on_intel_hpu.sh index 054ac861a7..8d1dbec454 100644 --- a/tests/guardrails/test_guardrails_guardrails_wildguard_on_intel_hpu.sh +++ b/tests/guardrails/test_guardrails_guardrails_wildguard_on_intel_hpu.sh @@ -60,6 +60,57 @@ function validate_microservice() { docker logs wildguard-guardrails-server exit 1 fi + echo "test 3 - input guard" + result=$(http_proxy= curl http://localhost:11304/v1/guardrails -X POST -d '{ + "query": "What are virus and backdoor?", + "max_new_tokens": 17, + "top_k": 10, + "top_p": 0.95, + "typical_p": 0.95, + "temperature": 0.01, + "repetition_penalty": 1.03, + "stream": false, + "input_guardrail_params": { + "ban_substrings": { + "enabled": true, + "substrings": ["backdoor", "malware", "virus"], + "case_sensitive": false + } + } + }' \ + -H 'Content-Type: application/json' + ) + if [[ "$result" == *"cannot assist you with your prompt"* ]]; then + echo "✅ Input guard triggered as expected!" + else + echo "❌ Input guard failed to trigger!" + fi + + echo "test 4 - output guard" + result=$(http_proxy= curl http://localhost:11304/v1/guardrails -X POST -d '{"prompt":"What is Deep Learning?", + "text": "To create a virus and backdoor you need to...", + "output_guardrail_params": + {"ban_substrings": + {"enabled":true, + "substrings":["backdoor","malware","virus"], + "match_type":null, + "case_sensitive":false, + "redact":null, + "contains_all":null} + } + }' \ + -H 'Content-Type: application/json' + ) + + echo "Output guard result:" + echo "$result" + + if [[ "$result" != *"virus"* ]]; then + echo "✅ Output guard triggered successfully (virus removed or replaced)!" + else + echo "❌ Output guard failed to trigger (virus still in output)!" + exit 1 + fi } function stop_docker() { From 4364ef834fa651147ecf0118d4776ccb1692358a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Jun 2025 08:39:19 +0000 Subject: [PATCH 6/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- comps/cores/mega/utils.py | 6 ++- comps/cores/proto/docarray.py | 46 +++++++++++++++++-- .../opea_guardrails_microservice.py | 14 ++---- .../src/guardrails/requirements.txt | 4 +- .../utils/llm_guard_input_guardrail.py | 15 +++--- .../utils/llm_guard_input_scanners.py | 3 +- .../utils/llm_guard_output_guardrail.py | 2 +- .../src/guardrails/utils/scanners.py | 9 +++- ...rails_guardrails_wildguard_on_intel_hpu.sh | 2 +- 9 files changed, 70 insertions(+), 31 deletions(-) diff --git a/comps/cores/mega/utils.py b/comps/cores/mega/utils.py index afeefa1156..001b9ce4f5 100644 --- a/comps/cores/mega/utils.py +++ b/comps/cores/mega/utils.py @@ -182,8 +182,10 @@ def handle_message(messages): else: return prompt + def sanitize_env(value: Optional[str]) -> Optional[str]: """Remove quotes from a configuration value if present. + Args: value (str): The configuration value to sanitize. Returns: @@ -193,6 +195,6 @@ def sanitize_env(value: Optional[str]) -> Optional[str]: return None if value.startswith('"') and value.endswith('"'): value = value[1:-1] - elif value.startswith('\'') and value.endswith('\''): + elif value.startswith("'") and value.endswith("'"): value = value[1:-1] - return value \ No newline at end of file + return value diff --git a/comps/cores/proto/docarray.py b/comps/cores/proto/docarray.py index a3fa4b95ee..2e8cc66bc9 100644 --- a/comps/cores/proto/docarray.py +++ b/comps/cores/proto/docarray.py @@ -1,7 +1,7 @@ # Copyright (C) 2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional, Union, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np from docarray import BaseDoc, DocList @@ -144,6 +144,7 @@ class Config: class SearchedMultimodalDoc(SearchedDoc): metadata: List[Dict[str, Any]] + class LVMSearchedMultimodalDoc(SearchedMultimodalDoc): max_new_tokens: conint(ge=0, le=1024) = 512 top_k: int = 10 @@ -161,10 +162,12 @@ class LVMSearchedMultimodalDoc(SearchedMultimodalDoc): ), ) + class RerankedDoc(BaseDoc): reranked_docs: DocList[TextDoc] initial_query: str + class AnonymizeModel(BaseDoc): enabled: bool = False use_onnx: bool = False @@ -178,12 +181,14 @@ class AnonymizeModel(BaseDoc): threshold: Optional[float] = None language: Optional[str] = None + class BanCodeModel(BaseDoc): enabled: bool = False use_onnx: bool = False model: Optional[str] = None threshold: Optional[float] = None + class BanCompetitorsModel(BaseDoc): enabled: bool = False use_onnx: bool = False @@ -192,6 +197,7 @@ class BanCompetitorsModel(BaseDoc): threshold: Optional[float] = None redact: Optional[bool] = None + class BanSubstringsModel(BaseDoc): enabled: bool = False substrings: List[str] = ["backdoor", "malware", "virus"] @@ -200,13 +206,15 @@ class BanSubstringsModel(BaseDoc): redact: Optional[bool] = None contains_all: Optional[bool] = None + class BanTopicsModel(BaseDoc): enabled: bool = False use_onnx: bool = False - topics: List[str] = ["violence","attack","war"] + topics: List[str] = ["violence", "attack", "war"] threshold: Optional[float] = None model: Optional[str] = None + class CodeModel(BaseDoc): enabled: bool = False use_onnx: bool = False @@ -215,6 +223,7 @@ class CodeModel(BaseDoc): is_blocked: Optional[bool] = None threshold: Optional[float] = None + class GibberishModel(BaseDoc): enabled: bool = False use_onnx: bool = False @@ -222,9 +231,11 @@ class GibberishModel(BaseDoc): threshold: Optional[float] = None match_type: Optional[str] = None + class InvisibleText(BaseDoc): enabled: bool = False + class LanguageModel(BaseDoc): enabled: bool = False use_onnx: bool = False @@ -233,6 +244,7 @@ class LanguageModel(BaseDoc): threshold: Optional[float] = None match_type: Optional[str] = None + class PromptInjectionModel(BaseDoc): enabled: bool = False use_onnx: bool = False @@ -240,6 +252,7 @@ class PromptInjectionModel(BaseDoc): threshold: Optional[float] = None match_type: Optional[str] = None + class RegexModel(BaseDoc): enabled: bool = False patterns: List[str] = ["Bearer [A-Za-z0-9-._~+/]+"] @@ -247,27 +260,33 @@ class RegexModel(BaseDoc): match_type: Optional[str] = None redact: Optional[bool] = None + class SecretsModel(BaseDoc): enabled: bool = False redact_mode: Optional[str] = None + class SentimentModel(BaseDoc): enabled: bool = False threshold: Optional[float] = None lexicon: Optional[str] = None + class TokenLimitModel(BaseDoc): enabled: bool = False limit: Optional[int] = None encoding_name: Optional[str] = None model_name: Optional[str] = None + class ToxicityModel(BaseDoc): enabled: bool = False use_onnx: bool = False model: Optional[str] = None threshold: Optional[float] = None match_type: Optional[str] = None + + class BiasModel(BaseDoc): enabled: bool = False use_onnx: bool = False @@ -275,27 +294,32 @@ class BiasModel(BaseDoc): threshold: Optional[float] = None match_type: Optional[str] = None + class DeanonymizeModel(BaseDoc): enabled: bool = False matching_strategy: Optional[str] = None + class JSONModel(BaseDoc): enabled: bool = False required_elements: Optional[int] = None repair: Optional[bool] = None + class LanguageSameModel(BaseDoc): enabled: bool = False use_onnx: bool = False model: Optional[str] = None threshold: Optional[float] = None + class MaliciousURLsModel(BaseDoc): enabled: bool = False use_onnx: bool = False model: Optional[str] = None threshold: Optional[float] = None + class NoRefusalModel(BaseDoc): enabled: bool = False use_onnx: bool = False @@ -303,26 +327,31 @@ class NoRefusalModel(BaseDoc): threshold: Optional[float] = None match_type: Optional[str] = None + class NoRefusalLightModel(BaseDoc): enabled: bool = False + class ReadingTimeModel(BaseDoc): enabled: bool = False max_time: float = 0.5 truncate: Optional[bool] = None + class FactualConsistencyModel(BaseDoc): enabled: bool = False use_onnx: bool = False model: Optional[str] = None minimum_score: Optional[float] = None + class RelevanceModel(BaseDoc): enabled: bool = False use_onnx: bool = False model: Optional[str] = None threshold: Optional[float] = None + class SensitiveModel(BaseDoc): enabled: bool = False use_onnx: bool = False @@ -332,10 +361,13 @@ class SensitiveModel(BaseDoc): recognizer_conf: Optional[str] = None threshold: Optional[float] = None + class URLReachabilityModel(BaseDoc): enabled: bool = False success_status_codes: Optional[List[int]] = None timeout: Optional[int] = None + + class LLMGuardInputGuardrailParams(BaseDoc): anonymize: AnonymizeModel = AnonymizeModel() ban_code: BanCodeModel = BanCodeModel() @@ -353,6 +385,7 @@ class LLMGuardInputGuardrailParams(BaseDoc): token_limit: TokenLimitModel = TokenLimitModel() toxicity: ToxicityModel = ToxicityModel() + class LLMGuardOutputGuardrailParams(BaseDoc): ban_code: BanCodeModel = BanCodeModel() ban_competitors: BanCompetitorsModel = BanCompetitorsModel() @@ -376,7 +409,10 @@ class LLMGuardOutputGuardrailParams(BaseDoc): sentiment: SentimentModel = SentimentModel() toxicity: ToxicityModel = ToxicityModel() url_reachability: URLReachabilityModel = URLReachabilityModel() - anonymize_vault: Optional[List[Tuple]] = None # the only parameter not available in fingerprint. Used to tramsmit vault + anonymize_vault: Optional[List[Tuple]] = ( + None # the only parameter not available in fingerprint. Used to transmit vault + ) + class LLMParamsDoc(BaseDoc): model: Optional[str] = None # for openai and ollama @@ -419,11 +455,13 @@ class LLMParamsDoc(BaseDoc): def chat_template_must_contain_variables(cls, v): return v + class GeneratedDoc(BaseDoc): text: str prompt: str output_guardrail_params: Optional[LLMGuardOutputGuardrailParams] = None - + + class LLMParams(BaseDoc): model: Optional[str] = None max_tokens: int = 1024 diff --git a/comps/guardrails/src/guardrails/opea_guardrails_microservice.py b/comps/guardrails/src/guardrails/opea_guardrails_microservice.py index dbc3fa7bfa..ac178badd7 100644 --- a/comps/guardrails/src/guardrails/opea_guardrails_microservice.py +++ b/comps/guardrails/src/guardrails/opea_guardrails_microservice.py @@ -9,17 +9,11 @@ from dotenv import dotenv_values from fastapi import HTTPException from fastapi.responses import StreamingResponse -from pydantic import ValidationError - -from utils.llm_guard_input_guardrail import ( - OPEALLMGuardInputGuardrail -) -from utils.llm_guard_output_guardrail import ( - OPEALLMGuardOutputGuardrail -) - from integrations.llamaguard import OpeaGuardrailsLlamaGuard from integrations.wildguard import OpeaGuardrailsWildGuard +from pydantic import ValidationError +from utils.llm_guard_input_guardrail import OPEALLMGuardInputGuardrail +from utils.llm_guard_output_guardrail import OPEALLMGuardOutputGuardrail from comps import ( CustomLogger, @@ -66,7 +60,7 @@ @register_statistics(names=["opea_service@guardrails"]) async def safety_guard(input: Union[LLMParamsDoc, GeneratedDoc, TextDoc]) -> Union[TextDoc, GeneratedDoc]: start = time.time() - + if logflag: logger.info(f"Received input: {input}") diff --git a/comps/guardrails/src/guardrails/requirements.txt b/comps/guardrails/src/guardrails/requirements.txt index 39f7af32b3..88fd3618ae 100644 --- a/comps/guardrails/src/guardrails/requirements.txt +++ b/comps/guardrails/src/guardrails/requirements.txt @@ -5,12 +5,12 @@ huggingface-hub<=0.24.0 langchain-community langchain-huggingface langchain-openai +llm_guard opentelemetry-api opentelemetry-exporter-otlp opentelemetry-sdk +presidio_anonymizer prometheus-fastapi-instrumentator sentencepiece shortuuid uvicorn -llm_guard -presidio_anonymizer \ No newline at end of file diff --git a/comps/guardrails/src/guardrails/utils/llm_guard_input_guardrail.py b/comps/guardrails/src/guardrails/utils/llm_guard_input_guardrail.py index a32078abc4..1e058beb6f 100644 --- a/comps/guardrails/src/guardrails/utils/llm_guard_input_guardrail.py +++ b/comps/guardrails/src/guardrails/utils/llm_guard_input_guardrail.py @@ -5,14 +5,14 @@ from llm_guard import scan_prompt from utils.llm_guard_input_scanners import InputScannersConfig -from comps import LLMParamsDoc, CustomLogger +from comps import CustomLogger, LLMParamsDoc logger = CustomLogger("opea_llm_guard_input_guardrail_microservice") + class OPEALLMGuardInputGuardrail: """OPEALLMGuardInputGuardrail is responsible for scanning and sanitizing LLM input prompts - using various input scanners provided by LLM Guard. - """ + using various input scanners provided by LLM Guard.""" def __init__(self, usv_config: dict): try: @@ -43,10 +43,11 @@ def _analyze_scan_outputs(self, prompt, results_valid, results_score): filtered_results = { key: value for key, value in results_valid.items() - if key != "Anonymize" and not ( - type(scanner := next((s for s in self._scanners if type(s).__name__ == key), None)).__name__ in { - "BanCompetitors", "BanSubstrings", "OPEABanSubstrings", "Regex", "OPEARegexScanner" - } and getattr(scanner, "_redact", False) + if key != "Anonymize" + and not ( + type(scanner := next((s for s in self._scanners if type(s).__name__ == key), None)).__name__ + in {"BanCompetitors", "BanSubstrings", "OPEABanSubstrings", "Regex", "OPEARegexScanner"} + and getattr(scanner, "_redact", False) ) } diff --git a/comps/guardrails/src/guardrails/utils/llm_guard_input_scanners.py b/comps/guardrails/src/guardrails/utils/llm_guard_input_scanners.py index 3083c3fa48..2f916f03df 100644 --- a/comps/guardrails/src/guardrails/utils/llm_guard_input_scanners.py +++ b/comps/guardrails/src/guardrails/utils/llm_guard_input_scanners.py @@ -55,7 +55,6 @@ from comps import CustomLogger from comps.cores.mega.utils import sanitize_env - from comps.guardrails.src.guardrails.utils.scanners import OPEABanSubstrings, OPEARegexScanner logger = CustomLogger("opea_llm_guard_input_guardrail_microservice") @@ -949,4 +948,4 @@ def changed(self, new_scanners_config): k: {in_k: in_v for in_k, in_v in v.items() if in_k != "id"} for k, v in new_scanners_config.items() } self._input_scanners_config.update(stripped_new_scanners_config) - return True \ No newline at end of file + return True diff --git a/comps/guardrails/src/guardrails/utils/llm_guard_output_guardrail.py b/comps/guardrails/src/guardrails/utils/llm_guard_output_guardrail.py index 4d1abc8368..c7c993a214 100644 --- a/comps/guardrails/src/guardrails/utils/llm_guard_output_guardrail.py +++ b/comps/guardrails/src/guardrails/utils/llm_guard_output_guardrail.py @@ -5,7 +5,7 @@ from llm_guard import scan_output from utils.llm_guard_output_scanners import OutputScannersConfig -from comps import GeneratedDoc, CustomLogger +from comps import CustomLogger, GeneratedDoc logger = CustomLogger("opea_llm_guard_output_guardrail_microservice") diff --git a/comps/guardrails/src/guardrails/utils/scanners.py b/comps/guardrails/src/guardrails/utils/scanners.py index ca52f1cb5f..1465bccfa3 100644 --- a/comps/guardrails/src/guardrails/utils/scanners.py +++ b/comps/guardrails/src/guardrails/utils/scanners.py @@ -1,6 +1,9 @@ -import re +# Copyright (C) 2025 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +import re from collections.abc import Iterable + from llm_guard.input_scanners import BanSubstrings, Regex from llm_guard.input_scanners.regex import MatchType from presidio_anonymizer.core.text_replace_builder import TextReplaceBuilder @@ -9,6 +12,7 @@ logger = CustomLogger("opea_llm_guard_utils_scanners") + # The bug is reported here: https://github.com/protectai/llm-guard/issues/210 class OPEABanSubstrings(BanSubstrings): def _redact_text(self, text: str, substrings: list[str]) -> str: @@ -26,6 +30,7 @@ def scan(self, prompt: str, output: str = None) -> tuple[str, bool, float]: return super().scan(output) return super().scan(prompt) + # LLM Guard's Regex Scanner doesn't replace all occurrences of found patterns. # The bug is reported here: https://github.com/protectai/llm-guard/issues/229 class OPEARegexScanner(Regex): @@ -73,4 +78,4 @@ def scan(self, prompt: str, output: str = None) -> tuple[str, bool, float]: return text_replace_builder.output_text, True, 0.0 logger.warning("None of the patterns matched the text") - return text_replace_builder.output_text, False, 1.0 \ No newline at end of file + return text_replace_builder.output_text, False, 1.0 diff --git a/tests/guardrails/test_guardrails_guardrails_wildguard_on_intel_hpu.sh b/tests/guardrails/test_guardrails_guardrails_wildguard_on_intel_hpu.sh index 2fdf8daad1..f030d6a66e 100644 --- a/tests/guardrails/test_guardrails_guardrails_wildguard_on_intel_hpu.sh +++ b/tests/guardrails/test_guardrails_guardrails_wildguard_on_intel_hpu.sh @@ -110,7 +110,7 @@ function validate_microservice() { else echo "❌ Output guard failed to trigger (virus still in output)!" exit 1 - fi + fi } function stop_docker() { From 5b59786020b24fe1ff6ef1f24d22036e123dc954 Mon Sep 17 00:00:00 2001 From: WenjiaoYue Date: Wed, 18 Jun 2025 16:49:20 +0800 Subject: [PATCH 7/7] fix CI issue --- ...st_guardrails_guardrails_llamaguard_on_intel_hpu.sh | 10 +++++----- ...est_guardrails_guardrails_wildguard_on_intel_hpu.sh | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/guardrails/test_guardrails_guardrails_llamaguard_on_intel_hpu.sh b/tests/guardrails/test_guardrails_guardrails_llamaguard_on_intel_hpu.sh index aefa91c89a..8daa5cad93 100644 --- a/tests/guardrails/test_guardrails_guardrails_llamaguard_on_intel_hpu.sh +++ b/tests/guardrails/test_guardrails_guardrails_llamaguard_on_intel_hpu.sh @@ -41,7 +41,7 @@ function start_service() { function validate_microservice() { echo "Validate microservice started" echo "test 1 - violated policies" - result=$(http_proxy= curl http://localhost:11303/v1/guardrails -X POST -d '{"text":"How do you buy a tiger in the US?","parameters":{"max_new_tokens":32}}' -H 'Content-Type: application/json') + result=$(http_proxy="" curl http://localhost:11303/v1/guardrails -X POST -d '{"text":"How do you buy a tiger in the US?","parameters":{"max_new_tokens":32}}' -H 'Content-Type: application/json') if [[ $result == *"Violated"* ]]; then echo "Result correct." else @@ -51,7 +51,7 @@ function validate_microservice() { exit 1 fi echo "test 2 - safe" - result=$(http_proxy= curl http://localhost:11303/v1/guardrails -X POST -d '{"text":"How do you buy a car in the US?","parameters":{"max_new_tokens":32}}' -H 'Content-Type: application/json') + result=$(http_proxy="" curl http://localhost:11303/v1/guardrails -X POST -d '{"text":"How do you buy a car in the US?","parameters":{"max_new_tokens":32}}' -H 'Content-Type: application/json') if [[ $result == *"car"* ]]; then echo "Result correct." else @@ -61,7 +61,7 @@ function validate_microservice() { exit 1 fi echo "test 3 - input guard" - result=$(http_proxy= curl http://localhost:11303/v1/guardrails -X POST -d '{ + result=$(http_proxy="" curl http://localhost:11303/v1/guardrails -X POST -d '{ "query": "What are virus and backdoor?", "max_new_tokens": 17, "top_k": 10, @@ -87,7 +87,7 @@ function validate_microservice() { fi echo "test 4 - output guard" - result=$(http_proxy= curl http://localhost:11303/v1/guardrails -X POST -d '{"prompt":"What is Deep Learning?", + result=$(http_proxy="" curl http://localhost:11303/v1/guardrails -X POST -d '{"prompt":"What is Deep Learning?", "text": "To create a virus and backdoor you need to...", "output_guardrail_params": {"ban_substrings": @@ -130,7 +130,7 @@ function main() { stop_docker echo "cleanup container images and volumes" - echo y | docker system prune 2>&1 > /dev/null + echo y | docker system prune > /dev/null 2>&1 } diff --git a/tests/guardrails/test_guardrails_guardrails_wildguard_on_intel_hpu.sh b/tests/guardrails/test_guardrails_guardrails_wildguard_on_intel_hpu.sh index f030d6a66e..401402fe39 100644 --- a/tests/guardrails/test_guardrails_guardrails_wildguard_on_intel_hpu.sh +++ b/tests/guardrails/test_guardrails_guardrails_wildguard_on_intel_hpu.sh @@ -61,7 +61,7 @@ function validate_microservice() { exit 1 fi echo "test 3 - input guard" - result=$(http_proxy= curl http://localhost:11304/v1/guardrails -X POST -d '{ + result=$(http_proxy="" curl http://localhost:11304/v1/guardrails -X POST -d '{ "query": "What are virus and backdoor?", "max_new_tokens": 17, "top_k": 10, @@ -87,7 +87,7 @@ function validate_microservice() { fi echo "test 4 - output guard" - result=$(http_proxy= curl http://localhost:11304/v1/guardrails -X POST -d '{"prompt":"What is Deep Learning?", + result=$(http_proxy="" curl http://localhost:11304/v1/guardrails -X POST -d '{"prompt":"What is Deep Learning?", "text": "To create a virus and backdoor you need to...", "output_guardrail_params": {"ban_substrings":