Skip to content
Merged
Show file tree
Hide file tree
Changes from 96 commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
3000d4c
draft postponed import pattern for cohere generator
leondz May 2, 2025
757e0f3
move extra dependency requirements into classdefs, mediate requiremen…
leondz May 5, 2025
9310d0a
actually do the plugin dep load
leondz May 5, 2025
dac569e
migrate generators to 'extra dependencies' pattern
leondz May 5, 2025
35e93fc
prune dupe lazyload
leondz May 7, 2025
bf7f36b
extra_dependency_names in all plugins
leondz May 7, 2025
6a39b0c
active must be False for Probes using extra modules
leondz May 7, 2025
56c6182
make PIL optional in generators.huggingface.LLaVA
leondz May 7, 2025
3657e04
move optional load fail to ModuleNotFoundError
leondz May 7, 2025
865d604
add _load/_clear_deps() into base generator and _load/_clear client
leondz May 7, 2025
d61957d
put the MNFE where it belongs
leondz May 7, 2025
8a7051e
backoff exception placeholder must inherit base exception
leondz May 7, 2025
60775f6
test for reqs presence in pyproject.toml, requirements.txt
leondz May 7, 2025
31e98d4
handle hyphen in pypi pkg names
leondz May 7, 2025
75babb7
rm optional plugin deps
leondz May 7, 2025
83f551a
skip generator tests if optional deps absent
leondz May 8, 2025
dd51196
support sub-package deps
leondz May 8, 2025
b33a46c
scope optimum to nvidia
leondz May 8, 2025
de5b3f1
move import function to _load_deps
leondz May 8, 2025
19c31fe
rm import handling in langchain
leondz May 8, 2025
54fabc5
amend optimum to be nvidia flavour
leondz May 8, 2025
ffac714
dry - use garak._plugins.PLUGIN_TYPES as canonical def of 1st class p…
leondz May 8, 2025
97c8160
unify backoff exception pattern mediated via garak GeneratorBackoffEx…
leondz May 9, 2025
1d4e69c
skip instantiation when modules not present
leondz May 9, 2025
6164bc5
catch straggling backoff exception wrappings
leondz May 9, 2025
85fb7c3
Merge branch 'main' into update/optional_imports
leondz May 9, 2025
0402116
use isinstance for exception matching
leondz May 9, 2025
e287fe9
don't backoff on 404
leondz May 9, 2025
6339648
merge in our good pal main
leondz May 16, 2025
76b1774
switch to pyproject; get tests deps if testing
leondz May 16, 2025
ca133e4
add [dev] target
leondz May 16, 2025
8e8a5b9
add required jsonschema that was previously implicit from now-optiona…
leondz May 16, 2025
aa7500a
specify versions; move to secure versions cf. #1207
leondz May 16, 2025
4f2e5ef
skip internal config mappings for req consistency testing
leondz May 16, 2025
69cfef2
skip test option for non-test workflow
leondz May 16, 2025
a1da5ed
skip ollama tests if no module
leondz May 16, 2025
3a8605d
rm spurious dep check
leondz May 16, 2025
d2d17ad
straggling spurious check
leondz May 16, 2025
13974b8
Merge branch 'main' into update/optional_imports
leondz May 28, 2025
dc83929
Merge branch 'main' into update/optional_imports
leondz Jun 8, 2025
d527650
merge in octo removal
leondz Jun 8, 2025
8c46730
add all_plugins option; handle pkg name != import nonsense in pillow
leondz Jun 11, 2025
06180b6
rm unused import
leondz Jun 11, 2025
7c22dea
cache maint workflow gets deps for all plugins
leondz Jun 11, 2025
ce23d70
merge main
leondz Jun 30, 2025
8ab94bd
Merge branch 'main' into update/optional_imports
leondz Jul 3, 2025
bb67a3e
Merge branch 'main' into update/optional_imports
leondz Jul 3, 2025
b45ba35
use correct backoff exception name
leondz Jul 3, 2025
de86505
merg main / turn & conv
leondz Aug 22, 2025
38e6a15
cohere v2 validation: update backoff errors, remove double unpacking …
leondz Sep 25, 2025
64399a5
merge main
leondz Sep 25, 2025
65ac9fe
rm unconditional top level ollama import in test
leondz Sep 25, 2025
ea71cea
wrap llava test global imports in try/except
leondz Sep 25, 2025
14b958d
Update tests/test_reqs.py to use global plugins def
leondz Sep 25, 2025
d0c90ea
move plugin-general tests to tests/plugins
leondz Sep 25, 2025
de28c75
force cache update to include new plugin param
leondz Sep 25, 2025
1742f60
gate dep-requiring tests
leondz Sep 25, 2025
c877225
migrate to deferred loading
leondz Sep 25, 2025
6eede53
cohere generator partial fixes
leondz Sep 26, 2025
2bcca2c
revent to main for cohere
leondz Sep 26, 2025
9d57b5c
add deferred loading to cohere; migrate to new library exception names
leondz Sep 26, 2025
e195148
skip cohere tests if module not present
leondz Sep 26, 2025
f8e8635
skip tomllib-using tests if lib not present
leondz Sep 26, 2025
246b7dc
rm audio,dra hard deps
leondz Sep 26, 2025
b9e3b73
bring pyproject up to standard, add tests
leondz Sep 26, 2025
40b6bcc
summon librosa in probes.audio
leondz Sep 26, 2025
ae5ad5d
deselect audio achilles heel by default
leondz Sep 26, 2025
1a57c8e
force update cache sorry
leondz Sep 26, 2025
738e8e8
skip tests that fail on import (maybe a custom exception is better)
leondz Sep 26, 2025
46e1794
skip tests where deps not present
leondz Sep 29, 2025
dfbe9ac
scan and report all missing modules in _plugins.load_plugin
leondz Sep 29, 2025
35fa3d6
generalise dep loading & clearing to _plugins; activate in probes also
leondz Sep 29, 2025
2e6054f
merge main
leondz Nov 12, 2025
04b4731
use default deps
leondz Nov 12, 2025
f9138f5
add docs for deferred loading
leondz Nov 12, 2025
4f3aeae
revert requirements tests
leondz Nov 12, 2025
a35409f
revert to requirements.txt route
leondz Nov 12, 2025
e89b3f2
reinsert whitesp
leondz Nov 12, 2025
0f7a618
a one-liner/one-call method for regenerating plugin cache would have …
leondz Nov 12, 2025
1f9998c
litellm api upgr
leondz Nov 12, 2025
dafef3b
update litellm to use local module
leondz Nov 12, 2025
c2708c6
redo model->target in azure
leondz Nov 12, 2025
c5a95a9
skip test on missing imports
leondz Nov 13, 2025
b4b7504
handle optimum package naming
leondz Nov 13, 2025
0e074dc
invert exception selection
leondz Nov 13, 2025
c60d814
make dep name processing consistent
leondz Nov 13, 2025
309a246
refresh plugin cache
leondz Nov 13, 2025
74ff9a3
Merge branch 'main' into update/optional_imports
leondz Nov 13, 2025
677eacb
merge main
leondz Nov 21, 2025
6da2389
rm git merge notations in generated XDG cache plugin json
leondz Nov 21, 2025
27119e3
remove nv optimum / EOL
leondz Nov 26, 2025
b06a634
Merge branch 'main' into update/optional_imports
leondz Nov 26, 2025
07de920
rm more Optimum
leondz Nov 26, 2025
143f84d
add depnames for apikey probes, detector
leondz Nov 26, 2025
69244dc
Merge branch 'main' into update/optional_imports
leondz Nov 27, 2025
53597f0
migrate bedrock generator to extra_dependency_names pattern
leondz Nov 27, 2025
5daff3d
update param names, types in _load and _clear deps
leondz Dec 3, 2025
957cc75
handle MNFE exceptions during `_load_deps`
leondz Dec 4, 2025
621bfaf
cut _plugins.load_optional_module
leondz Dec 5, 2025
0be265b
defer extra dependency checks to plugin class
jmartin-tech Dec 12, 2025
05b9f5f
Merge 'main' into update/optional_imports
jmartin-tech Dec 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions docs/source/configurable.rst
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,6 @@ probes and run each prompt just once:

If we save this as ``latent1.yaml`` somewhere, then we can use it with ``garak --config latent1.yaml``.



Using a Custom JSON Config
^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down Expand Up @@ -258,6 +256,26 @@ At plugin load, the plugin instance has attributes named in ``DEFAULT_PARAMS``
automatically created, and populated with either values given in the supplied
config, or the default.

Fixed plugin parameters
^^^^^^^^^^^^^^^^^^^^^^^

Some plugin parameters aren't intended to be altered at instantiation via config.
These are the fixed plugin parameters, and are generally those not given in ``DEFAULT_PARAMS``.
Descriptions of these are as follows (for a probe - other plugins are similar):

* ``description`` - A short description of what the plugin does
* ``active`` - Whether or not the plugin is active (i.e. selected) by default
* ``doc_uri`` - Link to more information about the plugin
* ``extended_detectors`` - Option detectors to use on probe results
* ``extra_dependency_names`` - Extra Python modules that garka should import when instantiatng the plugin
* ``goal`` - Brief description in imperative form of the probe's intent
* ``modality`` - Which modalities the probe supports (as of Nov 2024 the list is ``text``, ``image``, ``audio``, ``video``, ``3d``)
* ``parallelisable_attempts`` - Is the probe parallelisable? Recommended false if it has to use an LLM to develop attacks, particularly a local one
* ``primary_detector`` - What detector should be used on the probe's outputs?
* ``tags`` - List of tags applicable to the plugin, drawn from ``garak/data/tags.misp.tsv``
* ``mod_time`` - Modification timestamp of the plugin source file used to generate this data


.. _config_with_yaml:

Configuring Plugins with YAML
Expand Down
3 changes: 3 additions & 0 deletions docs/source/extending.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ The recipe for writing a new plugin or plugin class isn't outlandish:
* Start a new module inheriting from one of the base classes, e.g. :class:`garak.probes.base.Probe`
* Override as little as possible.

If you use custom modules not included in garak's default list, include these in the plugin's top-level ``extra_dependency_names`` parameter.
Garak's plugin loader (``garak._plugins.load_plugin()``) will manage the import and inject the requested module as ``self.<module>``.


Guides to writing plugins
-------------------------
Expand Down
64 changes: 64 additions & 0 deletions garak/_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,28 @@ def load_plugin(path, break_on_fail=True, config_root=_config) -> object:
) from ve
else:
return False

full_plugin_name = ".".join((category, module_name, plugin_class_name))

# check cache for optional imports
if category in PLUGIN_TYPES:
extra_dependency_names = PluginCache.instance()[category][full_plugin_name][
"extra_dependency_names"
]
if len(extra_dependency_names) > 0:
absent_modules = []
for dependency_module_name in extra_dependency_names:
for (
dependency_path
) in [ # support both plain names and also multi-point names e.g. langchain.llms
".".join(dependency_module_name.split(".")[: n + 1])
for n in range(dependency_module_name.count(".") + 1)
]:
if importlib.util.find_spec(dependency_path) is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Testing shows find_spec needs to be provided the runtime package name, currently dependency_path entries are the pypi package name.

This can be tested by attempting to load generators.huggingface.LLaVA.

⛔ Plugin 'generators.huggingface.LLaVA' requires Python modules which aren't installed/available: 'pillow'
💡 Try 'pip install pillow' to get missing module.
Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/jemartin/Projects/nvidia/garak/garak/__main__.py", line 14, in <module>
    main()
  File "/home/jemartin/Projects/nvidia/garak/garak/__main__.py", line 9, in main
    cli.main(sys.argv[1:])
  File "/home/jemartin/Projects/nvidia/garak/garak/cli.py", line 596, in main
    generator = _plugins.load_plugin(
                ^^^^^^^^^^^^^^^^^^^^^
  File "/home/jemartin/Projects/nvidia/garak/garak/_plugins.py", line 428, in load_plugin
    _import_failed(absent_modules, full_plugin_name)
  File "/home/jemartin/Projects/nvidia/garak/garak/_plugins.py", line 493, in _import_failed
    raise ModuleNotFoundError(msg)
ModuleNotFoundError: ⛔ Plugin 'generators.huggingface.LLaVA' requires Python modules which aren't installed/available: 'pillow'

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at this code in particular, I don't think this validation is needed here at this time. It may be more appropriate have handle a missing import exception around the module instantiation call.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at this code in particular, I don't think this validation is needed here at this time. It may be more appropriate have handle a missing import exception around the module instantiation call.

I believe the intent here is to summarise in one pass a list of all missing module names, which is determined using find_spec.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is simply not the right place for this, load_plugin is called at instantiation so this is only evaluating for this plugin which is fully processed by _load_deps so not needed here.

In the next iteration PR #1475 we might want to add an early preprocessor that takes a comprehensive look at the full run config to determine if all dependencies required for the full run are available however I am thinking that might turn out to be an overly complex goal that may get deferred or shelved in favor of allowing the run to skip probes that happen to be missing dependencies instead for blocking start of the run. More discussion of that can happen in that PR in later.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, this was extra code added after initial PR was made when a feature was requested to list all missing modules rather than one at a time. Is that additional feature no longer a requirement to land?

absent_modules.append(dependency_module_name)
if len(absent_modules):
_import_failed(absent_modules, full_plugin_name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the goal of load_plugin is to return an instance no need to check the dependencies early just let the instance creation raise:

Suggested change
if category in PLUGIN_TYPES:
extra_dependency_names = PluginCache.instance()[category][full_plugin_name][
"extra_dependency_names"
]
if len(extra_dependency_names) > 0:
absent_modules = []
for dependency_module_name in extra_dependency_names:
for (
dependency_path
) in [ # support both plain names and also multi-point names e.g. langchain.llms
".".join(dependency_module_name.split(".")[: n + 1])
for n in range(dependency_module_name.count(".") + 1)
]:
if importlib.util.find_spec(dependency_path) is None:
absent_modules.append(dependency_module_name)
if len(absent_modules):
_import_failed(absent_modules, full_plugin_name)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this in tension with the requested feature of listing all missing modules at once rather than piecemeal? I realise the granularity is different, but don't we want to cause the minimum number of user round trips between execution and dep installation?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really, load_plugin is only called when actually instantiating a plugin, testing in this location we will not evaluate all plugins required for the run. Since generators are the primary plugin type using this pattern just removing this is acceptable for now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should

  1. merge these threads
  2. get clear about reqs for this PR

I don't mind if we advise all missing modules at once or piecemeal. The latter has better UX. Agree this feature should belong in the right PR and code location, if the feature is going to manifest.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood, based on the targeted changes in scope this PR should remove this block.

Also the revisions already to made to load_deps will provide a consistent verbose list of all packages required for the plugin that is attempting to load. This covers the same scope as what this block does with added context available.


module_path = f"garak.{category}.{module_name}"
try:
mod = importlib.import_module(module_path)
Expand Down Expand Up @@ -434,6 +456,7 @@ def load_plugin(path, break_on_fail=True, config_root=_config) -> object:
if plugin_instance is None:
plugin_instance = klass(config_root=config_root)
PluginProvider.storeInstance(plugin_instance, config_root)

except Exception as e:
logging.warning(
"Exception instantiating %s.%s: %s",
Expand All @@ -448,3 +471,44 @@ def load_plugin(path, break_on_fail=True, config_root=_config) -> object:
return False

return plugin_instance


def load_optional_module(module_name: str):
try:
m = importlib.import_module(module_name)
except ModuleNotFoundError:
requesting_module = Path(inspect.stack()[1].filename).name.replace(".py", "")
_import_failed([module_name], requesting_module)
return m


def _import_failed(absent_modules: [str], calling_module: str):
quoted_module_list = "'" + "', '".join(absent_modules) + "'"
module_list = " ".join(absent_modules)
msg = f"⛔ Plugin '{calling_module}' requires Python modules which aren't installed/available: {quoted_module_list}"
hint = f"💡 Try 'pip install {module_list}' to get missing module."
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just noting, it might be nice to have this reference a group in the pyproject.toml. This may be added in #1475 when the as these groups are introduced.

logging.critical(msg)
print(msg + "\n" + hint)
raise ModuleNotFoundError(msg)


def _load_deps(self, deps_override=list()):
# load external dependencies. should be invoked at construction and
# in _client_load (if used)
dep_names = deps_override if deps_override else self.extra_dependency_names
for extra_dependency in dep_names:
extra_dep_name = extra_dependency.replace(".", "_").replace("-", "_")
if not hasattr(self, extra_dep_name) or getattr(self, extra_dep_name) is None:
setattr(
self,
extra_dep_name,
load_optional_module(extra_dependency),
)


def _clear_deps(self):
# unload external dependencies from class. should be invoked before
# serialisation, esp. in _clear_client (if used)
for extra_dependency in self.extra_dependency_names:
extra_dep_name = extra_dependency.replace(".", "_").replace("-", "_")
setattr(self, extra_dep_name, None)
2 changes: 2 additions & 0 deletions garak/buffs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class Buff(Configurable):
doc_uri = ""
lang = None # set of languages this buff should be constrained to
active = True
# list of strings naming modules required but not explicitly in garak by default
extra_dependency_names = []

DEFAULT_PARAMS = {}

Expand Down
2 changes: 2 additions & 0 deletions garak/detectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class Detector(Configurable):
accuracy = None
active = True
tags = [] # list of taxonomy categories per the MISP format
# list of strings naming modules required but not explicitly in garak by default
extra_dependency_names = []

# support mainstream any-to-any large models
# legal element for str list `modality['in']`: 'text', 'image', 'audio', 'video', '3d'
Expand Down
2 changes: 1 addition & 1 deletion garak/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class TargetNameMissingError(GarakException):
"""A generator requires target_name to be set, but it wasn't"""


class GarakBackoffTrigger(GarakException):
class GeneratorBackoffTrigger(GarakException):
"""Thrown when backoff should be triggered"""


Expand Down
1 change: 1 addition & 0 deletions garak/generators/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def _validate_env_var(self):
return super()._validate_env_var()

def _load_client(self):
self._load_deps()
if self.target_name in openai_model_mapping:
self.target_name = openai_model_mapping[self.target_name]

Expand Down
8 changes: 7 additions & 1 deletion garak/generators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from colorama import Fore, Style
import tqdm

from garak import _config
from garak import _config, _plugins
from garak.attempt import Message, Conversation
from garak.configurable import Configurable
from garak.exception import GarakException
Expand Down Expand Up @@ -46,6 +46,8 @@ class Generator(Configurable):
supports_multiple_generations = (
False # can more than one generation be extracted per request?
)
# list of strings naming modules required but not explicitly in garak by default
extra_dependency_names = []

def __init__(self, name="", config_root=_config):
self._load_config(config_root)
Expand All @@ -69,6 +71,10 @@ def __init__(self, name="", config_root=_config):
f"🦜 loading {Style.BRIGHT}{Fore.LIGHTMAGENTA_EX}generator{Style.RESET_ALL}: {self.generator_family_name}: {self.name}"
)
logging.info("generator init: %s", self)
self._load_deps()

_load_deps = _plugins._load_deps
_clear_deps = _plugins._clear_deps

def _call_model(
self, prompt: Conversation, generations_this_call: int = 1
Expand Down
34 changes: 15 additions & 19 deletions garak/generators/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class BedrockGenerator(Generator):
active = True
generator_family_name = "Bedrock"
supports_multiple_generations = False
extra_dependency_names = ["boto3", "botocore"]

DEFAULT_PARAMS = Generator.DEFAULT_PARAMS | {
"temperature": 0.7,
Expand Down Expand Up @@ -106,7 +107,7 @@ def __init__(self, name="", config_root=_config):

def _validate_env_var(self):
"""Validate and set region from environment variables if not configured.

Checks AWS_REGION and AWS_DEFAULT_REGION environment variables only if
the region parameter is still at its default value.
"""
Expand All @@ -115,23 +116,16 @@ def _validate_env_var(self):
if env_region:
logging.info(f"Using AWS region from environment: {env_region}")
self.region = env_region

return super()._validate_env_var()

def _load_client(self):
"""Load and configure the boto3 bedrock-runtime client.

Uses boto3's standard credential chain for authentication.
"""
try:
import boto3
except ImportError:
raise ImportError(
"boto3 is required for the Bedrock generator. "
"Install it with: pip install boto3"
)

self.client = boto3.client(

self.client = self.boto3.client(
service_name="bedrock-runtime",
region_name=self.region,
)
Expand All @@ -142,6 +136,8 @@ def _clear_client(self):
"""Clear the boto3 client to enable object pickling."""
if hasattr(self, "client"):
self.client = None
for module_name in self.extra_dependency_names:
setattr(self, module_name, None)

def __getstate__(self):
"""Prepare object for pickling by clearing the boto3 client."""
Expand All @@ -156,13 +152,13 @@ def __setstate__(self, state):
@staticmethod
def _conversation_to_list(conversation: Conversation) -> list[dict]:
"""Convert Conversation object to Bedrock Converse API message format.

AWS Bedrock expects messages in the format:
{"role": "user", "content": [{"text": "message text"}]}

Args:
conversation: Conversation object to convert

Returns:
List of message dictionaries in Bedrock format
"""
Expand All @@ -174,7 +170,7 @@ def _conversation_to_list(conversation: Conversation) -> list[dict]:

@backoff.on_exception(
backoff.fibo,
garak.exception.GarakBackoffTrigger,
garak.exception.GeneratorBackoffTrigger,
max_value=70,
)
def _call_model(
Expand Down Expand Up @@ -246,20 +242,20 @@ def _call_model(
return [Message(text=text)]

except Exception as e:
from botocore.exceptions import ClientError

if isinstance(e, ClientError):
if isinstance(e, self.botocore.exceptions.ClientError):
error_code = e.response.get("Error", {}).get("Code", "")
error_message = e.response.get("Error", {}).get("Message", "")

logging.error(f"Bedrock API error [{error_code}]: {error_message}")

if error_code in ["ThrottlingException", "ServiceUnavailableException"]:
raise garak.exception.GarakBackoffTrigger from e
raise garak.exception.GeneratorBackoffTrigger from e

return [None]

logging.exception("Error calling Bedrock model")
return [None]

DEFAULT_CLASS = "BedrockGenerator"

DEFAULT_CLASS = "BedrockGenerator"
23 changes: 16 additions & 7 deletions garak/generators/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,11 @@
from typing import List, Union

import backoff
from cohere.core.api_error import ApiError
import cohere
import tqdm

from garak import _config
from garak.attempt import Message, Conversation
from garak.exception import GeneratorBackoffTrigger
from garak.generators.base import Generator


Expand Down Expand Up @@ -52,6 +51,8 @@ class CohereGenerator(Generator):
"api_version": "v2", # "v1" for legacy generate API, "v2" for chat API (recommended)
}

extra_dependency_names = ["cohere"]

generator_family_name = "Cohere"

def __init__(self, name="command", config_root=_config):
Expand All @@ -74,11 +75,11 @@ def __init__(self, name="command", config_root=_config):
# Initialize appropriate client based on API version
# Following Cohere's guidance to use Client() for v1 and ClientV2() for v2
if self.api_version == "v1":
self.generator = cohere.Client(api_key=self.api_key)
self.generator = self.cohere.Client(api_key=self.api_key)
else: # api_version == "v2"
self.generator = cohere.ClientV2(api_key=self.api_key)
self.generator = self.cohere.ClientV2(api_key=self.api_key)

@backoff.on_exception(backoff.fibo, ApiError, max_value=70)
@backoff.on_exception(backoff.fibo, GeneratorBackoffTrigger, max_value=70)
def _call_cohere_api(self, prompt_text, request_size=COHERE_GENERATION_LIMIT):
"""Empty prompts raise API errors (e.g. invalid request: prompt must be at least 1 token long).
We catch these using the ApiError base class in Cohere v5+.
Expand Down Expand Up @@ -127,9 +128,17 @@ def _call_cohere_api(self, prompt_text, request_size=COHERE_GENERATION_LIMIT):
"Chat response structure doesn't match expected format"
)
responses.append(str(response))
except ApiError as e:
raise e # bubble up ApiError for backoff handling
except Exception as e:

backoff_exception_types = (
self.cohere.errors.GatewayTimeoutError,
self.cohere.errors.TooManyRequestsError,
self.cohere.errors.ServiceUnavailableError,
self.cohere.errors.InternalServerError,
)
for backoff_exception in backoff_exception_types:
if isinstance(e, backoff_exception):
raise GeneratorBackoffTrigger from e # bubble up ApiError for backoff handling
logging.error(f"Chat API error: {e}")
responses.append(None)

Expand Down
1 change: 1 addition & 0 deletions garak/generators/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class GroqChat(OpenAICompatible):
generator_family_name = "Groq"

def _load_client(self):
self._load_deps()
self.client = openai.OpenAI(base_url=self.uri, api_key=self.api_key)
if self.name in ("", None):
raise ValueError(
Expand Down
19 changes: 7 additions & 12 deletions garak/generators/guardrails.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,22 @@ class NeMoGuardrails(Generator):

supports_multiple_generations = False
generator_family_name = "Guardrails"
extra_dependency_names = ["nemoguardrails"]

def __init__(self, name="", config_root=_config):
# another class that may need to skip testing due to non required dependency
try:
from nemoguardrails import RailsConfig, LLMRails
except ImportError as e:
raise NameError(
"You must first install NeMo Guardrails using `pip install nemoguardrails`."
) from e

self.name = name
self._load_config(config_root)
self.fullname = f"Guardrails {self.name}"

# Currently, we use the target_name as the path to the config
with redirect_stderr(io.StringIO()) as f: # quieten the tqdm
config = RailsConfig.from_path(self.name)
self.rails = LLMRails(config=config)

super().__init__(self.name, config_root=config_root)

set_verbose = self.nemoguardrails.logging.verbose.set_verbose
# Currently, we use the model_name as the path to the config
with redirect_stderr(io.StringIO()) as f: # quieten the tqdm
config = self.nemoguardrails.RailsConfig.from_path(self.name)
self.rails = self.nemoguardrails.LLMRails(config=config)

def _call_model(
self, prompt: Conversation, generations_this_call: int = 1
) -> List[Union[Message, None]]:
Expand Down
Loading