Skip to content
Merged
6 changes: 6 additions & 0 deletions docs/source/garak.generators.mistral.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
garak.generators.mistral

.. automodule:: garak.generators.mistral
:members:
:undoc-members:
:show-inheritance:
1 change: 1 addition & 0 deletions docs/source/generators.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ For a detailed oversight into how a generator operates, see :ref:`garak.generato
garak.generators.langchain
garak.generators.langchain_serve
garak.generators.litellm
garak.generators.mistral
garak.generators.octo
garak.generators.ollama
garak.generators.openai
Expand Down
59 changes: 59 additions & 0 deletions garak/generators/mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
DEFAULT_CLASS = "MistralGenerator"
import os
import backoff
from garak.generators.base import Generator
import garak._config as _config
from mistralai import Mistral
from garak import exception


class MistralGenerator(Generator):
"""
Interface for public endpoints of models hosted in Mistral La Plateforme (console.mistral.ai).
Expects API key in MISTRAL_API_TOKEN environment variable.
"""

generator_family_name = "mistral"
fullname = "Mistral AI"
supports_multiple_generations = False
ENV_VAR = "MISTRAL_API_KEY"
DEFAULT_PARAMS = Generator.DEFAULT_PARAMS | {
"name": "mistral-large-latest",
}

# avoid attempt to pickle the client attribute
def __getstate__(self) -> object:
self._clear_client()
return dict(self.__dict__)

# restore the client attribute
def __setstate__(self, d) -> object:
self.__dict__.update(d)
self._load_client()

def _load_client(self):
self.client = Mistral(api_key=self.api_key)

def _clear_client(self):
self.client = None

def __init__(self, name="", config_root=_config):
super().__init__(name, config_root)
if self.api_key is not None:
# ensure the token is in the expected runtime env var
os.environ[self.ENV_VAR] = self.api_key
self._load_client()

@backoff.on_exception(backoff.fibo, exception.RateLimitHit, max_value=70)
def _call_model(self, prompt, generations_this_call=1):
print(self.name)
chat_response = self.client.chat.complete(
model=self.name,
messages=[
{
"role": "user",
"content": prompt,
},
],
)
return [chat_response.choices[0].message.content]
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ dependencies = [
"xdg-base-dirs>=6.0.1",
"wn==0.9.5",
"ollama>=0.4.7",
"tiktoken>=0.7.0"
"tiktoken>=0.7.0",
"mistralai==1.5.2"
]

[project.optional-dependencies]
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ xdg-base-dirs>=6.0.1
wn==0.9.5
ollama>=0.4.7
tiktoken>=0.7.0
mistralai==1.5.2
# tests
pytest>=8.0
pytest-mock>=3.14.0
Expand Down
35 changes: 35 additions & 0 deletions tests/generators/test_mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import os
import pytest
from unittest.mock import patch
from garak.generators.mistral import MistralGenerator

DEFAULT_DEPLOYMENT_NAME = "mistral-small-latest"

@patch.dict(os.environ, {"MISTRAL_API_KEY": "fake_api_key"})
@patch("garak.generators.mistral.MistralGenerator.generate")
def test_mistral_generator(mock_generate):
# Définir le retour simulé
mock_generate.return_value = ["Mocked response"]

# Initialiser le générateur
generator = MistralGenerator()

# Appeler la méthode générer
output = generator.generate("Test prompt")

# Vérifier que la fonction a bien été appelée
mock_generate.assert_called_once_with("Test prompt")

# Vérifier le résultat
assert output == ["Mocked response"]

@pytest.mark.skipif(
os.getenv(MistralGenerator.ENV_VAR, None) is None,
reason=f"Mistral API key is not set in {MistralGenerator.ENV_VAR}",
)
def test_mistral_chat():
generator = MistralGenerator(name=DEFAULT_DEPLOYMENT_NAME)
assert generator.name == DEFAULT_DEPLOYMENT_NAME
output = generator.generate("Hello Mistral!")
assert len(output) == 1 # expect 1 generation by default
print("test passed!")