Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion docs/open_source/scan/scan_llm/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,30 @@ giskard.llm.set_default_client(mc)
::::::
::::::{tab-item} Ollama
```python
import giskard
from openai import OpenAI
from giskard.llm.client.openai import OpenAIClient
from giskard.llm.client.mistral import MistralClient

# Setup the Ollama client with API key and base URL
_client = OpenAI(base_url="http://localhost:11434/v1/", api_key="ollama")
oc = OpenAIClient(model="gemma:2b", client=_client)
giskard.llm.set_default_client(oc)
```
::::::
::::::{tab-item} Claude 3

```python
import os
import boto3
import giskard

from giskard.llm.client.bedrock import ClaudeBedrockClient

bedrock_runtime = boto3.client("bedrock-runtime", region_name=os.environ["AWS_DEFAULT_REGION"])
claude_client = ClaudeBedrockClient(bedrock_runtime, model="anthropic.claude-3-haiku-20240307-v1:0")
giskard.llm.set_default_client(claude_client)
```

::::::
::::::{tab-item} Custom Client
```python
Expand Down
86 changes: 86 additions & 0 deletions giskard/llm/client/bedrock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import Optional, Sequence

import json

from ..config import LLMConfigurationError
from ..errors import LLMImportError
from . import LLMClient
from .base import ChatMessage

try:
import boto3 # noqa: F401
except ImportError as err:
raise LLMImportError(
flavor="llm", msg="To use Bedrock models, please install the `boto3` package with `pip install boto3`"
) from err


class ClaudeBedrockClient(LLMClient):
def __init__(
self,
bedrock_runtime_client,
model: str = "anthropic.claude-3-sonnet-20240229-v1:0",
anthropic_version: str = "bedrock-2023-05-31",
):
self._client = bedrock_runtime_client
self.model = model
self.anthropic_version = anthropic_version

def complete(
self,
messages: Sequence[ChatMessage],
temperature: float = 1,
max_tokens: Optional[int] = 1000,
caller_id: Optional[str] = None,
seed: Optional[int] = None,
format=None,
) -> ChatMessage:
# only supporting claude 3 to start
if "claude-3" not in self.model:
raise LLMConfigurationError(f"Only claude-3 models are supported as of now, got {self.model}")

# extract system prompt from messages
system_prompt = ""
if len(messages) > 1:
if messages[0].role.lower() == "user" and messages[1].role.lower() == "user":
system_prompt = messages[0].content
messages = messages[1:]

# Create the messages format needed for bedrock specifically
input_msg_prompt = []
for msg in messages:
if msg.role.lower() == "assistant":
input_msg_prompt.append({"role": "assistant", "content": [{"type": "text", "text": msg.content}]})
else:
input_msg_prompt.append({"role": "user", "content": [{"type": "text", "text": msg.content}]})

# create the json body to send to the API
body = json.dumps(
{
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": max_tokens,
"temperature": temperature,
"system": system_prompt,
"messages": input_msg_prompt,
}
)

# invoke the model and get the response
try:
accept = "application/json"
contentType = "application/json"
response = self._client.invoke_model(body=body, modelId=self.model, accept=accept, contentType=contentType)
completion = json.loads(response.get("body").read())
except RuntimeError as err:
raise LLMConfigurationError("Could not get response from Bedrock API") from err

self.logger.log_call(
prompt_tokens=completion["usage"]["input_tokens"],
sampled_tokens=completion["usage"]["input_tokens"],
model=self.model,
client_class=self.__class__.__name__,
caller_id=caller_id,
)

msg = completion["content"][0]["text"]
return ChatMessage(role="assistant", content=msg)
40 changes: 40 additions & 0 deletions giskard/llm/embeddings/bedrock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Sequence

import json

import numpy as np

from .base import BaseEmbedding


class BedrockEmbedding(BaseEmbedding):
def __init__(self, client, model: str):
"""
Parameters
----------
client : Bedrock
boto3 based Bedrock runtime client instance.
model : str
Model name.
"""
self.model = model
self.client = client

def embed(self, texts: Sequence[str]) -> np.ndarray:
if "titan" not in self.model:
raise ValueError(f"Only titan embedding models are supported currently, got {self.model} instead")

if isinstance(texts, str):
texts = [texts]

accept = "application/json"
contentType = "application/json"
embeddings = []
for text in texts:
body = json.dumps({"inputText": text})
response = self.client.invoke_model(body=body, modelId=self.model, accept=accept, contentType=contentType)
response_body = json.loads(response.get("body").read())
embedding = response_body.get("embedding")
embeddings.append(embedding)

return np.array(embeddings)
Loading