diff --git a/garak/generators/langchain.py b/garak/generators/langchain.py index 124bb9ab4..0299c5af7 100644 --- a/garak/generators/langchain.py +++ b/garak/generators/langchain.py @@ -38,18 +38,36 @@ class LangChainLLMGenerator(Generator): "frequency_penalty": 0.0, "presence_penalty": 0.0, "stop": [], + "model_provider": None, + "configurable_fields": None, } - extra_dependency_names = ["langchain.llms"] + extra_dependency_names = ["langchain.chat_models"] generator_family_name = "LangChain" + _unsafe_attributes = ["generator"] + def __init__(self, name="", config_root=_config): self.name = name self._load_config(config_root) self.fullname = f"LangChain LLM {self.name}" super().__init__(self.name, config_root=config_root) + self._load_unsafe() + + def _load_unsafe(self): + configured_fields = {} + if self.configurable_fields: + for field in self.configurable_fields: + if hasattr(self, field): + configured_fields[field] = getattr(self, field) + + # if not already added pass thru an configured `api_key` + if hasattr(self, "api_key") and not configured_fields.get("api_key", None): + configured_fields["api_key"] = self.api_key - self.generator = getattr(self.langchain_llms, self.name)() + self.generator = self.langchain_chat_models.init_chat_model( + self.name, configurable_fields=self.configurable_fields, **configured_fields + ) def _call_model( self, prompt: Conversation, generations_this_call: int = 1 @@ -60,8 +78,9 @@ def _call_model( This calls invoke once per generation; invoke() seems to have the best support across LangChain LLM integrations. """ - # Should this be expanded to process a whole conversation in some way? - return [Message(r) for r in self.generator.invoke(prompt.last_message().text)] + conv = self._conversation_to_list(prompt) + resp = self.generator.invoke(conv) + return [Message(resp.content)] if hasattr(resp, "content") else [None] DEFAULT_CLASS = "LangChainLLMGenerator" diff --git a/pyproject.toml b/pyproject.toml index f27d3bbb6..0e26d4f53 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,7 +99,7 @@ dependencies = [ "accelerate>=0.23.0", "avidtools==0.1.2", "stdlibs>=2022.10.9", - "langchain>=0.3.25,<1.0.0", + "langchain>=1.2.1", "nemollm>=0.3.0", "cmd2==2.4.3", "torch>=2.6.0", diff --git a/requirements.txt b/requirements.txt index 35a84f00e..4688b5c0c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,7 @@ nltk>=3.9.1 accelerate>=0.23.0 avidtools==0.1.2 stdlibs>=2022.10.9 -langchain>=0.3.25,<1.0.0 +langchain>=1.2.1 nemollm>=0.3.0 cmd2==2.4.3 torch>=2.6.0