Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/user-guides/advanced/embedding-search-providers.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ knowledge_base:
The default implementation is also designed to support asynchronous execution of the embedding computation process, thereby enhancing the efficiency of the search functionality.

The `cache` configuration is optional. If enabled, it uses the specified `key_generator` and `store` to cache the embeddings. The `store_config` can be used to provide additional configuration options required for the store.
The default `cache` configuration uses the `md5` key generator and the `filesystem` store. The cache is disabled by default.
The default `cache` configuration uses the `md5` key generator and the `filesystem` store. In rare cases when `md5` is not available `sha256` will be used - please see [hashing settings](../advanced/hashing-settings.md) for more information. The cache is disabled by default.

## Batch Implementation

Expand Down
42 changes: 42 additions & 0 deletions docs/user-guides/advanced/hashing-settings.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Hashing settings

## Overview

Nemo Guardrails uses hashing mainly for caching purposes. By default, the `md5` hashing algorithm is used. Caching of search queries is disabled by default, but this does not disable it entirely.

## FIPS considerations

In some regulated environments, the `md5` hashing algorithm may not be available (e.g., FIPS-compliant Python). In such cases, `sha256` hashing will be used instead. This default applies across the library unless explicitly overridden.

## Setting hashing algorithm

To explicitly set the hashing algorithm, call the following function before running the Nemo Guardrails library code:

```python
from nemoguardrails.hashing import set_default_hash_algorithm
set_default_hash_algorithm('sha256')
```

## Additional considerations

When caching is enabled and `key_generator` is set in the configuration, it overrides the library default for caching embedding searches.

Example:

```yaml
knowledge_base:
embedding_search_provider:
name: default
parameters:
embedding_engine: FastEmbed
embedding_model: all-MiniLM-L6-v2
use_batching: False
max_batch_size: 10
max_batch_hold: 0.01
search_threshold: None
cache:
enabled: True
key_generator: sha256 # <- Overrides the library default.
store: filesystem
store_config: {}
```
9 changes: 9 additions & 0 deletions nemoguardrails/embeddings/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ def generate_key(self, text: str) -> str:
return hashlib.md5(text.encode("utf-8")).hexdigest()


class SHA256KeyGenerator(KeyGenerator):
"""SHA256-based key generator."""

name = "sha256"

def generate_key(self, text: str) -> str:
return hashlib.sha256(text.encode("utf-8")).hexdigest()


class CacheStore(ABC):
"""Abstract class for cache stores."""

Expand Down
90 changes: 90 additions & 0 deletions nemoguardrails/hashing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib

_default_hash_algorithm: str


def set_default_hash_algorithm(algorithm: str):
"""
Set the default hash algorithm.

Parameters
----------
algorithm : str
The name of the hash algorithm to set as default.
The available options are: "md5", "sha256".

Raises
------
ValueError
If the provided algorithm is not supported.
"""
_supported = {"md5", "sha256"}

if algorithm not in _supported:
raise ValueError(
f"Unsupported value: {algorithm}, " f"use one of {','.join(_supported)}"
)

global _default_hash_algorithm
_default_hash_algorithm = algorithm


def get_default_hash_algorithm() -> str:
"""Returns the default hash algorithm based on the system configuration."""
return _default_hash_algorithm


def generate_hash(text: str) -> str:
"""
Get the hash of a given text using the default hash function.

Args:
text (str): The text to hash.

Returns:
str: The hash of the text.
"""
hash_func = getattr(hashlib, _default_hash_algorithm)
return hash_func(text.encode()).hexdigest()


def _is_md5_available() -> bool:
"""
Check if MD5 usage is allowed. In some FIPS-compliant Python builds, the MD5 hashing
function may be missing or raise an exception when using OpenSSL compiled in FIPS mode.

When MD5 is not available, AttributeError will be raised for missing hashlib.md5.
When OpenSSL is compiled in FIPS mode, the _hashlib.UnsupportedDigestmodError(ValueError)
will be raised.

Returns
-------
bool
True if MD5 is available, False otherwise.
"""
try:
hashlib.md5()
return True
except (AttributeError, ValueError):
return False


def detect_default_hash_algorithm():
set_default_hash_algorithm("md5" if _is_md5_available() else "sha256")


detect_default_hash_algorithm()
12 changes: 5 additions & 7 deletions nemoguardrails/kb/kb.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import hashlib
import logging
import os
from time import time
from typing import Callable, List, Optional, cast

from nemoguardrails.embeddings.index import EmbeddingsIndex, IndexItem
from nemoguardrails.hashing import generate_hash
from nemoguardrails.kb.utils import split_markdown_in_topic_chunks
from nemoguardrails.rails.llm.config import EmbeddingSearchProvider, KnowledgeBaseConfig

Expand Down Expand Up @@ -114,18 +114,16 @@ async def build(self):
if not index_items:
return

# We compute the md5
# We compute the hash using default hash algorithm
# As part of the hash, we also include the embedding engine and the model
# to prevent the cache being used incorrectly when the embedding model changes.
hash_prefix = self.config.embedding_search_provider.parameters.get(
"embedding_engine", ""
) + self.config.embedding_search_provider.parameters.get("embedding_model", "")

md5_hash = hashlib.md5(
(hash_prefix + "".join(all_text_items)).encode("utf-8")
).hexdigest()
cache_file = os.path.join(CACHE_FOLDER, f"{md5_hash}.ann")
embedding_size_file = os.path.join(CACHE_FOLDER, f"{md5_hash}.esize")
hash_value = generate_hash(hash_prefix + "".join(all_text_items))
cache_file = os.path.join(CACHE_FOLDER, f"{hash_value}.ann")
embedding_size_file = os.path.join(CACHE_FOLDER, f"{hash_value}.esize")

# If we have already computed this before, we use it
if (
Expand Down
3 changes: 2 additions & 1 deletion nemoguardrails/rails/llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from nemoguardrails.colang.v2_x.lang.colang_ast import Flow
from nemoguardrails.colang.v2_x.lang.utils import format_colang_parsing_error_message
from nemoguardrails.colang.v2_x.runtime.errors import ColangParsingError
from nemoguardrails.hashing import get_default_hash_algorithm

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -247,7 +248,7 @@ class EmbeddingsCacheConfig(BaseModel):
description="Whether caching of the embeddings should be enabled or not.",
)
key_generator: str = Field(
default="md5",
default=get_default_hash_algorithm(),
description="The method to use for generating the cache keys.",
)
store: str = Field(
Expand Down
24 changes: 24 additions & 0 deletions tests/test_cache_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
KeyGenerator,
MD5KeyGenerator,
RedisCacheStore,
SHA256KeyGenerator,
cache_embeddings,
)
from nemoguardrails.rails.llm.config import EmbeddingsCacheConfig
Expand Down Expand Up @@ -58,6 +59,29 @@ def test_md5_key_generator():
assert len(key) == 32 # MD5 hash is 32 characters long


def test_sha256_key_generator():
key_gen = SHA256KeyGenerator()
key = key_gen.generate_key("test")
assert isinstance(key, str)
assert len(key) == 64 # SHA256 hash is 64 characters long


@pytest.mark.parametrize(
"name, expected_class",
[
("hash", HashKeyGenerator),
("md5", MD5KeyGenerator),
("sha256", SHA256KeyGenerator),
],
)
def test_key_generator_class(name, expected_class):
assert KeyGenerator.from_name(name) == expected_class


def test_embedding_cache_config_default():
assert EmbeddingsCacheConfig().key_generator == "md5"


def test_in_memory_cache_store():
cache = InMemoryCacheStore()
cache.set("key", "value")
Expand Down
12 changes: 12 additions & 0 deletions tests/test_configs/with_sha256_hash/config.co
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
define user ask capabilities
"What can you do?"
"What can you help me with?"
"tell me what you can do"
"tell me about you"

define bot inform capabilities
"I am an AI assistant that helps answer questions."

define flow
user ask capabilities
bot inform capabilities
24 changes: 24 additions & 0 deletions tests/test_configs/with_sha256_hash/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
models:
- type: main
engine: openai
model: gpt-3.5-turbo-instruct

core:
embedding_search_provider:
name: default
parameters:
embedding_engine: openai
embedding_model: text-embedding-ada-002
cache:
enabled: True
key_generator: sha256

knowledge_base:
embedding_search_provider:
name: default
parameters:
embedding_engine: openai
embedding_model: text-embedding-ada-002
cache:
enabled: True
key_generator: sha256
78 changes: 78 additions & 0 deletions tests/test_hashing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import patch

import pytest
from _hashlib import UnsupportedDigestmodError

from nemoguardrails.hashing import detect_default_hash_algorithm as setup_hashing
from nemoguardrails.hashing import (
generate_hash,
get_default_hash_algorithm,
set_default_hash_algorithm,
)


@pytest.fixture(scope="function")
def md5_is_missing():
"""Raise an exception when hashlib.md5 is not available."""
with patch("hashlib.md5", side_effect=AttributeError):
setup_hashing()
yield

# cleanup
setup_hashing()


@pytest.fixture(scope="function")
def md5_unsupported_digest():
"""Raise an exception when hashlib is using OpenSSL compiled in FIPS mode."""
with patch("hashlib.md5", side_effect=UnsupportedDigestmodError):
setup_hashing()
yield

# cleanup
setup_hashing()


@pytest.fixture(params=["md5_is_missing", "md5_unsupported_digest"])
def md5_not_available(request):
yield request.getfixturevalue(request.param)


def test_default_without_md5(md5_not_available):
assert get_default_hash_algorithm() == "sha256"


def test_default_with_md5():
assert get_default_hash_algorithm() == "md5"


def test_hash_without_md5(md5_not_available):
hash_value = generate_hash("test")
assert isinstance(hash_value, str)
assert len(hash_value) == 64 # SHA256 hash is 64 characters long


def test_hash_with_md5():
hash_value = generate_hash("test")
assert isinstance(hash_value, str)
assert len(hash_value) == 32 # MD5 hash is 32 characters long


def test_invalid_hash_algorithm_not_allowed():
with pytest.raises(ValueError):
set_default_hash_algorithm("invalid")