Skip to content

Commit 290f17f

Browse files
Merge pull request #2025 from Giskard-AI/dep/upgrade-mistralai
Upgraded `mistralai` dep to >= 1
2 parents 498f138 + 551eb51 commit 290f17f

File tree

4 files changed

+53
-42
lines changed

4 files changed

+53
-42
lines changed

giskard/llm/client/mistral.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Optional, Sequence
22

3+
import os
34
from dataclasses import asdict
45
from logging import warning
56

@@ -9,18 +10,17 @@
910
from .base import ChatMessage
1011

1112
try:
12-
from mistralai.client import MistralClient as _MistralClient
13-
from mistralai.models.chat_completion import ChatMessage as MistralChatMessage
13+
from mistralai import Mistral
1414
except ImportError as err:
1515
raise LLMImportError(
1616
flavor="llm", msg="To use Mistral models, please install the `mistralai` package with `pip install mistralai`"
1717
) from err
1818

1919

2020
class MistralClient(LLMClient):
21-
def __init__(self, model: str = "mistral-large-latest", client: _MistralClient = None):
21+
def __init__(self, model: str = "mistral-large-latest", client: Mistral = None):
2222
self.model = model
23-
self._client = client or _MistralClient()
23+
self._client = client or Mistral(api_key=os.getenv("MISTRAL_API_KEY", ""))
2424

2525
def complete(
2626
self,
@@ -43,9 +43,9 @@ def complete(
4343
extra_params["response_format"] = {"type": "json_object"}
4444

4545
try:
46-
completion = self._client.chat(
46+
completion = self._client.chat.complete(
4747
model=self.model,
48-
messages=[MistralChatMessage(**asdict(m)) for m in messages],
48+
messages=[asdict(m) for m in messages],
4949
temperature=temperature,
5050
max_tokens=max_tokens,
5151
**extra_params,

pdm.lock

Lines changed: 24 additions & 11 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ dev = [
5656
"pytest-asyncio>=0.21.1",
5757
"pydantic>=2",
5858
"avidtools",
59-
"mistralai>=0.1.8, <1",
59+
"mistralai>=1",
6060
"boto3>=1.34.88",
6161
"scikit-learn==1.4.2",
6262
]

tests/llm/test_llm_client.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44
import pydantic
55
import pytest
66
from google.generativeai.types import ContentDict
7-
from mistralai.models.chat_completion import ChatCompletionResponse, ChatCompletionResponseChoice
8-
from mistralai.models.chat_completion import ChatMessage as MistralChatMessage
9-
from mistralai.models.chat_completion import FinishReason, UsageInfo
107
from openai.types import CompletionUsage
118
from openai.types.chat import ChatCompletion, ChatCompletionMessage
129
from openai.types.chat.chat_completion import Choice
@@ -35,22 +32,6 @@
3532
)
3633

3734

38-
DEMO_MISTRAL_RESPONSE = ChatCompletionResponse(
39-
id="2d62260a7a354e02922a4f6ad36930d3",
40-
object="chat.completion",
41-
created=1630000000,
42-
model="mistral-large",
43-
choices=[
44-
ChatCompletionResponseChoice(
45-
index=0,
46-
message=MistralChatMessage(role="assistant", content="This is a test!", name=None, tool_calls=None),
47-
finish_reason=FinishReason.stop,
48-
)
49-
],
50-
usage=UsageInfo(prompt_tokens=9, total_tokens=89, completion_tokens=80),
51-
)
52-
53-
5435
def test_llm_complete_message():
5536
client = Mock()
5637
client.chat.completions.create.return_value = DEMO_OPENAI_RESPONSE
@@ -69,19 +50,36 @@ def test_llm_complete_message():
6950

7051
@pytest.mark.skipif(not PYDANTIC_V2, reason="Mistral raise an error with pydantic < 2")
7152
def test_mistral_client():
53+
from mistralai.models import ChatCompletionChoice, ChatCompletionResponse, UsageInfo
54+
55+
demo_response = ChatCompletionResponse(
56+
id="2d62260a7a354e02922a4f6ad36930d3",
57+
object="chat.completion",
58+
created=1630000000,
59+
model="mistral-large",
60+
choices=[
61+
ChatCompletionChoice(
62+
index=0,
63+
message={"role": "assistant", "content": "This is a test!"},
64+
finish_reason="stop",
65+
)
66+
],
67+
usage=UsageInfo(prompt_tokens=9, total_tokens=89, completion_tokens=80),
68+
)
69+
7270
client = Mock()
73-
client.chat.return_value = DEMO_MISTRAL_RESPONSE
71+
client.chat.complete.return_value = demo_response
7472

7573
from giskard.llm.client.mistral import MistralClient
7674

7775
res = MistralClient(model="mistral-large", client=client).complete(
7876
[ChatMessage(role="user", content="Hello")], temperature=0.11, max_tokens=12
7977
)
8078

81-
client.chat.assert_called_once()
82-
assert client.chat.call_args[1]["messages"] == [MistralChatMessage(role="user", content="Hello")]
83-
assert client.chat.call_args[1]["temperature"] == 0.11
84-
assert client.chat.call_args[1]["max_tokens"] == 12
79+
client.chat.complete.assert_called_once()
80+
assert client.chat.complete.call_args[1]["messages"] == [{"role": "user", "content": "Hello"}]
81+
assert client.chat.complete.call_args[1]["temperature"] == 0.11
82+
assert client.chat.complete.call_args[1]["max_tokens"] == 12
8583

8684
assert isinstance(res, ChatMessage)
8785
assert res.content == "This is a test!"

0 commit comments

Comments
 (0)