Skip to content

Commit cf9a4bb

Browse files
Merge pull request #1905 from celmore25/feature/bedrock-claude3
[GSK-1590] native support for claude 3 and titan embeddings on Bedrock
2 parents ee68c39 + a8e2e8c commit cf9a4bb

File tree

6 files changed

+786
-491
lines changed

6 files changed

+786
-491
lines changed

docs/open_source/scan/scan_llm/index.md

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,15 +91,30 @@ giskard.llm.set_default_client(mc)
9191
::::::
9292
::::::{tab-item} Ollama
9393
```python
94+
import giskard
9495
from openai import OpenAI
9596
from giskard.llm.client.openai import OpenAIClient
96-
from giskard.llm.client.mistral import MistralClient
9797

9898
# Setup the Ollama client with API key and base URL
9999
_client = OpenAI(base_url="http://localhost:11434/v1/", api_key="ollama")
100100
oc = OpenAIClient(model="gemma:2b", client=_client)
101101
giskard.llm.set_default_client(oc)
102102
```
103+
::::::
104+
::::::{tab-item} Claude 3
105+
106+
```python
107+
import os
108+
import boto3
109+
import giskard
110+
111+
from giskard.llm.client.bedrock import ClaudeBedrockClient
112+
113+
bedrock_runtime = boto3.client("bedrock-runtime", region_name=os.environ["AWS_DEFAULT_REGION"])
114+
claude_client = ClaudeBedrockClient(bedrock_runtime, model="anthropic.claude-3-haiku-20240307-v1:0")
115+
giskard.llm.set_default_client(claude_client)
116+
```
117+
103118
::::::
104119
::::::{tab-item} Custom Client
105120
```python

giskard/llm/client/bedrock.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from typing import Optional, Sequence
2+
3+
import json
4+
5+
from ..config import LLMConfigurationError
6+
from ..errors import LLMImportError
7+
from . import LLMClient
8+
from .base import ChatMessage
9+
10+
try:
11+
import boto3 # noqa: F401
12+
except ImportError as err:
13+
raise LLMImportError(
14+
flavor="llm", msg="To use Bedrock models, please install the `boto3` package with `pip install boto3`"
15+
) from err
16+
17+
18+
class ClaudeBedrockClient(LLMClient):
19+
def __init__(
20+
self,
21+
bedrock_runtime_client,
22+
model: str = "anthropic.claude-3-sonnet-20240229-v1:0",
23+
anthropic_version: str = "bedrock-2023-05-31",
24+
):
25+
self._client = bedrock_runtime_client
26+
self.model = model
27+
self.anthropic_version = anthropic_version
28+
29+
def complete(
30+
self,
31+
messages: Sequence[ChatMessage],
32+
temperature: float = 1,
33+
max_tokens: Optional[int] = 1000,
34+
caller_id: Optional[str] = None,
35+
seed: Optional[int] = None,
36+
format=None,
37+
) -> ChatMessage:
38+
# only supporting claude 3 to start
39+
if "claude-3" not in self.model:
40+
raise LLMConfigurationError(f"Only claude-3 models are supported as of now, got {self.model}")
41+
42+
# extract system prompt from messages
43+
system_prompt = ""
44+
if len(messages) > 1:
45+
if messages[0].role.lower() == "user" and messages[1].role.lower() == "user":
46+
system_prompt = messages[0].content
47+
messages = messages[1:]
48+
49+
# Create the messages format needed for bedrock specifically
50+
input_msg_prompt = []
51+
for msg in messages:
52+
if msg.role.lower() == "assistant":
53+
input_msg_prompt.append({"role": "assistant", "content": [{"type": "text", "text": msg.content}]})
54+
else:
55+
input_msg_prompt.append({"role": "user", "content": [{"type": "text", "text": msg.content}]})
56+
57+
# create the json body to send to the API
58+
body = json.dumps(
59+
{
60+
"anthropic_version": "bedrock-2023-05-31",
61+
"max_tokens": max_tokens,
62+
"temperature": temperature,
63+
"system": system_prompt,
64+
"messages": input_msg_prompt,
65+
}
66+
)
67+
68+
# invoke the model and get the response
69+
try:
70+
accept = "application/json"
71+
contentType = "application/json"
72+
response = self._client.invoke_model(body=body, modelId=self.model, accept=accept, contentType=contentType)
73+
completion = json.loads(response.get("body").read())
74+
except RuntimeError as err:
75+
raise LLMConfigurationError("Could not get response from Bedrock API") from err
76+
77+
self.logger.log_call(
78+
prompt_tokens=completion["usage"]["input_tokens"],
79+
sampled_tokens=completion["usage"]["input_tokens"],
80+
model=self.model,
81+
client_class=self.__class__.__name__,
82+
caller_id=caller_id,
83+
)
84+
85+
msg = completion["content"][0]["text"]
86+
return ChatMessage(role="assistant", content=msg)

giskard/llm/embeddings/bedrock.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from typing import Sequence
2+
3+
import json
4+
5+
import numpy as np
6+
7+
from .base import BaseEmbedding
8+
9+
10+
class BedrockEmbedding(BaseEmbedding):
11+
def __init__(self, client, model: str):
12+
"""
13+
Parameters
14+
----------
15+
client : Bedrock
16+
boto3 based Bedrock runtime client instance.
17+
model : str
18+
Model name.
19+
"""
20+
self.model = model
21+
self.client = client
22+
23+
def embed(self, texts: Sequence[str]) -> np.ndarray:
24+
if "titan" not in self.model:
25+
raise ValueError(f"Only titan embedding models are supported currently, got {self.model} instead")
26+
27+
if isinstance(texts, str):
28+
texts = [texts]
29+
30+
accept = "application/json"
31+
contentType = "application/json"
32+
embeddings = []
33+
for text in texts:
34+
body = json.dumps({"inputText": text})
35+
response = self.client.invoke_model(body=body, modelId=self.model, accept=accept, contentType=contentType)
36+
response_body = json.loads(response.get("body").read())
37+
embedding = response_body.get("embedding")
38+
embeddings.append(embedding)
39+
40+
return np.array(embeddings)

0 commit comments

Comments
 (0)