Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions giskard/llm/talk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,25 @@
You interact with the model through different tools. Tools are functions, whose response is used to enrich you with the necessary context to answer a user's question.

The descriptions of these tools is provided below:
Tools descriptions: {tools_descriptions}
Tools descriptions:
{tools_descriptions}

Your main goal is to choose and execute an appropriate tool, which help you to answer a user's question.
For the chosen tool you need to create an input, regarding provided tool specification.
If there is an error during the tool calling, you will get this information. You need to make a summary about this error and inform the user.

Please provide polite and concise answers to the user and avoid explaining the result, until the user explicitly asks you to do it.
Please provide polite, and concise answers to the user and avoid explaining the result, until the user explicitly asks you to do it.
Please, take into account, that user not necessarily have computer science background, thus your answers must be clear to people from different domains.
Please note that you cannot share any confidential information with the user, and if the question is not related to the model, you must return an explanation why you cannot answer.

Your will interact with the following model:
Model name: {model_name}
Model description: {model_description}.
Model description: {model_description}
Model features: {feature_names}
"""

ERROR_RESPONSE = """There is an error, when calling the tool. Detailed info:
Error message: "{error_msg}"\n
Tool name: "{tool_name}"\n
Tool arguments: "{tool_args}"\n
"""
72 changes: 70 additions & 2 deletions giskard/llm/talk/tools.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from __future__ import annotations
from abc import ABC, abstractmethod

import numpy as np
from typing import TYPE_CHECKING

from shap import Explanation

if TYPE_CHECKING:
from giskard.models.base import BaseModel

Expand All @@ -21,8 +25,8 @@ def __call__(self, *args, **kwargs) -> str:

class PredictFromDatasetTool(BaseTool):
default_name = "predict_from_dataset"
default_description = ("From the user input, it extracts a necessary information to filter rows from the dataset, "
"then it runs the model prediction on that rows and finally returns the prediction result.")
default_description = ("You expect a dictionary with features and their values to filter rows from the dataset, "
"then you run the model prediction on that rows and finally return the prediction result.")

def __init__(self, model: BaseModel, dataset: Dataset, name: str = None, description: str = None):
self._model = model
Expand Down Expand Up @@ -84,3 +88,67 @@ def __call__(self, row_filter: dict) -> str:
result = ", ".join(prediction)

return result


class SHAPExplanationTool(PredictFromDatasetTool):
default_name = "shap_prediction_explanation"
default_description = ("You expect a dictionary with feature names as keys and their values as dict values, "
"which you use to filter rows in the dataset, "
"then you run the SHAP explanation on that filtered rows, "
"and finally you return the SHAP explanation result as well as the model prediction result."
"Please note, that the bigger SHAP value - the more important feature is for prediction.")

_result_template = "'{feature_name}' | {attributions_values}"

def _get_shap_explanations(self, filtered_dataset: Dataset) -> Explanation:
from shap import KernelExplainer
from giskard.models.model_explanation import _prepare_for_explanation, _get_background_example, _get_highest_proba_shap

# Prepare background sample to be used in the KernelSHAP.
background_df = self._model.prepare_dataframe(self._dataset.df, self._dataset.column_dtypes, self._dataset.target)
background_sample = _get_background_example(background_df, self._dataset.column_types)

# Prepare input data for an explanation.
data_to_explain = _prepare_for_explanation(filtered_dataset.df, model=self._model, dataset=self._dataset)

def prediction_function(_df):
"""Rolls-back SHAP casting of all columns to the 'object' type."""
return self._model.predict_df(_df.astype(data_to_explain.dtypes))

# Obtain SHAP explanations.
explainer = KernelExplainer(prediction_function, background_sample, data_to_explain.columns, keep_index=True)
shap_values = explainer.shap_values(data_to_explain, silent=True)

if self._model.is_classification:
shap_values = _get_highest_proba_shap(shap_values, self._model, filtered_dataset)

# Put SHAP values to the Explanation object for a convenience.
feature_names = self._model.meta.feature_names or list(self._dataset.df.columns.drop(self._dataset.target, errors="ignore"))
shap_explanations = Explanation(shap_values, data=filtered_dataset.df[feature_names], feature_names=feature_names)
return shap_explanations

def __call__(self, row_filter: dict) -> str:
# 0) Get prediction result from the parent tool.
prediction_result = super().__call__(row_filter)

# 1) Filter the dataset using the LLM-created filter expression.
filtered_dataset = self._get_filtered_dataset(row_filter)

# 2) Get a SHAP explanation.
explanations = self._get_shap_explanations(filtered_dataset)

# 3) Finalise the result.
shap_result = [self._result_template.format(feature_name=f_name,
attributions_values=list(zip(np.abs(explanations[:, f_name].values),
explanations[:, f_name].data)))
for f_name in explanations.feature_names]
shap_result = "\n".join(shap_result)

result = (f"Prediction result:\n"
f"{prediction_result}\n\n"
f"SHAP result: \n"
f"'Feature name' | [('SHAP value', 'Feature value'), ...]\n"
f"--------------------------------------------------------\n"
f"{shap_result}\n\n")

return result
21 changes: 15 additions & 6 deletions giskard/models/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from ...core.validation import configured_validate_arguments
from ...datasets.base import Dataset
from ...llm import get_default_client
from ...llm.talk.config import MODEL_INSTRUCTION
from ...llm.talk.tools import BaseTool, PredictFromDatasetTool
from ...llm.talk.config import MODEL_INSTRUCTION, ERROR_RESPONSE
from ...llm.talk.tools import BaseTool, PredictFromDatasetTool, SHAPExplanationTool
from ...ml_worker.exceptions.giskard_exception import GiskardException, python_env_exception_helper
from ...ml_worker.utils.logging import Timer
from ...models.cache import ModelCache
Expand Down Expand Up @@ -568,7 +568,8 @@ def to_mlflow(self, *_args, **_kwargs):
def _get_available_tools(self, dataset: Dataset) -> dict[str, BaseTool]:
"""Get the dictionary with available tools"""
tools = {
PredictFromDatasetTool.default_name: PredictFromDatasetTool(self, dataset)
PredictFromDatasetTool.default_name: PredictFromDatasetTool(self, dataset),
SHAPExplanationTool.default_name: SHAPExplanationTool(self, dataset)
}

return tools
Expand Down Expand Up @@ -606,9 +607,17 @@ def talk(self, question: str, dataset: Dataset) -> str:

# Get the reference to the chosen function.
function = available_tools[function_name]
function_response = function(
**function_args
)

try:
function_response = function(
**function_args
)
except Exception as error_msg:
function_response = ERROR_RESPONSE.format(
tool_name=function_name,
tool_args=function_args,
error_msg=error_msg.args[0]
)

# Append the tool's response to the conversation.
messages.append(
Expand Down