Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
7316d1d
Added tool calls support for LLM clients
kevinmessiaen Dec 21, 2023
c7eeee2
Migrated LLMBasedEvaluator to use tools instead of function
kevinmessiaen Dec 21, 2023
aaec45b
Migrated BaseDataGenerator to use tools instead of function
kevinmessiaen Dec 21, 2023
81d0e40
Migrated TestcaseRequirementsGenerator to use tools instead of function
kevinmessiaen Dec 21, 2023
cfe3401
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
kevinmessiaen Dec 21, 2023
a59c8ef
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
kevinmessiaen Dec 22, 2023
7776246
Fixed tests
kevinmessiaen Dec 22, 2023
a8874f8
Merge remote-tracking branch 'origin/feature/gsk-2367-migrate-openai-…
kevinmessiaen Dec 22, 2023
70eac12
Fixed CoherencyEvaluator
kevinmessiaen Dec 22, 2023
daa2059
Fixed tests
kevinmessiaen Dec 22, 2023
4f3bfa0
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
kevinmessiaen Dec 22, 2023
2447cf8
Added ID to LLM tool calls
kevinmessiaen Dec 22, 2023
9841ccf
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
mattbit Dec 22, 2023
6000e4c
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
kevinmessiaen Jan 1, 2024
ef2a7a2
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
mattbit Jan 2, 2024
2afcf94
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
andreybavt Jan 2, 2024
67e9fd2
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
kevinmessiaen Jan 5, 2024
0066a14
Import reorganization
kevinmessiaen Jan 5, 2024
ca6c1a3
Fixed test
kevinmessiaen Jan 5, 2024
3b60c31
Added raw_output to LLMOutput
kevinmessiaen Jan 8, 2024
3908086
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
kevinmessiaen Jan 8, 2024
abc37aa
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
kevinmessiaen Jan 24, 2024
3e65e57
Revert "Added raw_output to LLMOutput"
kevinmessiaen Jan 24, 2024
dd48192
Improved code structure
kevinmessiaen Jan 24, 2024
5227b76
Fixed tool calls
kevinmessiaen Jan 24, 2024
3fcc077
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
kevinmessiaen Jan 26, 2024
261b8e7
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
kevinmessiaen Jan 26, 2024
98ae8c7
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
kevinmessiaen Jan 26, 2024
1f66c50
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
kevinmessiaen Jan 31, 2024
1ac608d
Added a toJSON method
kevinmessiaen Jan 31, 2024
2639c81
Use model dump instead of toJSON
kevinmessiaen Jan 31, 2024
bdecc8b
Use model dump instead of toJSON
kevinmessiaen Jan 31, 2024
1687ecc
Fixed model serialization
kevinmessiaen Jan 31, 2024
be657c3
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
kevinmessiaen Feb 1, 2024
947ed6a
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
kevinmessiaen Feb 2, 2024
7d6b53d
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
mattbit Feb 2, 2024
b91e41e
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
kevinmessiaen Feb 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions giskard/llm/client/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from dataclasses import dataclass, field

from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Sequence

from .logger import LLMLogger

Expand All @@ -16,6 +16,7 @@ class LLMFunctionCall:
class LLMOutput:
message: Optional[str] = None
function_call: Optional[LLMFunctionCall] = None
tool_calls: Sequence[LLMFunctionCall] = field(default_factory=list)
Copy link
Copy Markdown
Contributor

@AbSsEnT AbSsEnT Dec 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this moment it is ok for a backward compatibility, but I believe, that by the end of the day we need to keep only "tool" and not "function" naming convention and thus related variables like "function_call", as well as we may want change "LLMFunctionCall" to "LLMToolCall"

Copy link
Copy Markdown
Member Author

@kevinmessiaen kevinmessiaen Dec 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I'm not sure if function_call are being used in other PR's and might be used by some users having developed there own test.

I guess it's okay to keep LLMFunctionCall for now since only tool supported for now is of type function. I guess we will need to update it once we have other tools supported by OpenAI.



class LLMClient(ABC):
Expand All @@ -33,5 +34,7 @@ def complete(
max_tokens=None,
function_call: Optional[Dict] = None,
caller_id: Optional[str] = None,
tools=None,
tool_choice=None,
) -> LLMOutput:
...
30 changes: 28 additions & 2 deletions giskard/llm/client/openai.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import json
from abc import ABC, abstractmethod
from typing import Dict, Optional, Sequence

from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
from typing import Dict, Optional, Sequence

from . import LLMClient, LLMFunctionCall, LLMLogger, LLMOutput
from ..config import LLMConfigurationError
Expand Down Expand Up @@ -37,6 +37,8 @@ def _completion(
function_call: Optional[Dict] = None,
max_tokens=None,
caller_id: Optional[str] = None,
tools=None,
tool_choice=None,
) -> dict:
...

Expand All @@ -48,6 +50,8 @@ def complete(
max_tokens=None,
function_call: Optional[Dict] = None,
caller_id: Optional[str] = None,
tools=None,
tool_choice=None,
):
cc = self._completion(
messages=messages,
Expand All @@ -56,6 +60,8 @@ def complete(
function_call=function_call,
max_tokens=max_tokens,
caller_id=caller_id,
tools=tools,
tool_choice=tool_choice,
)

function_call = None
Expand All @@ -69,7 +75,15 @@ def complete(
except (json.JSONDecodeError, KeyError) as err:
raise LLMGenerationError("Could not parse function call") from err

return LLMOutput(message=cc["content"], function_call=function_call)
tool_calls = []
for tool_call in cc.get("tool_calls") or []:
tool_calls.append(
LLMFunctionCall(
function=tool_call["function"]["name"], args=json.loads(tool_call["function"]["arguments"])
)
)

return LLMOutput(message=cc["content"], function_call=function_call, tool_calls=tool_calls)


class LegacyOpenAIClient(BaseOpenAIClient):
Expand All @@ -89,12 +103,18 @@ def _completion(
function_call: Optional[Dict] = None,
max_tokens=None,
caller_id: Optional[str] = None,
tools=None,
tool_choice=None,
):
extra_params = dict()
if function_call is not None:
extra_params["function_call"] = function_call
if functions is not None:
extra_params["functions"] = functions
if tools is not None:
extra_params["tools"] = tools
if tool_choice is not None:
extra_params["tool_choice"] = tool_choice

try:
completion = openai.ChatCompletion.create(
Expand Down Expand Up @@ -134,12 +154,18 @@ def _completion(
function_call: Optional[Dict] = None,
max_tokens=None,
caller_id: Optional[str] = None,
tools=None,
tool_choice=None,
):
extra_params = dict()
if function_call is not None:
extra_params["function_call"] = function_call
if functions is not None:
extra_params["functions"] = functions
if tools is not None:
extra_params["tools"] = tools
if tool_choice is not None:
extra_params["tool_choice"] = tool_choice

try:
completion = self._client.chat.completions.create(
Expand Down
37 changes: 20 additions & 17 deletions giskard/llm/evaluators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,25 @@

EVALUATE_MODEL_FUNCTIONS = [
{
"name": "evaluate_model",
"description": "Evaluates if the model passes the test",
"parameters": {
"type": "object",
"properties": {
"passed_test": {
"type": "boolean",
"description": "true if the model successfully passes the test",
},
"reason": {
"type": "string",
"description": "optional short description of why the model does not pass the test, in 1 or 2 short sentences",
"type": "function",
"function": {
"name": "evaluate_model",
"description": "Evaluates if the model passes the test",
"parameters": {
"type": "object",
"properties": {
"passed_test": {
"type": "boolean",
"description": "true if the model successfully passes the test",
},
"reason": {
"type": "string",
"description": "optional short description of why the model does not pass the test, in 1 or 2 short sentences",
},
},
"required": ["passed_test"],
},
},
"required": ["passed_test"],
},
]

Expand Down Expand Up @@ -90,18 +93,18 @@ def evaluate(self, model: BaseModel, dataset: Dataset):
try:
out = self.llm_client.complete(
[{"role": "system", "content": prompt}],
functions=funcs,
function_call={"name": "evaluate_model"},
tools=funcs,
tool_choice={"type": "function", "function": {"name": "evaluate_model"}},
temperature=self.llm_temperature,
caller_id=self.__class__.__name__,
)
if out.function_call is None or "passed_test" not in out.function_call.args:
if len(out.tool_calls) != 1 or "passed_test" not in out.tool_calls[0].args:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this specific case, we do not expect, that we can make multiple tool calls, right?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes since we are evaluating a single case we expect it to be evaluated once

raise LLMGenerationError("Invalid function call arguments received")
except LLMGenerationError as err:
errored.append({"message": str(err), "sample": sample})
continue

args = out.function_call.args
args = out.tool_calls[0].args
if args["passed_test"]:
succeeded.append({"input_vars": input_vars, "model_output": model_output, "reason": args.get("reason")})
else:
Expand Down
8 changes: 4 additions & 4 deletions giskard/llm/evaluators/coherency.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,13 @@ def _eval_pair(self, model: BaseModel, input_1, input_2, output_1, output_2):

out = self.llm_client.complete(
[{"role": "system", "content": prompt}],
functions=EVALUATE_MODEL_FUNCTIONS,
function_call={"name": "evaluate_model"}, # force function call
tools=EVALUATE_MODEL_FUNCTIONS,
tool_choice={"type": "function", "function": {"name": "evaluate_model"}}, # force tool call
temperature=self.llm_temperature,
caller_id=self.__class__.__name__,
)

if out.function_call is None or "passed_test" not in out.function_call.args:
if len(out.tool_calls) != 1 or "passed_test" not in out.tool_calls[0].args:
raise LLMGenerationError("Invalid function call arguments received")

return out.function_call.args["passed_test"], out.function_call.args.get("reason")
return out.tool_calls[0].args["passed_test"], out.tool_calls[0].args.get("reason")
44 changes: 23 additions & 21 deletions giskard/llm/generators/base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from typing import Optional, Sequence

from abc import ABC, abstractmethod

import pandas as pd
from pydantic import BaseModel
from typing import Optional, Sequence

from ...datasets.base import Dataset
from ..client import LLMClient, get_default_client
from ..errors import LLMGenerationError
from ...datasets.base import Dataset

DEFAULT_GENERATE_INPUTS_PROMPT = """You are auditing an AI model. Your task is to generate typical but varied inputs for this model.

Expand Down Expand Up @@ -61,20 +60,23 @@ def _make_generate_input_prompt(self, model: BaseModel, num_samples: int):
def _make_generate_input_functions(self, model: BaseModel, num_samples: int):
return [
{
"name": "generate_inputs",
"description": "generates inputs for model audit",
"parameters": {
"type": "object",
"properties": {
"inputs": {
"type": "array",
"items": {
"type": "object",
"properties": {name: {"type": "string"} for name in model.meta.feature_names},
},
}
"type": "function",
"function": {
"name": "generate_inputs",
"description": "generates inputs for model audit",
"parameters": {
"type": "object",
"properties": {
"inputs": {
"type": "array",
"items": {
"type": "object",
"properties": {name: {"type": "string"} for name in model.meta.feature_names},
},
}
},
"required": ["inputs"],
},
"required": ["inputs"],
},
}
]
Expand Down Expand Up @@ -107,19 +109,19 @@ def generate_dataset(self, model: BaseModel, num_samples: int = 10, column_types

"""
prompt = self._make_generate_input_prompt(model, num_samples)
functions = self._make_generate_input_functions(model, num_samples)
tools = self._make_generate_input_functions(model, num_samples)

out = self.llm_client.complete(
messages=[{"role": "system", "content": prompt}],
functions=functions,
function_call={"name": "generate_inputs"},
tools=tools,
tool_choice={"type": "function", "function": {"name": "generate_inputs"}},
temperature=self.llm_temperature,
caller_id=self.__class__.__name__,
)

try:
generated = out.function_call.args["inputs"]
except (AttributeError, KeyError) as err:
generated = out.tool_calls[0].args["inputs"]
except (AttributeError, KeyError, IndexError) as err:
raise LLMGenerationError("Could not parse generated inputs") from err

dataset = Dataset(
Expand Down
42 changes: 24 additions & 18 deletions giskard/llm/testcase.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,32 @@

GENERATE_REQUIREMENTS_FUNCTIONS = [
{
"name": "generate_requirements",
"description": "Generates requirements for model testing",
"parameters": {
"type": "object",
"properties": {
"requirements": {
"type": "array",
"items": {"type": "string", "description": "A requirement the model must satisfy"},
}
"type": "function",
"function": {
"name": "generate_requirements",
"description": "Generates requirements for model testing",
"parameters": {
"type": "object",
"properties": {
"requirements": {
"type": "array",
"items": {"type": "string", "description": "A requirement the model must satisfy"},
}
},
"required": ["requirements"],
},
"required": ["requirements"],
},
},
{
"name": "skip",
"description": "Skips the generation when no relevant requirements can be generated",
"parameters": {
"type": "object",
"properties": {},
"required": [],
"type": "function",
"function": {
"name": "skip",
"description": "Skips the generation when no relevant requirements can be generated",
"parameters": {
"type": "object",
"properties": {},
"required": [],
},
},
},
]
Expand Down Expand Up @@ -70,8 +76,8 @@ def generate_requirements(self, model: BaseModel, max_requirements: int = 5):
functions = self._make_generate_requirements_functions()
out = self.llm_client.complete(
messages=[{"role": "system", "content": prompt}],
functions=functions,
function_call={"name": "generate_requirements"},
tools=functions,
tool_choice={"type": "function", "function": {"name": "generate_requirements"}},
temperature=self.llm_temperature,
caller_id=self.__class__.__name__,
)
Expand Down
Loading