Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
276 changes: 263 additions & 13 deletions giskard/rag/knowledge_base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
from typing import Dict, Optional, Sequence
from typing import Dict, Optional, Sequence, TYPE_CHECKING

import json
import logging
import os
import tempfile
import uuid
from pathlib import Path

import numpy as np
import pandas as pd
from huggingface_hub import DatasetCard, HfApi, HfFolder, hf_hub_download, upload_file
from sklearn.cluster import HDBSCAN

from ..llm.client import ChatMessage, LLMClient, get_default_client
if TYPE_CHECKING:
from ..llm.client import LLMClient
from ..llm.embeddings.base import BaseEmbedding

from ..llm.client import ChatMessage, get_default_client
from ..llm.embeddings import get_default_embedding
from ..llm.embeddings.base import BaseEmbedding
from ..llm.errors import LLMImportError
from ..utils.analytics_collector import analytics
from ..utils.language_detection import detect_lang
Expand Down Expand Up @@ -109,8 +117,8 @@ def __init__(
data: pd.DataFrame,
columns: Optional[Sequence[str]] = None,
seed: int = None,
llm_client: Optional[LLMClient] = None,
embedding_model: Optional[BaseEmbedding] = None,
llm_client: Optional["LLMClient"] = None,
embedding_model: Optional["BaseEmbedding"] = None,
min_topic_size: Optional[int] = None,
chunk_size: int = 2048,
) -> None:
Expand All @@ -134,6 +142,7 @@ def __init__(
self._knowledge_base_df = data
self._columns = columns

self.seed = seed
self._rng = np.random.default_rng(seed=seed)
self._llm_client = llm_client or get_default_client()
self._embedding_model = embedding_model or get_default_embedding()
Expand All @@ -146,6 +155,7 @@ def __init__(
self._topics_inst = None
self._index_inst = None
self._reduced_embeddings_inst = None
self._config_inst = None

# Detect language of the documents, use only the first characters of a few documents to speed up the process
document_languages = [
Expand All @@ -167,8 +177,215 @@ def __init__(
},
)

def save(self, dirpath: str):
"""Save the KnowledgeBase as a parquet file, to a target directory.

Parameters
----------
dirpath : str
The directory path where the KnowledgeBase will be saved.
"""
os.makedirs(dirpath, exist_ok=True)

# save the knowledge base
knowledge_base_df = self._knowledge_base_df.copy()
knowledge_base_df["embeddings"] = (
self._embeddings_inst.tolist() if self._embeddings_inst is not None else [None] * len(self._documents)
)
knowledge_base_df["reduced_embeddings"] = (
self._reduced_embeddings_inst.tolist()
if self._reduced_embeddings_inst is not None
else [None] * len(self._documents)
)
knowledge_base_df.to_parquet(os.path.join(dirpath, "knowledge_base.parquet"))

# save config
with open(os.path.join(dirpath, "config.json"), "w") as f:
json.dump(self._config, f, indent=2)

logger.info("KnowledgeBase saved successfully.")

def push_to_hf_hub(self, repo_id: str, hf_token: Optional[str] = None, private: bool = False):
"""
Push the KnowledgeBase to the Hugging Face Hub.

Parameters
----------
repo_id : str
The repo ID on Hugging Face Hub (e.g., "org-name/my-knowledge-base").
hf_token : str, optional
Hugging Face token for authentication. If None, will use local token.
private : bool
Whether to make the repo private or public.
"""

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we also pass kwargs for the push to hub/upload file etc?

api = HfApi(token=hf_token or HfFolder.get_token())
api.create_repo(repo_id, repo_type="dataset", private=private, exist_ok=True)

with tempfile.TemporaryDirectory() as tmpdir:
self.save(tmpdir)

# push the files to the hub
for filename in ["knowledge_base.parquet", "config.json"]:
filepath = os.path.join(tmpdir, filename)
if os.path.exists(filepath):
upload_file(
path_or_fileobj=filepath,
path_in_repo=filename,
repo_id=repo_id,
repo_type="dataset",
token=hf_token,
)
logger.info(f"KnowledgeBase pushed to https://huggingface.co/datasets/{repo_id}")

# Load the dataset card template
template_path = Path(__file__).parent / "knowledge_base_card_template.md"
template = template_path.read_text()

# Make and push the dataset card
content = template.format(
repo_id=repo_id, num_items=len(self._knowledge_base_df), config=json.dumps(self._config_inst, indent=4)
)
return DatasetCard(content=content).push_to_hub(repo_id=repo_id, token=hf_token, repo_type="dataset")

@classmethod
def from_pandas(cls, df: pd.DataFrame, columns: Optional[Sequence[str]] = None, **kwargs) -> "KnowledgeBase":
def load(
cls,
dirpath: str,
llm_client: Optional["LLMClient"] = None,
embedding_model: Optional["BaseEmbedding"] = None,
) -> "KnowledgeBase":
"""
Load a KnowledgeBase instance by loading data from specified file paths.

This method initializes a KnowledgeBase object using the provided file paths for configuration, data,
embeddings, and models. It handles loading and deserialization of the necessary components, such as
the knowledge base data, LLM client, embedding model, and precomputed embeddings.

Parameters
----------
dirpath : str
The directory path where the KnowledgeBase files are stored.
llm_client: LLMClient, optional:
The LLM client to use for question generation. If not specified, a default openai client will be used.
embedding_model: BaseEmbedding, optional
The giskard embedding model to use for the knowledge base. By default we use giskard default model which is OpenAI "text-embedding-ada-002".

Returns
-------
KnowledgeBase
An instance of the KnowledgeBase class loaded with the documents, embeddings and configuration.
"""
# Check if the directory exists
assert os.path.exists(dirpath), f"Directory {dirpath} does not exist."
knowledge_base_filepath = os.path.join(dirpath, "knowledge_base.parquet")
config_filepath = os.path.join(dirpath, "config.json")

# Load data from files
data = pd.read_parquet(knowledge_base_filepath)
try:
config = json.loads(open(config_filepath, "r").read())
except Exception as e:
logger.warning(f"Could not load config: {e}. Using default configuration.")
config = {}

# Build the KnowledgeBase instance
logger.info("Building the KnowledgeBase instance...")
kb = cls(
data=data,
columns=config.get("columns"),
llm_client=llm_client or get_default_client(),
embedding_model=embedding_model or get_default_embedding(),
chunk_size=config.get("chunk_size", 2048),
seed=config.get("seed"),
min_topic_size=config.get("min_topic_size"),
)

# Load the embeddings if available
try:
kb.store_embeddings_inst(np.load(data["embeddings"].values.tolist()))
except Exception as e:
logger.warning(f"Could not load embeddings: {e}.")

try:
kb.store_reduced_embeddings_inst(np.load(data["reduced_embeddings"].values.tolist()))
except Exception as e:
logger.warning(f"Could not load reduced embeddings: {e}.")

logger.info("KnowledgeBase loaded successfully.")
return kb

@classmethod
def load_from_hf_hub(
cls,
repo_id: str,
hf_token: Optional[str] = None,
llm_client: Optional["LLMClient"] = None,
embedding_model: Optional["BaseEmbedding"] = None,
**hf_hub_kwargs,
) -> "KnowledgeBase":
"""
Load a KnowledgeBase from the Hugging Face Hub.

This method retrieves the necessary files from the specified Hugging Face Hub repository
and reconstructs a KnowledgeBase instance using the stored data, embeddings, and configuration.

Parameters
----------
repo_id : str
The repository ID on the Hugging Face Hub (e.g., "org-name/my-knowledge-base").
hf_token : str, optional
Hugging Face token for authentication. If None, the local token will be used.
llm_client : LLMClient, optional
An optional LLMClient instance. If not provided, it will be loaded from the repository.
embedding_model : BaseEmbedding, optional
An optional embedding model instance. If not provided, it will be loaded from the repository.
**hf_hub_kwargs : dict
Additional keyword arguments to pass to the Hugging Face Hub download function.

Returns
-------
KnowledgeBase
An instance of the KnowledgeBase class loaded with the data, embeddings and configuration from the specified repository.

Raises
------
ValueError
If the required "knowledge_base.parquet" file cannot be downloaded.

Notes
-----
- The repository must contain the following file:
- "knowledge_base.parquet": The raw data for the KnowledgeBase.
- Optional file:
- "config.json": A JSON file containing configuration parameters.
"""
with tempfile.TemporaryDirectory() as tmpdir:
for filename in ["knowledge_base.parquet", "config.json"]:
try:
hf_hub_download(
repo_id,
filename=filename,
repo_type="dataset",
token=hf_token,
local_dir=tmpdir,
**hf_hub_kwargs,
)
except Exception as e:
logger.warning(f"Failed to download {filename}: {e}")
if filename == "knowledge_base.parquet":
raise ValueError(f"Failed to download {filename}. Cannot load the KnowledgeBase without it.")

return cls.load(tmpdir, llm_client=llm_client, embedding_model=embedding_model)

@classmethod
def from_pandas(
cls,
df: pd.DataFrame,
columns: Optional[Sequence[str]] = None,
embeddings_inst: np.ndarray = None,
reduced_embeddings_inst: np.ndarray = None,
**kwargs,
) -> "KnowledgeBase":
"""Create a KnowledgeBase from a pandas DataFrame.

Parameters
Expand All @@ -180,18 +397,53 @@ def from_pandas(cls, df: pd.DataFrame, columns: Optional[Sequence[str]] = None,
dataframe will be concatenated to produce a single document.
Example: if your knowledge base consists in FAQ data with columns "Q" and "A", we will format each row into a
single document "Q: [question]\\nA: [answer]" to generate questions.
embeddings_inst: np.ndarray
The precomputed embeddings for the knowledge base documents.
reduced_embeddings_inst: np.ndarray
The precomputed reduced embeddings for the knowledge base documents.
kwargs:
Additional settings for knowledge base (see __init__).
"""
return cls(data=df, columns=columns, **kwargs)
kb = cls(data=df, columns=columns, **kwargs)
if embeddings_inst is not None:
logging.info("Storing embeddings in KnowledgeBase.")
kb.store_embeddings_inst(embeddings_inst)
if reduced_embeddings_inst is not None:
logging.info("Storing reduced embeddings in KnowledgeBase.")
kb.store_reduced_embeddings_inst(reduced_embeddings_inst)
logger.info("KnowledgeBase created from pandas DataFrame.")
return kb

def store_embeddings_inst(self, embeddings_inst: np.ndarray):
self._embeddings_inst = embeddings_inst
for doc, emb in zip(self._documents, self._embeddings_inst):
doc.embeddings = emb

def store_reduced_embeddings_inst(self, reduced_embeddings_inst: np.ndarray):
self._reduced_embeddings_inst = reduced_embeddings_inst
for doc, emb in zip(self._documents, self._reduced_embeddings_inst):
doc.reduced_embeddings = emb

@property
def _config(self):
if not self._config_inst:
self._config_inst = {
"columns": self._columns,
"chunk_size": self.chunk_size,
"min_topic_size": self._min_topic_size,
"language": self._language,
"seed": self.seed,
"embedding_model_class": getattr(self._embedding_model, "__class__", None),
"embedding_model": getattr(self._embedding_model, "model", None),
}
return self._config_inst

@property
def _embeddings(self):
if self._embeddings_inst is None:
logger.debug("Computing Knowledge Base embeddings.")
self._embeddings_inst = np.array(self._embedding_model.embed([doc.content for doc in self._documents]))
for doc, emb in zip(self._documents, self._embeddings_inst):
doc.embeddings = emb
embeddings_inst = np.array(self._embedding_model.embed([doc.content for doc in self._documents]))
self.store_embeddings_inst(embeddings_inst)
return self._embeddings_inst

@property
Expand All @@ -205,9 +457,7 @@ def _reduced_embeddings(self):
random_state=1234,
n_jobs=1,
)
self._reduced_embeddings_inst = reducer.fit_transform(self._embeddings)
for doc, emb in zip(self._documents, self._reduced_embeddings_inst):
doc.reduced_embeddings = emb
self.store_reduced_embeddings_inst(reducer.fit_transform(self._embeddings))
return self._reduced_embeddings_inst

@property
Expand Down
51 changes: 51 additions & 0 deletions giskard/rag/knowledge_base_card_template.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
---
tags:
- giskard
- knowledge-base
- information-retrieval

task_categories:
- text-generation
- text2text-generation
- question-answering
- text-retrieval
---

# Dataset Card for {repo_id}
> This repository was created using the [giskard](https://github.com/Giskard-AI/giskard) library, an open-source Python framework designed to evaluate and test AI systems.

This dataset comprises a giskard's `KnowledgeBase` containing {num_items} documents. If embeddings were generated before the saving process, they are included and will be automatically loaded into a vector store when required.

## Usage

You can load this knowledge base using the following code:

```python
from giskard.rag import KnowledgeBase
kb = KnowledgeBase.load_from_hf_hub("{repo_id}")
```

## Configuration

The configuration details for this Knowledge Base (can also be found in the `config.json` file):

```bash
{config}
```

---

<h2 style="text-align: center;">
<span style="display: inline-flex; align-items: center;">
Built with
<a href="https://giskard.ai" target="_blank" style="display: inline-flex;">
<img src="https://cdn.prod.website-files.com/601d6f7d0b9c984f07bf10bc/62983fa8ef716259c397a57d_logo.svg"
alt="Giskard Logo"
width="100">
</a>
</span>
</h2>

<div style="text-align: center;">
<a href="https://github.com/Giskard-AI/giskard" target="_blank" style="display: inline-flex;"> Giskard </a> helps identify performance, bias, and security issues in AI applications, supporting both LLM-based systems like RAG agents and traditional machine learning models for tabular data.
</div>
Loading
Loading