Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
7316d1d
Added tool calls support for LLM clients
kevinmessiaen Dec 21, 2023
c7eeee2
Migrated LLMBasedEvaluator to use tools instead of function
kevinmessiaen Dec 21, 2023
aaec45b
Migrated BaseDataGenerator to use tools instead of function
kevinmessiaen Dec 21, 2023
81d0e40
Migrated TestcaseRequirementsGenerator to use tools instead of function
kevinmessiaen Dec 21, 2023
cfe3401
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
kevinmessiaen Dec 21, 2023
a59c8ef
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
kevinmessiaen Dec 22, 2023
7776246
Fixed tests
kevinmessiaen Dec 22, 2023
a8874f8
Merge remote-tracking branch 'origin/feature/gsk-2367-migrate-openai-…
kevinmessiaen Dec 22, 2023
70eac12
Fixed CoherencyEvaluator
kevinmessiaen Dec 22, 2023
daa2059
Fixed tests
kevinmessiaen Dec 22, 2023
4f3bfa0
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
kevinmessiaen Dec 22, 2023
2447cf8
Added ID to LLM tool calls
kevinmessiaen Dec 22, 2023
9841ccf
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
mattbit Dec 22, 2023
6000e4c
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
kevinmessiaen Jan 1, 2024
ef2a7a2
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
mattbit Jan 2, 2024
2afcf94
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
andreybavt Jan 2, 2024
67e9fd2
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
kevinmessiaen Jan 5, 2024
0066a14
Import reorganization
kevinmessiaen Jan 5, 2024
ca6c1a3
Fixed test
kevinmessiaen Jan 5, 2024
3b60c31
Added raw_output to LLMOutput
kevinmessiaen Jan 8, 2024
3908086
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
kevinmessiaen Jan 8, 2024
abc37aa
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
kevinmessiaen Jan 24, 2024
3e65e57
Revert "Added raw_output to LLMOutput"
kevinmessiaen Jan 24, 2024
dd48192
Improved code structure
kevinmessiaen Jan 24, 2024
5227b76
Fixed tool calls
kevinmessiaen Jan 24, 2024
3fcc077
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
kevinmessiaen Jan 26, 2024
261b8e7
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
kevinmessiaen Jan 26, 2024
98ae8c7
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
kevinmessiaen Jan 26, 2024
1f66c50
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
kevinmessiaen Jan 31, 2024
1ac608d
Added a toJSON method
kevinmessiaen Jan 31, 2024
2639c81
Use model dump instead of toJSON
kevinmessiaen Jan 31, 2024
bdecc8b
Use model dump instead of toJSON
kevinmessiaen Jan 31, 2024
1687ecc
Fixed model serialization
kevinmessiaen Jan 31, 2024
be657c3
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
kevinmessiaen Feb 1, 2024
947ed6a
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
kevinmessiaen Feb 2, 2024
7d6b53d
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
mattbit Feb 2, 2024
b91e41e
Merge branch 'main' into feature/gsk-2367-migrate-openai-api-call-fro…
kevinmessiaen Feb 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions giskard/llm/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import os

from .base import LLMClient, LLMFunctionCall, LLMOutput
from .base import LLMClient, LLMFunctionCall, LLMMessage, LLMToolCall
from .logger import LLMLogger

_default_client = None
Expand Down Expand Up @@ -75,7 +75,8 @@ def get_default_client() -> LLMClient:
__all__ = [
"LLMClient",
"LLMFunctionCall",
"LLMOutput",
"LLMToolCall",
"LLMMessage",
"LLMLogger",
"get_default_client",
"set_llm_model",
Expand Down
31 changes: 23 additions & 8 deletions giskard/llm/client/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional, Sequence

from abc import ABC, abstractmethod
from dataclasses import dataclass
Expand All @@ -8,14 +8,27 @@

@dataclass
class LLMFunctionCall:
function: str
args: Any
name: str
arguments: Any


@dataclass
class LLMOutput:
message: Optional[str] = None
function_call: Optional[LLMFunctionCall] = None
class LLMToolCall:
id: str
type: str
function: LLMFunctionCall


@dataclass
class LLMMessage:
role: str
content: Optional[str]
function_call: Optional[LLMFunctionCall]
tool_calls: Optional[List[LLMToolCall]]

@staticmethod
def create_message(role: str, content: str):
return LLMMessage(role=role, content=content, function_call=None, tool_calls=None)


class LLMClient(ABC):
Expand All @@ -27,11 +40,13 @@ def logger(self) -> LLMLogger:
@abstractmethod
def complete(
self,
messages,
messages: Sequence[LLMMessage],
functions=None,
temperature=0.5,
max_tokens=None,
function_call: Optional[Dict] = None,
caller_id: Optional[str] = None,
) -> LLMOutput:
tools=None,
tool_choice=None,
) -> LLMMessage:
...
114 changes: 96 additions & 18 deletions giskard/llm/client/openai.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional, Sequence
from typing import Dict, List, Optional, Sequence

import json
from abc import ABC, abstractmethod
Expand All @@ -7,7 +7,8 @@

from ..config import LLMConfigurationError
from ..errors import LLMGenerationError, LLMImportError
from . import LLMClient, LLMFunctionCall, LLMLogger, LLMOutput
from . import LLMClient, LLMFunctionCall, LLMLogger, LLMMessage
from .base import LLMToolCall

try:
import openai
Expand Down Expand Up @@ -38,39 +39,101 @@ def _completion(
function_call: Optional[Dict] = None,
max_tokens=None,
caller_id: Optional[str] = None,
tools=None,
tool_choice=None,
) -> dict:
...

@staticmethod
def _serialize_function_call(function_call: LLMFunctionCall) -> Dict:
return {"name": function_call.name, "arguments": json.dumps(function_call.arguments)}

@staticmethod
def _serialize_tool_call(tool_call: LLMToolCall) -> Dict:
return {
"id": tool_call.id,
"type": tool_call.type,
"function": BaseOpenAIClient._serialize_function_call(tool_call.function),
}

@staticmethod
def _serialize_tool_calls(tool_calls: List[LLMToolCall]) -> List[Dict]:
return [BaseOpenAIClient._serialize_tool_call(tool_call) for tool_call in tool_calls]

@staticmethod
def _serialize_message(response: LLMMessage) -> Dict:
result = {
"role": response.role,
"content": response.content,
"function_call": BaseOpenAIClient._serialize_function_call(response.function_call)
if response.function_call
else None,
"tool_calls": BaseOpenAIClient._serialize_tool_calls(response.tool_calls) if response.tool_calls else None,
}

return {key: value for key, value in result.items() if value is not None}

@staticmethod
def _parse_function_call(function_call) -> LLMFunctionCall:
try:
return LLMFunctionCall(
name=function_call["name"],
arguments=json.loads(function_call["arguments"]),
)
except (json.JSONDecodeError, KeyError) as err:
raise LLMGenerationError("Could not parse function call") from err

@staticmethod
def _parse_tool_call(tool_call) -> LLMToolCall:
return LLMToolCall(
id=tool_call["id"],
type=tool_call["type"],
function=BaseOpenAIClient._parse_function_call(tool_call["function"]),
)

@staticmethod
def _parse_tool_calls(tool_calls) -> List[LLMToolCall]:
return [BaseOpenAIClient._parse_tool_call(tool_call) for tool_call in tool_calls]

@staticmethod
def _parse_message(response) -> LLMMessage:
return LLMMessage(
role=response["role"],
content=response["content"],
function_call=BaseOpenAIClient._parse_function_call(response["function_call"])
if "function_call" in response and response["function_call"] is not None
else None,
tool_calls=BaseOpenAIClient._parse_tool_calls(response["tool_calls"])
if "tool_calls" in response and response["tool_calls"] is not None
else None,
)

def complete(
self,
messages,
messages: Sequence[LLMMessage],
functions=None,
temperature=0.5,
max_tokens=None,
function_call: Optional[Dict] = None,
caller_id: Optional[str] = None,
tools=None,
tool_choice=None,
):
cc = self._completion(
messages=messages,
llm_message = self._completion(
messages=[
BaseOpenAIClient._serialize_message(message) if isinstance(message, LLMMessage) else message
for message in messages
],
temperature=temperature,
functions=functions,
function_call=function_call,
max_tokens=max_tokens,
caller_id=caller_id,
tools=tools,
tool_choice=tool_choice,
)

function_call = None

if fc := cc.get("function_call"):
try:
function_call = LLMFunctionCall(
function=fc["name"],
args=json.loads(fc["arguments"], strict=False),
)
except (json.JSONDecodeError, KeyError) as err:
raise LLMGenerationError("Could not parse function call") from err

return LLMOutput(message=cc["content"], function_call=function_call)
return BaseOpenAIClient._parse_message(llm_message)


class LegacyOpenAIClient(BaseOpenAIClient):
Expand All @@ -90,17 +153,26 @@ def _completion(
function_call: Optional[Dict] = None,
max_tokens=None,
caller_id: Optional[str] = None,
tools=None,
tool_choice=None,
):
extra_params = dict()
if function_call is not None:
extra_params["function_call"] = function_call
if functions is not None:
extra_params["functions"] = functions
if tools is not None:
extra_params["tools"] = tools
if tool_choice is not None:
extra_params["tool_choice"] = tool_choice

try:
completion = openai.ChatCompletion.create(
model=self.model,
messages=messages,
messages=[
BaseOpenAIClient._serialize_message(message) if isinstance(message, LLMMessage) else message
for message in messages
],
temperature=temperature,
max_tokens=max_tokens,
**extra_params,
Expand Down Expand Up @@ -135,12 +207,18 @@ def _completion(
function_call: Optional[Dict] = None,
max_tokens=None,
caller_id: Optional[str] = None,
tools=None,
tool_choice=None,
):
extra_params = dict()
if function_call is not None:
extra_params["function_call"] = function_call
if functions is not None:
extra_params["functions"] = functions
if tools is not None:
extra_params["tools"] = tools
if tool_choice is not None:
extra_params["tool_choice"] = tool_choice

try:
completion = self._client.chat.completions.create(
Expand Down
37 changes: 20 additions & 17 deletions giskard/llm/evaluators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,25 @@

EVALUATE_MODEL_FUNCTIONS = [
{
"name": "evaluate_model",
"description": "Evaluates if the model passes the test",
"parameters": {
"type": "object",
"properties": {
"passed_test": {
"type": "boolean",
"description": "true if the model successfully passes the test",
},
"reason": {
"type": "string",
"description": "optional short description of why the model does not pass the test, in 1 or 2 short sentences",
"type": "function",
"function": {
"name": "evaluate_model",
"description": "Evaluates if the model passes the test",
"parameters": {
"type": "object",
"properties": {
"passed_test": {
"type": "boolean",
"description": "true if the model successfully passes the test",
},
"reason": {
"type": "string",
"description": "optional short description of why the model does not pass the test, in 1 or 2 short sentences",
},
},
"required": ["passed_test"],
},
},
"required": ["passed_test"],
},
]

Expand Down Expand Up @@ -101,20 +104,20 @@ def evaluate(self, model: BaseModel, dataset: Dataset):
try:
out = self.llm_client.complete(
[{"role": "system", "content": prompt}],
functions=funcs,
function_call={"name": "evaluate_model"},
tools=funcs,
tool_choice={"type": "function", "function": {"name": "evaluate_model"}},
temperature=self.llm_temperature,
caller_id=self.__class__.__name__,
)
if out.function_call is None or "passed_test" not in out.function_call.args:
if len(out.tool_calls) != 1 or "passed_test" not in out.tool_calls[0].function.arguments:
raise LLMGenerationError("Invalid function call arguments received")
except LLMGenerationError as err:
status.append(TestResultStatus.ERROR)
reasons.append(str(err))
errored.append({"message": str(err), "sample": sample})
continue

args = out.function_call.args
args = out.tool_calls[0].function.arguments
reasons.append(args.get("reason"))
if args["passed_test"]:
status.append(TestResultStatus.PASSED)
Expand Down
8 changes: 4 additions & 4 deletions giskard/llm/evaluators/coherency.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,13 @@ def _eval_pair(self, model: BaseModel, input_1, input_2, output_1, output_2):

out = self.llm_client.complete(
[{"role": "system", "content": prompt}],
functions=EVALUATE_MODEL_FUNCTIONS,
function_call={"name": "evaluate_model"}, # force function call
tools=EVALUATE_MODEL_FUNCTIONS,
tool_choice={"type": "function", "function": {"name": "evaluate_model"}}, # force tool call
temperature=self.llm_temperature,
caller_id=self.__class__.__name__,
)

if out.function_call is None or "passed_test" not in out.function_call.args:
if len(out.tool_calls) != 1 or "passed_test" not in out.tool_calls[0].function.arguments:
raise LLMGenerationError("Invalid function call arguments received")

return out.function_call.args["passed_test"], out.function_call.args.get("reason")
return out.tool_calls[0].function.arguments["passed_test"], out.tool_calls[0].function.arguments.get("reason")
39 changes: 21 additions & 18 deletions giskard/llm/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,23 @@ def _make_generate_input_prompt(self, model: BaseModel, num_samples: int):
def _make_generate_input_functions(self, model: BaseModel, num_samples: int):
return [
{
"name": "generate_inputs",
"description": "generates inputs for model audit",
"parameters": {
"type": "object",
"properties": {
"inputs": {
"type": "array",
"items": {
"type": "object",
"properties": {name: {"type": "string"} for name in model.feature_names},
},
}
"type": "function",
"function": {
"name": "generate_inputs",
"description": "generates inputs for model audit",
"parameters": {
"type": "object",
"properties": {
"inputs": {
"type": "array",
"items": {
"type": "object",
"properties": {name: {"type": "string"} for name in model.feature_names},
},
}
},
"required": ["inputs"],
},
"required": ["inputs"],
},
}
]
Expand Down Expand Up @@ -109,19 +112,19 @@ def generate_dataset(self, model: BaseModel, num_samples: int = 10, column_types

"""
prompt = self._make_generate_input_prompt(model, num_samples)
functions = self._make_generate_input_functions(model, num_samples)
tools = self._make_generate_input_functions(model, num_samples)

out = self.llm_client.complete(
messages=[{"role": "system", "content": prompt}],
functions=functions,
function_call={"name": "generate_inputs"},
tools=tools,
tool_choice={"type": "function", "function": {"name": "generate_inputs"}},
temperature=self.llm_temperature,
caller_id=self.__class__.__name__,
)

try:
generated = out.function_call.args["inputs"]
except (AttributeError, KeyError) as err:
generated = out.tool_calls[0].function.arguments["inputs"]
except (AttributeError, KeyError, IndexError, TypeError) as err:
raise LLMGenerationError("Could not parse generated inputs") from err

dataset = Dataset(
Expand Down
Loading